first commit
This commit is contained in:
91
Seg_All_In_One_MMSeg/projects/hssn/README.md
Normal file
91
Seg_All_In_One_MMSeg/projects/hssn/README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# HSSN
|
||||
|
||||
## Description
|
||||
|
||||
Author: AI-Tianlong
|
||||
|
||||
This project implements `Deep Hierarchical Semantic Segmentation` inference on `cityscapes` dataset
|
||||
|
||||
## Usage
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.8
|
||||
- PyTorch 1.6 or higher
|
||||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
|
||||
- mmcv v2.0.0rc4
|
||||
- mmengine >=0.4.0
|
||||
|
||||
### Dataset preparing
|
||||
|
||||
Preparing `cityscapes` dataset following this [Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets)
|
||||
|
||||
### Testing commands
|
||||
|
||||
Please put [`hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth`](https://download.openmmlab.com/mmsegmentation/v0.5/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth) to `mmsegmentation/checkpoints`
|
||||
|
||||
#### Multi-GPUs Test
|
||||
|
||||
```bash
|
||||
# --tta optional, multi-scale test, need mmengine >=0.4.0
|
||||
bash tools/dist_test.sh [configs] [model weights] [number of gpu] --tta
|
||||
```
|
||||
|
||||
#### Example
|
||||
|
||||
```shell
|
||||
bash tools/dist_test.sh projects/hssn/configs/hssn/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py checkpoints/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth 2 --tta
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
### Cityscapes
|
||||
|
||||
| Method | Backbone | Crop Size | mIoU | mIoU (ms+flip) | config | model |
|
||||
| :--------: | :------: | :-------: | :---: | :------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| DeeplabV3+ | R-101-D8 | 512x1024 | 81.61 | 82.71 | [config](https://github.com/open-mmlab/mmsegmentation/tree/main/projects/HieraSeg/configs/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth) |
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/50650583/210488953-e3e35ade-1132-47e1-9dfd-cf12b357ae80.png" width="50%"><img src="https://user-images.githubusercontent.com/50650583/210489746-e35ee229-3234-4292-a649-a8cd85f312ad.png" width="50%">
|
||||
|
||||
## Citation
|
||||
|
||||
This project is modified from [qhanghu/HSSN_pytorch](https://github.com/qhanghu/HSSN_pytorch)
|
||||
|
||||
```bibtex
|
||||
@article{li2022deep,
|
||||
title={Deep Hierarchical Semantic Segmentation},
|
||||
author={Li, Liulei and Zhou, Tianfei and Wang, Wenguan and Li, Jianwu and Yang, Yi},
|
||||
journal={CVPR},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Checklist
|
||||
|
||||
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
|
||||
|
||||
- [x] Finish the code
|
||||
|
||||
- [x] Basic docstrings & proper citation
|
||||
|
||||
- [x] Test-time correctness
|
||||
|
||||
- [x] A full README
|
||||
|
||||
- [ ] Milestone 2: Indicates a successful model implementation.
|
||||
|
||||
- [ ] Training-time correctness
|
||||
|
||||
- [ ] Milestone 3: Good to be a part of our core package!
|
||||
|
||||
- [ ] Type hints and docstrings
|
||||
|
||||
- [ ] Unit tests
|
||||
|
||||
- [ ] Code polishing
|
||||
|
||||
- [ ] Metafile.yml
|
||||
|
||||
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
|
||||
|
||||
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
|
||||
@@ -0,0 +1,67 @@
|
||||
# dataset settings
|
||||
dataset_type = 'CityscapesDataset'
|
||||
data_root = 'data/cityscapes/'
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 1024),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='leftImg8bit/train', seg_map_path='gtFine/train'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='leftImg8bit/val', seg_map_path='gtFine/val'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,15 @@
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
)
|
||||
vis_backends = [dict(type='LocalVisBackend')]
|
||||
visualizer = dict(
|
||||
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||
log_processor = dict(by_epoch=False)
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume = False
|
||||
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
@@ -0,0 +1,55 @@
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_val=0,
|
||||
seg_pad_val=255)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
data_preprocessor=data_preprocessor,
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNetV1d',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='DepthwiseSeparableASPPContrastHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
dilations=(1, 12, 24, 36),
|
||||
c1_in_channels=256,
|
||||
c1_channels=48,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
proj='convmlp',
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
@@ -0,0 +1,24 @@
|
||||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='PolyLR',
|
||||
eta_min=1e-4,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end=80000,
|
||||
by_epoch=False)
|
||||
]
|
||||
# training schedule for 80k
|
||||
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
@@ -0,0 +1,21 @@
|
||||
_base_ = [
|
||||
'../_base_/models/deeplabv3plus_r50-d8_vd_contrast.py',
|
||||
'../_base_/datasets/cityscapes.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
|
||||
custom_imports = dict(imports=[
|
||||
'projects.hssn.decode_head.sep_aspp_contrast_head',
|
||||
'projects.hssn.losses.hiera_triplet_loss_cityscape'
|
||||
])
|
||||
|
||||
model = dict(
|
||||
pretrained=None,
|
||||
backbone=dict(depth=101),
|
||||
decode_head=dict(
|
||||
num_classes=26,
|
||||
loss_decode=dict(
|
||||
type='HieraTripletLossCityscape', num_classes=19,
|
||||
loss_weight=1.0)),
|
||||
auxiliary_head=dict(num_classes=19),
|
||||
test_cfg=dict(mode='whole', is_hiera=True, hiera_num_classes=7))
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .sep_aspp_contrast_head import DepthwiseSeparableASPPContrastHead
|
||||
|
||||
__all__ = ['DepthwiseSeparableASPPContrastHead']
|
||||
@@ -0,0 +1,193 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.sep_aspp_head import DepthwiseSeparableASPPHead
|
||||
from mmseg.models.losses import accuracy
|
||||
from mmseg.models.utils import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
|
||||
|
||||
class ProjectionHead(nn.Module):
|
||||
"""ProjectionHead, project feature map to specific channels.
|
||||
|
||||
Args:
|
||||
dim_in (int): Input channels.
|
||||
norm_cfg (dict): config of norm layer.
|
||||
proj_dim (int): Output channels. Default: 256.
|
||||
proj (str): Projection type, 'linear' or 'convmlp'. Default: 'convmlp'
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim_in: int,
|
||||
norm_cfg: dict,
|
||||
proj_dim: int = 256,
|
||||
proj: str = 'convmlp'):
|
||||
super().__init__()
|
||||
assert proj in ['convmlp', 'linear']
|
||||
if proj == 'linear':
|
||||
self.proj = nn.Conv2d(dim_in, proj_dim, kernel_size=1)
|
||||
elif proj == 'convmlp':
|
||||
self.proj = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, kernel_size=1),
|
||||
build_norm_layer(norm_cfg, dim_in)[1], nn.ReLU(inplace=True),
|
||||
nn.Conv2d(dim_in, proj_dim, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.normalize(self.proj(x), p=2, dim=1)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DepthwiseSeparableASPPContrastHead(DepthwiseSeparableASPPHead):
|
||||
"""Deep Hierarchical Semantic Segmentation. This head is the implementation
|
||||
of `<https://arxiv.org/abs/2203.14335>`_.
|
||||
|
||||
Based on Encoder-Decoder with Atrous Separable Convolution for
|
||||
Semantic Image Segmentation.
|
||||
`DeepLabV3+ <https://arxiv.org/abs/1802.02611>`_.
|
||||
|
||||
Args:
|
||||
proj (str): The type of ProjectionHead, 'linear' or 'convmlp',
|
||||
default 'convmlp'
|
||||
"""
|
||||
|
||||
def __init__(self, proj: str = 'convmlp', **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.proj_head = ProjectionHead(
|
||||
dim_in=2048, norm_cfg=self.norm_cfg, proj=proj)
|
||||
self.register_buffer('step', torch.zeros(1))
|
||||
|
||||
def forward(self, inputs) -> Tuple[Tensor]:
|
||||
"""Forward function."""
|
||||
output = super().forward(inputs)
|
||||
|
||||
self.step += 1
|
||||
embedding = self.proj_head(inputs[-1])
|
||||
|
||||
return output, embedding
|
||||
|
||||
def predict_by_feat(self, seg_logits: Tuple[Tensor],
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Transform a batch of output seg_logits to the input shape.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
# HSSN decode_head output is: (out, embedding): tuple
|
||||
# only need 'out' here.
|
||||
if isinstance(seg_logits, tuple):
|
||||
seg_logit = seg_logits[0]
|
||||
|
||||
if seg_logit.size(1) == 26: # For cityscapes dataset,19 + 7
|
||||
hiera_num_classes = 7
|
||||
seg_logit[:, 0:2] += seg_logit[:, -7]
|
||||
seg_logit[:, 2:5] += seg_logit[:, -6]
|
||||
seg_logit[:, 5:8] += seg_logit[:, -5]
|
||||
seg_logit[:, 8:10] += seg_logit[:, -4]
|
||||
seg_logit[:, 10:11] += seg_logit[:, -3]
|
||||
seg_logit[:, 11:13] += seg_logit[:, -2]
|
||||
seg_logit[:, 13:19] += seg_logit[:, -1]
|
||||
|
||||
elif seg_logit.size(1) == 12: # For Pascal_person dataset, 7 + 5
|
||||
hiera_num_classes = 5
|
||||
seg_logit[:, 0:1] = seg_logit[:, 0:1] + \
|
||||
seg_logit[:, 7] + seg_logit[:, 10]
|
||||
seg_logit[:, 1:5] = seg_logit[:, 1:5] + \
|
||||
seg_logit[:, 8] + seg_logit[:, 11]
|
||||
seg_logit[:, 5:7] = seg_logit[:, 5:7] + \
|
||||
seg_logit[:, 9] + seg_logit[:, 11]
|
||||
|
||||
elif seg_logit.size(1) == 25: # For LIP dataset, 20 + 5
|
||||
hiera_num_classes = 5
|
||||
seg_logit[:, 0:1] = seg_logit[:, 0:1] + \
|
||||
seg_logit[:, 20] + seg_logit[:, 23]
|
||||
seg_logit[:, 1:8] = seg_logit[:, 1:8] + \
|
||||
seg_logit[:, 21] + seg_logit[:, 24]
|
||||
seg_logit[:, 10:12] = seg_logit[:, 10:12] + \
|
||||
seg_logit[:, 21] + seg_logit[:, 24]
|
||||
seg_logit[:, 13:16] = seg_logit[:, 13:16] + \
|
||||
seg_logit[:, 21] + seg_logit[:, 24]
|
||||
seg_logit[:, 8:10] = seg_logit[:, 8:10] + \
|
||||
seg_logit[:, 22] + seg_logit[:, 24]
|
||||
seg_logit[:, 12:13] = seg_logit[:, 12:13] + \
|
||||
seg_logit[:, 22] + seg_logit[:, 24]
|
||||
seg_logit[:, 16:20] = seg_logit[:, 16:20] + \
|
||||
seg_logit[:, 22] + seg_logit[:, 24]
|
||||
|
||||
# elif seg_logit.size(1) == 144 # For Mapillary dataset, 124+16+4
|
||||
# unofficial repository not release mapillary until 2023/2/6
|
||||
|
||||
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
|
||||
# slide inference
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
elif 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape'][:2]
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
seg_logit = seg_logit[:, :-hiera_num_classes]
|
||||
seg_logit = resize(
|
||||
input=seg_logit,
|
||||
size=size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
return seg_logit
|
||||
|
||||
def loss_by_feat(
|
||||
self,
|
||||
seg_logits: Tuple[Tensor], # (out, embedding)
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute segmentation loss. Will fix in future.
|
||||
|
||||
Args:
|
||||
seg_logits (Tuple[Tensor]): The output from decode head
|
||||
forward function.
|
||||
For this decode_head output are (out, embedding): tuple
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logit_before = seg_logits[0]
|
||||
embedding = seg_logits[1]
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
|
||||
loss = dict()
|
||||
seg_logit = resize(
|
||||
input=seg_logit_before,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.sampler is not None:
|
||||
seg_weight = self.sampler.sample(seg_logit, seg_label)
|
||||
else:
|
||||
seg_weight = None
|
||||
seg_label = seg_label.squeeze(1)
|
||||
seg_logit_before = resize(
|
||||
input=seg_logit_before,
|
||||
scale_factor=0.5,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
loss['loss_seg'] = self.loss_decode(
|
||||
self.step,
|
||||
embedding,
|
||||
seg_logit_before,
|
||||
seg_logit,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
loss['acc_seg'] = accuracy(seg_logit, seg_label)
|
||||
return loss
|
||||
4
Seg_All_In_One_MMSeg/projects/hssn/losses/__init__.py
Normal file
4
Seg_All_In_One_MMSeg/projects/hssn/losses/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hiera_triplet_loss_cityscape import HieraTripletLossCityscape
|
||||
|
||||
__all__ = ['HieraTripletLossCityscape']
|
||||
@@ -0,0 +1,218 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmseg.models.builder import LOSSES
|
||||
from mmseg.models.losses.cross_entropy_loss import CrossEntropyLoss
|
||||
from .tree_triplet_loss import TreeTripletLoss
|
||||
|
||||
hiera_map = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 5, 6, 6, 6, 6, 6, 6]
|
||||
hiera_index = [[0, 2], [2, 5], [5, 8], [8, 10], [10, 11], [11, 13], [13, 19]]
|
||||
|
||||
hiera = {
|
||||
'hiera_high': {
|
||||
'flat': [0, 2],
|
||||
'construction': [2, 5],
|
||||
'object': [5, 8],
|
||||
'nature': [8, 10],
|
||||
'sky': [10, 11],
|
||||
'human': [11, 13],
|
||||
'vehicle': [13, 19]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def prepare_targets(targets):
|
||||
b, h, w = targets.shape
|
||||
targets_high = torch.ones(
|
||||
(b, h, w), dtype=targets.dtype, device=targets.device) * 255
|
||||
indices_high = []
|
||||
for index, high in enumerate(hiera['hiera_high'].keys()):
|
||||
indices = hiera['hiera_high'][high]
|
||||
for ii in range(indices[0], indices[1]):
|
||||
targets_high[targets == ii] = index
|
||||
indices_high.append(indices)
|
||||
|
||||
return targets, targets_high, indices_high
|
||||
|
||||
|
||||
def losses_hiera(predictions,
|
||||
targets,
|
||||
targets_top,
|
||||
num_classes,
|
||||
indices_high,
|
||||
eps=1e-8):
|
||||
"""Implementation of hiera loss.
|
||||
|
||||
Args:
|
||||
predictions (torch.Tensor): seg logits produced by decode head.
|
||||
targets (torch.Tensor): The learning label of the prediction.
|
||||
targets_top (torch.Tensor): The hierarchy ground truth of the learning
|
||||
label.
|
||||
num_classes (int): Number of categories.
|
||||
indices_high (List[List[int]]): Hierarchy indices of each hierarchy.
|
||||
eps (float):Term added to the Logarithm to improve numerical stability.
|
||||
"""
|
||||
b, _, h, w = predictions.shape
|
||||
predictions = torch.sigmoid(predictions.float())
|
||||
void_indices = (targets == 255)
|
||||
targets[void_indices] = 0
|
||||
targets = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2)
|
||||
void_indices2 = (targets_top == 255)
|
||||
targets_top[void_indices2] = 0
|
||||
targets_top = F.one_hot(targets_top, num_classes=7).permute(0, 3, 1, 2)
|
||||
|
||||
MCMA = predictions[:, :num_classes, :, :]
|
||||
MCMB = torch.zeros((b, 7, h, w)).to(predictions)
|
||||
for ii in range(7):
|
||||
MCMB[:, ii:ii + 1, :, :] = torch.max(
|
||||
torch.cat([
|
||||
predictions[:, indices_high[ii][0]:indices_high[ii][1], :, :],
|
||||
predictions[:, num_classes + ii:num_classes + ii + 1, :, :]
|
||||
],
|
||||
dim=1), 1, True)[0]
|
||||
|
||||
MCLB = predictions[:, num_classes:num_classes + 7, :, :]
|
||||
MCLA = predictions[:, :num_classes, :, :].clone()
|
||||
for ii in range(7):
|
||||
for jj in range(indices_high[ii][0], indices_high[ii][1]):
|
||||
MCLA[:, jj:jj + 1, :, :] = torch.min(
|
||||
torch.cat([
|
||||
predictions[:, jj:jj + 1, :, :], MCLB[:, ii:ii + 1, :, :]
|
||||
],
|
||||
dim=1), 1, True)[0]
|
||||
|
||||
valid_indices = (~void_indices).unsqueeze(1)
|
||||
num_valid = valid_indices.sum()
|
||||
valid_indices2 = (~void_indices2).unsqueeze(1)
|
||||
num_valid2 = valid_indices2.sum()
|
||||
# channel_num*sum()/one_channel_valid already has a weight
|
||||
loss = (
|
||||
(-targets[:, :num_classes, :, :] * torch.log(MCLA + eps) -
|
||||
(1.0 - targets[:, :num_classes, :, :]) * torch.log(1.0 - MCMA + eps))
|
||||
* valid_indices).sum() / num_valid / num_classes
|
||||
loss += ((-targets_top[:, :, :, :] * torch.log(MCLB + eps) -
|
||||
(1.0 - targets_top[:, :, :, :]) * torch.log(1.0 - MCMB + eps)) *
|
||||
valid_indices2).sum() / num_valid2 / 7
|
||||
|
||||
return 5 * loss
|
||||
|
||||
|
||||
def losses_hiera_focal(predictions,
|
||||
targets,
|
||||
targets_top,
|
||||
num_classes,
|
||||
indices_high,
|
||||
eps=1e-8,
|
||||
gamma=2):
|
||||
"""Implementation of hiera loss.
|
||||
|
||||
Args:
|
||||
predictions (torch.Tensor): seg logits produced by decode head.
|
||||
targets (torch.Tensor): The learning label of the prediction.
|
||||
targets_top (torch.Tensor): The hierarchy ground truth of the learning
|
||||
label.
|
||||
num_classes (int): Number of categories.
|
||||
indices_high (List[List[int]]): Hierarchy indices of each hierarchy.
|
||||
eps (float):Term added to the Logarithm to improve numerical stability.
|
||||
Defaults: 1e-8.
|
||||
gamma (int): The exponent value. Defaults: 2.
|
||||
"""
|
||||
b, _, h, w = predictions.shape
|
||||
predictions = torch.sigmoid(predictions.float())
|
||||
void_indices = (targets == 255)
|
||||
targets[void_indices] = 0
|
||||
targets = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2)
|
||||
void_indices2 = (targets_top == 255)
|
||||
targets_top[void_indices2] = 0
|
||||
targets_top = F.one_hot(targets_top, num_classes=7).permute(0, 3, 1, 2)
|
||||
|
||||
MCMA = predictions[:, :num_classes, :, :]
|
||||
MCMB = torch.zeros((b, 7, h, w),
|
||||
dtype=predictions.dtype,
|
||||
device=predictions.device)
|
||||
for ii in range(7):
|
||||
MCMB[:, ii:ii + 1, :, :] = torch.max(
|
||||
torch.cat([
|
||||
predictions[:, indices_high[ii][0]:indices_high[ii][1], :, :],
|
||||
predictions[:, num_classes + ii:num_classes + ii + 1, :, :]
|
||||
],
|
||||
dim=1), 1, True)[0]
|
||||
|
||||
MCLB = predictions[:, num_classes:num_classes + 7, :, :]
|
||||
MCLA = predictions[:, :num_classes, :, :].clone()
|
||||
for ii in range(7):
|
||||
for jj in range(indices_high[ii][0], indices_high[ii][1]):
|
||||
MCLA[:, jj:jj + 1, :, :] = torch.min(
|
||||
torch.cat([
|
||||
predictions[:, jj:jj + 1, :, :], MCLB[:, ii:ii + 1, :, :]
|
||||
],
|
||||
dim=1), 1, True)[0]
|
||||
|
||||
valid_indices = (~void_indices).unsqueeze(1)
|
||||
num_valid = valid_indices.sum()
|
||||
valid_indices2 = (~void_indices2).unsqueeze(1)
|
||||
num_valid2 = valid_indices2.sum()
|
||||
# channel_num*sum()/one_channel_valid already has a weight
|
||||
loss = ((-targets[:, :num_classes, :, :] * torch.pow(
|
||||
(1.0 - MCLA), gamma) * torch.log(MCLA + eps) -
|
||||
(1.0 - targets[:, :num_classes, :, :]) * torch.pow(MCMA, gamma) *
|
||||
torch.log(1.0 - MCMA + eps)) *
|
||||
valid_indices).sum() / num_valid / num_classes
|
||||
loss += (
|
||||
(-targets_top[:, :, :, :] * torch.pow(
|
||||
(1.0 - MCLB), gamma) * torch.log(MCLB + eps) -
|
||||
(1.0 - targets_top[:, :, :, :]) * torch.pow(MCMB, gamma) *
|
||||
torch.log(1.0 - MCMB + eps)) * valid_indices2).sum() / num_valid2 / 7
|
||||
|
||||
return 5 * loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class HieraTripletLossCityscape(nn.Module):
|
||||
"""Modified from https://github.com/qhanghu/HSSN_pytorch/blob/main/mmseg/mo
|
||||
dels/losses/hiera_triplet_loss_cityscape.py."""
|
||||
|
||||
def __init__(self, num_classes, use_sigmoid=False, loss_weight=1.0):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.loss_weight = loss_weight
|
||||
self.treetripletloss = TreeTripletLoss(num_classes, hiera_map,
|
||||
hiera_index)
|
||||
self.ce = CrossEntropyLoss()
|
||||
|
||||
def forward(self,
|
||||
step,
|
||||
embedding,
|
||||
cls_score_before,
|
||||
cls_score,
|
||||
label,
|
||||
weight=None,
|
||||
**kwargs):
|
||||
targets, targets_top, indices_top = prepare_targets(label)
|
||||
|
||||
loss = losses_hiera(cls_score, targets, targets_top, self.num_classes,
|
||||
indices_top)
|
||||
ce_loss = self.ce(cls_score[:, :-7], label)
|
||||
ce_loss2 = self.ce(cls_score[:, -7:], targets_top)
|
||||
loss = loss + ce_loss + ce_loss2
|
||||
|
||||
loss_triplet, class_count = self.treetripletloss(embedding, label)
|
||||
class_counts = [
|
||||
torch.ones_like(class_count)
|
||||
for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(class_counts, class_count, async_op=False)
|
||||
class_counts = torch.cat(class_counts, dim=0)
|
||||
|
||||
if torch.distributed.get_world_size() == torch.nonzero(
|
||||
class_counts, as_tuple=False).size(0):
|
||||
factor = 1 / 4 * (1 + torch.cos(
|
||||
torch.tensor((step.item() - 80000) / 80000 *
|
||||
math.pi))) if step.item() < 80000 else 0.5
|
||||
loss += factor * loss_triplet
|
||||
|
||||
return loss * self.loss_weight
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmseg.models.builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class TreeTripletLoss(nn.Module):
|
||||
"""TreeTripletLoss. Modified from https://github.com/qhanghu/HSSN_pytorch/b
|
||||
lob/main/mmseg/models/losses/tree_triplet_loss.py.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories.
|
||||
hiera_map (List[int]): Hierarchy map of each category.
|
||||
hiera_index (List[List[int]]): Hierarchy indices of each hierarchy.
|
||||
ignore_index (int): Specifies a target value that is ignored and
|
||||
does not contribute to the input gradients. Defaults: 255.
|
||||
|
||||
Examples:
|
||||
>>> num_classes = 19
|
||||
>>> hiera_map = [
|
||||
0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 5, 6, 6, 6, 6, 6, 6]
|
||||
>>> hiera_index = [
|
||||
0, 2], [2, 5], [5, 8], [8, 10], [10, 11], [11, 13], [13, 19]]
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, hiera_map, hiera_index, ignore_index=255):
|
||||
super().__init__()
|
||||
|
||||
self.ignore_label = ignore_index
|
||||
self.num_classes = num_classes
|
||||
self.hiera_map = hiera_map
|
||||
self.hiera_index = hiera_index
|
||||
|
||||
def forward(self, feats: torch.Tensor, labels=None, max_triplet=200):
|
||||
labels = labels.unsqueeze(1).float().clone()
|
||||
labels = torch.nn.functional.interpolate(
|
||||
labels, (feats.shape[2], feats.shape[3]), mode='nearest')
|
||||
labels = labels.squeeze(1).long()
|
||||
assert labels.shape[-1] == feats.shape[-1], '{} {}'.format(
|
||||
labels.shape, feats.shape)
|
||||
|
||||
labels = labels.view(-1)
|
||||
feats = feats.permute(0, 2, 3, 1)
|
||||
feats = feats.contiguous().view(-1, feats.shape[-1])
|
||||
|
||||
triplet_loss = 0
|
||||
exist_classes = torch.unique(labels)
|
||||
exist_classes = [x for x in exist_classes if x != 255]
|
||||
class_count = 0
|
||||
|
||||
for ii in exist_classes:
|
||||
index_range = self.hiera_index[self.hiera_map[ii]]
|
||||
index_anchor = labels == ii
|
||||
index_pos = (labels >= index_range[0]) & (
|
||||
labels < index_range[-1]) & (~index_anchor)
|
||||
index_neg = (labels < index_range[0]) | (labels >= index_range[-1])
|
||||
|
||||
min_size = min(
|
||||
torch.sum(index_anchor), torch.sum(index_pos),
|
||||
torch.sum(index_neg), max_triplet)
|
||||
|
||||
feats_anchor = feats[index_anchor][:min_size]
|
||||
feats_pos = feats[index_pos][:min_size]
|
||||
feats_neg = feats[index_neg][:min_size]
|
||||
|
||||
distance = torch.zeros(min_size, 2).to(feats)
|
||||
distance[:, 0:1] = 1 - (feats_anchor * feats_pos).sum(1, True)
|
||||
distance[:, 1:2] = 1 - (feats_anchor * feats_neg).sum(1, True)
|
||||
|
||||
# margin always 0.1 + (4-2)/4 since the hierarchy is three level
|
||||
# TODO: should include label of pos is the same as anchor
|
||||
margin = 0.6 * torch.ones(min_size).to(feats)
|
||||
|
||||
tl = distance[:, 0] - distance[:, 1] + margin
|
||||
tl = F.relu(tl)
|
||||
|
||||
if tl.size(0) > 0:
|
||||
triplet_loss += tl.mean()
|
||||
class_count += 1
|
||||
if class_count == 0:
|
||||
return None, torch.tensor([0]).to(feats)
|
||||
triplet_loss /= class_count
|
||||
return triplet_loss, torch.tensor([class_count]).to(feats)
|
||||
Reference in New Issue
Block a user