first commit

This commit is contained in:
admin
2026-05-20 15:05:35 +08:00
commit ac09b26253
2048 changed files with 189478 additions and 0 deletions

View File

@@ -0,0 +1,91 @@
# HSSN
## Description
Author: AI-Tianlong
This project implements `Deep Hierarchical Semantic Segmentation` inference on `cityscapes` dataset
## Usage
### Prerequisites
- Python 3.8
- PyTorch 1.6 or higher
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
- mmcv v2.0.0rc4
- mmengine >=0.4.0
### Dataset preparing
Preparing `cityscapes` dataset following this [Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets)
### Testing commands
Please put [`hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth`](https://download.openmmlab.com/mmsegmentation/v0.5/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth) to `mmsegmentation/checkpoints`
#### Multi-GPUs Test
```bash
# --tta optional, multi-scale test, need mmengine >=0.4.0
bash tools/dist_test.sh [configs] [model weights] [number of gpu] --tta
```
#### Example
```shell
bash tools/dist_test.sh projects/hssn/configs/hssn/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py checkpoints/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth 2 --tta
```
## Results
### Cityscapes
| Method | Backbone | Crop Size | mIoU | mIoU (ms+flip) | config | model |
| :--------: | :------: | :-------: | :---: | :------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------: |
| DeeplabV3+ | R-101-D8 | 512x1024 | 81.61 | 82.71 | [config](https://github.com/open-mmlab/mmsegmentation/tree/main/projects/HieraSeg/configs/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth) |
<img src="https://user-images.githubusercontent.com/50650583/210488953-e3e35ade-1132-47e1-9dfd-cf12b357ae80.png" width="50%"><img src="https://user-images.githubusercontent.com/50650583/210489746-e35ee229-3234-4292-a649-a8cd85f312ad.png" width="50%">
## Citation
This project is modified from [qhanghu/HSSN_pytorch](https://github.com/qhanghu/HSSN_pytorch)
```bibtex
@article{li2022deep,
title={Deep Hierarchical Semantic Segmentation},
author={Li, Liulei and Zhou, Tianfei and Wang, Wenguan and Li, Jianwu and Yang, Yi},
journal={CVPR},
year={2022}
}
```
## Checklist
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
- [x] Finish the code
- [x] Basic docstrings & proper citation
- [x] Test-time correctness
- [x] A full README
- [ ] Milestone 2: Indicates a successful model implementation.
- [ ] Training-time correctness
- [ ] Milestone 3: Good to be a part of our core package!
- [ ] Type hints and docstrings
- [ ] Unit tests
- [ ] Code polishing
- [ ] Metafile.yml
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.

View File

@@ -0,0 +1,67 @@
# dataset settings
dataset_type = 'CityscapesDataset'
data_root = 'data/cityscapes/'
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='leftImg8bit/train', seg_map_path='gtFine/train'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='leftImg8bit/val', seg_map_path='gtFine/val'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

View File

@@ -0,0 +1,15 @@
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_processor = dict(by_epoch=False)
log_level = 'INFO'
load_from = None
resume = False
tta_model = dict(type='SegTTAModel')

View File

@@ -0,0 +1,55 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained=None,
backbone=dict(
type='ResNetV1d',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DepthwiseSeparableASPPContrastHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
c1_in_channels=256,
c1_channels=48,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
proj='convmlp',
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@@ -0,0 +1,24 @@
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=1e-4,
power=0.9,
begin=0,
end=80000,
by_epoch=False)
]
# training schedule for 80k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))

View File

@@ -0,0 +1,21 @@
_base_ = [
'../_base_/models/deeplabv3plus_r50-d8_vd_contrast.py',
'../_base_/datasets/cityscapes.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]
custom_imports = dict(imports=[
'projects.hssn.decode_head.sep_aspp_contrast_head',
'projects.hssn.losses.hiera_triplet_loss_cityscape'
])
model = dict(
pretrained=None,
backbone=dict(depth=101),
decode_head=dict(
num_classes=26,
loss_decode=dict(
type='HieraTripletLossCityscape', num_classes=19,
loss_weight=1.0)),
auxiliary_head=dict(num_classes=19),
test_cfg=dict(mode='whole', is_hiera=True, hiera_num_classes=7))

View File

@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .sep_aspp_contrast_head import DepthwiseSeparableASPPContrastHead
__all__ = ['DepthwiseSeparableASPPContrastHead']

View File

@@ -0,0 +1,193 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from torch import Tensor
from mmseg.models.decode_heads.sep_aspp_head import DepthwiseSeparableASPPHead
from mmseg.models.losses import accuracy
from mmseg.models.utils import resize
from mmseg.registry import MODELS
from mmseg.utils import SampleList
class ProjectionHead(nn.Module):
"""ProjectionHead, project feature map to specific channels.
Args:
dim_in (int): Input channels.
norm_cfg (dict): config of norm layer.
proj_dim (int): Output channels. Default: 256.
proj (str): Projection type, 'linear' or 'convmlp'. Default: 'convmlp'
"""
def __init__(self,
dim_in: int,
norm_cfg: dict,
proj_dim: int = 256,
proj: str = 'convmlp'):
super().__init__()
assert proj in ['convmlp', 'linear']
if proj == 'linear':
self.proj = nn.Conv2d(dim_in, proj_dim, kernel_size=1)
elif proj == 'convmlp':
self.proj = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=1),
build_norm_layer(norm_cfg, dim_in)[1], nn.ReLU(inplace=True),
nn.Conv2d(dim_in, proj_dim, kernel_size=1))
def forward(self, x):
return torch.nn.functional.normalize(self.proj(x), p=2, dim=1)
@MODELS.register_module()
class DepthwiseSeparableASPPContrastHead(DepthwiseSeparableASPPHead):
"""Deep Hierarchical Semantic Segmentation. This head is the implementation
of `<https://arxiv.org/abs/2203.14335>`_.
Based on Encoder-Decoder with Atrous Separable Convolution for
Semantic Image Segmentation.
`DeepLabV3+ <https://arxiv.org/abs/1802.02611>`_.
Args:
proj (str): The type of ProjectionHead, 'linear' or 'convmlp',
default 'convmlp'
"""
def __init__(self, proj: str = 'convmlp', **kwargs):
super().__init__(**kwargs)
self.proj_head = ProjectionHead(
dim_in=2048, norm_cfg=self.norm_cfg, proj=proj)
self.register_buffer('step', torch.zeros(1))
def forward(self, inputs) -> Tuple[Tensor]:
"""Forward function."""
output = super().forward(inputs)
self.step += 1
embedding = self.proj_head(inputs[-1])
return output, embedding
def predict_by_feat(self, seg_logits: Tuple[Tensor],
batch_img_metas: List[dict]) -> Tensor:
"""Transform a batch of output seg_logits to the input shape.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
Tensor: Outputs segmentation logits map.
"""
# HSSN decode_head output is: (out, embedding): tuple
# only need 'out' here.
if isinstance(seg_logits, tuple):
seg_logit = seg_logits[0]
if seg_logit.size(1) == 26: # For cityscapes dataset19 + 7
hiera_num_classes = 7
seg_logit[:, 0:2] += seg_logit[:, -7]
seg_logit[:, 2:5] += seg_logit[:, -6]
seg_logit[:, 5:8] += seg_logit[:, -5]
seg_logit[:, 8:10] += seg_logit[:, -4]
seg_logit[:, 10:11] += seg_logit[:, -3]
seg_logit[:, 11:13] += seg_logit[:, -2]
seg_logit[:, 13:19] += seg_logit[:, -1]
elif seg_logit.size(1) == 12: # For Pascal_person dataset, 7 + 5
hiera_num_classes = 5
seg_logit[:, 0:1] = seg_logit[:, 0:1] + \
seg_logit[:, 7] + seg_logit[:, 10]
seg_logit[:, 1:5] = seg_logit[:, 1:5] + \
seg_logit[:, 8] + seg_logit[:, 11]
seg_logit[:, 5:7] = seg_logit[:, 5:7] + \
seg_logit[:, 9] + seg_logit[:, 11]
elif seg_logit.size(1) == 25: # For LIP dataset, 20 + 5
hiera_num_classes = 5
seg_logit[:, 0:1] = seg_logit[:, 0:1] + \
seg_logit[:, 20] + seg_logit[:, 23]
seg_logit[:, 1:8] = seg_logit[:, 1:8] + \
seg_logit[:, 21] + seg_logit[:, 24]
seg_logit[:, 10:12] = seg_logit[:, 10:12] + \
seg_logit[:, 21] + seg_logit[:, 24]
seg_logit[:, 13:16] = seg_logit[:, 13:16] + \
seg_logit[:, 21] + seg_logit[:, 24]
seg_logit[:, 8:10] = seg_logit[:, 8:10] + \
seg_logit[:, 22] + seg_logit[:, 24]
seg_logit[:, 12:13] = seg_logit[:, 12:13] + \
seg_logit[:, 22] + seg_logit[:, 24]
seg_logit[:, 16:20] = seg_logit[:, 16:20] + \
seg_logit[:, 22] + seg_logit[:, 24]
# elif seg_logit.size(1) == 144 # For Mapillary dataset, 124+16+4
# unofficial repository not release mapillary until 2023/2/6
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
# slide inference
size = batch_img_metas[0]['img_shape']
elif 'pad_shape' in batch_img_metas[0]:
size = batch_img_metas[0]['pad_shape'][:2]
else:
size = batch_img_metas[0]['img_shape']
seg_logit = seg_logit[:, :-hiera_num_classes]
seg_logit = resize(
input=seg_logit,
size=size,
mode='bilinear',
align_corners=self.align_corners)
return seg_logit
def loss_by_feat(
self,
seg_logits: Tuple[Tensor], # (out, embedding)
batch_data_samples: SampleList) -> dict:
"""Compute segmentation loss. Will fix in future.
Args:
seg_logits (Tuple[Tensor]): The output from decode head
forward function.
For this decode_head output are (out, embedding): tuple
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logit_before = seg_logits[0]
embedding = seg_logits[1]
seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict()
seg_logit = resize(
input=seg_logit_before,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logit, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
seg_logit_before = resize(
input=seg_logit_before,
scale_factor=0.5,
mode='bilinear',
align_corners=self.align_corners)
loss['loss_seg'] = self.loss_decode(
self.step,
embedding,
seg_logit_before,
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(seg_logit, seg_label)
return loss

View File

@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hiera_triplet_loss_cityscape import HieraTripletLossCityscape
__all__ = ['HieraTripletLossCityscape']

View File

@@ -0,0 +1,218 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.models.builder import LOSSES
from mmseg.models.losses.cross_entropy_loss import CrossEntropyLoss
from .tree_triplet_loss import TreeTripletLoss
hiera_map = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 5, 6, 6, 6, 6, 6, 6]
hiera_index = [[0, 2], [2, 5], [5, 8], [8, 10], [10, 11], [11, 13], [13, 19]]
hiera = {
'hiera_high': {
'flat': [0, 2],
'construction': [2, 5],
'object': [5, 8],
'nature': [8, 10],
'sky': [10, 11],
'human': [11, 13],
'vehicle': [13, 19]
}
}
def prepare_targets(targets):
b, h, w = targets.shape
targets_high = torch.ones(
(b, h, w), dtype=targets.dtype, device=targets.device) * 255
indices_high = []
for index, high in enumerate(hiera['hiera_high'].keys()):
indices = hiera['hiera_high'][high]
for ii in range(indices[0], indices[1]):
targets_high[targets == ii] = index
indices_high.append(indices)
return targets, targets_high, indices_high
def losses_hiera(predictions,
targets,
targets_top,
num_classes,
indices_high,
eps=1e-8):
"""Implementation of hiera loss.
Args:
predictions (torch.Tensor): seg logits produced by decode head.
targets (torch.Tensor): The learning label of the prediction.
targets_top (torch.Tensor): The hierarchy ground truth of the learning
label.
num_classes (int): Number of categories.
indices_high (List[List[int]]): Hierarchy indices of each hierarchy.
eps (float):Term added to the Logarithm to improve numerical stability.
"""
b, _, h, w = predictions.shape
predictions = torch.sigmoid(predictions.float())
void_indices = (targets == 255)
targets[void_indices] = 0
targets = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2)
void_indices2 = (targets_top == 255)
targets_top[void_indices2] = 0
targets_top = F.one_hot(targets_top, num_classes=7).permute(0, 3, 1, 2)
MCMA = predictions[:, :num_classes, :, :]
MCMB = torch.zeros((b, 7, h, w)).to(predictions)
for ii in range(7):
MCMB[:, ii:ii + 1, :, :] = torch.max(
torch.cat([
predictions[:, indices_high[ii][0]:indices_high[ii][1], :, :],
predictions[:, num_classes + ii:num_classes + ii + 1, :, :]
],
dim=1), 1, True)[0]
MCLB = predictions[:, num_classes:num_classes + 7, :, :]
MCLA = predictions[:, :num_classes, :, :].clone()
for ii in range(7):
for jj in range(indices_high[ii][0], indices_high[ii][1]):
MCLA[:, jj:jj + 1, :, :] = torch.min(
torch.cat([
predictions[:, jj:jj + 1, :, :], MCLB[:, ii:ii + 1, :, :]
],
dim=1), 1, True)[0]
valid_indices = (~void_indices).unsqueeze(1)
num_valid = valid_indices.sum()
valid_indices2 = (~void_indices2).unsqueeze(1)
num_valid2 = valid_indices2.sum()
# channel_num*sum()/one_channel_valid already has a weight
loss = (
(-targets[:, :num_classes, :, :] * torch.log(MCLA + eps) -
(1.0 - targets[:, :num_classes, :, :]) * torch.log(1.0 - MCMA + eps))
* valid_indices).sum() / num_valid / num_classes
loss += ((-targets_top[:, :, :, :] * torch.log(MCLB + eps) -
(1.0 - targets_top[:, :, :, :]) * torch.log(1.0 - MCMB + eps)) *
valid_indices2).sum() / num_valid2 / 7
return 5 * loss
def losses_hiera_focal(predictions,
targets,
targets_top,
num_classes,
indices_high,
eps=1e-8,
gamma=2):
"""Implementation of hiera loss.
Args:
predictions (torch.Tensor): seg logits produced by decode head.
targets (torch.Tensor): The learning label of the prediction.
targets_top (torch.Tensor): The hierarchy ground truth of the learning
label.
num_classes (int): Number of categories.
indices_high (List[List[int]]): Hierarchy indices of each hierarchy.
eps (float):Term added to the Logarithm to improve numerical stability.
Defaults: 1e-8.
gamma (int): The exponent value. Defaults: 2.
"""
b, _, h, w = predictions.shape
predictions = torch.sigmoid(predictions.float())
void_indices = (targets == 255)
targets[void_indices] = 0
targets = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2)
void_indices2 = (targets_top == 255)
targets_top[void_indices2] = 0
targets_top = F.one_hot(targets_top, num_classes=7).permute(0, 3, 1, 2)
MCMA = predictions[:, :num_classes, :, :]
MCMB = torch.zeros((b, 7, h, w),
dtype=predictions.dtype,
device=predictions.device)
for ii in range(7):
MCMB[:, ii:ii + 1, :, :] = torch.max(
torch.cat([
predictions[:, indices_high[ii][0]:indices_high[ii][1], :, :],
predictions[:, num_classes + ii:num_classes + ii + 1, :, :]
],
dim=1), 1, True)[0]
MCLB = predictions[:, num_classes:num_classes + 7, :, :]
MCLA = predictions[:, :num_classes, :, :].clone()
for ii in range(7):
for jj in range(indices_high[ii][0], indices_high[ii][1]):
MCLA[:, jj:jj + 1, :, :] = torch.min(
torch.cat([
predictions[:, jj:jj + 1, :, :], MCLB[:, ii:ii + 1, :, :]
],
dim=1), 1, True)[0]
valid_indices = (~void_indices).unsqueeze(1)
num_valid = valid_indices.sum()
valid_indices2 = (~void_indices2).unsqueeze(1)
num_valid2 = valid_indices2.sum()
# channel_num*sum()/one_channel_valid already has a weight
loss = ((-targets[:, :num_classes, :, :] * torch.pow(
(1.0 - MCLA), gamma) * torch.log(MCLA + eps) -
(1.0 - targets[:, :num_classes, :, :]) * torch.pow(MCMA, gamma) *
torch.log(1.0 - MCMA + eps)) *
valid_indices).sum() / num_valid / num_classes
loss += (
(-targets_top[:, :, :, :] * torch.pow(
(1.0 - MCLB), gamma) * torch.log(MCLB + eps) -
(1.0 - targets_top[:, :, :, :]) * torch.pow(MCMB, gamma) *
torch.log(1.0 - MCMB + eps)) * valid_indices2).sum() / num_valid2 / 7
return 5 * loss
@LOSSES.register_module()
class HieraTripletLossCityscape(nn.Module):
"""Modified from https://github.com/qhanghu/HSSN_pytorch/blob/main/mmseg/mo
dels/losses/hiera_triplet_loss_cityscape.py."""
def __init__(self, num_classes, use_sigmoid=False, loss_weight=1.0):
super().__init__()
self.num_classes = num_classes
self.loss_weight = loss_weight
self.treetripletloss = TreeTripletLoss(num_classes, hiera_map,
hiera_index)
self.ce = CrossEntropyLoss()
def forward(self,
step,
embedding,
cls_score_before,
cls_score,
label,
weight=None,
**kwargs):
targets, targets_top, indices_top = prepare_targets(label)
loss = losses_hiera(cls_score, targets, targets_top, self.num_classes,
indices_top)
ce_loss = self.ce(cls_score[:, :-7], label)
ce_loss2 = self.ce(cls_score[:, -7:], targets_top)
loss = loss + ce_loss + ce_loss2
loss_triplet, class_count = self.treetripletloss(embedding, label)
class_counts = [
torch.ones_like(class_count)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(class_counts, class_count, async_op=False)
class_counts = torch.cat(class_counts, dim=0)
if torch.distributed.get_world_size() == torch.nonzero(
class_counts, as_tuple=False).size(0):
factor = 1 / 4 * (1 + torch.cos(
torch.tensor((step.item() - 80000) / 80000 *
math.pi))) if step.item() < 80000 else 0.5
loss += factor * loss_triplet
return loss * self.loss_weight

View File

@@ -0,0 +1,86 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.models.builder import LOSSES
@LOSSES.register_module()
class TreeTripletLoss(nn.Module):
"""TreeTripletLoss. Modified from https://github.com/qhanghu/HSSN_pytorch/b
lob/main/mmseg/models/losses/tree_triplet_loss.py.
Args:
num_classes (int): Number of categories.
hiera_map (List[int]): Hierarchy map of each category.
hiera_index (List[List[int]]): Hierarchy indices of each hierarchy.
ignore_index (int): Specifies a target value that is ignored and
does not contribute to the input gradients. Defaults: 255.
Examples:
>>> num_classes = 19
>>> hiera_map = [
0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 5, 6, 6, 6, 6, 6, 6]
>>> hiera_index = [
0, 2], [2, 5], [5, 8], [8, 10], [10, 11], [11, 13], [13, 19]]
"""
def __init__(self, num_classes, hiera_map, hiera_index, ignore_index=255):
super().__init__()
self.ignore_label = ignore_index
self.num_classes = num_classes
self.hiera_map = hiera_map
self.hiera_index = hiera_index
def forward(self, feats: torch.Tensor, labels=None, max_triplet=200):
labels = labels.unsqueeze(1).float().clone()
labels = torch.nn.functional.interpolate(
labels, (feats.shape[2], feats.shape[3]), mode='nearest')
labels = labels.squeeze(1).long()
assert labels.shape[-1] == feats.shape[-1], '{} {}'.format(
labels.shape, feats.shape)
labels = labels.view(-1)
feats = feats.permute(0, 2, 3, 1)
feats = feats.contiguous().view(-1, feats.shape[-1])
triplet_loss = 0
exist_classes = torch.unique(labels)
exist_classes = [x for x in exist_classes if x != 255]
class_count = 0
for ii in exist_classes:
index_range = self.hiera_index[self.hiera_map[ii]]
index_anchor = labels == ii
index_pos = (labels >= index_range[0]) & (
labels < index_range[-1]) & (~index_anchor)
index_neg = (labels < index_range[0]) | (labels >= index_range[-1])
min_size = min(
torch.sum(index_anchor), torch.sum(index_pos),
torch.sum(index_neg), max_triplet)
feats_anchor = feats[index_anchor][:min_size]
feats_pos = feats[index_pos][:min_size]
feats_neg = feats[index_neg][:min_size]
distance = torch.zeros(min_size, 2).to(feats)
distance[:, 0:1] = 1 - (feats_anchor * feats_pos).sum(1, True)
distance[:, 1:2] = 1 - (feats_anchor * feats_neg).sum(1, True)
# margin always 0.1 + (4-2)/4 since the hierarchy is three level
# TODO: should include label of pos is the same as anchor
margin = 0.6 * torch.ones(min_size).to(feats)
tl = distance[:, 0] - distance[:, 1] + margin
tl = F.relu(tl)
if tl.size(0) > 0:
triplet_loss += tl.mean()
class_count += 1
if class_count == 0:
return None, torch.tensor([0]).to(feats)
triplet_loss /= class_count
return triplet_loss, torch.tensor([class_count]).to(feats)