first commit
This commit is contained in:
1
Seg_All_In_One_MMSeg/mmseg/.mim/configs
Symbolic link
1
Seg_All_In_One_MMSeg/mmseg/.mim/configs
Symbolic link
@@ -0,0 +1 @@
|
||||
../../configs
|
||||
1
Seg_All_In_One_MMSeg/mmseg/.mim/dataset-index.yml
Symbolic link
1
Seg_All_In_One_MMSeg/mmseg/.mim/dataset-index.yml
Symbolic link
@@ -0,0 +1 @@
|
||||
../../dataset-index.yml
|
||||
1
Seg_All_In_One_MMSeg/mmseg/.mim/model-index.yml
Symbolic link
1
Seg_All_In_One_MMSeg/mmseg/.mim/model-index.yml
Symbolic link
@@ -0,0 +1 @@
|
||||
../../model-index.yml
|
||||
1
Seg_All_In_One_MMSeg/mmseg/.mim/tools
Symbolic link
1
Seg_All_In_One_MMSeg/mmseg/.mim/tools
Symbolic link
@@ -0,0 +1 @@
|
||||
../../tools
|
||||
73
Seg_All_In_One_MMSeg/mmseg/__init__.py
Normal file
73
Seg_All_In_One_MMSeg/mmseg/__init__.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
from packaging.version import parse
|
||||
|
||||
from .version import __version__, version_info
|
||||
|
||||
MMCV_MIN = '2.0.0rc4'
|
||||
MMCV_MAX = '2.2.0'
|
||||
MMENGINE_MIN = '0.5.0'
|
||||
MMENGINE_MAX = '1.0.0'
|
||||
|
||||
|
||||
def digit_version(version_str: str, length: int = 4):
|
||||
"""Convert a version string into a tuple of integers.
|
||||
|
||||
This method is usually used for comparing two versions. For pre-release
|
||||
versions: alpha < beta < rc.
|
||||
|
||||
Args:
|
||||
version_str (str): The version string.
|
||||
length (int): The maximum number of version levels. Default: 4.
|
||||
|
||||
Returns:
|
||||
tuple[int]: The version info in digits (integers).
|
||||
"""
|
||||
version = parse(version_str)
|
||||
assert version.release, f'failed to parse version {version_str}'
|
||||
release = list(version.release)
|
||||
release = release[:length]
|
||||
if len(release) < length:
|
||||
release = release + [0] * (length - len(release))
|
||||
if version.is_prerelease:
|
||||
mapping = {'a': -3, 'b': -2, 'rc': -1}
|
||||
val = -4
|
||||
# version.pre can be None
|
||||
if version.pre:
|
||||
if version.pre[0] not in mapping:
|
||||
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
|
||||
'version checking may go wrong')
|
||||
else:
|
||||
val = mapping[version.pre[0]]
|
||||
release.extend([val, version.pre[-1]])
|
||||
else:
|
||||
release.extend([val, 0])
|
||||
|
||||
elif version.is_postrelease:
|
||||
release.extend([1, version.post])
|
||||
else:
|
||||
release.extend([0, 0])
|
||||
return tuple(release)
|
||||
|
||||
|
||||
mmcv_min_version = digit_version(MMCV_MIN)
|
||||
mmcv_max_version = digit_version(MMCV_MAX)
|
||||
mmcv_version = digit_version(mmcv.__version__)
|
||||
|
||||
assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \
|
||||
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
||||
f'Please install mmcv>=2.0.0rc4.'
|
||||
|
||||
mmengine_min_version = digit_version(MMENGINE_MIN)
|
||||
mmengine_max_version = digit_version(MMENGINE_MAX)
|
||||
mmengine_version = digit_version(mmengine.__version__)
|
||||
|
||||
assert (mmengine_min_version <= mmengine_version < mmengine_max_version), \
|
||||
f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
|
||||
f'Please install mmengine>={mmengine_min_version}, '\
|
||||
f'<{mmengine_max_version}.'
|
||||
|
||||
__all__ = ['__version__', 'version_info', 'digit_version']
|
||||
9
Seg_All_In_One_MMSeg/mmseg/apis/__init__.py
Normal file
9
Seg_All_In_One_MMSeg/mmseg/apis/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .inference import inference_model, init_model, show_result_pyplot
|
||||
from .mmseg_inferencer import MMSegInferencer
|
||||
from .remote_sense_inferencer import RSImage, RSInferencer
|
||||
|
||||
__all__ = [
|
||||
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer',
|
||||
'RSInferencer', 'RSImage'
|
||||
]
|
||||
189
Seg_All_In_One_MMSeg/mmseg/apis/inference.py
Normal file
189
Seg_All_In_One_MMSeg/mmseg/apis/inference.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
from mmseg.models import BaseSegmentor
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
from .utils import ImageType, _preprare_data
|
||||
|
||||
|
||||
def init_model(config: Union[str, Path, Config],
|
||||
checkpoint: Optional[str] = None,
|
||||
device: str = 'cuda:0',
|
||||
cfg_options: Optional[dict] = None):
|
||||
"""Initialize a segmentor from config file.
|
||||
|
||||
Args:
|
||||
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
|
||||
:obj:`Path`, or the config object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
|
||||
Use 'cpu' for loading model on CPU.
|
||||
cfg_options (dict, optional): Options to override some settings in
|
||||
the used config.
|
||||
Returns:
|
||||
nn.Module: The constructed segmentor.
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
'but got {}'.format(type(config)))
|
||||
if cfg_options is not None:
|
||||
config.merge_from_dict(cfg_options)
|
||||
if config.model.type == 'EncoderDecoder':
|
||||
if 'init_cfg' in config.model.backbone:
|
||||
config.model.backbone.init_cfg = None
|
||||
elif config.model.type == 'MultimodalEncoderDecoder':
|
||||
for k, v in config.model.items():
|
||||
if isinstance(v, dict) and 'init_cfg' in v:
|
||||
config.model[k].init_cfg = None
|
||||
config.model.pretrained = None
|
||||
config.model.train_cfg = None
|
||||
init_default_scope(config.get('default_scope', 'mmseg'))
|
||||
|
||||
model = MODELS.build(config.model)
|
||||
if checkpoint is not None:
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint.get('meta', {}):
|
||||
# mmseg 1.x
|
||||
model.dataset_meta = dataset_meta
|
||||
elif 'CLASSES' in checkpoint.get('meta', {}):
|
||||
# < mmseg 1.x
|
||||
classes = checkpoint['meta']['CLASSES']
|
||||
palette = checkpoint['meta']['PALETTE']
|
||||
model.dataset_meta = {'classes': classes, 'palette': palette}
|
||||
else:
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn(
|
||||
'dataset_meta or class names are not saved in the '
|
||||
'checkpoint\'s meta data, classes and palette will be'
|
||||
'set according to num_classes ')
|
||||
num_classes = model.decode_head.num_classes
|
||||
dataset_name = None
|
||||
for name in dataset_aliases.keys():
|
||||
if len(get_classes(name)) == num_classes:
|
||||
dataset_name = name
|
||||
break
|
||||
if dataset_name is None:
|
||||
warnings.warn(
|
||||
'No suitable dataset found, use Cityscapes by default')
|
||||
dataset_name = 'cityscapes'
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes(dataset_name),
|
||||
'palette': get_palette(dataset_name)
|
||||
}
|
||||
model.cfg = config # save the config in the model for convenience
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def inference_model(model: BaseSegmentor,
|
||||
img: ImageType) -> Union[SegDataSample, SampleList]:
|
||||
"""Inference image(s) with the segmentor.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded segmentor.
|
||||
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
|
||||
images.
|
||||
|
||||
Returns:
|
||||
:obj:`SegDataSample` or list[:obj:`SegDataSample`]:
|
||||
If imgs is a list or tuple, the same length list type results
|
||||
will be returned, otherwise return the segmentation results directly.
|
||||
"""
|
||||
# prepare data
|
||||
data, is_batch = _preprare_data(img, model)
|
||||
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
results = model.test_step(data)
|
||||
|
||||
return results if is_batch else results[0]
|
||||
|
||||
|
||||
def show_result_pyplot(model: BaseSegmentor,
|
||||
img: Union[str, np.ndarray],
|
||||
result: SegDataSample,
|
||||
opacity: float = 0.5,
|
||||
title: str = '',
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
wait_time: float = 0,
|
||||
show: bool = True,
|
||||
with_labels: Optional[bool] = True,
|
||||
save_dir=None,
|
||||
out_file=None):
|
||||
"""Visualize the segmentation results on the image.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded segmentor.
|
||||
img (str or np.ndarray): Image filename or loaded image.
|
||||
result (SegDataSample): The prediction SegDataSample result.
|
||||
opacity(float): Opacity of painted segmentation map.
|
||||
Default 0.5. Must be in (0, 1] range.
|
||||
title (str): The title of pyplot figure.
|
||||
Default is ''.
|
||||
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
|
||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
||||
Defaults to True.
|
||||
wait_time (float): The interval of show (s). 0 is the special value
|
||||
that means "forever". Defaults to 0.
|
||||
show (bool): Whether to display the drawn image.
|
||||
Default to True.
|
||||
with_labels(bool, optional): Add semantic labels in visualization
|
||||
result, Default to True.
|
||||
save_dir (str, optional): Save file dir for all storage backends.
|
||||
If it is None, the backend storage will not save any data.
|
||||
out_file (str, optional): Path to output file. Default to None.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
np.ndarray: the drawn image which channel is RGB.
|
||||
"""
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if isinstance(img, str):
|
||||
image = mmcv.imread(img, channel_order='rgb')
|
||||
else:
|
||||
image = img
|
||||
if save_dir is not None:
|
||||
mkdir_or_exist(save_dir)
|
||||
# init visualizer
|
||||
visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
save_dir=save_dir,
|
||||
alpha=opacity)
|
||||
visualizer.dataset_meta = dict(
|
||||
classes=model.dataset_meta['classes'],
|
||||
palette=model.dataset_meta['palette'])
|
||||
visualizer.add_datasample(
|
||||
name=title,
|
||||
image=image,
|
||||
data_sample=result,
|
||||
draw_gt=draw_gt,
|
||||
draw_pred=draw_pred,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file,
|
||||
show=show,
|
||||
with_labels=with_labels)
|
||||
vis_img = visualizer.get_image()
|
||||
|
||||
return vis_img
|
||||
382
Seg_All_In_One_MMSeg/mmseg/apis/mmseg_inferencer.py
Normal file
382
Seg_All_In_One_MMSeg/mmseg/apis/mmseg_inferencer.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.transforms import Compose
|
||||
from mmengine.infer.infer import BaseInferencer, ModelType
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner.checkpoint import _load_checkpoint_to_model
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
InputType = Union[str, np.ndarray]
|
||||
InputsType = Union[InputType, Sequence[InputType]]
|
||||
PredType = Union[SegDataSample, SampleList]
|
||||
|
||||
|
||||
class MMSegInferencer(BaseInferencer):
|
||||
"""Semantic segmentation inferencer, provides inference and visualization
|
||||
interfaces. Note: MMEngine >= 0.5.0 is required.
|
||||
|
||||
Args:
|
||||
model (str, optional): Path to the config file or the model name
|
||||
defined in metafile. Take the `mmseg metafile <https://github.com/open-mmlab/mmsegmentation/blob/main/configs/fcn/metafile.yaml>`_
|
||||
as an example the `model` could be
|
||||
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model
|
||||
will be download automatically. If use config file, like
|
||||
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the
|
||||
`weights` should be defined.
|
||||
weights (str, optional): Path to the checkpoint. If it is not specified
|
||||
and model is a model name of metafile, the weights will be loaded
|
||||
from metafile. Defaults to None.
|
||||
classes (list, optional): Input classes for result rendering, as the
|
||||
prediction of segmentation model is a segment map with label
|
||||
indices, `classes` is a list which includes items responding to the
|
||||
label indices. If classes is not defined, visualizer will take
|
||||
`cityscapes` classes by default. Defaults to None.
|
||||
palette (list, optional): Input palette for result rendering, which is
|
||||
a list of color palette responding to the classes. If palette is
|
||||
not defined, visualizer will take `cityscapes` palette by default.
|
||||
Defaults to None.
|
||||
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_
|
||||
visulizer will use the meta information of the dataset i.e. classes
|
||||
and palette, but the `classes` and `palette` have higher priority.
|
||||
Defaults to None.
|
||||
device (str, optional): Device to run inference. If None, the available
|
||||
device will be automatically used. Defaults to None.
|
||||
scope (str, optional): The scope of the model. Defaults to 'mmseg'.
|
||||
""" # noqa
|
||||
|
||||
preprocess_kwargs: set = set()
|
||||
forward_kwargs: set = {'mode', 'out_dir'}
|
||||
visualize_kwargs: set = {
|
||||
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis',
|
||||
'with_labels'
|
||||
}
|
||||
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
||||
|
||||
def __init__(self,
|
||||
model: Union[ModelType, str],
|
||||
weights: Optional[str] = None,
|
||||
classes: Optional[Union[str, List]] = None,
|
||||
palette: Optional[Union[str, List]] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
scope: Optional[str] = 'mmseg') -> None:
|
||||
# A global counter tracking the number of images processes, for
|
||||
# naming of the output images
|
||||
self.num_visualized_imgs = 0
|
||||
self.num_pred_imgs = 0
|
||||
init_default_scope(scope if scope else 'mmseg')
|
||||
super().__init__(
|
||||
model=model, weights=weights, device=device, scope=scope)
|
||||
|
||||
if device == 'cpu' or not torch.cuda.is_available():
|
||||
self.model = revert_sync_batchnorm(self.model)
|
||||
|
||||
assert isinstance(self.visualizer, SegLocalVisualizer)
|
||||
self.visualizer.set_dataset_meta(classes, palette, dataset_name)
|
||||
|
||||
def _load_weights_to_model(self, model: nn.Module,
|
||||
checkpoint: Optional[dict],
|
||||
cfg: Optional[ConfigType]) -> None:
|
||||
"""Loading model weights and meta information from cfg and checkpoint.
|
||||
|
||||
Subclasses could override this method to load extra meta information
|
||||
from ``checkpoint`` and ``cfg`` to model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to load weights and meta information.
|
||||
checkpoint (dict, optional): The loaded checkpoint.
|
||||
cfg (Config or ConfigDict, optional): The loaded config.
|
||||
"""
|
||||
|
||||
if checkpoint is not None:
|
||||
_load_checkpoint_to_model(model, checkpoint)
|
||||
checkpoint_meta = checkpoint.get('meta', {})
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint_meta:
|
||||
# mmsegmentation 1.x
|
||||
model.dataset_meta = {
|
||||
'classes': checkpoint_meta['dataset_meta'].get('classes'),
|
||||
'palette': checkpoint_meta['dataset_meta'].get('palette')
|
||||
}
|
||||
elif 'CLASSES' in checkpoint_meta:
|
||||
# mmsegmentation 0.x
|
||||
classes = checkpoint_meta['CLASSES']
|
||||
palette = checkpoint_meta.get('PALETTE', None)
|
||||
model.dataset_meta = {'classes': classes, 'palette': palette}
|
||||
else:
|
||||
warnings.warn(
|
||||
'dataset_meta or class names are not saved in the '
|
||||
'checkpoint\'s meta data, use classes of Cityscapes by '
|
||||
'default.')
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes('cityscapes'),
|
||||
'palette': get_palette('cityscapes')
|
||||
}
|
||||
else:
|
||||
warnings.warn('Checkpoint is not loaded, and the inference '
|
||||
'result is calculated by the randomly initialized '
|
||||
'model!')
|
||||
warnings.warn(
|
||||
'weights is None, use cityscapes classes by default.')
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes('cityscapes'),
|
||||
'palette': get_palette('cityscapes')
|
||||
}
|
||||
|
||||
def __call__(self,
|
||||
inputs: InputsType,
|
||||
return_datasamples: bool = False,
|
||||
batch_size: int = 1,
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
out_dir: str = '',
|
||||
img_out_dir: str = 'vis',
|
||||
pred_out_dir: str = 'pred',
|
||||
**kwargs) -> dict:
|
||||
"""Call the inferencer.
|
||||
|
||||
Args:
|
||||
inputs (Union[list, str, np.ndarray]): Inputs for the inferencer.
|
||||
return_datasamples (bool): Whether to return results as
|
||||
:obj:`SegDataSample`. Defaults to False.
|
||||
batch_size (int): Batch size. Defaults to 1.
|
||||
show (bool): Whether to display the rendering color segmentation
|
||||
mask in a popup window. Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
out_dir (str): Output directory of inference results. Defaults
|
||||
to ''.
|
||||
img_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||
rendering color segmentation mask, so `out_dir` must be defined
|
||||
if you would like to save predicted mask. Defaults to 'vis'.
|
||||
pred_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||
predicted mask file, so `out_dir` must be defined if you would
|
||||
like to save predicted mask. Defaults to 'pred'.
|
||||
|
||||
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
|
||||
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
||||
Each key in kwargs should be in the corresponding set of
|
||||
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
||||
and ``postprocess_kwargs``.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results.
|
||||
"""
|
||||
|
||||
if out_dir != '':
|
||||
pred_out_dir = osp.join(out_dir, pred_out_dir)
|
||||
img_out_dir = osp.join(out_dir, img_out_dir)
|
||||
else:
|
||||
pred_out_dir = ''
|
||||
img_out_dir = ''
|
||||
|
||||
return super().__call__(
|
||||
inputs=inputs,
|
||||
return_datasamples=return_datasamples,
|
||||
batch_size=batch_size,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
img_out_dir=img_out_dir,
|
||||
pred_out_dir=pred_out_dir,
|
||||
return_vis=return_vis,
|
||||
**kwargs)
|
||||
|
||||
def visualize(self,
|
||||
inputs: list,
|
||||
preds: List[dict],
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
img_out_dir: str = '',
|
||||
opacity: float = 0.8,
|
||||
with_labels: Optional[bool] = True) -> List[np.ndarray]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Args:
|
||||
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
||||
preds (Any): Predictions of the model.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
img_out_dir (str): Output directory of rendering prediction i.e.
|
||||
color segmentation mask. Defaults: ''
|
||||
opacity (int, float): The transparency of segmentation mask.
|
||||
Defaults to 0.8.
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: Visualization results.
|
||||
"""
|
||||
if not show and img_out_dir == '' and not return_vis:
|
||||
return None
|
||||
if self.visualizer is None:
|
||||
raise ValueError('Visualization needs the "visualizer" term'
|
||||
'defined in the config, but got None.')
|
||||
|
||||
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
|
||||
self.visualizer.alpha = opacity
|
||||
|
||||
results = []
|
||||
|
||||
for single_input, pred in zip(inputs, preds):
|
||||
if isinstance(single_input, str):
|
||||
img_bytes = mmengine.fileio.get(single_input)
|
||||
img = mmcv.imfrombytes(img_bytes)
|
||||
img = img[:, :, ::-1]
|
||||
img_name = osp.basename(single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
img = single_input.copy()
|
||||
img_num = str(self.num_visualized_imgs).zfill(8) + '_vis'
|
||||
img_name = f'{img_num}.jpg'
|
||||
else:
|
||||
raise ValueError('Unsupported input type:'
|
||||
f'{type(single_input)}')
|
||||
|
||||
out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\
|
||||
else None
|
||||
|
||||
self.visualizer.add_datasample(
|
||||
img_name,
|
||||
img,
|
||||
pred,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_gt=False,
|
||||
draw_pred=True,
|
||||
out_file=out_file,
|
||||
with_labels=with_labels)
|
||||
if return_vis:
|
||||
results.append(self.visualizer.get_image())
|
||||
self.num_visualized_imgs += 1
|
||||
|
||||
return results if return_vis else None
|
||||
|
||||
def postprocess(self,
|
||||
preds: PredType,
|
||||
visualization: List[np.ndarray],
|
||||
return_datasample: bool = False,
|
||||
pred_out_dir: str = '') -> dict:
|
||||
"""Process the predictions and visualization results from ``forward``
|
||||
and ``visualize``.
|
||||
|
||||
This method should be responsible for the following tasks:
|
||||
|
||||
1. Pack the predictions and visualization results and return them.
|
||||
2. Save the predictions, if it needed.
|
||||
|
||||
Args:
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
visualization (List[np.ndarray]): The list of rendering color
|
||||
segmentation mask.
|
||||
return_datasample (bool): Whether to return results as datasamples.
|
||||
Defaults to False.
|
||||
pred_out_dir: File to save the inference results w/o
|
||||
visualization. If left as empty, no file will be saved.
|
||||
Defaults to ''.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results with key ``predictions``
|
||||
and ``visualization``
|
||||
|
||||
- ``visualization (Any)``: Returned by :meth:`visualize`
|
||||
- ``predictions`` (List[np.ndarray], np.ndarray): Returned by
|
||||
:meth:`forward` and processed in :meth:`postprocess`.
|
||||
If ``return_datasample=False``, it will be the segmentation mask
|
||||
with label indice.
|
||||
"""
|
||||
if return_datasample:
|
||||
if len(preds) == 1:
|
||||
return preds[0]
|
||||
else:
|
||||
return preds
|
||||
|
||||
results_dict = {}
|
||||
|
||||
results_dict['predictions'] = []
|
||||
results_dict['visualization'] = []
|
||||
|
||||
for i, pred in enumerate(preds):
|
||||
pred_data = dict()
|
||||
if 'pred_sem_seg' in pred.keys():
|
||||
pred_data['sem_seg'] = pred.pred_sem_seg.numpy().data[0]
|
||||
elif 'pred_depth_map' in pred.keys():
|
||||
pred_data['depth_map'] = pred.pred_depth_map.numpy().data[0]
|
||||
|
||||
if visualization is not None:
|
||||
vis = visualization[i]
|
||||
results_dict['visualization'].append(vis)
|
||||
if pred_out_dir != '':
|
||||
mmengine.mkdir_or_exist(pred_out_dir)
|
||||
for key, data in pred_data.items():
|
||||
post_fix = '_pred.png' if key == 'sem_seg' else '_pred.npy'
|
||||
img_name = str(self.num_pred_imgs).zfill(8) + post_fix
|
||||
img_path = osp.join(pred_out_dir, img_name)
|
||||
if key == 'sem_seg':
|
||||
output = Image.fromarray(data.astype(np.uint8))
|
||||
output.save(img_path)
|
||||
else:
|
||||
np.save(img_path, data)
|
||||
pred_data = next(iter(pred_data.values()))
|
||||
results_dict['predictions'].append(pred_data)
|
||||
self.num_pred_imgs += 1
|
||||
|
||||
if len(results_dict['predictions']) == 1:
|
||||
results_dict['predictions'] = results_dict['predictions'][0]
|
||||
if visualization is not None:
|
||||
results_dict['visualization'] = \
|
||||
results_dict['visualization'][0]
|
||||
return results_dict
|
||||
|
||||
def _init_pipeline(self, cfg: ConfigType) -> Compose:
|
||||
"""Initialize the test pipeline.
|
||||
|
||||
Return a pipeline to handle various input data, such as ``str``,
|
||||
``np.ndarray``. It is an abstract method in BaseInferencer, and should
|
||||
be implemented in subclasses.
|
||||
|
||||
The returned pipeline will be used to process a single data.
|
||||
It will be used in :meth:`preprocess` like this:
|
||||
|
||||
.. code-block:: python
|
||||
def preprocess(self, inputs, batch_size, **kwargs):
|
||||
...
|
||||
dataset = map(self.pipeline, dataset)
|
||||
...
|
||||
"""
|
||||
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
# Loading annotations is also not applicable
|
||||
for transform in ('LoadAnnotations', 'LoadDepthAnnotation'):
|
||||
idx = self._get_transform_idx(pipeline_cfg, transform)
|
||||
if idx != -1:
|
||||
del pipeline_cfg[idx]
|
||||
|
||||
load_img_idx = self._get_transform_idx(pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
if load_img_idx == -1:
|
||||
raise ValueError(
|
||||
'LoadImageFromFile is not found in the test pipeline')
|
||||
pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader'
|
||||
return Compose(pipeline_cfg)
|
||||
|
||||
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
|
||||
"""Returns the index of the transform in a pipeline.
|
||||
|
||||
If the transform is not found, returns -1.
|
||||
"""
|
||||
for i, transform in enumerate(pipeline_cfg):
|
||||
if transform['type'] == name:
|
||||
return i
|
||||
return -1
|
||||
279
Seg_All_In_One_MMSeg/mmseg/apis/remote_sense_inferencer.py
Normal file
279
Seg_All_In_One_MMSeg/mmseg/apis/remote_sense_inferencer.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import threading
|
||||
from queue import Queue
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import load_checkpoint
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import _preprare_data
|
||||
|
||||
|
||||
class RSImage:
|
||||
"""Remote sensing image class.
|
||||
|
||||
Args:
|
||||
img (str or gdal.Dataset): Image file path or gdal.Dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, image):
|
||||
self.dataset = gdal.Open(image, gdal.GA_ReadOnly) if isinstance(
|
||||
image, str) else image
|
||||
assert isinstance(self.dataset, gdal.Dataset), \
|
||||
f'{image} is not a image'
|
||||
self.width = self.dataset.RasterXSize
|
||||
self.height = self.dataset.RasterYSize
|
||||
self.channel = self.dataset.RasterCount
|
||||
self.trans = self.dataset.GetGeoTransform()
|
||||
self.proj = self.dataset.GetProjection()
|
||||
self.band_list = []
|
||||
self.band_list.extend(
|
||||
self.dataset.GetRasterBand(c + 1) for c in range(self.channel))
|
||||
self.grids = []
|
||||
|
||||
def read(self, grid: Optional[List] = None) -> np.ndarray:
|
||||
"""Read image data. If grid is None, read the whole image.
|
||||
|
||||
Args:
|
||||
grid (Optional[List], optional): Grid to read. Defaults to None.
|
||||
Returns:
|
||||
np.ndarray: Image data.
|
||||
"""
|
||||
if grid is None:
|
||||
return np.einsum('ijk->jki', self.dataset.ReadAsArray())
|
||||
assert len(
|
||||
grid) >= 4, 'grid must be a list containing at least 4 elements'
|
||||
data = self.dataset.ReadAsArray(*grid[:4])
|
||||
if data.ndim == 2:
|
||||
data = data[np.newaxis, ...]
|
||||
return np.einsum('ijk->jki', data)
|
||||
|
||||
def write(self, data: Optional[np.ndarray], grid: Optional[List] = None):
|
||||
"""Write image data.
|
||||
|
||||
Args:
|
||||
grid (Optional[List], optional): Grid to write. Defaults to None.
|
||||
data (Optional[np.ndarray], optional): Data to write.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: Either grid or data must be provided.
|
||||
"""
|
||||
if grid is not None:
|
||||
assert len(grid) == 8, 'grid must be a list of 8 elements'
|
||||
for band in self.band_list:
|
||||
band.WriteArray(
|
||||
data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]],
|
||||
grid[0] + grid[4], grid[1] + grid[5])
|
||||
elif data is not None:
|
||||
for i in range(self.channel):
|
||||
self.band_list[i].WriteArray(data[..., i])
|
||||
else:
|
||||
raise ValueError('Either grid or data must be provided.')
|
||||
|
||||
def create_seg_map(self, output_path: Optional[str] = None):
|
||||
if output_path is None:
|
||||
output_path = 'output_label.tif'
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
seg_map = driver.Create(output_path, self.width, self.height, 1,
|
||||
gdal.GDT_Byte)
|
||||
seg_map.SetGeoTransform(self.trans)
|
||||
seg_map.SetProjection(self.proj)
|
||||
seg_map_img = RSImage(seg_map)
|
||||
seg_map_img.path = output_path
|
||||
return seg_map_img
|
||||
|
||||
def create_grids(self,
|
||||
window_size: Tuple[int, int],
|
||||
stride: Tuple[int, int] = (0, 0)):
|
||||
"""Create grids for image inference.
|
||||
|
||||
Args:
|
||||
window_size (Tuple[int, int]): the size of the sliding window.
|
||||
stride (Tuple[int, int], optional): the stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
|
||||
Raises:
|
||||
AssertionError: window_size must be a tuple of 2 elements.
|
||||
AssertionError: stride must be a tuple of 2 elements.
|
||||
"""
|
||||
assert len(
|
||||
window_size) == 2, 'window_size must be a tuple of 2 elements'
|
||||
assert len(stride) == 2, 'stride must be a tuple of 2 elements'
|
||||
win_w, win_h = window_size
|
||||
stride_x, stride_y = stride
|
||||
|
||||
stride_x = win_w if stride_x == 0 else stride_x
|
||||
stride_y = win_h if stride_y == 0 else stride_y
|
||||
|
||||
x_half_overlap = (win_w - stride_x + 1) // 2
|
||||
y_half_overlap = (win_h - stride_y + 1) // 2
|
||||
|
||||
for y in range(0, self.height, stride_y):
|
||||
y_end = y + win_h >= self.height
|
||||
y_offset = self.height - win_h if y_end else y
|
||||
y_size = win_h
|
||||
y_crop_off = 0 if y_offset == 0 else y_half_overlap
|
||||
y_crop_size = y_size if y_end else win_h - y_crop_off
|
||||
|
||||
for x in range(0, self.width, stride_x):
|
||||
x_end = x + win_w >= self.width
|
||||
x_offset = self.width - win_w if x_end else x
|
||||
x_size = win_w
|
||||
x_crop_off = 0 if x_offset == 0 else x_half_overlap
|
||||
x_crop_size = x_size if x_end else win_w - x_crop_off
|
||||
|
||||
self.grids.append([
|
||||
x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off,
|
||||
x_crop_size, y_crop_size
|
||||
])
|
||||
|
||||
|
||||
class RSInferencer:
|
||||
"""Remote sensing inference class.
|
||||
|
||||
Args:
|
||||
model (BaseModel): The loaded model.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
thread (int, optional): Number of threads. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, model: BaseModel, batch_size: int = 1, thread: int = 1):
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.END_FLAG = object()
|
||||
self.read_buffer = Queue(self.batch_size)
|
||||
self.write_buffer = Queue(self.batch_size)
|
||||
self.thread = thread
|
||||
|
||||
@classmethod
|
||||
def from_config_path(cls,
|
||||
config_path: str,
|
||||
checkpoint_path: str,
|
||||
batch_size: int = 1,
|
||||
thread: int = 1,
|
||||
device: Optional[str] = 'cpu'):
|
||||
"""Initialize a segmentor from config file.
|
||||
|
||||
Args:
|
||||
config_path (str): Config file path.
|
||||
checkpoint_path (str): Checkpoint path.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
"""
|
||||
init_default_scope('mmseg')
|
||||
cfg = Config.fromfile(config_path)
|
||||
model = MODELS.build(cfg.model)
|
||||
model.cfg = cfg
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return cls(model, batch_size, thread)
|
||||
|
||||
@classmethod
|
||||
def from_model(cls,
|
||||
model: BaseModel,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
batch_size: int = 1,
|
||||
thread: int = 1,
|
||||
device: Optional[str] = 'cpu'):
|
||||
"""Initialize a segmentor from model.
|
||||
|
||||
Args:
|
||||
model (BaseModel): The loaded model.
|
||||
checkpoint_path (Optional[str]): Checkpoint path.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
"""
|
||||
if checkpoint_path is not None:
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
model.to(device)
|
||||
return cls(model, batch_size, thread)
|
||||
|
||||
def read(self,
|
||||
image: RSImage,
|
||||
window_size: Tuple[int, int],
|
||||
strides: Tuple[int, int] = (0, 0)):
|
||||
"""Load image data to read buffer.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to read.
|
||||
window_size (Tuple[int, int]): The size of the sliding window.
|
||||
strides (Tuple[int, int], optional): The stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
"""
|
||||
image.create_grids(window_size, strides)
|
||||
for grid in image.grids:
|
||||
self.read_buffer.put([grid, image.read(grid=grid)])
|
||||
self.read_buffer.put(self.END_FLAG)
|
||||
|
||||
def inference(self):
|
||||
"""Inference image data from read buffer and put the result to write
|
||||
buffer."""
|
||||
while True:
|
||||
item = self.read_buffer.get()
|
||||
if item == self.END_FLAG:
|
||||
self.read_buffer.put(self.END_FLAG)
|
||||
self.write_buffer.put(item)
|
||||
break
|
||||
data, _ = _preprare_data(item[1], self.model)
|
||||
with torch.no_grad():
|
||||
result = self.model.test_step(data)
|
||||
item[1] = result[0].pred_sem_seg.cpu().data.numpy()[0]
|
||||
self.write_buffer.put(item)
|
||||
self.read_buffer.task_done()
|
||||
|
||||
def write(self, image: RSImage, output_path: Optional[str] = None):
|
||||
"""Write image data from write buffer.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to write.
|
||||
output_path (Optional[str], optional): The path to save the
|
||||
segmentation map. Defaults to None.
|
||||
"""
|
||||
seg_map = image.create_seg_map(output_path)
|
||||
while True:
|
||||
item = self.write_buffer.get()
|
||||
if item == self.END_FLAG:
|
||||
break
|
||||
seg_map.write(data=item[1], grid=item[0])
|
||||
self.write_buffer.task_done()
|
||||
|
||||
def run(self,
|
||||
image: RSImage,
|
||||
window_size: Tuple[int, int],
|
||||
strides: Tuple[int, int] = (0, 0),
|
||||
output_path: Optional[str] = None):
|
||||
"""Run inference with multi-threading.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to inference.
|
||||
window_size (Tuple[int, int]): The size of the sliding window.
|
||||
strides (Tuple[int, int], optional): The stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
output_path (Optional[str], optional): The path to save the
|
||||
segmentation map. Defaults to None.
|
||||
"""
|
||||
read_thread = threading.Thread(
|
||||
target=self.read, args=(image, window_size, strides))
|
||||
read_thread.start()
|
||||
inference_threads = []
|
||||
for _ in range(self.thread):
|
||||
inference_thread = threading.Thread(target=self.inference)
|
||||
inference_thread.start()
|
||||
inference_threads.append(inference_thread)
|
||||
write_thread = threading.Thread(
|
||||
target=self.write, args=(image, output_path))
|
||||
write_thread.start()
|
||||
read_thread.join()
|
||||
for inference_thread in inference_threads:
|
||||
inference_thread.join()
|
||||
write_thread.join()
|
||||
41
Seg_All_In_One_MMSeg/mmseg/apis/utils.py
Normal file
41
Seg_All_In_One_MMSeg/mmseg/apis/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import defaultdict
|
||||
from typing import Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
||||
|
||||
|
||||
def _preprare_data(imgs: ImageType, model: BaseModel):
|
||||
|
||||
cfg = model.cfg
|
||||
for t in cfg.test_pipeline:
|
||||
if t.get('type') == 'LoadAnnotations':
|
||||
cfg.test_pipeline.remove(t)
|
||||
|
||||
is_batch = True
|
||||
if not isinstance(imgs, (list, tuple)):
|
||||
imgs = [imgs]
|
||||
is_batch = False
|
||||
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
|
||||
|
||||
# TODO: Consider using the singleton pattern to avoid building
|
||||
# a pipeline for each inference
|
||||
pipeline = Compose(cfg.test_pipeline)
|
||||
|
||||
data = defaultdict(list)
|
||||
for img in imgs:
|
||||
if isinstance(img, np.ndarray):
|
||||
data_ = dict(img=img)
|
||||
else:
|
||||
data_ = dict(img_path=img)
|
||||
data_ = pipeline(data_)
|
||||
data['inputs'].append(data_['inputs'])
|
||||
data['data_samples'].append(data_['data_samples'])
|
||||
|
||||
return data, is_batch
|
||||
79
Seg_All_In_One_MMSeg/mmseg/configs/_base_/datasets/loveda.py
Normal file
79
Seg_All_In_One_MMSeg/mmseg/configs/_base_/datasets/loveda.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.transforms.loading import LoadImageFromFile
|
||||
from mmcv.transforms.processing import (RandomFlip, RandomResize, Resize,
|
||||
TestTimeAug)
|
||||
from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler
|
||||
|
||||
from mmseg.datasets.loveda import LoveDADataset
|
||||
from mmseg.datasets.transforms.formatting import PackSegInputs
|
||||
from mmseg.datasets.transforms.loading import LoadAnnotations
|
||||
from mmseg.datasets.transforms.transforms import (PhotoMetricDistortion,
|
||||
RandomCrop)
|
||||
from mmseg.evaluation import IoUMetric
|
||||
|
||||
# dataset settings
|
||||
dataset_type = LoveDADataset
|
||||
data_root = 'data/loveDA'
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=LoadAnnotations, reduce_zero_label=True),
|
||||
dict(
|
||||
type=RandomResize,
|
||||
scale=(2048, 512),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type=RandomCrop, crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type=RandomFlip, prob=0.5),
|
||||
dict(type=PhotoMetricDistortion),
|
||||
dict(type=PackSegInputs)
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=Resize, scale=(1024, 1024), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type=LoadAnnotations, reduce_zero_label=True),
|
||||
dict(type=PackSegInputs)
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type=LoadImageFromFile, backend_args=None),
|
||||
dict(
|
||||
type=TestTimeAug,
|
||||
transforms=[[
|
||||
dict(type=Resize, scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type=RandomFlip, prob=0., direction='horizontal'),
|
||||
dict(type=RandomFlip, prob=1., direction='horizontal')
|
||||
], [dict(type=LoadAnnotations)],
|
||||
[dict(type=PackSegInputs)]])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=12,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type=InfiniteSampler, shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='img_dir/train', seg_map_path='ann_dir/train'),
|
||||
pipeline=train_pipeline))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type=DefaultSampler, shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),
|
||||
pipeline=test_pipeline))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
val_evaluator = dict(type=IoUMetric, iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.transforms.loading import LoadImageFromFile
|
||||
from mmcv.transforms.processing import (RandomFlip, RandomResize, Resize,
|
||||
TestTimeAug)
|
||||
from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler
|
||||
|
||||
from mmseg.datasets.potsdam import PotsdamDataset
|
||||
from mmseg.datasets.transforms.formatting import PackSegInputs
|
||||
from mmseg.datasets.transforms.loading import LoadAnnotations
|
||||
from mmseg.datasets.transforms.transforms import (PhotoMetricDistortion,
|
||||
RandomCrop)
|
||||
from mmseg.evaluation import IoUMetric
|
||||
|
||||
# dataset settings
|
||||
dataset_type = PotsdamDataset
|
||||
data_root = 'data/potsdam'
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=LoadAnnotations, reduce_zero_label=True),
|
||||
dict(
|
||||
type=RandomResize,
|
||||
scale=(512, 512),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type=RandomCrop, crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type=RandomFlip, prob=0.5),
|
||||
dict(type=PhotoMetricDistortion),
|
||||
dict(type=PackSegInputs)
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=Resize, scale=(512, 512), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type=LoadAnnotations, reduce_zero_label=True),
|
||||
dict(type=PackSegInputs)
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type=LoadImageFromFile, backend_args=None),
|
||||
dict(
|
||||
type=TestTimeAug,
|
||||
transforms=[[
|
||||
dict(type=Resize, scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type=RandomFlip, prob=0., direction='horizontal'),
|
||||
dict(type=RandomFlip, prob=1., direction='horizontal')
|
||||
], [dict(type=LoadAnnotations)],
|
||||
[dict(type=PackSegInputs)]])
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type=InfiniteSampler, shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='img_dir/train', seg_map_path='ann_dir/train'),
|
||||
pipeline=train_pipeline))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type=DefaultSampler, shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(
|
||||
type=IoUMetric, iou_metrics=['mIoU']) # 'mDice', 'mFscore'
|
||||
test_evaluator = val_evaluator
|
||||
22
Seg_All_In_One_MMSeg/mmseg/configs/_base_/default_runtime.py
Normal file
22
Seg_All_In_One_MMSeg/mmseg/configs/_base_/default_runtime.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmengine.visualization import LocalVisBackend
|
||||
|
||||
from mmseg.models import SegTTAModel
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=False,
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
)
|
||||
vis_backends = [dict(type=LocalVisBackend)]
|
||||
visualizer = dict(
|
||||
type=SegLocalVisualizer, vis_backends=vis_backends, name='visualizer')
|
||||
log_processor = dict(by_epoch=False)
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume = False
|
||||
|
||||
tta_model = dict(type=SegTTAModel)
|
||||
default_scope = None
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||
LoggerHook, ParamSchedulerHook)
|
||||
from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
|
||||
from mmengine.optim.scheduler.lr_scheduler import PolyLR
|
||||
from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop
|
||||
from torch.optim.sgd import SGD
|
||||
|
||||
from mmseg.engine import SegVisualizationHook
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type=SGD,
|
||||
# lr=0.01,
|
||||
# momentum=0.9,
|
||||
# weight_decay=0.0005
|
||||
)
|
||||
|
||||
optim_wrapper = dict(type=OptimWrapper, optimizer=optimizer, clip_grad=None)
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type=PolyLR,
|
||||
eta_min=1e-4,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end=160000,
|
||||
by_epoch=False)
|
||||
]
|
||||
# training schedule for 160k
|
||||
|
||||
train_cfg = dict(type=IterBasedTrainLoop, max_iters=160000, val_interval=8000)
|
||||
val_cfg = dict(type=ValLoop)
|
||||
test_cfg = dict(type=TestLoop)
|
||||
|
||||
default_hooks = dict(
|
||||
timer=dict(type=IterTimerHook),
|
||||
logger=dict(type=LoggerHook, interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type=ParamSchedulerHook),
|
||||
checkpoint=dict(type=CheckpointHook, by_epoch=False, interval=8000),
|
||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||
visualization=dict(type=SegVisualizationHook))
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||
LoggerHook, ParamSchedulerHook)
|
||||
from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
|
||||
from mmengine.optim.scheduler.lr_scheduler import PolyLR
|
||||
from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop
|
||||
from torch.optim.sgd import SGD
|
||||
|
||||
from mmseg.engine import SegVisualizationHook
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optim_wrapper = dict(type=OptimWrapper, optimizer=optimizer, clip_grad=None)
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type=PolyLR,
|
||||
eta_min=1e-4,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end=20000,
|
||||
by_epoch=False)
|
||||
]
|
||||
# training schedule for 20k
|
||||
train_cfg = dict(type=IterBasedTrainLoop, max_iters=20000, val_interval=2000)
|
||||
val_cfg = dict(type=ValLoop)
|
||||
test_cfg = dict(type=TestLoop)
|
||||
|
||||
default_hooks = dict(
|
||||
timer=dict(type=IterTimerHook),
|
||||
logger=dict(type=LoggerHook, interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type=ParamSchedulerHook),
|
||||
checkpoint=dict(type=CheckpointHook, by_epoch=False, interval=2000),
|
||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||
visualization=dict(type=SegVisualizationHook))
|
||||
@@ -0,0 +1,34 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||
LoggerHook, ParamSchedulerHook)
|
||||
from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
|
||||
from mmengine.optim.scheduler.lr_scheduler import PolyLR
|
||||
from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop
|
||||
# from mmengine.runner.loops import EpochBasedTrainLoop
|
||||
from torch.optim.sgd import SGD
|
||||
|
||||
from mmseg.engine import SegVisualizationHook
|
||||
|
||||
optimizer = dict(type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optim_wrapper = dict(type=OptimWrapper, optimizer=optimizer, clip_grad=None)
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type=PolyLR,
|
||||
eta_min=1e-4,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end=240000,
|
||||
by_epoch=False)
|
||||
]
|
||||
# training schedule for 240k
|
||||
train_cfg = dict(type=IterBasedTrainLoop, max_iters=240000, val_interval=24000)
|
||||
val_cfg = dict(type=ValLoop)
|
||||
test_cfg = dict(type=TestLoop)
|
||||
default_hooks = dict(
|
||||
timer=dict(type=IterTimerHook),
|
||||
logger=dict(type=LoggerHook, interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type=ParamSchedulerHook),
|
||||
checkpoint=dict(type=CheckpointHook, by_epoch=False, interval=24000),
|
||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||
visualization=dict(type=SegVisualizationHook))
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||
LoggerHook, ParamSchedulerHook)
|
||||
from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
|
||||
from mmengine.optim.scheduler.lr_scheduler import ConstantLR, LinearLR
|
||||
from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop
|
||||
# from mmengine.runner.loops import EpochBasedTrainLoop
|
||||
from torch.optim.adamw import AdamW
|
||||
|
||||
from mmseg.engine import SegVisualizationHook
|
||||
from mmseg.engine.schedulers import PolyLRRatio
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type=AdamW, lr=0.01, weight_decay=0.1)
|
||||
|
||||
optim_wrapper = dict(type=OptimWrapper, optimizer=optimizer, clip_grad=None)
|
||||
# learning policy
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(type=LinearLR, start_factor=3e-2, begin=0, end=12000, by_epoch=False),
|
||||
dict(
|
||||
type=PolyLRRatio,
|
||||
eta_min_ratio=3e-2,
|
||||
power=0.9,
|
||||
begin=12000,
|
||||
end=24000,
|
||||
by_epoch=False),
|
||||
dict(type=ConstantLR, by_epoch=False, factor=1, begin=24000, end=25000)
|
||||
]
|
||||
|
||||
# training schedule for 25k
|
||||
train_cfg = dict(type=IterBasedTrainLoop, max_iters=25000, val_interval=1000)
|
||||
val_cfg = dict(type=ValLoop)
|
||||
test_cfg = dict(type=TestLoop)
|
||||
|
||||
default_hooks = dict(
|
||||
timer=dict(type=IterTimerHook),
|
||||
logger=dict(type=LoggerHook, interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type=ParamSchedulerHook),
|
||||
checkpoint=dict(type=CheckpointHook, by_epoch=False, interval=1000),
|
||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||
visualization=dict(type=SegVisualizationHook))
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||
LoggerHook, ParamSchedulerHook)
|
||||
from mmengine.optim import AmpOptimWrapper # 导入 AmpOptimWrapper
|
||||
from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
|
||||
from mmengine.optim.scheduler.lr_scheduler import PolyLR
|
||||
# 导入 EpochBasedTrainLoop
|
||||
from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop
|
||||
from torch.optim.adamw import AdamW # 推荐使用 AdamW 优化器
|
||||
|
||||
from mmseg.engine import SegVisualizationHook
|
||||
|
||||
# --- 修改部分 ---
|
||||
|
||||
# 1. 优化器 (Optimizer)
|
||||
# 推荐使用 AdamW,它通常比 SGD 效果更好且更稳定
|
||||
optimizer = dict(
|
||||
type=AdamW,
|
||||
lr=0.0001, # AdamW 的学习率通常设置得比 SGD 小
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.01)
|
||||
|
||||
# 优化器封装 (OptimWrapper)
|
||||
# clip_grad 用于梯度裁剪,防止梯度爆炸,可以根据需要设置
|
||||
optim_wrapper = dict(
|
||||
type=OptimWrapper, optimizer=optimizer, clip_grad=dict(max_norm=1, norm_type=2))
|
||||
|
||||
# 2. 学习率调度器 (Learning Rate Scheduler)
|
||||
# 总轮次设为 200 epochs
|
||||
max_epochs = 200
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type=PolyLR,
|
||||
eta_min=1e-5, # 学习率最小值
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end=max_epochs, # 关键修改:结束点改为总 epochs
|
||||
by_epoch=True) # 关键修改:改为按 epoch 更新学习率
|
||||
]
|
||||
|
||||
# 3. 训练、验证和测试的配置
|
||||
# 关键修改:使用 EpochBasedTrainLoop
|
||||
train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1) # 每 1 个 epoch 验证一次
|
||||
val_cfg = dict(type=ValLoop)
|
||||
test_cfg = dict(type=TestLoop)
|
||||
|
||||
# 4. 默认钩子 (Default Hooks)
|
||||
default_hooks = dict(
|
||||
timer=dict(type=IterTimerHook),
|
||||
# 关键修改:让日志按 epoch 记录
|
||||
logger=dict(type=LoggerHook, interval=300, log_metric_by_epoch=True),
|
||||
param_scheduler=dict(type=ParamSchedulerHook),
|
||||
# 关键修改:让检查点(模型权重)按 epoch 保存
|
||||
checkpoint=dict(type=CheckpointHook, by_epoch=True, interval=10), # 每 10 个 epoch 保存一次
|
||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||
visualization=dict(type=SegVisualizationHook))
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||
LoggerHook, ParamSchedulerHook)
|
||||
from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
|
||||
from mmengine.optim.scheduler.lr_scheduler import PolyLR
|
||||
from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop
|
||||
# from mmengine.runner.loops import EpochBasedTrainLoop
|
||||
from torch.optim.sgd import SGD
|
||||
|
||||
from mmseg.engine import SegVisualizationHook
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optim_wrapper = dict(type=OptimWrapper, optimizer=optimizer, clip_grad=None)
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type=PolyLR,
|
||||
eta_min=1e-4,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end=320000,
|
||||
by_epoch=False)
|
||||
]
|
||||
# training schedule for 320k
|
||||
train_cfg = dict(type=IterBasedTrainLoop, max_iters=320000, val_interval=32000)
|
||||
val_cfg = dict(type=ValLoop)
|
||||
test_cfg = dict(type=TestLoop)
|
||||
default_hooks = dict(
|
||||
timer=dict(type=IterTimerHook),
|
||||
logger=dict(type=LoggerHook, interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type=ParamSchedulerHook),
|
||||
checkpoint=dict(type=CheckpointHook, by_epoch=False, interval=32000),
|
||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||
visualization=dict(type=SegVisualizationHook))
|
||||
@@ -0,0 +1,34 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||
LoggerHook, ParamSchedulerHook)
|
||||
from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
|
||||
from mmengine.optim.scheduler.lr_scheduler import PolyLR
|
||||
from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop
|
||||
from torch.optim.sgd import SGD
|
||||
|
||||
from mmseg.engine import SegVisualizationHook
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optim_wrapper = dict(type=OptimWrapper, optimizer=optimizer, clip_grad=None)
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type=PolyLR,
|
||||
eta_min=1e-4,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end=40000,
|
||||
by_epoch=False)
|
||||
]
|
||||
# training schedule for 40k
|
||||
train_cfg = dict(type=IterBasedTrainLoop, max_iters=40000, val_interval=4000)
|
||||
val_cfg = dict(type=ValLoop)
|
||||
test_cfg = dict(type=TestLoop)
|
||||
default_hooks = dict(
|
||||
timer=dict(type=IterTimerHook),
|
||||
logger=dict(type=LoggerHook, interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type=ParamSchedulerHook),
|
||||
checkpoint=dict(type=CheckpointHook, by_epoch=False, interval=4000),
|
||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||
visualization=dict(type=SegVisualizationHook))
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||
LoggerHook, ParamSchedulerHook)
|
||||
from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
|
||||
from mmengine.optim.scheduler.lr_scheduler import PolyLR
|
||||
from mmengine.runner.loops import IterBasedTrainLoop, TestLoop, ValLoop
|
||||
from torch.optim.sgd import SGD
|
||||
|
||||
from mmseg.engine import SegVisualizationHook
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type=SGD,
|
||||
# lr=0.01,
|
||||
# momentum=0.9,
|
||||
# weight_decay=0.0005
|
||||
)
|
||||
|
||||
optim_wrapper = dict(type=OptimWrapper, optimizer=optimizer, clip_grad=None)
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type=PolyLR,
|
||||
eta_min=1e-4,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end=80000,
|
||||
by_epoch=False)
|
||||
]
|
||||
# training schedule for 80k
|
||||
train_cfg = dict(type=IterBasedTrainLoop, max_iters=80000, val_interval=8000)
|
||||
val_cfg = dict(type=ValLoop)
|
||||
test_cfg = dict(type=TestLoop)
|
||||
|
||||
default_hooks = dict(
|
||||
timer=dict(type=IterTimerHook),
|
||||
logger=dict(type=LoggerHook, interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type=ParamSchedulerHook),
|
||||
checkpoint=dict(type=CheckpointHook, by_epoch=False, interval=8000),
|
||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||
visualization=dict(type=SegVisualizationHook))
|
||||
77
Seg_All_In_One_MMSeg/mmseg/datasets/__init__.py
Normal file
77
Seg_All_In_One_MMSeg/mmseg/datasets/__init__.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# yapf: disable
|
||||
from .ade import ADE20KDataset
|
||||
from .basesegdataset import BaseCDDataset, BaseSegDataset
|
||||
from .bdd100k import BDD100KDataset
|
||||
from .chase_db1 import ChaseDB1Dataset
|
||||
from .cityscapes import CityscapesDataset
|
||||
from .coco_stuff import COCOStuffDataset
|
||||
from .dark_zurich import DarkZurichDataset
|
||||
from .dataset_wrappers import MultiImageMixDataset
|
||||
from .decathlon import DecathlonDataset
|
||||
from .drive import DRIVEDataset
|
||||
from .dsdl import DSDLSegDataset
|
||||
from .hrf import HRFDataset
|
||||
from .hsi_drive import HSIDrive20Dataset
|
||||
from .isaid import iSAIDDataset
|
||||
from .isprs import ISPRSDataset
|
||||
from .levir import LEVIRCDDataset
|
||||
from .lip import LIPDataset
|
||||
from .loveda import LoveDADataset
|
||||
from .mapillary import MapillaryDataset_v1, MapillaryDataset_v2
|
||||
from .night_driving import NightDrivingDataset
|
||||
from .nyu import NYUDataset
|
||||
from .pascal_context import PascalContextDataset, PascalContextDataset59
|
||||
from .potsdam import PotsdamDataset
|
||||
from .refuge import REFUGEDataset
|
||||
from .stare import STAREDataset
|
||||
from .synapse import SynapseDataset
|
||||
from .publicdataset_cholecseg8k import PublicDataSet_CholecSeg8k # TODO
|
||||
from .my_dataset_model import MyDataset_model # TODO
|
||||
from .publicdataset_autolaparo import PublicDataSet_AutoLaparo # TODO
|
||||
from .publicdataset_endovis_2017 import PublicDataSet_Endovis_2017 # TODO
|
||||
from .publicdataset_dresden import PublicDataSet_Dresden # TODO
|
||||
from .publicdataset_endovis_2018 import PublicDataSet_Endovis_2018 # TODO
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
|
||||
LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray, LoadMultipleRSImageFromFile,
|
||||
LoadSingleRSImageFromFile, PackSegInputs,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
from .voc import PascalVOCDataset
|
||||
|
||||
# yapf: enable
|
||||
__all__ = [
|
||||
'PublicDataSet_CholecSeg8k', # TODO
|
||||
'MyDataset_model', # TODO
|
||||
'PublicDataSet_AutoLaparo', # TODO
|
||||
'PublicDataSet_Endovis_2017', # TODO
|
||||
'PublicDataSet_Dresden', # TODO
|
||||
'PublicDataSet_Endovis_2018', # TODO
|
||||
'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
|
||||
'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
|
||||
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
|
||||
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
|
||||
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
|
||||
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
|
||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
|
||||
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
|
||||
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
|
||||
'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
|
||||
'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
|
||||
'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset',
|
||||
'NYUDataset', 'HSIDrive20Dataset'
|
||||
]
|
||||
92
Seg_All_In_One_MMSeg/mmseg/datasets/ade.py
Normal file
92
Seg_All_In_One_MMSeg/mmseg/datasets/ade.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ADE20KDataset(BaseSegDataset):
|
||||
"""ADE20K dataset.
|
||||
|
||||
In segmentation map annotation for ADE20K, 0 stands for background, which
|
||||
is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
|
||||
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
|
||||
'.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',
|
||||
'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk',
|
||||
'person', 'earth', 'door', 'table', 'mountain', 'plant',
|
||||
'curtain', 'chair', 'car', 'water', 'painting', 'sofa',
|
||||
'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair',
|
||||
'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp',
|
||||
'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
||||
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
||||
'skyscraper', 'fireplace', 'refrigerator', 'grandstand',
|
||||
'path', 'stairs', 'runway', 'case', 'pool table', 'pillow',
|
||||
'screen door', 'stairway', 'river', 'bridge', 'bookcase',
|
||||
'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill',
|
||||
'bench', 'countertop', 'stove', 'palm', 'kitchen island',
|
||||
'computer', 'swivel chair', 'boat', 'bar', 'arcade machine',
|
||||
'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
||||
'chandelier', 'awning', 'streetlight', 'booth',
|
||||
'television receiver', 'airplane', 'dirt track', 'apparel',
|
||||
'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle',
|
||||
'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',
|
||||
'conveyer belt', 'canopy', 'washer', 'plaything',
|
||||
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall',
|
||||
'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food',
|
||||
'step', 'tank', 'trade name', 'microwave', 'pot', 'animal',
|
||||
'bicycle', 'lake', 'dishwasher', 'screen', 'blanket',
|
||||
'sculpture', 'hood', 'sconce', 'vase', 'traffic light',
|
||||
'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate',
|
||||
'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
||||
'clock', 'flag'),
|
||||
palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
||||
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
||||
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
||||
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
||||
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
||||
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
||||
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
||||
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
||||
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
||||
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
||||
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
||||
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
||||
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
||||
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
||||
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
||||
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
||||
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
||||
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
||||
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
||||
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
||||
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
||||
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
||||
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
||||
[102, 255, 0], [92, 0, 255]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
552
Seg_All_In_One_MMSeg/mmseg/datasets/basesegdataset.py
Normal file
552
Seg_All_In_One_MMSeg/mmseg/datasets/basesegdataset.py
Normal file
@@ -0,0 +1,552 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import mmengine
|
||||
import mmengine.fileio as fileio
|
||||
import numpy as np
|
||||
from mmengine.dataset import BaseDataset, Compose
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BaseSegDataset(BaseDataset):
|
||||
"""Custom dataset for semantic segmentation. An example of file structure
|
||||
is as followed.
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
├── data
|
||||
│ ├── my_dataset
|
||||
│ │ ├── img_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── ann_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{seg_map_suffix}
|
||||
│ │ │ │ ├── yyy{seg_map_suffix}
|
||||
│ │ │ │ ├── zzz{seg_map_suffix}
|
||||
│ │ │ ├── val
|
||||
|
||||
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
|
||||
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
||||
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
||||
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
||||
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
||||
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
|
||||
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
specify classes to load. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img_path=None, seg_map_path=None).
|
||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=True``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
METAINFO: dict = dict()
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(img_path='', seg_map_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
ignore_index: int = 255,
|
||||
reduce_zero_label: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
|
||||
self.img_suffix = img_suffix
|
||||
self.seg_map_suffix = seg_map_suffix
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
self.data_root = data_root
|
||||
self.data_prefix = copy.copy(data_prefix)
|
||||
self.ann_file = ann_file
|
||||
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||
self._indices = indices
|
||||
self.serialize_data = serialize_data
|
||||
self.test_mode = test_mode
|
||||
self.max_refetch = max_refetch
|
||||
self.data_list: List[dict] = []
|
||||
self.data_bytes: np.ndarray
|
||||
|
||||
# Set meta information.
|
||||
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
||||
|
||||
# Get label map for custom classes
|
||||
new_classes = self._metainfo.get('classes', None)
|
||||
self.label_map = self.get_label_map(new_classes)
|
||||
self._metainfo.update(
|
||||
dict(
|
||||
label_map=self.label_map,
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
|
||||
# Update palette based on label map or generate palette
|
||||
# if it is not defined
|
||||
updated_palette = self._update_palette()
|
||||
self._metainfo.update(dict(palette=updated_palette))
|
||||
|
||||
# Join paths.
|
||||
if self.data_root is not None:
|
||||
self._join_prefix()
|
||||
|
||||
# Build pipeline.
|
||||
self.pipeline = Compose(pipeline)
|
||||
# Full initialize the dataset.
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
if test_mode:
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
@classmethod
|
||||
def get_label_map(cls,
|
||||
new_classes: Optional[Sequence] = None
|
||||
) -> Union[Dict, None]:
|
||||
"""Require label mapping.
|
||||
|
||||
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||
its values are the new label ids, and is used for changing pixel
|
||||
labels in load_annotations. If and only if old classes in cls.METAINFO
|
||||
is not equal to new classes in self._metainfo and nether of them is not
|
||||
None, `label_map` is not None.
|
||||
|
||||
Args:
|
||||
new_classes (list, tuple, optional): The new classes name from
|
||||
metainfo. Default to None.
|
||||
|
||||
|
||||
Returns:
|
||||
dict, optional: The mapping from old classes in cls.METAINFO to
|
||||
new classes in self._metainfo
|
||||
"""
|
||||
old_classes = cls.METAINFO.get('classes', None)
|
||||
if (new_classes is not None and old_classes is not None
|
||||
and list(new_classes) != list(old_classes)):
|
||||
|
||||
label_map = {}
|
||||
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
||||
raise ValueError(
|
||||
f'new classes {new_classes} is not a '
|
||||
f'subset of classes {old_classes} in METAINFO.')
|
||||
for i, c in enumerate(old_classes):
|
||||
if c not in new_classes:
|
||||
label_map[i] = 255
|
||||
else:
|
||||
label_map[i] = new_classes.index(c)
|
||||
return label_map
|
||||
else:
|
||||
return None
|
||||
|
||||
def _update_palette(self) -> list:
|
||||
"""Update palette after loading metainfo.
|
||||
|
||||
If length of palette is equal to classes, just return the palette.
|
||||
If palette is not defined, it will randomly generate a palette.
|
||||
If classes is updated by customer, it will return the subset of
|
||||
palette.
|
||||
|
||||
Returns:
|
||||
Sequence: Palette for current dataset.
|
||||
"""
|
||||
palette = self._metainfo.get('palette', [])
|
||||
classes = self._metainfo.get('classes', [])
|
||||
# palette does match classes
|
||||
if len(palette) == len(classes):
|
||||
return palette
|
||||
|
||||
if len(palette) == 0:
|
||||
# Get random state before set seed, and restore
|
||||
# random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
new_palette = np.random.randint(
|
||||
0, 255, size=(len(classes), 3)).tolist()
|
||||
np.random.set_state(state)
|
||||
elif len(palette) >= len(classes) and self.label_map is not None:
|
||||
new_palette = []
|
||||
# return subset of palette
|
||||
for old_id, new_id in sorted(
|
||||
self.label_map.items(), key=lambda x: x[1]):
|
||||
if new_id != 255:
|
||||
new_palette.append(palette[old_id])
|
||||
new_palette = type(palette)(new_palette)
|
||||
else:
|
||||
raise ValueError('palette does not match classes '
|
||||
f'as metainfo is {self._metainfo}.')
|
||||
return new_palette
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
img_dir = self.data_prefix.get('img_path', None)
|
||||
ann_dir = self.data_prefix.get('seg_map_path', None)
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = mmengine.list_from_file(
|
||||
self.ann_file, backend_args=self.backend_args)
|
||||
for line in lines:
|
||||
img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img_name + self.img_suffix))
|
||||
if ann_dir is not None:
|
||||
seg_map = img_name + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
else:
|
||||
_suffix_len = len(self.img_suffix)
|
||||
for img in fileio.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=self.img_suffix,
|
||||
recursive=True,
|
||||
backend_args=self.backend_args):
|
||||
data_info = dict(img_path=osp.join(img_dir, img))
|
||||
if ann_dir is not None:
|
||||
seg_map = img[:-_suffix_len] + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BaseCDDataset(BaseDataset):
|
||||
"""Custom dataset for change detection. An example of file structure is as
|
||||
followed.
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
├── data
|
||||
│ ├── my_dataset
|
||||
│ │ ├── img_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── img_dir2
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── ann_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{seg_map_suffix}
|
||||
│ │ │ │ ├── yyy{seg_map_suffix}
|
||||
│ │ │ │ ├── zzz{seg_map_suffix}
|
||||
│ │ │ ├── val
|
||||
|
||||
The image names in img_dir and img_dir2 should be consistent.
|
||||
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
|
||||
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
||||
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
||||
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
||||
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
||||
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
|
||||
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
specify classes to load. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img_path=None, img_path2=None, seg_map_path=None).
|
||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||
img_suffix2 (str): Suffix of images. Default: '.jpg'
|
||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=True``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
METAINFO: dict = dict()
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
img_suffix='.jpg',
|
||||
img_suffix2='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(
|
||||
img_path='', img_path2='', seg_map_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
ignore_index: int = 255,
|
||||
reduce_zero_label: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
|
||||
self.img_suffix = img_suffix
|
||||
self.img_suffix2 = img_suffix2
|
||||
self.seg_map_suffix = seg_map_suffix
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
self.data_root = data_root
|
||||
self.data_prefix = copy.copy(data_prefix)
|
||||
self.ann_file = ann_file
|
||||
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||
self._indices = indices
|
||||
self.serialize_data = serialize_data
|
||||
self.test_mode = test_mode
|
||||
self.max_refetch = max_refetch
|
||||
self.data_list: List[dict] = []
|
||||
self.data_bytes: np.ndarray
|
||||
|
||||
# Set meta information.
|
||||
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
||||
|
||||
# Get label map for custom classes
|
||||
new_classes = self._metainfo.get('classes', None)
|
||||
self.label_map = self.get_label_map(new_classes)
|
||||
self._metainfo.update(
|
||||
dict(
|
||||
label_map=self.label_map,
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
|
||||
# Update palette based on label map or generate palette
|
||||
# if it is not defined
|
||||
updated_palette = self._update_palette()
|
||||
self._metainfo.update(dict(palette=updated_palette))
|
||||
|
||||
# Join paths.
|
||||
if self.data_root is not None:
|
||||
self._join_prefix()
|
||||
|
||||
# Build pipeline.
|
||||
self.pipeline = Compose(pipeline)
|
||||
# Full initialize the dataset.
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
if test_mode:
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
@classmethod
|
||||
def get_label_map(cls,
|
||||
new_classes: Optional[Sequence] = None
|
||||
) -> Union[Dict, None]:
|
||||
"""Require label mapping.
|
||||
|
||||
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||
its values are the new label ids, and is used for changing pixel
|
||||
labels in load_annotations. If and only if old classes in cls.METAINFO
|
||||
is not equal to new classes in self._metainfo and nether of them is not
|
||||
None, `label_map` is not None.
|
||||
|
||||
Args:
|
||||
new_classes (list, tuple, optional): The new classes name from
|
||||
metainfo. Default to None.
|
||||
|
||||
|
||||
Returns:
|
||||
dict, optional: The mapping from old classes in cls.METAINFO to
|
||||
new classes in self._metainfo
|
||||
"""
|
||||
old_classes = cls.METAINFO.get('classes', None)
|
||||
if (new_classes is not None and old_classes is not None
|
||||
and list(new_classes) != list(old_classes)):
|
||||
|
||||
label_map = {}
|
||||
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
||||
raise ValueError(
|
||||
f'new classes {new_classes} is not a '
|
||||
f'subset of classes {old_classes} in METAINFO.')
|
||||
for i, c in enumerate(old_classes):
|
||||
if c not in new_classes:
|
||||
label_map[i] = 255
|
||||
else:
|
||||
label_map[i] = new_classes.index(c)
|
||||
return label_map
|
||||
else:
|
||||
return None
|
||||
|
||||
def _update_palette(self) -> list:
|
||||
"""Update palette after loading metainfo.
|
||||
|
||||
If length of palette is equal to classes, just return the palette.
|
||||
If palette is not defined, it will randomly generate a palette.
|
||||
If classes is updated by customer, it will return the subset of
|
||||
palette.
|
||||
|
||||
Returns:
|
||||
Sequence: Palette for current dataset.
|
||||
"""
|
||||
palette = self._metainfo.get('palette', [])
|
||||
classes = self._metainfo.get('classes', [])
|
||||
# palette does match classes
|
||||
if len(palette) == len(classes):
|
||||
return palette
|
||||
|
||||
if len(palette) == 0:
|
||||
# Get random state before set seed, and restore
|
||||
# random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
new_palette = np.random.randint(
|
||||
0, 255, size=(len(classes), 3)).tolist()
|
||||
np.random.set_state(state)
|
||||
elif len(palette) >= len(classes) and self.label_map is not None:
|
||||
new_palette = []
|
||||
# return subset of palette
|
||||
for old_id, new_id in sorted(
|
||||
self.label_map.items(), key=lambda x: x[1]):
|
||||
if new_id != 255:
|
||||
new_palette.append(palette[old_id])
|
||||
new_palette = type(palette)(new_palette)
|
||||
else:
|
||||
raise ValueError('palette does not match classes '
|
||||
f'as metainfo is {self._metainfo}.')
|
||||
return new_palette
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
img_dir = self.data_prefix.get('img_path', None)
|
||||
img_dir2 = self.data_prefix.get('img_path2', None)
|
||||
ann_dir = self.data_prefix.get('seg_map_path', None)
|
||||
if osp.isfile(self.ann_file):
|
||||
lines = mmengine.list_from_file(
|
||||
self.ann_file, backend_args=self.backend_args)
|
||||
for line in lines:
|
||||
img_name = line.strip()
|
||||
if '.' in osp.basename(img_name):
|
||||
img_name, img_ext = osp.splitext(img_name)
|
||||
self.img_suffix = img_ext
|
||||
self.img_suffix2 = img_ext
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img_name + self.img_suffix),
|
||||
img_path2=osp.join(img_dir2, img_name + self.img_suffix2))
|
||||
|
||||
if ann_dir is not None:
|
||||
seg_map = img_name + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
else:
|
||||
for img in fileio.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=self.img_suffix,
|
||||
recursive=True,
|
||||
backend_args=self.backend_args):
|
||||
if '.' in osp.basename(img):
|
||||
img, img_ext = osp.splitext(img)
|
||||
self.img_suffix = img_ext
|
||||
self.img_suffix2 = img_ext
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img + self.img_suffix),
|
||||
img_path2=osp.join(img_dir2, img + self.img_suffix2))
|
||||
if ann_dir is not None:
|
||||
seg_map = img + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
30
Seg_All_In_One_MMSeg/mmseg/datasets/bdd100k.py
Normal file
30
Seg_All_In_One_MMSeg/mmseg/datasets/bdd100k.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmseg.datasets.basesegdataset import BaseSegDataset
|
||||
from mmseg.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BDD100KDataset(BaseSegDataset):
|
||||
METAINFO = dict(
|
||||
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain',
|
||||
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
|
||||
'motorcycle', 'bicycle'),
|
||||
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||
[190, 153, 153], [153, 153, 153], [250, 170,
|
||||
30], [220, 220, 0],
|
||||
[107, 142, 35], [152, 251, 152], [70, 130, 180],
|
||||
[220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
|
||||
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
32
Seg_All_In_One_MMSeg/mmseg/datasets/chase_db1.py
Normal file
32
Seg_All_In_One_MMSeg/mmseg/datasets/chase_db1.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ChaseDB1Dataset(BaseSegDataset):
|
||||
"""Chase_db1 dataset.
|
||||
|
||||
In segmentation map annotation for Chase_db1, 0 stands for background,
|
||||
which is included in 2 categories. ``reduce_zero_label`` is fixed to False.
|
||||
The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||
'_1stHO.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', 'vessel'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='_1stHO.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
assert fileio.exists(
|
||||
self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
30
Seg_All_In_One_MMSeg/mmseg/datasets/cityscapes.py
Normal file
30
Seg_All_In_One_MMSeg/mmseg/datasets/cityscapes.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class CityscapesDataset(BaseSegDataset):
|
||||
"""Cityscapes dataset.
|
||||
|
||||
The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
|
||||
fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain',
|
||||
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
|
||||
'motorcycle', 'bicycle'),
|
||||
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||
[190, 153, 153], [153, 153, 153], [250, 170,
|
||||
30], [220, 220, 0],
|
||||
[107, 142, 35], [152, 251, 152], [70, 130, 180],
|
||||
[220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
|
||||
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='_leftImg8bit.png',
|
||||
seg_map_suffix='_gtFine_labelTrainIds.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
99
Seg_All_In_One_MMSeg/mmseg/datasets/coco_stuff.py
Normal file
99
Seg_All_In_One_MMSeg/mmseg/datasets/coco_stuff.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class COCOStuffDataset(BaseSegDataset):
|
||||
"""COCO-Stuff dataset.
|
||||
|
||||
In segmentation map annotation for COCO-Stuff, Train-IDs of the 10k version
|
||||
are from 1 to 171, where 0 is the ignore index, and Train-ID of COCO Stuff
|
||||
164k is from 0 to 170, where 255 is the ignore index. So, they are all 171
|
||||
semantic categories. ``reduce_zero_label`` is set to True and False for the
|
||||
10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg',
|
||||
and ``seg_map_suffix`` is fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=(
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
|
||||
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
||||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
|
||||
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
||||
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
||||
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||||
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
||||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
|
||||
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
|
||||
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
||||
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
|
||||
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
|
||||
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
|
||||
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
|
||||
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
|
||||
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
|
||||
'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
|
||||
'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
|
||||
'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',
|
||||
'paper', 'pavement', 'pillow', 'plant-other', 'plastic',
|
||||
'platform', 'playingfield', 'railing', 'railroad', 'river', 'road',
|
||||
'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf',
|
||||
'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs',
|
||||
'stone', 'straw', 'structural-other', 'table', 'tent',
|
||||
'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick',
|
||||
'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone',
|
||||
'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
|
||||
'window-blind', 'window-other', 'wood'),
|
||||
palette=[[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
|
||||
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
|
||||
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
|
||||
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
|
||||
[0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
|
||||
[128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
|
||||
[64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],
|
||||
[0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],
|
||||
[0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],
|
||||
[64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],
|
||||
[0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],
|
||||
[128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],
|
||||
[0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],
|
||||
[64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],
|
||||
[0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
|
||||
[0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],
|
||||
[64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128],
|
||||
[128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128],
|
||||
[64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224],
|
||||
[64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0],
|
||||
[0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128],
|
||||
[64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224],
|
||||
[64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128],
|
||||
[128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192],
|
||||
[0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224],
|
||||
[0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0],
|
||||
[64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192],
|
||||
[0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224],
|
||||
[0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128],
|
||||
[192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128],
|
||||
[64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160],
|
||||
[0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64],
|
||||
[64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128],
|
||||
[64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160],
|
||||
[0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192],
|
||||
[192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192],
|
||||
[0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160],
|
||||
[64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64],
|
||||
[64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],
|
||||
[0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
|
||||
[64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
|
||||
[192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
|
||||
[64, 192, 96], [64, 160, 64], [64, 64, 0]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='_labelTrainIds.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
15
Seg_All_In_One_MMSeg/mmseg/datasets/dark_zurich.py
Normal file
15
Seg_All_In_One_MMSeg/mmseg/datasets/dark_zurich.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .cityscapes import CityscapesDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class DarkZurichDataset(CityscapesDataset):
|
||||
"""DarkZurichDataset dataset."""
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='_rgb_anon.png',
|
||||
seg_map_suffix='_gt_labelTrainIds.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
136
Seg_All_In_One_MMSeg/mmseg/datasets/dataset_wrappers.py
Normal file
136
Seg_All_In_One_MMSeg/mmseg/datasets/dataset_wrappers.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import collections
|
||||
import copy
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from mmengine.dataset import ConcatDataset, force_full_init
|
||||
|
||||
from mmseg.registry import DATASETS, TRANSFORMS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MultiImageMixDataset:
|
||||
"""A wrapper of multiple images mixed dataset.
|
||||
|
||||
Suitable for training on multiple images mixed data augmentation like
|
||||
mosaic and mixup.
|
||||
|
||||
Args:
|
||||
dataset (ConcatDataset or dict): The dataset to be mixed.
|
||||
pipeline (Sequence[dict]): Sequence of transform object or
|
||||
config dict to be composed.
|
||||
skip_type_keys (list[str], optional): Sequence of type string to
|
||||
be skip pipeline. Default to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset: Union[ConcatDataset, dict],
|
||||
pipeline: Sequence[dict],
|
||||
skip_type_keys: Optional[List[str]] = None,
|
||||
lazy_init: bool = False) -> None:
|
||||
assert isinstance(pipeline, collections.abc.Sequence)
|
||||
|
||||
if isinstance(dataset, dict):
|
||||
self.dataset = DATASETS.build(dataset)
|
||||
elif isinstance(dataset, ConcatDataset):
|
||||
self.dataset = dataset
|
||||
else:
|
||||
raise TypeError(
|
||||
'elements in datasets sequence should be config or '
|
||||
f'`ConcatDataset` instance, but got {type(dataset)}')
|
||||
|
||||
if skip_type_keys is not None:
|
||||
assert all([
|
||||
isinstance(skip_type_key, str)
|
||||
for skip_type_key in skip_type_keys
|
||||
])
|
||||
self._skip_type_keys = skip_type_keys
|
||||
|
||||
self.pipeline = []
|
||||
self.pipeline_types = []
|
||||
for transform in pipeline:
|
||||
if isinstance(transform, dict):
|
||||
self.pipeline_types.append(transform['type'])
|
||||
transform = TRANSFORMS.build(transform)
|
||||
self.pipeline.append(transform)
|
||||
else:
|
||||
raise TypeError('pipeline must be a dict')
|
||||
|
||||
self._metainfo = self.dataset.metainfo
|
||||
self.num_samples = len(self.dataset)
|
||||
|
||||
self._fully_initialized = False
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
@property
|
||||
def metainfo(self) -> dict:
|
||||
"""Get the meta information of the multi-image-mixed dataset.
|
||||
|
||||
Returns:
|
||||
dict: The meta information of multi-image-mixed dataset.
|
||||
"""
|
||||
return copy.deepcopy(self._metainfo)
|
||||
|
||||
def full_init(self):
|
||||
"""Loop to ``full_init`` each dataset."""
|
||||
if self._fully_initialized:
|
||||
return
|
||||
|
||||
self.dataset.full_init()
|
||||
self._ori_len = len(self.dataset)
|
||||
self._fully_initialized = True
|
||||
|
||||
@force_full_init
|
||||
def get_data_info(self, idx: int) -> dict:
|
||||
"""Get annotation by index.
|
||||
|
||||
Args:
|
||||
idx (int): Global index of ``ConcatDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The idx-th annotation of the datasets.
|
||||
"""
|
||||
return self.dataset.get_data_info(idx)
|
||||
|
||||
@force_full_init
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = copy.deepcopy(self.dataset[idx])
|
||||
for (transform, transform_type) in zip(self.pipeline,
|
||||
self.pipeline_types):
|
||||
if self._skip_type_keys is not None and \
|
||||
transform_type in self._skip_type_keys:
|
||||
continue
|
||||
|
||||
if hasattr(transform, 'get_indices'):
|
||||
indices = transform.get_indices(self.dataset)
|
||||
if not isinstance(indices, collections.abc.Sequence):
|
||||
indices = [indices]
|
||||
mix_results = [
|
||||
copy.deepcopy(self.dataset[index]) for index in indices
|
||||
]
|
||||
results['mix_results'] = mix_results
|
||||
|
||||
results = transform(results)
|
||||
|
||||
if 'mix_results' in results:
|
||||
results.pop('mix_results')
|
||||
|
||||
return results
|
||||
|
||||
def update_skip_type_keys(self, skip_type_keys):
|
||||
"""Update skip_type_keys.
|
||||
|
||||
It is called by an external hook.
|
||||
|
||||
Args:
|
||||
skip_type_keys (list[str], optional): Sequence of type
|
||||
string to be skip pipeline.
|
||||
"""
|
||||
assert all([
|
||||
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
|
||||
])
|
||||
self._skip_type_keys = skip_type_keys
|
||||
96
Seg_All_In_One_MMSeg/mmseg/datasets/decathlon.py
Normal file
96
Seg_All_In_One_MMSeg/mmseg/datasets/decathlon.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from typing import List
|
||||
|
||||
from mmengine.fileio import load
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class DecathlonDataset(BaseSegDataset):
|
||||
"""Dataset for Dacathlon dataset.
|
||||
|
||||
The dataset.json format is shown as follows
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
{
|
||||
"name": "BRATS",
|
||||
"tensorImageSize": "4D",
|
||||
"modality":
|
||||
{
|
||||
"0": "FLAIR",
|
||||
"1": "T1w",
|
||||
"2": "t1gd",
|
||||
"3": "T2w"
|
||||
},
|
||||
"labels": {
|
||||
"0": "background",
|
||||
"1": "edema",
|
||||
"2": "non-enhancing tumor",
|
||||
"3": "enhancing tumour"
|
||||
},
|
||||
"numTraining": 484,
|
||||
"numTest": 266,
|
||||
"training":
|
||||
[
|
||||
{
|
||||
"image": "./imagesTr/BRATS_306.nii.gz"
|
||||
"label": "./labelsTr/BRATS_306.nii.gz"
|
||||
...
|
||||
}
|
||||
]
|
||||
"test":
|
||||
[
|
||||
"./imagesTs/BRATS_557.nii.gz"
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
# `self.ann_file` denotes the absolute annotation file path if
|
||||
# `self.root=None` or relative path if `self.root=/path/to/data/`.
|
||||
annotations = load(self.ann_file)
|
||||
if not isinstance(annotations, dict):
|
||||
raise TypeError(f'The annotations loaded from annotation file '
|
||||
f'should be a dict, but got {type(annotations)}!')
|
||||
raw_data_list = annotations[
|
||||
'training'] if not self.test_mode else annotations['test']
|
||||
data_list = []
|
||||
for raw_data_info in raw_data_list:
|
||||
# `2:` works for removing './' in file path, which will break
|
||||
# loading from cloud storage.
|
||||
if isinstance(raw_data_info, dict):
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, raw_data_info['image']
|
||||
[2:]))
|
||||
data_info['seg_map_path'] = osp.join(
|
||||
self.data_root, raw_data_info['label'][2:])
|
||||
else:
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, raw_data_info)[2:])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
annotations.pop('training')
|
||||
annotations.pop('test')
|
||||
|
||||
metainfo = copy.deepcopy(annotations)
|
||||
metainfo['classes'] = [*metainfo['labels'].values()]
|
||||
# Meta information load from annotation file will not influence the
|
||||
# existed meta information load from `BaseDataset.METAINFO` and
|
||||
# `metainfo` arguments defined in constructor.
|
||||
for k, v in metainfo.items():
|
||||
self._metainfo.setdefault(k, v)
|
||||
|
||||
return data_list
|
||||
32
Seg_All_In_One_MMSeg/mmseg/datasets/drive.py
Normal file
32
Seg_All_In_One_MMSeg/mmseg/datasets/drive.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class DRIVEDataset(BaseSegDataset):
|
||||
"""DRIVE dataset.
|
||||
|
||||
In segmentation map annotation for DRIVE, 0 stands for background, which is
|
||||
included in 2 categories. ``reduce_zero_label`` is fixed to False. The
|
||||
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||
'_manual1.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', 'vessel'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='_manual1.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
assert fileio.exists(
|
||||
self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
116
Seg_All_In_One_MMSeg/mmseg/datasets/dsdl.py
Normal file
116
Seg_All_In_One_MMSeg/mmseg/datasets/dsdl.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
try:
|
||||
from dsdl.dataset import DSDLDataset
|
||||
except ImportError:
|
||||
DSDLDataset = None
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class DSDLSegDataset(BaseSegDataset):
|
||||
"""Dataset for dsdl segmentation.
|
||||
|
||||
Args:
|
||||
specific_key_path(dict): Path of specific key which can not
|
||||
be loaded by it's field name.
|
||||
pre_transform(dict): pre-transform functions before loading.
|
||||
used_labels(sequence): list of actual used classes in train steps,
|
||||
this must be subset of class domain.
|
||||
"""
|
||||
|
||||
METAINFO = {}
|
||||
|
||||
def __init__(self,
|
||||
specific_key_path: Dict = {},
|
||||
pre_transform: Dict = {},
|
||||
used_labels: Optional[Sequence] = None,
|
||||
**kwargs) -> None:
|
||||
|
||||
if DSDLDataset is None:
|
||||
raise RuntimeError(
|
||||
'Package dsdl is not installed. Please run "pip install dsdl".'
|
||||
)
|
||||
self.used_labels = used_labels
|
||||
|
||||
loc_config = dict(type='LocalFileReader', working_dir='')
|
||||
if kwargs.get('data_root'):
|
||||
kwargs['ann_file'] = os.path.join(kwargs['data_root'],
|
||||
kwargs['ann_file'])
|
||||
required_fields = ['Image', 'LabelMap']
|
||||
|
||||
self.dsdldataset = DSDLDataset(
|
||||
dsdl_yaml=kwargs['ann_file'],
|
||||
location_config=loc_config,
|
||||
required_fields=required_fields,
|
||||
specific_key_path=specific_key_path,
|
||||
transform=pre_transform,
|
||||
)
|
||||
BaseSegDataset.__init__(self, **kwargs)
|
||||
|
||||
def load_data_list(self) -> List[Dict]:
|
||||
"""Load data info from a dsdl yaml file named as ``self.ann_file``
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of data list.
|
||||
"""
|
||||
|
||||
if self.used_labels:
|
||||
self._metainfo['classes'] = tuple(self.used_labels)
|
||||
self.label_map = self.get_label_map(self.used_labels)
|
||||
else:
|
||||
self._metainfo['classes'] = tuple(['background'] +
|
||||
self.dsdldataset.class_names)
|
||||
data_list = []
|
||||
|
||||
for i, data in enumerate(self.dsdldataset):
|
||||
datainfo = dict(
|
||||
img_path=os.path.join(self.data_prefix['img_path'],
|
||||
data['Image'][0].location),
|
||||
seg_map_path=os.path.join(self.data_prefix['seg_map_path'],
|
||||
data['LabelMap'][0].location),
|
||||
label_map=self.label_map,
|
||||
reduce_zero_label=self.reduce_zero_label,
|
||||
seg_fields=[],
|
||||
)
|
||||
data_list.append(datainfo)
|
||||
|
||||
return data_list
|
||||
|
||||
def get_label_map(self,
|
||||
new_classes: Optional[Sequence] = None
|
||||
) -> Union[Dict, None]:
|
||||
"""Require label mapping.
|
||||
|
||||
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||
its values are the new label ids, and is used for changing pixel
|
||||
labels in load_annotations. If and only if old classes in class_dom
|
||||
is not equal to new classes in args and nether of them is not
|
||||
None, `label_map` is not None.
|
||||
Args:
|
||||
new_classes (list, tuple, optional): The new classes name from
|
||||
metainfo. Default to None.
|
||||
Returns:
|
||||
dict, optional: The mapping from old classes to new classes.
|
||||
"""
|
||||
old_classes = ['background'] + self.dsdldataset.class_names
|
||||
if (new_classes is not None and old_classes is not None
|
||||
and list(new_classes) != list(old_classes)):
|
||||
|
||||
label_map = {}
|
||||
if not set(new_classes).issubset(old_classes):
|
||||
raise ValueError(
|
||||
f'new classes {new_classes} is not a '
|
||||
f'subset of classes {old_classes} in class_dom.')
|
||||
for i, c in enumerate(old_classes):
|
||||
if c not in new_classes:
|
||||
label_map[i] = 255
|
||||
else:
|
||||
label_map[i] = new_classes.index(c)
|
||||
return label_map
|
||||
else:
|
||||
return None
|
||||
32
Seg_All_In_One_MMSeg/mmseg/datasets/hrf.py
Normal file
32
Seg_All_In_One_MMSeg/mmseg/datasets/hrf.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class HRFDataset(BaseSegDataset):
|
||||
"""HRF dataset.
|
||||
|
||||
In segmentation map annotation for HRF, 0 stands for background, which is
|
||||
included in 2 categories. ``reduce_zero_label`` is fixed to False. The
|
||||
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||
'.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', 'vessel'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
assert fileio.exists(
|
||||
self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
42
Seg_All_In_One_MMSeg/mmseg/datasets/hsi_drive.py
Normal file
42
Seg_All_In_One_MMSeg/mmseg/datasets/hsi_drive.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.datasets import BaseSegDataset
|
||||
from mmseg.registry import DATASETS
|
||||
|
||||
classes_exp = ('unlabelled', 'road', 'road marks', 'vegetation',
|
||||
'painted metal', 'sky', 'concrete', 'pedestrian', 'water',
|
||||
'unpainted metal', 'glass')
|
||||
palette_exp = [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0],
|
||||
[255, 0, 0], [0, 0, 255], [102, 51, 0], [255, 255, 0],
|
||||
[0, 207, 250], [255, 166, 0], [0, 204, 204]]
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class HSIDrive20Dataset(BaseSegDataset):
|
||||
"""HSI-Drive v2.0 (https://ieeexplore.ieee.org/document/10371793), the
|
||||
updated version of HSI-Drive
|
||||
(https://ieeexplore.ieee.org/document/9575298), is a structured dataset for
|
||||
the research and development of automated driving systems (ADS) supported
|
||||
by hyperspectral imaging (HSI). It contains per-pixel manually annotated
|
||||
images selected from videos recorded in real driving conditions and has
|
||||
been organized according to four parameters: season, daytime, road type,
|
||||
and weather conditions.
|
||||
|
||||
The video sequences have been captured with a small-size 25-band VNIR
|
||||
(Visible-NearlnfraRed) snapshot hyperspectral camera mounted on a driving
|
||||
automobile. As a consequence, you need to modify the in_channels parameter
|
||||
of your model from 3 (RGB images) to 25 (HSI images) as it is done in
|
||||
configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py
|
||||
|
||||
Apart from the abovementioned articles, additional information is provided
|
||||
in the website (https://ipaccess.ehu.eus/HSI-Drive/) from where you can
|
||||
download the dataset and also visualize some examples of segmented videos.
|
||||
"""
|
||||
|
||||
METAINFO = dict(classes=classes_exp, palette=palette_exp)
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.npy',
|
||||
seg_map_suffix='.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
39
Seg_All_In_One_MMSeg/mmseg/datasets/isaid.py
Normal file
39
Seg_All_In_One_MMSeg/mmseg/datasets/isaid.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class iSAIDDataset(BaseSegDataset):
|
||||
""" iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images
|
||||
In segmentation map annotation for iSAID dataset, which is included
|
||||
in 16 categories. ``reduce_zero_label`` is fixed to False. The
|
||||
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||
'_manual1.png'.
|
||||
"""
|
||||
|
||||
METAINFO = dict(
|
||||
classes=('background', 'ship', 'store_tank', 'baseball_diamond',
|
||||
'tennis_court', 'basketball_court', 'Ground_Track_Field',
|
||||
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
|
||||
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
|
||||
'Harbor'),
|
||||
palette=[[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
|
||||
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
|
||||
[0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127],
|
||||
[0, 127, 191], [0, 127, 255], [0, 100, 155]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='_instance_color_RGB.png',
|
||||
ignore_index=255,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
ignore_index=ignore_index,
|
||||
**kwargs)
|
||||
assert fileio.exists(
|
||||
self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
29
Seg_All_In_One_MMSeg/mmseg/datasets/isprs.py
Normal file
29
Seg_All_In_One_MMSeg/mmseg/datasets/isprs.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ISPRSDataset(BaseSegDataset):
|
||||
"""ISPRS dataset.
|
||||
|
||||
In segmentation map annotation for ISPRS, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('impervious_surface', 'building', 'low_vegetation', 'tree',
|
||||
'car', 'clutter'),
|
||||
palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||
[255, 255, 0], [255, 0, 0]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
31
Seg_All_In_One_MMSeg/mmseg/datasets/levir.py
Normal file
31
Seg_All_In_One_MMSeg/mmseg/datasets/levir.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseCDDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class LEVIRCDDataset(BaseCDDataset):
|
||||
"""ISPRS dataset.
|
||||
|
||||
In segmentation map annotation for ISPRS, 0 is to ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
|
||||
METAINFO = dict(
|
||||
classes=('background', 'changed'),
|
||||
palette=[[0, 0, 0], [255, 255, 255]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png',
|
||||
img_suffix2='.png',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
img_suffix2=img_suffix2,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
47
Seg_All_In_One_MMSeg/mmseg/datasets/lip.py
Normal file
47
Seg_All_In_One_MMSeg/mmseg/datasets/lip.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class LIPDataset(BaseSegDataset):
|
||||
"""LIP dataset.
|
||||
|
||||
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
|
||||
'.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Hat', 'Hair', 'Glove', 'Sunglasses',
|
||||
'UpperClothes', 'Dress', 'Coat', 'Socks', 'Pants',
|
||||
'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm',
|
||||
'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe',
|
||||
'Right-shoe'),
|
||||
palette=(
|
||||
[0, 0, 0],
|
||||
[128, 0, 0],
|
||||
[255, 0, 0],
|
||||
[0, 85, 0],
|
||||
[170, 0, 51],
|
||||
[255, 85, 0],
|
||||
[0, 0, 85],
|
||||
[0, 119, 221],
|
||||
[85, 85, 0],
|
||||
[0, 85, 85],
|
||||
[85, 51, 0],
|
||||
[52, 86, 128],
|
||||
[0, 128, 0],
|
||||
[0, 0, 255],
|
||||
[51, 170, 221],
|
||||
[0, 255, 255],
|
||||
[85, 255, 170],
|
||||
[170, 255, 85],
|
||||
[255, 255, 0],
|
||||
[255, 170, 0],
|
||||
))
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
29
Seg_All_In_One_MMSeg/mmseg/datasets/loveda.py
Normal file
29
Seg_All_In_One_MMSeg/mmseg/datasets/loveda.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class LoveDADataset(BaseSegDataset):
|
||||
"""LoveDA dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', 'building', 'road', 'water', 'barren', 'forest',
|
||||
'agricultural'),
|
||||
palette=[[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
|
||||
[159, 129, 183], [0, 255, 0], [255, 195, 128]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
176
Seg_All_In_One_MMSeg/mmseg/datasets/mapillary.py
Normal file
176
Seg_All_In_One_MMSeg/mmseg/datasets/mapillary.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MapillaryDataset_v1(BaseSegDataset):
|
||||
"""Mapillary Vistas Dataset.
|
||||
|
||||
Dataset paper link:
|
||||
http://ieeexplore.ieee.org/document/8237796/
|
||||
|
||||
v1.2 contain 66 object classes.
|
||||
(37 instance-specific)
|
||||
|
||||
v2.0 contain 124 object classes.
|
||||
(70 instance-specific, 46 stuff, 8 void or crowd).
|
||||
|
||||
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
||||
fixed to '.png' for Mapillary Vistas Dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail',
|
||||
'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain',
|
||||
'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track',
|
||||
'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building',
|
||||
'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
|
||||
'Other Rider', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
|
||||
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench',
|
||||
'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera',
|
||||
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
|
||||
'Phone Booth', 'Pothole', 'Street Light', 'Pole',
|
||||
'Traffic Sign Frame', 'Utility Pole', 'Traffic Light',
|
||||
'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can',
|
||||
'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle',
|
||||
'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
|
||||
'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'),
|
||||
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
|
||||
[180, 165, 180], [90, 120, 150], [102, 102, 156],
|
||||
[128, 64, 255], [140, 140, 200], [170, 170, 170],
|
||||
[250, 170, 160], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110],
|
||||
[244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90],
|
||||
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
|
||||
[200, 128, 128], [255, 255, 255], [64, 170,
|
||||
64], [230, 160, 50],
|
||||
[70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
|
||||
[100, 140, 180], [220, 220, 220], [220, 128, 128],
|
||||
[222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33],
|
||||
[100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100],
|
||||
[153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30],
|
||||
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
|
||||
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
|
||||
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10,
|
||||
10], [0, 0, 0]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MapillaryDataset_v2(BaseSegDataset):
|
||||
"""Mapillary Vistas Dataset.
|
||||
|
||||
Dataset paper link:
|
||||
http://ieeexplore.ieee.org/document/8237796/
|
||||
|
||||
v1.2 contain 66 object classes.
|
||||
(37 instance-specific)
|
||||
|
||||
v2.0 contain 124 object classes.
|
||||
(70 instance-specific, 46 stuff, 8 void or crowd).
|
||||
|
||||
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
||||
fixed to '.png' for Mapillary Vistas Dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=(
|
||||
'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block',
|
||||
'Curb', 'Fence', 'Guard Rail', 'Barrier', 'Road Median',
|
||||
'Road Side', 'Lane Separator', 'Temporary Barrier', 'Wall',
|
||||
'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Driveway',
|
||||
'Parking', 'Parking Aisle', 'Pedestrian Area', 'Rail Track',
|
||||
'Road', 'Road Shoulder', 'Service Lane', 'Sidewalk',
|
||||
'Traffic Island', 'Bridge', 'Building', 'Garage', 'Tunnel',
|
||||
'Person', 'Person Group', 'Bicyclist', 'Motorcyclist',
|
||||
'Other Rider', 'Lane Marking - Dashed Line',
|
||||
'Lane Marking - Straight Line', 'Lane Marking - Zigzag Line',
|
||||
'Lane Marking - Ambiguous', 'Lane Marking - Arrow (Left)',
|
||||
'Lane Marking - Arrow (Other)', 'Lane Marking - Arrow (Right)',
|
||||
'Lane Marking - Arrow (Split Left or Straight)',
|
||||
'Lane Marking - Arrow (Split Right or Straight)',
|
||||
'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - Give Way (Row)',
|
||||
'Lane Marking - Give Way (Single)',
|
||||
'Lane Marking - Hatched (Chevron)',
|
||||
'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other',
|
||||
'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)',
|
||||
'Lane Marking - Symbol (Other)', 'Lane Marking - Text',
|
||||
'Lane Marking (only) - Dashed Line',
|
||||
'Lane Marking (only) - Crosswalk', 'Lane Marking (only) - Other',
|
||||
'Lane Marking (only) - Test', 'Mountain', 'Sand', 'Sky', 'Snow',
|
||||
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack',
|
||||
'Catch Basin', 'CCTV Camera', 'Fire Hydrant', 'Junction Box',
|
||||
'Mailbox', 'Manhole', 'Parking Meter', 'Phone Booth', 'Pothole',
|
||||
'Signage - Advertisement', 'Signage - Ambiguous', 'Signage - Back',
|
||||
'Signage - Information', 'Signage - Other', 'Signage - Store',
|
||||
'Street Light', 'Pole', 'Pole Group', 'Traffic Sign Frame',
|
||||
'Utility Pole', 'Traffic Cone', 'Traffic Light - General (Single)',
|
||||
'Traffic Light - Pedestrians', 'Traffic Light - General (Upright)',
|
||||
'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists',
|
||||
'Traffic Light - Other', 'Traffic Sign - Ambiguous',
|
||||
'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)',
|
||||
'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)',
|
||||
'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)',
|
||||
'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat',
|
||||
'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle',
|
||||
'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve',
|
||||
'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static',
|
||||
'Unlabeled'),
|
||||
palette=[[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32],
|
||||
[196, 196, 196], [190, 153, 153], [180, 165, 180],
|
||||
[90, 120, 150], [250, 170, 33], [250, 170, 34],
|
||||
[128, 128, 128], [250, 170, 35], [102, 102, 156],
|
||||
[128, 64, 255], [140, 140, 200], [170, 170, 170],
|
||||
[250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110],
|
||||
[110, 110, 110], [244, 35, 232], [128, 196,
|
||||
128], [150, 100, 100],
|
||||
[70, 70, 70], [150, 150, 150], [150, 120, 90], [220, 20, 60],
|
||||
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
|
||||
[255, 255, 255], [255, 255, 255], [250, 170, 29],
|
||||
[250, 170, 28], [250, 170, 26], [250, 170,
|
||||
25], [250, 170, 24],
|
||||
[250, 170, 22], [250, 170, 21], [250, 170,
|
||||
20], [255, 255, 255],
|
||||
[250, 170, 19], [250, 170, 18], [250, 170,
|
||||
12], [250, 170, 11],
|
||||
[255, 255, 255], [255, 255, 255], [250, 170, 16],
|
||||
[250, 170, 15], [250, 170, 15], [255, 255, 255],
|
||||
[255, 255, 255], [255, 255, 255], [255, 255, 255],
|
||||
[64, 170, 64], [230, 160, 50],
|
||||
[70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
|
||||
[100, 140, 180], [220, 128, 128], [222, 40,
|
||||
40], [100, 170, 30],
|
||||
[40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255],
|
||||
[142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30],
|
||||
[250, 173, 30], [250, 174, 30], [250, 175,
|
||||
30], [250, 176, 30],
|
||||
[210, 170, 100], [153, 153, 153], [153, 153, 153],
|
||||
[128, 128, 128], [0, 0, 80], [210, 60, 60], [250, 170, 30],
|
||||
[250, 170, 30], [250, 170, 30], [250, 170,
|
||||
30], [250, 170, 30],
|
||||
[250, 170, 30], [192, 192, 192], [192, 192, 192],
|
||||
[192, 192, 192], [220, 220, 0], [220, 220, 0], [0, 0, 196],
|
||||
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
|
||||
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
|
||||
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170],
|
||||
[32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81],
|
||||
[111, 111, 0], [0, 0, 0]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
28
Seg_All_In_One_MMSeg/mmseg/datasets/my_dataset.py
Normal file
28
Seg_All_In_One_MMSeg/mmseg/datasets/my_dataset.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MyDataset(BaseSegDataset): # 表示你定义的数据的名字,顺便取一个名字即可
|
||||
"""MyDataset dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=['背景', '肝脏', '胆囊', '分离钳', '止血海绵', '肝总管', '胆总管', '吸引器', '剪刀', '止血纱布', '生物夹', '无损伤钳', '喷洒', '胆囊管', '胆囊动脉', '电凝', '标本袋', '引流管', '纱布', '金属钛夹', '术中超声', '吻合器', '乳胶管', '推结器', '肝带', '钳夹', '超声刀', '脂肪', '双极电凝', '棉球', '血管阻断夹', '肿瘤', '针', '线', '韧带', '胆囊静脉'], # 背景最好放到第一个
|
||||
palette=[[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95], [0, 160, 233], [52, 184, 178], [90, 120, 41], [255, 0, 0], [177, 0, 0], [167, 24, 233], [112, 113, 150], [0, 255, 0], [255, 255, 255], [0, 255, 255], [138, 251, 213], [136, 162, 196], [197, 83, 181], [202, 202, 200], [113, 102, 140], [66, 115, 82], [240, 16, 116], [155, 132, 0], [155, 62, 0], [146, 175, 236], [255, 172, 159], [245, 161, 0], [134, 124, 118], [0, 157, 142], [181, 85, 105], [42, 8, 66]]) # TODO 标注类型和颜色
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png', # TODO mask图像类型
|
||||
seg_map_suffix='_gtFine_labelTrainIds.png', # TODO mask图像后缀
|
||||
reduce_zero_label=False, # TODO 在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
# assert fileio.exists(
|
||||
# self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
28
Seg_All_In_One_MMSeg/mmseg/datasets/my_dataset_model.py
Normal file
28
Seg_All_In_One_MMSeg/mmseg/datasets/my_dataset_model.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MyDataset_model(BaseSegDataset): # 表示你定义的数据的名字,顺便取一个名字即可
|
||||
"""MyDataset_model dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=['背景', '肝脏', '胆囊', '分离钳', '止血海绵', '肝总管', '胆总管', '吸引器', '剪刀', '止血纱布', '生物夹', '无损伤钳', '喷洒', '胆囊管', '胆囊动脉', '电凝', '标本袋', '引流管', '纱布', '金属钛夹', '术中超声', '吻合器', '乳胶管', '推结器', '肝带', '钳夹', '超声刀', '脂肪', '双极电凝', '棉球', '血管阻断夹', '肿瘤', '针', '线', '韧带', '胆囊静脉'], # 背景最好放到第一个
|
||||
palette=[[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95], [0, 160, 233], [52, 184, 178], [90, 120, 41], [255, 0, 0], [177, 0, 0], [167, 24, 233], [112, 113, 150], [0, 255, 0], [255, 255, 255], [0, 255, 255], [138, 251, 213], [136, 162, 196], [197, 83, 181], [202, 202, 200], [113, 102, 140], [66, 115, 82], [240, 16, 116], [155, 132, 0], [155, 62, 0], [146, 175, 236], [255, 172, 159], [245, 161, 0], [134, 124, 118], [0, 157, 142], [181, 85, 105], [42, 8, 66]]) # TODO 标注类型和颜色
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png', # TODO mask图像类型
|
||||
seg_map_suffix='_gtFine_labelTrainIds.png', # TODO mask图像后缀
|
||||
reduce_zero_label=False, # TODO 在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
# assert fileio.exists(
|
||||
# self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
15
Seg_All_In_One_MMSeg/mmseg/datasets/night_driving.py
Normal file
15
Seg_All_In_One_MMSeg/mmseg/datasets/night_driving.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .cityscapes import CityscapesDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class NightDrivingDataset(CityscapesDataset):
|
||||
"""NightDrivingDataset dataset."""
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='_leftImg8bit.png',
|
||||
seg_map_suffix='_gtCoarse_labelTrainIds.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
123
Seg_All_In_One_MMSeg/mmseg/datasets/nyu.py
Normal file
123
Seg_All_In_One_MMSeg/mmseg/datasets/nyu.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import List
|
||||
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class NYUDataset(BaseSegDataset):
|
||||
"""NYU depth estimation dataset. The file structure should be.
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
├── data
|
||||
│ ├── nyu
|
||||
│ │ ├── images
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── scene_xxx.jpg
|
||||
│ │ │ │ ├── ...
|
||||
│ │ │ ├── test
|
||||
│ │ ├── annotations
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── scene_xxx.png
|
||||
│ │ │ │ ├── ...
|
||||
│ │ │ ├── test
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
specify classes to load. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img_path='images', depth_map_path='annotations').
|
||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=True``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('printer_room', 'bathroom', 'living_room', 'study',
|
||||
'conference_room', 'study_room', 'kitchen', 'home_office',
|
||||
'bedroom', 'dinette', 'playroom', 'indoor_balcony',
|
||||
'laundry_room', 'basement', 'excercise_room', 'foyer',
|
||||
'home_storage', 'cafe', 'furniture_store', 'office_kitchen',
|
||||
'student_lounge', 'dining_room', 'reception_room',
|
||||
'computer_lab', 'classroom', 'office', 'bookstore'))
|
||||
|
||||
def __init__(self,
|
||||
data_prefix=dict(
|
||||
img_path='images', depth_map_path='annotations'),
|
||||
img_suffix='.jpg',
|
||||
depth_map_suffix='.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
data_prefix=data_prefix,
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=depth_map_suffix,
|
||||
**kwargs)
|
||||
|
||||
def _get_category_id_from_filename(self, image_fname: str) -> int:
|
||||
"""Retrieve the category ID from the given image filename."""
|
||||
image_fname = osp.basename(image_fname)
|
||||
position = image_fname.find(next(filter(str.isdigit, image_fname)), 0)
|
||||
categoty_name = image_fname[:position - 1]
|
||||
if categoty_name not in self._metainfo['classes']:
|
||||
return -1
|
||||
else:
|
||||
return self._metainfo['classes'].index(categoty_name)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
img_dir = self.data_prefix.get('img_path', None)
|
||||
ann_dir = self.data_prefix.get('depth_map_path', None)
|
||||
|
||||
_suffix_len = len(self.img_suffix)
|
||||
for img in fileio.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=self.img_suffix,
|
||||
recursive=True,
|
||||
backend_args=self.backend_args):
|
||||
data_info = dict(img_path=osp.join(img_dir, img))
|
||||
if ann_dir is not None:
|
||||
depth_map = img[:-_suffix_len] + self.seg_map_suffix
|
||||
data_info['depth_map_path'] = osp.join(ann_dir, depth_map)
|
||||
data_info['seg_fields'] = []
|
||||
data_info['category_id'] = self._get_category_id_from_filename(img)
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
116
Seg_All_In_One_MMSeg/mmseg/datasets/pascal_context.py
Normal file
116
Seg_All_In_One_MMSeg/mmseg/datasets/pascal_context.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class PascalContextDataset(BaseSegDataset):
|
||||
"""PascalContext dataset.
|
||||
|
||||
In segmentation map annotation for PascalContext, 0 stands for background,
|
||||
which is included in 60 categories. ``reduce_zero_label`` is fixed to
|
||||
False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
||||
fixed to '.png'.
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
"""
|
||||
|
||||
METAINFO = dict(
|
||||
classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes',
|
||||
'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle',
|
||||
'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling',
|
||||
'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog',
|
||||
'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',
|
||||
'horse', 'keyboard', 'light', 'motorbike', 'mountain',
|
||||
'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road',
|
||||
'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow',
|
||||
'sofa', 'table', 'track', 'train', 'tree', 'truck',
|
||||
'tvmonitor', 'wall', 'water', 'window', 'wood'),
|
||||
palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
|
||||
|
||||
def __init__(self,
|
||||
ann_file='',
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
ann_file=ann_file,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
assert fileio.exists(self.data_prefix['img_path'], self.backend_args)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class PascalContextDataset59(BaseSegDataset):
|
||||
"""PascalContext dataset.
|
||||
|
||||
In segmentation map annotation for PascalContext, 0 stands for background,
|
||||
which is included in 60 categories. ``reduce_zero_label`` is fixed to
|
||||
True. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
||||
fixed to '.png'.
|
||||
Noted: If the background is 255 and the ids of categories are from 0 to 58,
|
||||
``reduce_zero_label`` needs to be set to False.
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
|
||||
'bird', 'boat', 'book', 'bottle', 'building', 'bus',
|
||||
'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
|
||||
'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
|
||||
'floor', 'flower', 'food', 'grass', 'ground', 'horse',
|
||||
'keyboard', 'light', 'motorbike', 'mountain', 'mouse',
|
||||
'person', 'plate', 'platform', 'pottedplant', 'road', 'rock',
|
||||
'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa',
|
||||
'table', 'track', 'train', 'tree', 'truck', 'tvmonitor',
|
||||
'wall', 'water', 'window', 'wood'),
|
||||
palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
|
||||
[120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
|
||||
|
||||
def __init__(self,
|
||||
ann_file='',
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=True,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
ann_file=ann_file,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
assert fileio.exists(self.data_prefix['img_path'], self.backend_args)
|
||||
29
Seg_All_In_One_MMSeg/mmseg/datasets/potsdam.py
Normal file
29
Seg_All_In_One_MMSeg/mmseg/datasets/potsdam.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class PotsdamDataset(BaseSegDataset):
|
||||
"""ISPRS Potsdam dataset.
|
||||
|
||||
In segmentation map annotation for Potsdam dataset, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('impervious_surface', 'building', 'low_vegetation', 'tree',
|
||||
'car', 'clutter'),
|
||||
palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||
[255, 255, 0], [255, 0, 0]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class PublicDataSet_AutoLaparo(BaseSegDataset): # 表示你定义的数据的名字,顺便取一个名字即可
|
||||
"""PublicDataSet_AutoLaparo dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=['背景', '1', '2', '3', '4', '5', '6', '7', '8', '9'], # 背景最好放到第一个
|
||||
palette=[[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95]]) # TODO 标注类型和颜色
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png', # TODO mask图像类型
|
||||
seg_map_suffix='.png', # TODO mask图像后缀
|
||||
reduce_zero_label=False, # TODO 在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
# assert fileio.exists(
|
||||
# self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class PublicDataSet_CholecSeg8k(BaseSegDataset): # 表示你定义的数据的名字,顺便取一个名字即可
|
||||
"""PublicDataSet_CholecSeg8k dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=['背景', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12'], # 背景最好放到第一个
|
||||
palette=[[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95], [0, 160, 233], [52, 184, 178], [90, 120, 41]]) # TODO 标注类型和颜色
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png', # TODO mask图像类型
|
||||
seg_map_suffix='.png', # TODO mask图像后缀
|
||||
reduce_zero_label=False, # TODO 在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
# assert fileio.exists(
|
||||
# self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
28
Seg_All_In_One_MMSeg/mmseg/datasets/publicdataset_dresden.py
Normal file
28
Seg_All_In_One_MMSeg/mmseg/datasets/publicdataset_dresden.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class PublicDataSet_Dresden(BaseSegDataset): # 表示你定义的数据的名字,顺便取一个名字即可
|
||||
"""PublicDataSet_Dresden dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=['背景', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'], # 背景最好放到第一个
|
||||
palette=[[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95], [0, 160, 233]]) # TODO 标注类型和颜色
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png', # TODO mask图像类型
|
||||
seg_map_suffix='.png', # TODO mask图像后缀
|
||||
reduce_zero_label=False, # TODO 在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
# assert fileio.exists(
|
||||
# self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class PublicDataSet_Endovis_2017(BaseSegDataset): # 表示你定义的数据的名字,顺便取一个名字即可
|
||||
"""PublicDataSet_Endovis_2017 dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=['背景', '1', '2', '3', '4', '5', '6', '7'], # 背景最好放到第一个
|
||||
palette=[[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255]]) # TODO 标注类型和颜色
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.bmp', # TODO mask图像类型
|
||||
seg_map_suffix='.bmp', # TODO mask图像后缀
|
||||
reduce_zero_label=False, # TODO 在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
# assert fileio.exists(
|
||||
# self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class PublicDataSet_Endovis_2018(BaseSegDataset): # 表示你定义的数据的名字,顺便取一个名字即可
|
||||
"""PublicDataSet_Endovis_2018 dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=['背景', '1', '2', '3', '4', '5', '6', '7'], # 背景最好放到第一个
|
||||
palette=[[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255]]) # TODO 标注类型和颜色
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.bmp', # TODO mask图像类型
|
||||
seg_map_suffix='.bmp', # TODO mask图像后缀
|
||||
reduce_zero_label=False, # TODO 在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
# assert fileio.exists(
|
||||
# self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
28
Seg_All_In_One_MMSeg/mmseg/datasets/refuge.py
Normal file
28
Seg_All_In_One_MMSeg/mmseg/datasets/refuge.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class REFUGEDataset(BaseSegDataset):
|
||||
"""REFUGE dataset.
|
||||
|
||||
In segmentation map annotation for REFUGE, 0 stands for background, which
|
||||
is not included in 2 categories. ``reduce_zero_label`` is fixed to True.
|
||||
The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||
'.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', ' Optic Cup', 'Optic Disc'),
|
||||
palette=[[120, 120, 120], [6, 230, 230], [56, 59, 120]])
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs)
|
||||
assert fileio.exists(
|
||||
self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
32
Seg_All_In_One_MMSeg/mmseg/datasets/stare.py
Normal file
32
Seg_All_In_One_MMSeg/mmseg/datasets/stare.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class STAREDataset(BaseSegDataset):
|
||||
"""STARE dataset.
|
||||
|
||||
In segmentation map annotation for STARE, 0 stands for background, which is
|
||||
included in 2 categories. ``reduce_zero_label`` is fixed to False. The
|
||||
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||
'.ah.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', 'vessel'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.ah.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
assert fileio.exists(
|
||||
self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
28
Seg_All_In_One_MMSeg/mmseg/datasets/synapse.py
Normal file
28
Seg_All_In_One_MMSeg/mmseg/datasets/synapse.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class SynapseDataset(BaseSegDataset):
|
||||
"""Synapse dataset.
|
||||
|
||||
Before dataset preprocess of Synapse, there are total 13 categories of
|
||||
foreground which does not include background. After preprocessing, 8
|
||||
foreground categories are kept while the other 5 foreground categories are
|
||||
handled as background. The ``img_suffix`` is fixed to '.jpg' and
|
||||
``seg_map_suffix`` is fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', 'aorta', 'gallbladder', 'left_kidney',
|
||||
'right_kidney', 'liver', 'pancreas', 'spleen', 'stomach'),
|
||||
palette=[[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0],
|
||||
[0, 255, 255], [255, 0, 255], [255, 255, 0], [60, 255, 255],
|
||||
[240, 240, 240]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
30
Seg_All_In_One_MMSeg/mmseg/datasets/transforms/__init__.py
Normal file
30
Seg_All_In_One_MMSeg/mmseg/datasets/transforms/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .formatting import PackSegInputs
|
||||
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadDepthAnnotation, LoadImageFromNDArray,
|
||||
LoadMultipleRSImageFromFile, LoadSingleRSImageFromFile)
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomDepthMix, RandomFlip, RandomMosaic,
|
||||
RandomRotate, RandomRotFlip, Rerange, Resize,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
|
||||
# yapf: enable
|
||||
__all__ = [
|
||||
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
|
||||
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
|
||||
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
||||
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
|
||||
'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
|
||||
'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix',
|
||||
'RandomFlip', 'Resize'
|
||||
]
|
||||
112
Seg_All_In_One_MMSeg/mmseg/datasets/transforms/formatting.py
Normal file
112
Seg_All_In_One_MMSeg/mmseg/datasets/transforms/formatting.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from mmcv.transforms import to_tensor
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class PackSegInputs(BaseTransform):
|
||||
"""Pack the inputs data for the semantic segmentation.
|
||||
|
||||
The ``img_meta`` item is always populated. The contents of the
|
||||
``img_meta`` dictionary depends on ``meta_keys``. By default this includes:
|
||||
|
||||
- ``img_path``: filename of the image
|
||||
|
||||
- ``ori_shape``: original shape of the image as a tuple (h, w, c)
|
||||
|
||||
- ``img_shape``: shape of the image input to the network as a tuple \
|
||||
(h, w, c). Note that images may be zero padded on the \
|
||||
bottom/right if the batch tensor is larger than this shape.
|
||||
|
||||
- ``pad_shape``: shape of padded images
|
||||
|
||||
- ``scale_factor``: a float indicating the preprocessing scale
|
||||
|
||||
- ``flip``: a boolean indicating if image flip transform was used
|
||||
|
||||
- ``flip_direction``: the flipping direction
|
||||
|
||||
Args:
|
||||
meta_keys (Sequence[str], optional): Meta keys to be packed from
|
||||
``SegDataSample`` and collected in ``data[img_metas]``.
|
||||
Default: ``('img_path', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction')``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
meta_keys=('img_path', 'seg_map_path', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'reduce_zero_label')):
|
||||
self.meta_keys = meta_keys
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Method to pack the input data.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from the data pipeline.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
|
||||
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
|
||||
- 'data_sample' (obj:`SegDataSample`): The annotation info of the
|
||||
sample.
|
||||
"""
|
||||
packed_results = dict()
|
||||
if 'img' in results:
|
||||
img = results['img']
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
if not img.flags.c_contiguous:
|
||||
img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1)))
|
||||
else:
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = to_tensor(img).contiguous()
|
||||
packed_results['inputs'] = img
|
||||
|
||||
data_sample = SegDataSample()
|
||||
if 'gt_seg_map' in results:
|
||||
if len(results['gt_seg_map'].shape) == 2:
|
||||
data = to_tensor(results['gt_seg_map'][None,
|
||||
...].astype(np.int64))
|
||||
else:
|
||||
warnings.warn('Please pay attention your ground truth '
|
||||
'segmentation map, usually the segmentation '
|
||||
'map is 2D, but got '
|
||||
f'{results["gt_seg_map"].shape}')
|
||||
data = to_tensor(results['gt_seg_map'].astype(np.int64))
|
||||
gt_sem_seg_data = dict(data=data)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
if 'gt_edge_map' in results:
|
||||
gt_edge_data = dict(
|
||||
data=to_tensor(results['gt_edge_map'][None,
|
||||
...].astype(np.int64)))
|
||||
data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data)))
|
||||
|
||||
if 'gt_depth_map' in results:
|
||||
gt_depth_data = dict(
|
||||
data=to_tensor(results['gt_depth_map'][None, ...]))
|
||||
data_sample.set_data(dict(gt_depth_map=PixelData(**gt_depth_data)))
|
||||
|
||||
img_meta = {}
|
||||
for key in self.meta_keys:
|
||||
if key in results:
|
||||
img_meta[key] = results[key]
|
||||
data_sample.set_metainfo(img_meta)
|
||||
packed_results['data_samples'] = data_sample
|
||||
|
||||
return packed_results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(meta_keys={self.meta_keys})'
|
||||
return repr_str
|
||||
771
Seg_All_In_One_MMSeg/mmseg/datasets/transforms/loading.py
Normal file
771
Seg_All_In_One_MMSeg/mmseg/datasets/transforms/loading.py
Normal file
@@ -0,0 +1,771 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine.fileio as fileio
|
||||
import numpy as np
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import datafrombytes
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
"""Load annotations for semantic segmentation provided by dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# Filename of semantic segmentation ground truth file.
|
||||
'seg_map_path': 'a/b/c'
|
||||
}
|
||||
|
||||
After this module, the annotation has been changed to the format below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# in str
|
||||
'seg_fields': List
|
||||
# In uint8 type.
|
||||
'gt_seg_map': np.ndarray (H, W)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
||||
|
||||
Added Keys:
|
||||
|
||||
- seg_fields (List)
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||
by 1. Usually used for datasets where 0 is background label.
|
||||
Defaults to None.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'pillow'.
|
||||
backend_args (dict): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=None,
|
||||
backend_args=None,
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
super().__init__(
|
||||
with_bbox=False,
|
||||
with_label=False,
|
||||
with_seg=True,
|
||||
with_keypoints=False,
|
||||
imdecode_backend=imdecode_backend,
|
||||
backend_args=backend_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
if self.reduce_zero_label is not None:
|
||||
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||
'if you would like to ignore the zero label, please '
|
||||
'set `reduce_zero_label=True` when dataset '
|
||||
'initialized')
|
||||
self.imdecode_backend = imdecode_backend
|
||||
|
||||
def _load_seg_map(self, results: dict) -> None:
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded semantic segmentation annotations.
|
||||
"""
|
||||
|
||||
img_bytes = fileio.get(
|
||||
results['seg_map_path'], backend_args=self.backend_args)
|
||||
gt_semantic_seg = mmcv.imfrombytes(
|
||||
img_bytes, flag='unchanged',
|
||||
backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||
'Initialize dataset with `reduce_zero_label` as ' \
|
||||
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
||||
repr_str += f'backend_args={self.backend_args})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNDArray(LoadImageFromFile):
|
||||
"""Load an image from ``results['img']``.
|
||||
|
||||
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
|
||||
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
|
||||
from webcam.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_path
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to add image meta information.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict with Webcam read image in
|
||||
``results['img']``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
img = results['img']
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img_path'] = None
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalImageFromFile(BaseTransform):
|
||||
"""Load an biomedical mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||
N is the number of modalities, and data type is float32
|
||||
if set to_float32 = True, or float64 if decode_backend is 'nifti' and
|
||||
to_float32 is False.
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an float64 array.
|
||||
Defaults to True.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'nifti',
|
||||
to_xyz: bool = False,
|
||||
to_float32: bool = True,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.to_float32 = to_float32
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
|
||||
data_bytes = fileio.get(filename, self.backend_args)
|
||||
img = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
if len(img.shape) == 3:
|
||||
img = img[None, ...]
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
if self.to_xyz:
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[1:]
|
||||
results['ori_shape'] = img.shape[1:]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalAnnotation(BaseTransform):
|
||||
"""Load ``seg_map`` annotation provided by biomedical dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'gt_seg_map': np.ndarray (X, Y, Z) or (Z, Y, X)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- gt_seg_map (np.ndarray): Biomedical seg map with shape (Z, Y, X) by
|
||||
default, and data type is float32 if set to_float32 = True, or
|
||||
float64 if decode_backend is 'nifti' and to_float32 is False.
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
to_float32 (bool): Whether to convert the loaded seg map to a float32
|
||||
numpy array. If set to False, the loaded image is an float64 array.
|
||||
Defaults to True.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See :class:`mmengine.fileio` for details.
|
||||
Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'nifti',
|
||||
to_xyz: bool = False,
|
||||
to_float32: bool = True,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
super().__init__()
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.to_float32 = to_float32
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
data_bytes = fileio.get(results['seg_map_path'], self.backend_args)
|
||||
gt_seg_map = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
gt_seg_map = gt_seg_map.astype(np.float32)
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
if self.to_xyz:
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
results['gt_seg_map'] = gt_seg_map
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalData(BaseTransform):
|
||||
"""Load an biomedical image and annotation from file.
|
||||
|
||||
The loading data format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'img': np.ndarray data[:-1, X, Y, Z]
|
||||
'seg_map': np.ndarray data[-1, X, Y, Z]
|
||||
}
|
||||
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||
N is the number of modalities.
|
||||
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
|
||||
(Z, Y, X) by default.
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
with_seg (bool): Whether to parse and load the semantic segmentation
|
||||
annotation. Defaults to False.
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
with_seg=False,
|
||||
decode_backend: str = 'numpy',
|
||||
to_xyz: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None: # noqa
|
||||
self.with_seg = with_seg
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
data_bytes = fileio.get(results['img_path'], self.backend_args)
|
||||
data = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
# img is 4D data (N, X, Y, Z), N is the number of protocol
|
||||
img = data[:-1, :]
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
if self.to_xyz:
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[1:]
|
||||
results['ori_shape'] = img.shape[1:]
|
||||
|
||||
if self.with_seg:
|
||||
gt_seg_map = data[-1, :]
|
||||
if self.decode_backend == 'nifti':
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
if self.to_xyz:
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
results['gt_seg_map'] = gt_seg_map
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'with_seg={self.with_seg}, '
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class InferencerLoader(BaseTransform):
|
||||
"""Load an image from ``results['img']``.
|
||||
|
||||
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
|
||||
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
|
||||
from webcam.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_path
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.from_file = TRANSFORMS.build(
|
||||
dict(type='LoadImageFromFile', **kwargs))
|
||||
self.from_ndarray = TRANSFORMS.build(
|
||||
dict(type='LoadImageFromNDArray', **kwargs))
|
||||
|
||||
def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict:
|
||||
"""Transform function to add image meta information.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict with Webcam read image in
|
||||
``results['img']``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
if isinstance(single_input, str):
|
||||
inputs = dict(img_path=single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
inputs = dict(img=single_input)
|
||||
elif isinstance(single_input, dict):
|
||||
inputs = single_input
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'img' in inputs:
|
||||
return self.from_ndarray(inputs)
|
||||
return self.from_file(inputs)
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadSingleRSImageFromFile(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
ds = gdal.Open(filename)
|
||||
if ds is None:
|
||||
raise Exception(f'Unable to open file: {filename}')
|
||||
img = np.einsum('ijk->jki', ds.ReadAsArray())
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadMultipleRSImageFromFile(BaseTransform):
|
||||
"""Load two Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
- img_path2
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img2
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
self.to_float32 = to_float32
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
filename2 = results['img_path2']
|
||||
|
||||
ds = gdal.Open(filename)
|
||||
ds2 = gdal.Open(filename2)
|
||||
|
||||
if ds is None:
|
||||
raise Exception(f'Unable to open file: {filename}')
|
||||
if ds2 is None:
|
||||
raise Exception(f'Unable to open file: {filename2}')
|
||||
|
||||
img = np.einsum('ijk->jki', ds.ReadAsArray())
|
||||
img2 = np.einsum('ijk->jki', ds2.ReadAsArray())
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
img2 = img2.astype(np.float32)
|
||||
|
||||
if img.shape != img2.shape:
|
||||
raise Exception(f'Image shapes do not match:'
|
||||
f' {img.shape} vs {img2.shape}')
|
||||
|
||||
results['img'] = img
|
||||
results['img2'] = img2
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadDepthAnnotation(BaseTransform):
|
||||
"""Load ``depth_map`` annotation provided by depth estimation dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'gt_depth_map': np.ndarray [Y, X]
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_depth_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- gt_depth_map (np.ndarray): Depth map with shape (Y, X) by
|
||||
default, and data type is float32 if set to_float32 = True.
|
||||
- depth_rescale_factor (float): The rescale factor of depth map, which
|
||||
can be used to recover the original value of depth map.
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy', 'nifti', and 'cv2'. Defaults to 'cv2'.
|
||||
to_float32 (bool): Whether to convert the loaded depth map to a float32
|
||||
numpy array. If set to False, the loaded image is an uint16 array.
|
||||
Defaults to True.
|
||||
depth_rescale_factor (float): Factor to rescale the depth value to
|
||||
limit the range. Defaults to 1.0.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See :class:`mmengine.fileio` for details.
|
||||
Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'cv2',
|
||||
to_float32: bool = True,
|
||||
depth_rescale_factor: float = 1.0,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
super().__init__()
|
||||
self.decode_backend = decode_backend
|
||||
self.to_float32 = to_float32
|
||||
self.depth_rescale_factor = depth_rescale_factor
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load depth map.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded depth map.
|
||||
"""
|
||||
data_bytes = fileio.get(results['depth_map_path'], self.backend_args)
|
||||
gt_depth_map = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
gt_depth_map = gt_depth_map.astype(np.float32)
|
||||
|
||||
gt_depth_map *= self.depth_rescale_factor
|
||||
results['gt_depth_map'] = gt_depth_map
|
||||
results['seg_fields'].append('gt_depth_map')
|
||||
results['depth_rescale_factor'] = self.depth_rescale_factor
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNpyFile(LoadImageFromFile):
|
||||
"""Load an image from ``results['img_path']``.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def transform(self, results: dict) -> Optional[dict]:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from
|
||||
:class:`mmengine.dataset.BaseDataset`.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
|
||||
try:
|
||||
if Path(filename).suffix in ['.npy', '.npz']:
|
||||
img = np.load(filename)
|
||||
else:
|
||||
if self.file_client_args is not None:
|
||||
file_client = fileio.FileClient.infer_client(
|
||||
self.file_client_args, filename)
|
||||
img_bytes = file_client.get(filename)
|
||||
else:
|
||||
img_bytes = fileio.get(
|
||||
filename, backend_args=self.backend_args)
|
||||
img = mmcv.imfrombytes(
|
||||
img_bytes,
|
||||
flag=self.color_type,
|
||||
backend=self.imdecode_backend)
|
||||
except Exception as e:
|
||||
if self.ignore_empty:
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
|
||||
# in some cases, images are not read successfully, the img would be
|
||||
# `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427
|
||||
assert img is not None, f'failed to load image: {filename}'
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
2537
Seg_All_In_One_MMSeg/mmseg/datasets/transforms/transforms.py
Normal file
2537
Seg_All_In_One_MMSeg/mmseg/datasets/transforms/transforms.py
Normal file
File diff suppressed because it is too large
Load Diff
40
Seg_All_In_One_MMSeg/mmseg/datasets/voc.py
Normal file
40
Seg_All_In_One_MMSeg/mmseg/datasets/voc.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class PascalVOCDataset(BaseSegDataset):
|
||||
"""Pascal VOC dataset.
|
||||
|
||||
Args:
|
||||
split (str): Split txt file for Pascal VOC.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', 'aeroplane', 'bicycle', 'bird', 'boat',
|
||||
'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
|
||||
'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
|
||||
'sofa', 'train', 'tvmonitor'),
|
||||
palette=[[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
|
||||
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
|
||||
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
|
||||
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
|
||||
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
|
||||
[0, 64, 128]])
|
||||
|
||||
def __init__(self,
|
||||
ann_file,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
ann_file=ann_file,
|
||||
**kwargs)
|
||||
assert fileio.exists(self.data_prefix['img_path'],
|
||||
self.backend_args) and osp.isfile(self.ann_file)
|
||||
12
Seg_All_In_One_MMSeg/mmseg/engine/__init__.py
Normal file
12
Seg_All_In_One_MMSeg/mmseg/engine/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hooks import SegVisualizationHook
|
||||
from .optimizers import (ForceDefaultOptimWrapperConstructor,
|
||||
LayerDecayOptimizerConstructor,
|
||||
LearningRateDecayOptimizerConstructor)
|
||||
from .schedulers import PolyLRRatio
|
||||
|
||||
__all__ = [
|
||||
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
|
||||
'SegVisualizationHook', 'PolyLRRatio',
|
||||
'ForceDefaultOptimWrapperConstructor'
|
||||
]
|
||||
4
Seg_All_In_One_MMSeg/mmseg/engine/hooks/__init__.py
Normal file
4
Seg_All_In_One_MMSeg/mmseg/engine/hooks/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .visualization_hook import SegVisualizationHook
|
||||
|
||||
__all__ = ['SegVisualizationHook']
|
||||
129
Seg_All_In_One_MMSeg/mmseg/engine/hooks/visualization_hook.py
Normal file
129
Seg_All_In_One_MMSeg/mmseg/engine/hooks/visualization_hook.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import mmcv
|
||||
from mmengine.fileio import get
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
from mmseg.registry import HOOKS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class SegVisualizationHook(Hook):
|
||||
"""Segmentation Visualization Hook. Used to visualize validation and
|
||||
testing process prediction results.
|
||||
|
||||
In the testing phase:
|
||||
|
||||
1. If ``show`` is True, it means that only the prediction results are
|
||||
visualized without storing data, so ``vis_backends`` needs to
|
||||
be excluded.
|
||||
|
||||
Args:
|
||||
draw (bool): whether to draw prediction results. If it is False,
|
||||
it means that no drawing will be done. Defaults to False.
|
||||
interval (int): The interval of visualization. Defaults to 50.
|
||||
show (bool): Whether to display the drawn image. Default to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
draw: bool = False,
|
||||
interval: int = 50,
|
||||
show: bool = False,
|
||||
wait_time: float = 0.,
|
||||
backend_args: Optional[dict] = None):
|
||||
self._visualizer: Visualizer = Visualizer.get_current_instance()
|
||||
self.interval = interval
|
||||
self.show = show
|
||||
if self.show:
|
||||
# No need to think about vis backends.
|
||||
self._visualizer._vis_backends = {}
|
||||
warnings.warn('The show is True, it means that only '
|
||||
'the prediction results are visualized '
|
||||
'without storing data, so vis_backends '
|
||||
'needs to be excluded.')
|
||||
|
||||
self.wait_time = wait_time
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
self.draw = draw
|
||||
if not self.draw:
|
||||
warnings.warn('The draw is False, it means that the '
|
||||
'hook for visualization will not take '
|
||||
'effect. The results will NOT be '
|
||||
'visualized or stored.')
|
||||
self._test_index = 0
|
||||
|
||||
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
||||
outputs: Sequence[SegDataSample]) -> None:
|
||||
"""Run after every ``self.interval`` validation iterations.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples
|
||||
that contain annotations and predictions.
|
||||
"""
|
||||
if self.draw is False:
|
||||
return
|
||||
|
||||
# There is no guarantee that the same batch of images
|
||||
# is visualized for each evaluation.
|
||||
total_curr_iter = runner.iter + batch_idx
|
||||
|
||||
# Visualize only the first data
|
||||
img_path = outputs[0].img_path
|
||||
img_bytes = get(img_path, backend_args=self.backend_args)
|
||||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
||||
window_name = f'val_{osp.basename(img_path)}'
|
||||
|
||||
if total_curr_iter % self.interval == 0:
|
||||
self._visualizer.add_datasample(
|
||||
window_name,
|
||||
img,
|
||||
data_sample=outputs[0],
|
||||
show=self.show,
|
||||
wait_time=self.wait_time,
|
||||
step=total_curr_iter)
|
||||
|
||||
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
||||
outputs: Sequence[SegDataSample]) -> None:
|
||||
"""Run after every testing iterations.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the testing process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples
|
||||
that contain annotations and predictions.
|
||||
"""
|
||||
if self.draw is False:
|
||||
return
|
||||
|
||||
for data_sample in outputs:
|
||||
self._test_index += 1
|
||||
|
||||
img_path = data_sample.img_path
|
||||
window_name = f'test_{osp.basename(img_path)}'
|
||||
|
||||
img_path = data_sample.img_path
|
||||
img_bytes = get(img_path, backend_args=self.backend_args)
|
||||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
||||
|
||||
self._visualizer.add_datasample(
|
||||
window_name,
|
||||
img,
|
||||
data_sample=data_sample,
|
||||
show=self.show,
|
||||
wait_time=self.wait_time,
|
||||
step=self._test_index)
|
||||
9
Seg_All_In_One_MMSeg/mmseg/engine/optimizers/__init__.py
Normal file
9
Seg_All_In_One_MMSeg/mmseg/engine/optimizers/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .force_default_constructor import ForceDefaultOptimWrapperConstructor
|
||||
from .layer_decay_optimizer_constructor import (
|
||||
LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)
|
||||
|
||||
__all__ = [
|
||||
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
|
||||
'ForceDefaultOptimWrapperConstructor'
|
||||
]
|
||||
@@ -0,0 +1,255 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||
from mmengine.utils.dl_utils import mmcv_full_available
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
from torch.nn import GroupNorm, LayerNorm
|
||||
|
||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class ForceDefaultOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
||||
"""Default constructor with forced optimizer settings.
|
||||
|
||||
This constructor extends the default constructor to add an option for
|
||||
forcing default optimizer settings. This is useful for ensuring that
|
||||
certain parameters or layers strictly adhere to pre-defined default
|
||||
settings, regardless of any custom settings specified.
|
||||
|
||||
By default, each parameter share the same optimizer settings, and we
|
||||
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
|
||||
It is a dict and may contain various fields like 'custom_keys',
|
||||
'bias_lr_mult', etc., as well as the additional field
|
||||
`force_default_settings` which allows for enforcing default settings on
|
||||
optimizer parameters.
|
||||
|
||||
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
|
||||
one of the keys in ``custom_keys`` is a substring of the name of one
|
||||
parameter, then the setting of the parameter will be specified by
|
||||
``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
|
||||
be ignored. It should be noted that the aforementioned ``key`` is the
|
||||
longest key that is a substring of the name of the parameter. If there
|
||||
are multiple matched keys with the same length, then the key with lower
|
||||
alphabet order will be chosen.
|
||||
``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
|
||||
and ``decay_mult``. See Example 2 below.
|
||||
- ``bias_lr_mult`` (float): It will be multiplied to the learning
|
||||
rate for all bias parameters (except for those in normalization
|
||||
layers and offset layers of DCN).
|
||||
- ``bias_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all bias parameters (except for those in
|
||||
normalization layers, depthwise conv layers, offset layers of DCN).
|
||||
- ``norm_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all weight and bias parameters of normalization
|
||||
layers.
|
||||
- ``flat_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all one-dimensional parameters
|
||||
- ``dwconv_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all weight and bias parameters of depthwise conv
|
||||
layers.
|
||||
- ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
|
||||
rate for parameters of offset layer in the deformable convs
|
||||
of a model.
|
||||
- ``bypass_duplicate`` (bool): If true, the duplicate parameters
|
||||
would not be added into optimizer. Defaults to False.
|
||||
- ``force_default_settings`` (bool): If true, this will override any
|
||||
custom settings defined by ``custom_keys`` and enforce the use of
|
||||
default settings for optimizer parameters like ``bias_lr_mult``.
|
||||
This is particularly useful when you want to ensure that certain layers
|
||||
or parameters adhere strictly to the pre-defined default settings.
|
||||
|
||||
Note:
|
||||
|
||||
1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
|
||||
override the effect of ``bias_lr_mult`` in the bias of offset layer.
|
||||
So be careful when using both ``bias_lr_mult`` and
|
||||
``dcn_offset_lr_mult``. If you wish to apply both of them to the offset
|
||||
layer in deformable convs, set ``dcn_offset_lr_mult`` to the original
|
||||
``dcn_offset_lr_mult`` * ``bias_lr_mult``.
|
||||
|
||||
2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
|
||||
apply it to all the DCN layers in the model. So be careful when the
|
||||
model contains multiple DCN layers in places other than backbone.
|
||||
|
||||
3. When the option ``force_default_settings`` is true, it will override
|
||||
any custom settings provided in ``custom_keys``. This ensures that the
|
||||
default settings for the optimizer parameters are used.
|
||||
|
||||
Args:
|
||||
optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
|
||||
|
||||
Required fields of ``optim_wrapper_cfg`` are
|
||||
|
||||
- ``type``: class name of the OptimizerWrapper
|
||||
- ``optimizer``: The configuration of optimizer.
|
||||
|
||||
Optional fields of ``optim_wrapper_cfg`` are
|
||||
|
||||
- any arguments of the corresponding optimizer wrapper type,
|
||||
e.g., accumulative_counts, clip_grad, etc.
|
||||
|
||||
Required fields of ``optimizer`` are
|
||||
|
||||
- `type`: class name of the optimizer.
|
||||
|
||||
Optional fields of ``optimizer`` are
|
||||
|
||||
- any arguments of the corresponding optimizer type, e.g.,
|
||||
lr, weight_decay, momentum, etc.
|
||||
|
||||
paramwise_cfg (dict, optional): Parameter-wise options.
|
||||
|
||||
Example 1:
|
||||
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
||||
>>> optim_wrapper_cfg = dict(
|
||||
>>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01,
|
||||
>>> momentum=0.9, weight_decay=0.0001))
|
||||
>>> paramwise_cfg = dict(norm_decay_mult=0.)
|
||||
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
|
||||
>>> optim_wrapper_cfg, paramwise_cfg)
|
||||
>>> optim_wrapper = optim_wrapper_builder(model)
|
||||
|
||||
Example 2:
|
||||
>>> # assume model have attribute model.backbone and model.cls_head
|
||||
>>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict(
|
||||
>>> type='SGD', lr=0.01, weight_decay=0.95))
|
||||
>>> paramwise_cfg = dict(custom_keys={
|
||||
>>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)})
|
||||
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
|
||||
>>> optim_wrapper_cfg, paramwise_cfg)
|
||||
>>> optim_wrapper = optim_wrapper_builder(model)
|
||||
>>> # Then the `lr` and `weight_decay` for model.backbone is
|
||||
>>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
|
||||
>>> # model.cls_head is (0.01, 0.95).
|
||||
"""
|
||||
|
||||
def add_params(self,
|
||||
params: List[dict],
|
||||
module: nn.Module,
|
||||
prefix: str = '',
|
||||
is_dcn_module: Optional[Union[int, float]] = None) -> None:
|
||||
"""Add all parameters of module to the params list.
|
||||
|
||||
The parameters of the given module will be added to the list of param
|
||||
groups, with specific rules defined by paramwise_cfg.
|
||||
|
||||
Args:
|
||||
params (list[dict]): A list of param groups, it will be modified
|
||||
in place.
|
||||
module (nn.Module): The module to be added.
|
||||
prefix (str): The prefix of the module
|
||||
is_dcn_module (int|float|None): If the current module is a
|
||||
submodule of DCN, `is_dcn_module` will be passed to
|
||||
control conv_offset layer's learning rate. Defaults to None.
|
||||
"""
|
||||
# get param-wise options
|
||||
custom_keys = self.paramwise_cfg.get('custom_keys', {})
|
||||
# first sort with alphabet order and then sort with reversed len of str
|
||||
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
|
||||
|
||||
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', None)
|
||||
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
|
||||
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
|
||||
dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', None)
|
||||
flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
|
||||
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
|
||||
dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', None)
|
||||
force_default_settings = self.paramwise_cfg.get(
|
||||
'force_default_settings', False)
|
||||
|
||||
# special rules for norm layers and depth-wise conv layers
|
||||
is_norm = isinstance(module,
|
||||
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
|
||||
is_dwconv = (
|
||||
isinstance(module, torch.nn.Conv2d)
|
||||
and module.in_channels == module.groups)
|
||||
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
param_group = {'params': [param]}
|
||||
if bypass_duplicate and self._is_in(param_group, params):
|
||||
print_log(
|
||||
f'{prefix} is duplicate. It is skipped since '
|
||||
f'bypass_duplicate={bypass_duplicate}',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
continue
|
||||
if not param.requires_grad:
|
||||
params.append(param_group)
|
||||
continue
|
||||
|
||||
# if the parameter match one of the custom keys, ignore other rules
|
||||
is_custom = False
|
||||
for key in sorted_keys:
|
||||
if key in f'{prefix}.{name}':
|
||||
is_custom = True
|
||||
lr_mult = custom_keys[key].get('lr_mult', 1.)
|
||||
param_group['lr'] = self.base_lr * lr_mult
|
||||
if self.base_wd is not None:
|
||||
decay_mult = custom_keys[key].get('decay_mult', 1.)
|
||||
param_group['weight_decay'] = self.base_wd * decay_mult
|
||||
# add custom settings to param_group
|
||||
for k, v in custom_keys[key].items():
|
||||
param_group[k] = v
|
||||
break
|
||||
|
||||
if not is_custom or force_default_settings:
|
||||
# bias_lr_mult affects all bias parameters
|
||||
# except for norm.bias dcn.conv_offset.bias
|
||||
if name == 'bias' and not (
|
||||
is_norm or is_dcn_module) and bias_lr_mult is not None:
|
||||
param_group['lr'] = self.base_lr * bias_lr_mult
|
||||
|
||||
if (prefix.find('conv_offset') != -1 and is_dcn_module
|
||||
and dcn_offset_lr_mult is not None
|
||||
and isinstance(module, torch.nn.Conv2d)):
|
||||
# deal with both dcn_offset's bias & weight
|
||||
param_group['lr'] = self.base_lr * dcn_offset_lr_mult
|
||||
|
||||
# apply weight decay policies
|
||||
if self.base_wd is not None:
|
||||
# norm decay
|
||||
if is_norm and norm_decay_mult is not None:
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * norm_decay_mult
|
||||
# bias lr and decay
|
||||
elif (name == 'bias' and not is_dcn_module
|
||||
and bias_decay_mult is not None):
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * bias_decay_mult
|
||||
# depth-wise conv
|
||||
elif is_dwconv and dwconv_decay_mult is not None:
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * dwconv_decay_mult
|
||||
# flatten parameters except dcn offset
|
||||
elif (param.ndim == 1 and not is_dcn_module
|
||||
and flat_decay_mult is not None):
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * flat_decay_mult
|
||||
params.append(param_group)
|
||||
for key, value in param_group.items():
|
||||
if key == 'params':
|
||||
continue
|
||||
full_name = f'{prefix}.{name}' if prefix else name
|
||||
print_log(
|
||||
f'paramwise_options -- {full_name}:{key}={value}',
|
||||
logger='current')
|
||||
|
||||
if mmcv_full_available():
|
||||
from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
|
||||
is_dcn_module = isinstance(module,
|
||||
(DeformConv2d, ModulatedDeformConv2d))
|
||||
else:
|
||||
is_dcn_module = False
|
||||
for child_name, child_mod in module.named_children():
|
||||
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
|
||||
self.add_params(
|
||||
params,
|
||||
child_mod,
|
||||
prefix=child_prefix,
|
||||
is_dcn_module=is_dcn_module)
|
||||
@@ -0,0 +1,207 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import warnings
|
||||
|
||||
from mmengine.dist import get_dist_info
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||
|
||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
|
||||
|
||||
def get_layer_id_for_convnext(var_name, max_layer_id):
|
||||
"""Get the layer id to set the different learning rates in ``layer_wise``
|
||||
decay_type.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
max_layer_id (int): Maximum number of backbone layers.
|
||||
|
||||
Returns:
|
||||
int: The id number corresponding to different learning rate in
|
||||
``LearningRateDecayOptimizerConstructor``.
|
||||
"""
|
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
||||
'backbone.pos_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.downsample_layers'):
|
||||
stage_id = int(var_name.split('.')[2])
|
||||
if stage_id == 0:
|
||||
layer_id = 0
|
||||
elif stage_id == 1:
|
||||
layer_id = 2
|
||||
elif stage_id == 2:
|
||||
layer_id = 3
|
||||
elif stage_id == 3:
|
||||
layer_id = max_layer_id
|
||||
return layer_id
|
||||
elif var_name.startswith('backbone.stages'):
|
||||
stage_id = int(var_name.split('.')[2])
|
||||
block_id = int(var_name.split('.')[3])
|
||||
if stage_id == 0:
|
||||
layer_id = 1
|
||||
elif stage_id == 1:
|
||||
layer_id = 2
|
||||
elif stage_id == 2:
|
||||
layer_id = 3 + block_id // 3
|
||||
elif stage_id == 3:
|
||||
layer_id = max_layer_id
|
||||
return layer_id
|
||||
else:
|
||||
return max_layer_id + 1
|
||||
|
||||
|
||||
def get_stage_id_for_convnext(var_name, max_stage_id):
|
||||
"""Get the stage id to set the different learning rates in ``stage_wise``
|
||||
decay_type.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
max_stage_id (int): Maximum number of backbone layers.
|
||||
|
||||
Returns:
|
||||
int: The id number corresponding to different learning rate in
|
||||
``LearningRateDecayOptimizerConstructor``.
|
||||
"""
|
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
||||
'backbone.pos_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.downsample_layers'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.stages'):
|
||||
stage_id = int(var_name.split('.')[2])
|
||||
return stage_id + 1
|
||||
else:
|
||||
return max_stage_id - 1
|
||||
|
||||
|
||||
def get_layer_id_for_vit(var_name, max_layer_id):
|
||||
"""Get the layer id to set the different learning rates.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
num_max_layer (int): Maximum number of backbone layers.
|
||||
|
||||
Returns:
|
||||
int: Returns the layer id of the key.
|
||||
"""
|
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
||||
'backbone.pos_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.patch_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.layers'):
|
||||
layer_id = int(var_name.split('.')[2])
|
||||
return layer_id + 1
|
||||
else:
|
||||
return max_layer_id - 1
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
Note: Currently, this optimizer constructor is built for ConvNeXt,
|
||||
BEiT and MAE.
|
||||
"""
|
||||
|
||||
def add_params(self, params, module, **kwargs):
|
||||
"""Add all parameters of module to the params list.
|
||||
|
||||
The parameters of the given module will be added to the list of param
|
||||
groups, with specific rules defined by paramwise_cfg.
|
||||
|
||||
Args:
|
||||
params (list[dict]): A list of param groups, it will be modified
|
||||
in place.
|
||||
module (nn.Module): The module to be added.
|
||||
"""
|
||||
|
||||
parameter_groups = {}
|
||||
print_log(f'self.paramwise_cfg is {self.paramwise_cfg}')
|
||||
num_layers = self.paramwise_cfg.get('num_layers') + 2
|
||||
decay_rate = self.paramwise_cfg.get('decay_rate')
|
||||
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
|
||||
print_log('Build LearningRateDecayOptimizerConstructor '
|
||||
f'{decay_type} {decay_rate} - {num_layers}')
|
||||
weight_decay = self.base_wd
|
||||
for name, param in module.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue # frozen weights
|
||||
if len(param.shape) == 1 or name.endswith('.bias') or name in (
|
||||
'pos_embed', 'cls_token'):
|
||||
group_name = 'no_decay'
|
||||
this_weight_decay = 0.
|
||||
else:
|
||||
group_name = 'decay'
|
||||
this_weight_decay = weight_decay
|
||||
if 'layer_wise' in decay_type:
|
||||
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
||||
layer_id = get_layer_id_for_convnext(
|
||||
name, self.paramwise_cfg.get('num_layers'))
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
elif 'BEiT' in module.backbone.__class__.__name__ or \
|
||||
'MAE' in module.backbone.__class__.__name__:
|
||||
layer_id = get_layer_id_for_vit(name, num_layers)
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
elif decay_type == 'stage_wise':
|
||||
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
||||
layer_id = get_stage_id_for_convnext(name, num_layers)
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
group_name = f'layer_{layer_id}_{group_name}'
|
||||
|
||||
if group_name not in parameter_groups:
|
||||
scale = decay_rate**(num_layers - layer_id - 1)
|
||||
|
||||
parameter_groups[group_name] = {
|
||||
'weight_decay': this_weight_decay,
|
||||
'params': [],
|
||||
'param_names': [],
|
||||
'lr_scale': scale,
|
||||
'group_name': group_name,
|
||||
'lr': scale * self.base_lr,
|
||||
}
|
||||
|
||||
parameter_groups[group_name]['params'].append(param)
|
||||
parameter_groups[group_name]['param_names'].append(name)
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
to_display = {}
|
||||
for key in parameter_groups:
|
||||
to_display[key] = {
|
||||
'param_names': parameter_groups[key]['param_names'],
|
||||
'lr_scale': parameter_groups[key]['lr_scale'],
|
||||
'lr': parameter_groups[key]['lr'],
|
||||
'weight_decay': parameter_groups[key]['weight_decay'],
|
||||
}
|
||||
print_log(f'Param groups = {json.dumps(to_display, indent=2)}')
|
||||
params.extend(parameter_groups.values())
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
Note: Currently, this optimizer constructor is built for BEiT,
|
||||
and it will be deprecated.
|
||||
Please use ``LearningRateDecayOptimizerConstructor`` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg):
|
||||
warnings.warn('DeprecationWarning: Original '
|
||||
'LayerDecayOptimizerConstructor of BEiT '
|
||||
'will be deprecated. Please use '
|
||||
'LearningRateDecayOptimizerConstructor instead, '
|
||||
'and set decay_type = layer_wise_vit in paramwise_cfg.')
|
||||
paramwise_cfg.update({'decay_type': 'layer_wise_vit'})
|
||||
warnings.warn('DeprecationWarning: Layer_decay_rate will '
|
||||
'be deleted, please use decay_rate instead.')
|
||||
paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate')
|
||||
super().__init__(optim_wrapper_cfg, paramwise_cfg)
|
||||
4
Seg_All_In_One_MMSeg/mmseg/engine/schedulers/__init__.py
Normal file
4
Seg_All_In_One_MMSeg/mmseg/engine/schedulers/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .poly_ratio_scheduler import PolyLRRatio
|
||||
|
||||
__all__ = ['PolyLRRatio']
|
||||
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.optim.scheduler import PolyLR
|
||||
|
||||
from mmseg.registry import PARAM_SCHEDULERS
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class PolyLRRatio(PolyLR):
|
||||
"""Implements polynomial learning rate decay with ratio.
|
||||
|
||||
This scheduler adjusts the learning rate of each parameter group
|
||||
following a polynomial decay equation. The decay can occur in
|
||||
conjunction with external parameter adjustments made outside this
|
||||
scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
|
||||
eta_min (float): Minimum learning rate at the end of scheduling.
|
||||
Defaults to 0.
|
||||
eta_min_ratio (float, optional): The ratio of the minimum parameter
|
||||
value to the base parameter value. Either `eta_min` or
|
||||
`eta_min_ratio` should be specified. Defaults to None.
|
||||
power (float): The power of the polynomial. Defaults to 1.0.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
Defaults to 0.
|
||||
end (int): Step at which to stop updating the parameters.
|
||||
Defaults to INF.
|
||||
last_step (int): The index of last step. Used for resume without
|
||||
state dict. Defaults to -1.
|
||||
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||
epochs. Defaults to True.
|
||||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, eta_min_ratio: Optional[int] = None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.eta_min_ratio = eta_min_ratio
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
|
||||
if self.last_step == 0:
|
||||
return [
|
||||
group[self.param_name] for group in self.optimizer.param_groups
|
||||
]
|
||||
|
||||
param_groups_value = []
|
||||
for base_value, param_group in zip(self.base_values,
|
||||
self.optimizer.param_groups):
|
||||
eta_min = self.eta_min if self.eta_min_ratio is None else \
|
||||
base_value * self.eta_min_ratio
|
||||
step_ratio = (1 - 1 /
|
||||
(self.total_iters - self.last_step + 1))**self.power
|
||||
step_value = (param_group[self.param_name] -
|
||||
eta_min) * step_ratio + eta_min
|
||||
param_groups_value.append(step_value)
|
||||
|
||||
return param_groups_value
|
||||
4
Seg_All_In_One_MMSeg/mmseg/evaluation/__init__.py
Normal file
4
Seg_All_In_One_MMSeg/mmseg/evaluation/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .metrics import CityscapesMetric, DepthMetric, IoUMetric
|
||||
|
||||
__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric']
|
||||
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .citys_metric import CityscapesMetric
|
||||
from .depth_metric import DepthMetric
|
||||
from .iou_metric import IoUMetric
|
||||
|
||||
__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric']
|
||||
158
Seg_All_In_One_MMSeg/mmseg/evaluation/metrics/citys_metric.py
Normal file
158
Seg_All_In_One_MMSeg/mmseg/evaluation/metrics/citys_metric.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
try:
|
||||
|
||||
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
except ImportError:
|
||||
CSLabels = None
|
||||
CSEval = None
|
||||
|
||||
import numpy as np
|
||||
from mmengine.dist import is_main_process, master_only
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class CityscapesMetric(BaseMetric):
|
||||
"""Cityscapes evaluation metric.
|
||||
|
||||
Args:
|
||||
output_dir (str): The directory for output prediction
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
Default: 255.
|
||||
format_only (bool): Only format result for results commit without
|
||||
perform evaluation. It is useful when you want to format the result
|
||||
to a specific format and submit it to the test server.
|
||||
Defaults to False.
|
||||
keep_results (bool): Whether to keep the results. When ``format_only``
|
||||
is True, ``keep_results`` must be True. Defaults to False.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
output_dir: str,
|
||||
ignore_index: int = 255,
|
||||
format_only: bool = False,
|
||||
keep_results: bool = False,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
if CSEval is None:
|
||||
raise ImportError('Please run "pip install cityscapesscripts" to '
|
||||
'install cityscapesscripts first.')
|
||||
self.output_dir = output_dir
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
self.format_only = format_only
|
||||
if format_only:
|
||||
assert keep_results, (
|
||||
'When format_only is True, the results must be keep, please '
|
||||
f'set keep_results as True, but got {keep_results}')
|
||||
self.keep_results = keep_results
|
||||
self.prefix = prefix
|
||||
if is_main_process():
|
||||
mkdir_or_exist(self.output_dir)
|
||||
|
||||
@master_only
|
||||
def __del__(self) -> None:
|
||||
"""Clean up."""
|
||||
if not self.keep_results:
|
||||
shutil.rmtree(self.output_dir)
|
||||
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
mkdir_or_exist(self.output_dir)
|
||||
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy()
|
||||
# when evaluating with official cityscapesscripts,
|
||||
# labelIds should be used
|
||||
pred_label = self._convert_to_label_id(pred_label)
|
||||
basename = osp.splitext(osp.basename(data_sample['img_path']))[0]
|
||||
png_filename = osp.abspath(
|
||||
osp.join(self.output_dir, f'{basename}.png'))
|
||||
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
|
||||
output.save(png_filename)
|
||||
if self.format_only:
|
||||
# format_only always for test dataset without ground truth
|
||||
gt_filename = ''
|
||||
else:
|
||||
# when evaluating with official cityscapesscripts,
|
||||
# **_gtFine_labelIds.png is used
|
||||
gt_filename = data_sample['seg_map_path'].replace(
|
||||
'labelTrainIds.png', 'labelIds.png')
|
||||
self.results.append((png_filename, gt_filename))
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
|
||||
Returns:
|
||||
dict[str: float]: Cityscapes evaluation results.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
if self.format_only:
|
||||
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
|
||||
return OrderedDict()
|
||||
|
||||
msg = 'Evaluating in Cityscapes style'
|
||||
if logger is None:
|
||||
msg = '\n' + msg
|
||||
print_log(msg, logger=logger)
|
||||
|
||||
eval_results = dict()
|
||||
print_log(
|
||||
f'Evaluating results under {self.output_dir} ...', logger=logger)
|
||||
|
||||
CSEval.args.evalInstLevelScore = True
|
||||
CSEval.args.predictionPath = osp.abspath(self.output_dir)
|
||||
CSEval.args.evalPixelAccuracy = True
|
||||
CSEval.args.JSONOutput = False
|
||||
|
||||
pred_list, gt_list = zip(*results)
|
||||
metric = dict()
|
||||
eval_results.update(
|
||||
CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args))
|
||||
metric['averageScoreCategories'] = eval_results[
|
||||
'averageScoreCategories']
|
||||
metric['averageScoreInstCategories'] = eval_results[
|
||||
'averageScoreInstCategories']
|
||||
return metric
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_label_id(result):
|
||||
"""Convert trainId to id for cityscapes."""
|
||||
if isinstance(result, str):
|
||||
result = np.load(result)
|
||||
result_copy = result.copy()
|
||||
for trainId, label in CSLabels.trainId2label.items():
|
||||
result_copy[result == trainId] = label.id
|
||||
|
||||
return result_copy
|
||||
212
Seg_All_In_One_MMSeg/mmseg/evaluation/metrics/depth_metric.py
Normal file
212
Seg_All_In_One_MMSeg/mmseg/evaluation/metrics/depth_metric.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dist import is_main_process
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from prettytable import PrettyTable
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class DepthMetric(BaseMetric):
|
||||
"""Depth estimation evaluation metric.
|
||||
|
||||
Args:
|
||||
depth_metrics (List[str], optional): List of metrics to compute. If
|
||||
not specified, defaults to all metrics in self.METRICS.
|
||||
min_depth_eval (float): Minimum depth value for evaluation.
|
||||
Defaults to 0.0.
|
||||
max_depth_eval (float): Maximum depth value for evaluation.
|
||||
Defaults to infinity.
|
||||
crop_type (str, optional): Specifies the type of cropping to be used
|
||||
during evaluation. This option can affect how the evaluation mask
|
||||
is generated. Currently, 'nyu_crop' is supported, but other
|
||||
types can be added in future. Defaults to None if no cropping
|
||||
should be applied.
|
||||
depth_scale_factor (float): Factor to scale the depth values.
|
||||
Defaults to 1.0.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
output_dir (str): The directory for output prediction. Defaults to
|
||||
None.
|
||||
format_only (bool): Only format result for results commit without
|
||||
perform evaluation. It is useful when you want to save the result
|
||||
to a specific format and submit it to the test server.
|
||||
Defaults to False.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
METRICS = ('d1', 'd2', 'd3', 'abs_rel', 'sq_rel', 'rmse', 'rmse_log',
|
||||
'log10', 'silog')
|
||||
|
||||
def __init__(self,
|
||||
depth_metrics: Optional[List[str]] = None,
|
||||
min_depth_eval: float = 0.0,
|
||||
max_depth_eval: float = float('inf'),
|
||||
crop_type: Optional[str] = None,
|
||||
depth_scale_factor: float = 1.0,
|
||||
collect_device: str = 'cpu',
|
||||
output_dir: Optional[str] = None,
|
||||
format_only: bool = False,
|
||||
prefix: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
if depth_metrics is None:
|
||||
self.metrics = self.METRICS
|
||||
elif isinstance(depth_metrics, [tuple, list]):
|
||||
for metric in depth_metrics:
|
||||
assert metric in self.METRICS, f'the metric {metric} is not ' \
|
||||
f'supported. Please use metrics in {self.METRICS}'
|
||||
self.metrics = depth_metrics
|
||||
|
||||
# Validate crop_type, if provided
|
||||
assert crop_type in [
|
||||
None, 'nyu_crop'
|
||||
], (f'Invalid value for crop_type: {crop_type}. Supported values are '
|
||||
'None or \'nyu_crop\'.')
|
||||
self.crop_type = crop_type
|
||||
self.min_depth_eval = min_depth_eval
|
||||
self.max_depth_eval = max_depth_eval
|
||||
self.output_dir = output_dir
|
||||
if self.output_dir and is_main_process():
|
||||
mkdir_or_exist(self.output_dir)
|
||||
self.format_only = format_only
|
||||
self.depth_scale_factor = depth_scale_factor
|
||||
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_depth_map']['data'].squeeze()
|
||||
# format_only always for test dataset without ground truth
|
||||
if not self.format_only:
|
||||
gt_depth = data_sample['gt_depth_map']['data'].squeeze().to(
|
||||
pred_label)
|
||||
|
||||
eval_mask = self._get_eval_mask(gt_depth)
|
||||
self.results.append(
|
||||
(gt_depth[eval_mask], pred_label[eval_mask]))
|
||||
# format_result
|
||||
if self.output_dir is not None:
|
||||
basename = osp.splitext(osp.basename(
|
||||
data_sample['img_path']))[0]
|
||||
png_filename = osp.abspath(
|
||||
osp.join(self.output_dir, f'{basename}.png'))
|
||||
output_mask = pred_label.cpu().numpy(
|
||||
) * self.depth_scale_factor
|
||||
|
||||
cv2.imwrite(png_filename, output_mask.astype(np.uint16),
|
||||
[cv2.IMWRITE_PNG_COMPRESSION, 0])
|
||||
|
||||
def _get_eval_mask(self, gt_depth: Tensor):
|
||||
"""Generates an evaluation mask based on ground truth depth and
|
||||
cropping.
|
||||
|
||||
Args:
|
||||
gt_depth (Tensor): Ground truth depth map.
|
||||
|
||||
Returns:
|
||||
Tensor: Boolean mask where evaluation should be performed.
|
||||
"""
|
||||
valid_mask = torch.logical_and(gt_depth > self.min_depth_eval,
|
||||
gt_depth < self.max_depth_eval)
|
||||
|
||||
if self.crop_type == 'nyu_crop':
|
||||
# this implementation is adapted from
|
||||
# https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blob/main/depth/datasets/nyu.py # noqa
|
||||
crop_mask = torch.zeros_like(valid_mask)
|
||||
crop_mask[45:471, 41:601] = 1
|
||||
else:
|
||||
crop_mask = torch.ones_like(valid_mask)
|
||||
|
||||
eval_mask = torch.logical_and(valid_mask, crop_mask)
|
||||
return eval_mask
|
||||
|
||||
@staticmethod
|
||||
def _calc_all_metrics(gt_depth, pred_depth):
|
||||
"""Computes final evaluation metrics based on accumulated results."""
|
||||
assert gt_depth.shape == pred_depth.shape
|
||||
|
||||
thresh = torch.max((gt_depth / pred_depth), (pred_depth / gt_depth))
|
||||
diff = pred_depth - gt_depth
|
||||
diff_log = torch.log(pred_depth) - torch.log(gt_depth)
|
||||
|
||||
d1 = torch.sum(thresh < 1.25).float() / len(thresh)
|
||||
d2 = torch.sum(thresh < 1.25**2).float() / len(thresh)
|
||||
d3 = torch.sum(thresh < 1.25**3).float() / len(thresh)
|
||||
|
||||
abs_rel = torch.mean(torch.abs(diff) / gt_depth)
|
||||
sq_rel = torch.mean(torch.pow(diff, 2) / gt_depth)
|
||||
|
||||
rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
|
||||
rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log, 2)))
|
||||
|
||||
log10 = torch.mean(
|
||||
torch.abs(torch.log10(pred_depth) - torch.log10(gt_depth)))
|
||||
silog = torch.sqrt(
|
||||
torch.pow(diff_log, 2).mean() -
|
||||
0.5 * torch.pow(diff_log.mean(), 2))
|
||||
|
||||
return {
|
||||
'd1': d1.item(),
|
||||
'd2': d2.item(),
|
||||
'd3': d3.item(),
|
||||
'abs_rel': abs_rel.item(),
|
||||
'sq_rel': sq_rel.item(),
|
||||
'rmse': rmse.item(),
|
||||
'rmse_log': rmse_log.item(),
|
||||
'log10': log10.item(),
|
||||
'silog': silog.item()
|
||||
}
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: The computed metrics. The keys are the names of
|
||||
the metrics, and the values are corresponding results. The keys
|
||||
are identical with self.metrics.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
if self.format_only:
|
||||
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
|
||||
return OrderedDict()
|
||||
|
||||
metrics = defaultdict(list)
|
||||
for gt_depth, pred_depth in results:
|
||||
for key, value in self._calc_all_metrics(gt_depth,
|
||||
pred_depth).items():
|
||||
metrics[key].append(value)
|
||||
metrics = {k: sum(metrics[k]) / len(metrics[k]) for k in self.metrics}
|
||||
|
||||
table_data = PrettyTable()
|
||||
for key, val in metrics.items():
|
||||
table_data.add_column(key, [round(val, 5)])
|
||||
|
||||
print_log('results:', logger)
|
||||
print_log('\n' + table_data.get_string(), logger=logger)
|
||||
|
||||
return metrics
|
||||
286
Seg_All_In_One_MMSeg/mmseg/evaluation/metrics/iou_metric.py
Normal file
286
Seg_All_In_One_MMSeg/mmseg/evaluation/metrics/iou_metric.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dist import is_main_process
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from PIL import Image
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class IoUMetric(BaseMetric):
|
||||
"""IoU evaluation metric.
|
||||
|
||||
Args:
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
Default: 255.
|
||||
iou_metrics (list[str] | str): Metrics to be calculated, the options
|
||||
includes 'mIoU', 'mDice' and 'mFscore'.
|
||||
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||
by the numbers defined by the user. Default: None.
|
||||
beta (int): Determines the weight of recall in the combined score.
|
||||
Default: 1.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
output_dir (str): The directory for output prediction. Defaults to
|
||||
None.
|
||||
format_only (bool): Only format result for results commit without
|
||||
perform evaluation. It is useful when you want to save the result
|
||||
to a specific format and submit it to the test server.
|
||||
Defaults to False.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ignore_index: int = 255,
|
||||
iou_metrics: List[str] = ['mIoU'],
|
||||
nan_to_num: Optional[int] = None,
|
||||
beta: int = 1,
|
||||
collect_device: str = 'cpu',
|
||||
output_dir: Optional[str] = None,
|
||||
format_only: bool = False,
|
||||
prefix: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.metrics = iou_metrics
|
||||
self.nan_to_num = nan_to_num
|
||||
self.beta = beta
|
||||
self.output_dir = output_dir
|
||||
if self.output_dir and is_main_process():
|
||||
mkdir_or_exist(self.output_dir)
|
||||
self.format_only = format_only
|
||||
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
num_classes = len(self.dataset_meta['classes'])
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_sem_seg']['data'].squeeze()
|
||||
# format_only always for test dataset without ground truth
|
||||
if not self.format_only:
|
||||
label = data_sample['gt_sem_seg']['data'].squeeze().to(
|
||||
pred_label)
|
||||
self.results.append(
|
||||
self.intersect_and_union(pred_label, label, num_classes,
|
||||
self.ignore_index))
|
||||
# format_result
|
||||
if self.output_dir is not None:
|
||||
basename = osp.splitext(osp.basename(
|
||||
data_sample['img_path']))[0]
|
||||
png_filename = osp.abspath(
|
||||
osp.join(self.output_dir, f'{basename}.png'))
|
||||
output_mask = pred_label.cpu().numpy()
|
||||
# The index range of official ADE20k dataset is from 0 to 150.
|
||||
# But the index range of output is from 0 to 149.
|
||||
# That is because we set reduce_zero_label=True.
|
||||
if data_sample.get('reduce_zero_label', False):
|
||||
output_mask = output_mask + 1
|
||||
output = Image.fromarray(output_mask.astype(np.uint8))
|
||||
output.save(png_filename)
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: The computed metrics. The keys are the names of
|
||||
the metrics, and the values are corresponding results. The key
|
||||
mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision,
|
||||
mRecall.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
if self.format_only:
|
||||
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
|
||||
return OrderedDict()
|
||||
# convert list of tuples to tuple of lists, e.g.
|
||||
# [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to
|
||||
# ([A_1, ..., A_n], ..., [D_1, ..., D_n])
|
||||
results = tuple(zip(*results))
|
||||
assert len(results) == 4
|
||||
|
||||
total_area_intersect = sum(results[0])
|
||||
total_area_union = sum(results[1])
|
||||
total_area_pred_label = sum(results[2])
|
||||
total_area_label = sum(results[3])
|
||||
ret_metrics = self.total_area_to_metrics(
|
||||
total_area_intersect, total_area_union, total_area_pred_label,
|
||||
total_area_label, self.metrics, self.nan_to_num, self.beta)
|
||||
|
||||
class_names = self.dataset_meta['classes']
|
||||
|
||||
# summary table
|
||||
ret_metrics_summary = OrderedDict({
|
||||
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
|
||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||
})
|
||||
metrics = dict()
|
||||
for key, val in ret_metrics_summary.items():
|
||||
if key == 'aAcc':
|
||||
metrics[key] = val
|
||||
else:
|
||||
metrics['m' + key] = val
|
||||
|
||||
# each class table
|
||||
ret_metrics.pop('aAcc', None)
|
||||
ret_metrics_class = OrderedDict({
|
||||
ret_metric: np.round(ret_metric_value * 100, 2)
|
||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||
})
|
||||
ret_metrics_class.update({'Class': class_names})
|
||||
ret_metrics_class.move_to_end('Class', last=False)
|
||||
class_table_data = PrettyTable()
|
||||
for key, val in ret_metrics_class.items():
|
||||
class_table_data.add_column(key, val)
|
||||
|
||||
print_log('per class results:', logger)
|
||||
print_log('\n' + class_table_data.get_string(), logger=logger)
|
||||
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
|
||||
num_classes: int, ignore_index: int):
|
||||
"""Calculate Intersection and Union.
|
||||
|
||||
Args:
|
||||
pred_label (torch.tensor): Prediction segmentation map
|
||||
or predict result filename. The shape is (H, W).
|
||||
label (torch.tensor): Ground truth segmentation map
|
||||
or label filename. The shape is (H, W).
|
||||
num_classes (int): Number of categories.
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The intersection of prediction and ground truth
|
||||
histogram on all classes.
|
||||
torch.Tensor: The union of prediction and ground truth histogram on
|
||||
all classes.
|
||||
torch.Tensor: The prediction histogram on all classes.
|
||||
torch.Tensor: The ground truth histogram on all classes.
|
||||
"""
|
||||
|
||||
mask = (label != ignore_index)
|
||||
pred_label = pred_label[mask]
|
||||
label = label[mask]
|
||||
|
||||
intersect = pred_label[pred_label == label]
|
||||
area_intersect = torch.histc(
|
||||
intersect.float(), bins=(num_classes), min=0,
|
||||
max=num_classes - 1).cpu()
|
||||
area_pred_label = torch.histc(
|
||||
pred_label.float(), bins=(num_classes), min=0,
|
||||
max=num_classes - 1).cpu()
|
||||
area_label = torch.histc(
|
||||
label.float(), bins=(num_classes), min=0,
|
||||
max=num_classes - 1).cpu()
|
||||
area_union = area_pred_label + area_label - area_intersect
|
||||
return area_intersect, area_union, area_pred_label, area_label
|
||||
|
||||
@staticmethod
|
||||
def total_area_to_metrics(total_area_intersect: np.ndarray,
|
||||
total_area_union: np.ndarray,
|
||||
total_area_pred_label: np.ndarray,
|
||||
total_area_label: np.ndarray,
|
||||
metrics: List[str] = ['mIoU'],
|
||||
nan_to_num: Optional[int] = None,
|
||||
beta: int = 1):
|
||||
"""Calculate evaluation metrics
|
||||
Args:
|
||||
total_area_intersect (np.ndarray): The intersection of prediction
|
||||
and ground truth histogram on all classes.
|
||||
total_area_union (np.ndarray): The union of prediction and ground
|
||||
truth histogram on all classes.
|
||||
total_area_pred_label (np.ndarray): The prediction histogram on
|
||||
all classes.
|
||||
total_area_label (np.ndarray): The ground truth histogram on
|
||||
all classes.
|
||||
metrics (List[str] | str): Metrics to be evaluated, 'mIoU' and
|
||||
'mDice'.
|
||||
nan_to_num (int, optional): If specified, NaN values will be
|
||||
replaced by the numbers defined by the user. Default: None.
|
||||
beta (int): Determines the weight of recall in the combined score.
|
||||
Default: 1.
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: per category evaluation metrics,
|
||||
shape (num_classes, ).
|
||||
"""
|
||||
|
||||
def f_score(precision, recall, beta=1):
|
||||
"""calculate the f-score value.
|
||||
|
||||
Args:
|
||||
precision (float | torch.Tensor): The precision value.
|
||||
recall (float | torch.Tensor): The recall value.
|
||||
beta (int): Determines the weight of recall in the combined
|
||||
score. Default: 1.
|
||||
|
||||
Returns:
|
||||
[torch.tensor]: The f-score value.
|
||||
"""
|
||||
score = (1 + beta**2) * (precision * recall) / (
|
||||
(beta**2 * precision) + recall)
|
||||
return score
|
||||
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
|
||||
if not set(metrics).issubset(set(allowed_metrics)):
|
||||
raise KeyError(f'metrics {metrics} is not supported')
|
||||
|
||||
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
||||
ret_metrics = OrderedDict({'aAcc': all_acc})
|
||||
for metric in metrics:
|
||||
if metric == 'mIoU':
|
||||
iou = total_area_intersect / total_area_union
|
||||
acc = total_area_intersect / total_area_label
|
||||
ret_metrics['IoU'] = iou
|
||||
ret_metrics['Acc'] = acc
|
||||
elif metric == 'mDice':
|
||||
dice = 2 * total_area_intersect / (
|
||||
total_area_pred_label + total_area_label)
|
||||
acc = total_area_intersect / total_area_label
|
||||
ret_metrics['Dice'] = dice
|
||||
ret_metrics['Acc'] = acc
|
||||
elif metric == 'mFscore':
|
||||
precision = total_area_intersect / total_area_pred_label
|
||||
recall = total_area_intersect / total_area_label
|
||||
f_value = torch.tensor([
|
||||
f_score(x[0], x[1], beta) for x in zip(precision, recall)
|
||||
])
|
||||
ret_metrics['Fscore'] = f_value
|
||||
ret_metrics['Precision'] = precision
|
||||
ret_metrics['Recall'] = recall
|
||||
|
||||
ret_metrics = {
|
||||
metric: value.numpy()
|
||||
for metric, value in ret_metrics.items()
|
||||
}
|
||||
if nan_to_num is not None:
|
||||
ret_metrics = OrderedDict({
|
||||
metric: np.nan_to_num(metric_value, nan=nan_to_num)
|
||||
for metric, metric_value in ret_metrics.items()
|
||||
})
|
||||
return ret_metrics
|
||||
16
Seg_All_In_One_MMSeg/mmseg/models/__init__.py
Normal file
16
Seg_All_In_One_MMSeg/mmseg/models/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .assigners import * # noqa: F401,F403
|
||||
from .backbones import * # noqa: F401,F403
|
||||
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
|
||||
build_head, build_loss, build_segmentor)
|
||||
from .data_preprocessor import SegDataPreProcessor
|
||||
from .decode_heads import * # noqa: F401,F403
|
||||
from .losses import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .segmentors import * # noqa: F401,F403
|
||||
from .text_encoder import * # noqa: F401,F403
|
||||
|
||||
__all__ = [
|
||||
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
|
||||
'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor'
|
||||
]
|
||||
12
Seg_All_In_One_MMSeg/mmseg/models/assigners/__init__.py
Normal file
12
Seg_All_In_One_MMSeg/mmseg/models/assigners/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_assigner import BaseAssigner
|
||||
from .hungarian_assigner import HungarianAssigner
|
||||
from .match_cost import ClassificationCost, CrossEntropyLossCost, DiceCost
|
||||
|
||||
__all__ = [
|
||||
'BaseAssigner',
|
||||
'HungarianAssigner',
|
||||
'ClassificationCost',
|
||||
'CrossEntropyLossCost',
|
||||
'DiceCost',
|
||||
]
|
||||
18
Seg_All_In_One_MMSeg/mmseg/models/assigners/base_assigner.py
Normal file
18
Seg_All_In_One_MMSeg/mmseg/models/assigners/base_assigner.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
|
||||
class BaseAssigner(metaclass=ABCMeta):
|
||||
"""Base assigner that assigns masks to ground truth class labels."""
|
||||
|
||||
@abstractmethod
|
||||
def assign(self,
|
||||
pred_instances: InstanceData,
|
||||
gt_instances: InstanceData,
|
||||
gt_instances_ignore: Optional[InstanceData] = None,
|
||||
**kwargs):
|
||||
"""Assign masks to either a ground truth class label or a negative
|
||||
label."""
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from mmengine import ConfigDict
|
||||
from mmengine.structures import InstanceData
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
from .base_assigner import BaseAssigner
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class HungarianAssigner(BaseAssigner):
|
||||
"""Computes one-to-one matching between prediction masks and ground truth.
|
||||
|
||||
This class uses bipartite matching-based assignment to computes an
|
||||
assignment between the prediction masks and the ground truth. The
|
||||
assignment result is based on the weighted sum of match costs. The
|
||||
Hungarian algorithm is used to calculate the best matching with the
|
||||
minimum cost. The prediction masks that are not matched are classified
|
||||
as background.
|
||||
|
||||
Args:
|
||||
match_costs (ConfigDict|List[ConfigDict]): Match cost configs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
|
||||
ConfigDict]
|
||||
) -> None:
|
||||
|
||||
if isinstance(match_costs, dict):
|
||||
match_costs = [match_costs]
|
||||
elif isinstance(match_costs, list):
|
||||
assert len(match_costs) > 0, \
|
||||
'match_costs must not be a empty list.'
|
||||
|
||||
self.match_costs = [
|
||||
TASK_UTILS.build(match_cost) for match_cost in match_costs
|
||||
]
|
||||
|
||||
def assign(self, pred_instances: InstanceData, gt_instances: InstanceData,
|
||||
**kwargs):
|
||||
"""Computes one-to-one matching based on the weighted costs.
|
||||
|
||||
This method assign each query prediction to a ground truth or
|
||||
background. The assignment first calculates the cost for each
|
||||
category assigned to each query mask, and then uses the
|
||||
Hungarian algorithm to calculate the minimum cost as the best
|
||||
match.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Instances of model
|
||||
predictions. It includes "masks", with shape
|
||||
(n, h, w) or (n, l), and "cls", with shape (n, num_classes+1)
|
||||
gt_instances (InstanceData): Ground truth of instance
|
||||
annotations. It includes "labels", with shape (k, ),
|
||||
and "masks", with shape (k, h, w) or (k, l).
|
||||
|
||||
Returns:
|
||||
matched_quiery_inds (Tensor): The indexes of matched quieres.
|
||||
matched_label_inds (Tensor): The indexes of matched labels.
|
||||
"""
|
||||
# compute weighted cost
|
||||
cost_list = []
|
||||
with autocast(enabled=False):
|
||||
for match_cost in self.match_costs:
|
||||
cost = match_cost(
|
||||
pred_instances=pred_instances, gt_instances=gt_instances)
|
||||
cost_list.append(cost)
|
||||
cost = torch.stack(cost_list).sum(dim=0)
|
||||
|
||||
device = cost.device
|
||||
# do Hungarian matching on CPU using linear_sum_assignment
|
||||
cost = cost.detach().cpu()
|
||||
if linear_sum_assignment is None:
|
||||
raise ImportError('Please run "pip install scipy" '
|
||||
'to install scipy first.')
|
||||
|
||||
matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost)
|
||||
matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device)
|
||||
matched_label_inds = torch.from_numpy(matched_label_inds).to(device)
|
||||
|
||||
return matched_quiery_inds, matched_label_inds
|
||||
231
Seg_All_In_One_MMSeg/mmseg/models/assigners/match_cost.py
Normal file
231
Seg_All_In_One_MMSeg/mmseg/models/assigners/match_cost.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import abstractmethod
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
|
||||
|
||||
class BaseMatchCost:
|
||||
"""Base match cost class.
|
||||
|
||||
Args:
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, weight: Union[float, int] = 1.) -> None:
|
||||
self.weight = weight
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Instances of model predictions.
|
||||
It often includes "labels" and "scores".
|
||||
gt_instances (InstanceData): Ground truth of instance
|
||||
annotations. It usually includes "labels".
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ClassificationCost(BaseMatchCost):
|
||||
"""ClsSoftmaxCost.
|
||||
|
||||
Args:
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
|
||||
Examples:
|
||||
>>> from mmseg.models.assigners import ClassificationCost
|
||||
>>> import torch
|
||||
>>> self = ClassificationCost()
|
||||
>>> cls_pred = torch.rand(4, 3)
|
||||
>>> gt_labels = torch.tensor([0, 1, 2])
|
||||
>>> factor = torch.tensor([10, 8, 10, 8])
|
||||
>>> self(cls_pred, gt_labels)
|
||||
tensor([[-0.3430, -0.3525, -0.3045],
|
||||
[-0.3077, -0.2931, -0.3992],
|
||||
[-0.3664, -0.3455, -0.2881],
|
||||
[-0.3343, -0.2701, -0.3956]])
|
||||
"""
|
||||
|
||||
def __init__(self, weight: Union[float, int] = 1) -> None:
|
||||
super().__init__(weight=weight)
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): "scores" inside is
|
||||
predicted classification logits, of shape
|
||||
(num_queries, num_class).
|
||||
gt_instances (InstanceData): "labels" inside should have
|
||||
shape (num_gt, ).
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'scores'), \
|
||||
"pred_instances must contain 'scores'"
|
||||
assert hasattr(gt_instances, 'labels'), \
|
||||
"gt_instances must contain 'labels'"
|
||||
pred_scores = pred_instances.scores
|
||||
gt_labels = gt_instances.labels
|
||||
|
||||
pred_scores = pred_scores.softmax(-1)
|
||||
cls_cost = -pred_scores[:, gt_labels]
|
||||
|
||||
return cls_cost * self.weight
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class DiceCost(BaseMatchCost):
|
||||
"""Cost of mask assignments based on dice losses.
|
||||
|
||||
Args:
|
||||
pred_act (bool): Whether to apply sigmoid to mask_pred.
|
||||
Defaults to False.
|
||||
eps (float): Defaults to 1e-3.
|
||||
naive_dice (bool): If True, use the naive dice loss
|
||||
in which the power of the number in the denominator is
|
||||
the first power. If False, use the second power that
|
||||
is adopted by K-Net and SOLO. Defaults to True.
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pred_act: bool = False,
|
||||
eps: float = 1e-3,
|
||||
naive_dice: bool = True,
|
||||
weight: Union[float, int] = 1.) -> None:
|
||||
super().__init__(weight=weight)
|
||||
self.pred_act = pred_act
|
||||
self.eps = eps
|
||||
self.naive_dice = naive_dice
|
||||
|
||||
def _binary_mask_dice_loss(self, mask_preds: Tensor,
|
||||
gt_masks: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
mask_preds (Tensor): Mask prediction in shape (num_queries, *).
|
||||
gt_masks (Tensor): Ground truth in shape (num_gt, *)
|
||||
store 0 or 1, 0 for negative class and 1 for
|
||||
positive class.
|
||||
|
||||
Returns:
|
||||
Tensor: Dice cost matrix in shape (num_queries, num_gt).
|
||||
"""
|
||||
mask_preds = mask_preds.flatten(1)
|
||||
gt_masks = gt_masks.flatten(1).float()
|
||||
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
|
||||
if self.naive_dice:
|
||||
denominator = mask_preds.sum(-1)[:, None] + \
|
||||
gt_masks.sum(-1)[None, :]
|
||||
else:
|
||||
denominator = mask_preds.pow(2).sum(1)[:, None] + \
|
||||
gt_masks.pow(2).sum(1)[None, :]
|
||||
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
|
||||
return loss
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Predicted instances which
|
||||
must contain "masks".
|
||||
gt_instances (InstanceData): Ground truth which must contain
|
||||
"mask".
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'masks'), \
|
||||
"pred_instances must contain 'masks'"
|
||||
assert hasattr(gt_instances, 'masks'), \
|
||||
"gt_instances must contain 'masks'"
|
||||
pred_masks = pred_instances.masks
|
||||
gt_masks = gt_instances.masks
|
||||
|
||||
if self.pred_act:
|
||||
pred_masks = pred_masks.sigmoid()
|
||||
dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks)
|
||||
return dice_cost * self.weight
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class CrossEntropyLossCost(BaseMatchCost):
|
||||
"""CrossEntropyLossCost.
|
||||
|
||||
Args:
|
||||
use_sigmoid (bool): Whether the prediction uses sigmoid
|
||||
of softmax. Defaults to True.
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid: bool = True,
|
||||
weight: Union[float, int] = 1.) -> None:
|
||||
super().__init__(weight=weight)
|
||||
self.use_sigmoid = use_sigmoid
|
||||
|
||||
def _binary_cross_entropy(self, cls_pred: Tensor,
|
||||
gt_labels: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
|
||||
(num_queries, *).
|
||||
gt_labels (Tensor): The learning label of prediction with
|
||||
shape (num_gt, *).
|
||||
|
||||
Returns:
|
||||
Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
|
||||
"""
|
||||
cls_pred = cls_pred.flatten(1).float()
|
||||
gt_labels = gt_labels.flatten(1).float()
|
||||
n = cls_pred.shape[1]
|
||||
pos = F.binary_cross_entropy_with_logits(
|
||||
cls_pred, torch.ones_like(cls_pred), reduction='none')
|
||||
neg = F.binary_cross_entropy_with_logits(
|
||||
cls_pred, torch.zeros_like(cls_pred), reduction='none')
|
||||
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
|
||||
torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
|
||||
cls_cost = cls_cost / n
|
||||
|
||||
return cls_cost
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (:obj:`InstanceData`): Predicted instances which
|
||||
must contain ``masks``.
|
||||
gt_instances (:obj:`InstanceData`): Ground truth which must contain
|
||||
``masks``.
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'masks'), \
|
||||
"pred_instances must contain 'masks'"
|
||||
assert hasattr(gt_instances, 'masks'), \
|
||||
"gt_instances must contain 'masks'"
|
||||
pred_masks = pred_instances.masks
|
||||
gt_masks = gt_instances.masks
|
||||
if self.use_sigmoid:
|
||||
cls_cost = self._binary_cross_entropy(pred_masks, gt_masks)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return cls_cost * self.weight
|
||||
40
Seg_All_In_One_MMSeg/mmseg/models/backbones/__init__.py
Normal file
40
Seg_All_In_One_MMSeg/mmseg/models/backbones/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .beit import BEiT
|
||||
from .bisenetv1 import BiSeNetV1
|
||||
from .bisenetv2 import BiSeNetV2
|
||||
from .en_bisenetv2 import EnBiSeNetV2 # TODO
|
||||
from .my_bisnetv2_A1 import My_BiSeNetV2_A1 # TODO
|
||||
from .my_bisnetv2_A2 import My_BiSeNetV2_A2 # TODO
|
||||
from .my_bisnetv2_A1_add_A2 import My_BiSeNetV2_A1_add_A2 # TODO
|
||||
from .cgnet import CGNet
|
||||
from .ddrnet import DDRNet
|
||||
from .erfnet import ERFNet
|
||||
from .fast_scnn import FastSCNN
|
||||
from .hrnet import HRNet
|
||||
from .icnet import ICNet
|
||||
from .mae import MAE
|
||||
from .mit import MixVisionTransformer
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
from .mobilenet_v3 import MobileNetV3
|
||||
from .mscan import MSCAN
|
||||
from .pidnet import PIDNet
|
||||
from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
from .stdc import STDCContextPathNet, STDCNet
|
||||
from .swin import SwinTransformer
|
||||
from .timm_backbone import TIMMBackbone
|
||||
from .twins import PCPVT, SVT
|
||||
from .unet import UNet
|
||||
from .vit import VisionTransformer
|
||||
from .vpd import VPD
|
||||
|
||||
__all__ = [
|
||||
'EnBiSeNetV2', 'My_BiSeNetV2_A1', 'My_BiSeNetV2_A2', 'My_BiSeNetV2_A1_add_A2', # TODO
|
||||
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
|
||||
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
|
||||
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
|
||||
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
|
||||
'DDRNet', 'VPD'
|
||||
]
|
||||
554
Seg_All_In_One_MMSeg/mmseg/models/backbones/beit.py
Normal file
554
Seg_All_In_One_MMSeg/mmseg/models/backbones/beit.py
Normal file
@@ -0,0 +1,554 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import _load_checkpoint
|
||||
from scipy import interpolate
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed
|
||||
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
|
||||
|
||||
|
||||
class BEiTAttention(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||
position bias.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
bias (bool): The option to add leanable bias for q, k, v. If bias is
|
||||
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
||||
add leanable bias for q, v. If bias is False, it will not add bias
|
||||
for q, k, v. Default to 'qv_bias'.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float): Dropout ratio of output. Default: 0.
|
||||
init_cfg (dict | None, optional): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
bias='qv_bias',
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.bias = bias
|
||||
self.scale = qk_scale or head_embed_dims**-0.5
|
||||
|
||||
qkv_bias = bias
|
||||
if bias == 'qv_bias':
|
||||
self._init_qv_bias()
|
||||
qkv_bias = False
|
||||
|
||||
self.window_size = window_size
|
||||
self._init_rel_pos_embedding()
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||||
|
||||
def _init_qv_bias(self):
|
||||
self.q_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||||
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||||
|
||||
def _init_rel_pos_embedding(self):
|
||||
Wh, Ww = self.window_size
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3
|
||||
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, self.num_heads))
|
||||
|
||||
# get pair-wise relative position index for
|
||||
# each token inside the window
|
||||
coords_h = torch.arange(Wh)
|
||||
coords_w = torch.arange(Ww)
|
||||
# coords shape is (2, Wh, Ww)
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
# coords_flatten shape is (2, Wh*Ww)
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :])
|
||||
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
# shift to start from 0
|
||||
relative_coords[:, :, 0] += Wh - 1
|
||||
relative_coords[:, :, 1] += Ww - 1
|
||||
relative_coords[:, :, 0] *= 2 * Ww - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype)
|
||||
# relative_position_index shape is (Wh*Ww, Wh*Ww)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1)
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (tensor): input features with shape of (num_windows*B, N, C).
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
|
||||
if self.bias == 'qv_bias':
|
||||
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
|
||||
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
else:
|
||||
qkv = self.qkv(x)
|
||||
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
if self.relative_position_bias_table is not None:
|
||||
Wh = self.window_size[0]
|
||||
Ww = self.window_size[1]
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
Wh * Ww + 1, Wh * Ww + 1, -1)
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
bias (bool): The option to add leanable bias for q, k, v. If bias is
|
||||
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
||||
add leanable bias for q, v. If bias is False, it will not add bias
|
||||
for q, k, v. Default to 'qv_bias'.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
window_size (tuple[int], optional): The height and width of the window.
|
||||
Default: None.
|
||||
init_values (float, optional): Initialize the values of BEiTAttention
|
||||
and FFN with learnable scaling. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
bias='qv_bias',
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
window_size=None,
|
||||
attn_cfg=dict(),
|
||||
ffn_cfg=dict(add_identity=False),
|
||||
init_values=None):
|
||||
attn_cfg.update(dict(window_size=window_size, qk_scale=None))
|
||||
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=feedforward_channels,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
attn_cfg=attn_cfg,
|
||||
ffn_cfg=ffn_cfg)
|
||||
|
||||
# NOTE: drop path for stochastic depth, we shall see if
|
||||
# this is better than dropout here
|
||||
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
self.drop_path = build_dropout(
|
||||
dropout_layer) if dropout_layer else nn.Identity()
|
||||
self.gamma_1 = nn.Parameter(
|
||||
init_values * torch.ones(embed_dims), requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
init_values * torch.ones(embed_dims), requires_grad=True)
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = BEiTAttention(**attn_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BEiT(BaseModule):
|
||||
"""BERT Pre-Training of Image Transformers.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): Embedding dimension. Default: 768.
|
||||
num_layers (int): Depth of transformer. Default: 12.
|
||||
num_heads (int): Number of attention heads. Default: 12.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
qv_bias (bool): Enable bias for qv if True. Default: True.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_values (float): Initialize the values of BEiTAttention and FFN
|
||||
with learnable scaling.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_indices=-1,
|
||||
qv_bias=True,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
final_norm=False,
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
pretrained=None,
|
||||
init_values=0.1,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
elif isinstance(img_size, tuple):
|
||||
if len(img_size) == 1:
|
||||
img_size = to_2tuple(img_size[0])
|
||||
assert len(img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(img_size)}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.norm_eval = norm_eval
|
||||
self.pretrained = pretrained
|
||||
self.num_layers = num_layers
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.attn_drop_rate = attn_drop_rate
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.num_fcs = num_fcs
|
||||
self.qv_bias = qv_bias
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.patch_norm = patch_norm
|
||||
self.init_values = init_values
|
||||
self.window_size = (img_size[0] // patch_size,
|
||||
img_size[1] // patch_size)
|
||||
self.patch_shape = self.window_size
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
|
||||
self._build_patch_embedding()
|
||||
self._build_layers()
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
if out_indices == -1:
|
||||
out_indices = num_layers - 1
|
||||
self.out_indices = [out_indices]
|
||||
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
||||
self.out_indices = out_indices
|
||||
else:
|
||||
raise TypeError('out_indices must be type of int, list or tuple')
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
def _build_patch_embedding(self):
|
||||
"""Build patch embedding layer."""
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=self.in_channels,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding=0,
|
||||
norm_cfg=self.norm_cfg if self.patch_norm else None,
|
||||
init_cfg=None)
|
||||
|
||||
def _build_layers(self):
|
||||
"""Build transformer encoding layers."""
|
||||
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
|
||||
]
|
||||
self.layers = ModuleList()
|
||||
for i in range(self.num_layers):
|
||||
self.layers.append(
|
||||
BEiTTransformerEncoderLayer(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.num_heads,
|
||||
feedforward_channels=self.mlp_ratio * self.embed_dims,
|
||||
attn_drop_rate=self.attn_drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=self.num_fcs,
|
||||
bias='qv_bias' if self.qv_bias else False,
|
||||
act_cfg=self.act_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
window_size=self.window_size,
|
||||
init_values=self.init_values))
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def _geometric_sequence_interpolation(self, src_size, dst_size, sequence,
|
||||
num):
|
||||
"""Get new sequence via geometric sequence interpolation.
|
||||
|
||||
Args:
|
||||
src_size (int): Pos_embedding size in pre-trained model.
|
||||
dst_size (int): Pos_embedding size in the current model.
|
||||
sequence (tensor): The relative position bias of the pretrain
|
||||
model after removing the extra tokens.
|
||||
num (int): Number of attention heads.
|
||||
Returns:
|
||||
new_sequence (tensor): Geometric sequence interpolate the
|
||||
pre-trained relative position bias to the size of
|
||||
the current model.
|
||||
"""
|
||||
|
||||
def geometric_progression(a, r, n):
|
||||
return a * (1.0 - r**n) / (1.0 - r)
|
||||
|
||||
# Here is a binary function.
|
||||
left, right = 1.01, 1.5
|
||||
while right - left > 1e-6:
|
||||
q = (left + right) / 2.0
|
||||
gp = geometric_progression(1, q, src_size // 2)
|
||||
if gp > dst_size // 2:
|
||||
right = q
|
||||
else:
|
||||
left = q
|
||||
# The position of each interpolated point is determined
|
||||
# by the ratio obtained by dichotomy.
|
||||
dis = []
|
||||
cur = 1
|
||||
for i in range(src_size // 2):
|
||||
dis.append(cur)
|
||||
cur += q**(i + 1)
|
||||
r_ids = [-_ for _ in reversed(dis)]
|
||||
x = r_ids + [0] + dis
|
||||
y = r_ids + [0] + dis
|
||||
t = dst_size // 2.0
|
||||
dx = np.arange(-t, t + 0.1, 1.0)
|
||||
dy = np.arange(-t, t + 0.1, 1.0)
|
||||
# Interpolation functions are being executed and called.
|
||||
new_sequence = []
|
||||
for i in range(num):
|
||||
z = sequence[:, i].view(src_size, src_size).float().numpy()
|
||||
f = interpolate.interp2d(x, y, z, kind='cubic')
|
||||
new_sequence.append(
|
||||
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence))
|
||||
new_sequence = torch.cat(new_sequence, dim=-1)
|
||||
return new_sequence
|
||||
|
||||
def resize_rel_pos_embed(self, checkpoint):
|
||||
"""Resize relative pos_embed weights.
|
||||
|
||||
This function is modified from
|
||||
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT License
|
||||
Args:
|
||||
checkpoint (dict): Key and value of the pretrain model.
|
||||
Returns:
|
||||
state_dict (dict): Interpolate the relative pos_embed weights
|
||||
in the pre-train model to the current model size.
|
||||
"""
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
for key in all_keys:
|
||||
if 'relative_position_index' in key:
|
||||
state_dict.pop(key)
|
||||
# In order to keep the center of pos_bias as consistent as
|
||||
# possible after interpolation, and vice versa in the edge
|
||||
# area, the geometric sequence interpolation method is adopted.
|
||||
if 'relative_position_bias_table' in key:
|
||||
rel_pos_bias = state_dict[key]
|
||||
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
||||
dst_num_pos, _ = self.state_dict()[key].size()
|
||||
dst_patch_shape = self.patch_shape
|
||||
if dst_patch_shape[0] != dst_patch_shape[1]:
|
||||
raise NotImplementedError()
|
||||
# Count the number of extra tokens.
|
||||
num_extra_tokens = dst_num_pos - (
|
||||
dst_patch_shape[0] * 2 - 1) * (
|
||||
dst_patch_shape[1] * 2 - 1)
|
||||
src_size = int((src_num_pos - num_extra_tokens)**0.5)
|
||||
dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
|
||||
if src_size != dst_size:
|
||||
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
||||
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
||||
new_rel_pos_bias = self._geometric_sequence_interpolation(
|
||||
src_size, dst_size, rel_pos_bias, num_attn_heads)
|
||||
new_rel_pos_bias = torch.cat(
|
||||
(new_rel_pos_bias, extra_tokens), dim=0)
|
||||
state_dict[key] = new_rel_pos_bias
|
||||
|
||||
return state_dict
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
self.apply(_init_weights)
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
checkpoint = _load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||
self.load_state_dict(state_dict, False)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
# Copyright 2019 Ross Wightman
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
# Remove class token and reshape token for decoder head
|
||||
out = x[:, 1:]
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
332
Seg_All_In_One_MMSeg/mmseg/models/backbones/bisenetv1.py
Normal file
332
Seg_All_In_One_MMSeg/mmseg/models/backbones/bisenetv1.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class SpatialPath(BaseModule):
|
||||
"""Spatial Path to preserve the spatial size of the original input image
|
||||
and encode affluent spatial information.
|
||||
|
||||
Args:
|
||||
in_channels(int): The number of channels of input
|
||||
image. Default: 3.
|
||||
num_channels (Tuple[int]): The number of channels of
|
||||
each layers in Spatial Path.
|
||||
Default: (64, 64, 64, 128).
|
||||
Returns:
|
||||
x (torch.Tensor): Feature map for Feature Fusion Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
num_channels=(64, 64, 64, 128),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(num_channels) == 4, 'Length of input channels \
|
||||
of Spatial Path must be 4!'
|
||||
|
||||
self.layers = []
|
||||
for i in range(len(num_channels)):
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.layers.append(layer_name)
|
||||
if i == 0:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
elif i == len(num_channels) - 1:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=num_channels[i - 1],
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
else:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=num_channels[i - 1],
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer_stage = getattr(self, layer_name)
|
||||
x = layer_stage(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionRefinementModule(BaseModule):
|
||||
"""Attention Refinement Module (ARM) to refine the features of each stage.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
Returns:
|
||||
x_out (torch.Tensor): Feature map of Attention Refinement Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channel,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.atten_conv_layer = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
in_channels=out_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_layer(x)
|
||||
x_atten = self.atten_conv_layer(x)
|
||||
x_out = x * x_atten
|
||||
return x_out
|
||||
|
||||
|
||||
class ContextPath(BaseModule):
|
||||
"""Context Path to provide sufficient receptive field.
|
||||
|
||||
Args:
|
||||
backbone_cfg:(dict): Config of backbone of
|
||||
Context Path.
|
||||
context_channels (Tuple[int]): The number of channel numbers
|
||||
of various modules in Context Path.
|
||||
Default: (128, 256, 512).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation. Default: False.
|
||||
Returns:
|
||||
x_16_up, x_32_up (torch.Tensor, torch.Tensor): Two feature maps
|
||||
undergoing upsampling from 1/16 and 1/32 downsampling
|
||||
feature maps. These two feature maps are used for Feature
|
||||
Fusion Module and Auxiliary Head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
context_channels=(128, 256, 512),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(context_channels) == 3, 'Length of input channels \
|
||||
of Context Path must be 3!'
|
||||
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
|
||||
self.align_corners = align_corners
|
||||
self.arm16 = AttentionRefinementModule(context_channels[1],
|
||||
context_channels[0])
|
||||
self.arm32 = AttentionRefinementModule(context_channels[2],
|
||||
context_channels[0])
|
||||
self.conv_head32 = ConvModule(
|
||||
in_channels=context_channels[0],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv_head16 = ConvModule(
|
||||
in_channels=context_channels[0],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.gap_conv = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
in_channels=context_channels[2],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
x_4, x_8, x_16, x_32 = self.backbone(x)
|
||||
x_gap = self.gap_conv(x_32)
|
||||
|
||||
x_32_arm = self.arm32(x_32)
|
||||
x_32_sum = x_32_arm + x_gap
|
||||
x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest')
|
||||
x_32_up = self.conv_head32(x_32_up)
|
||||
|
||||
x_16_arm = self.arm16(x_16)
|
||||
x_16_sum = x_16_arm + x_32_up
|
||||
x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest')
|
||||
x_16_up = self.conv_head16(x_16_up)
|
||||
|
||||
return x_16_up, x_32_up
|
||||
|
||||
|
||||
class FeatureFusionModule(BaseModule):
|
||||
"""Feature Fusion Module to fuse low level output feature of Spatial Path
|
||||
and high level output feature of Context Path.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
Returns:
|
||||
x_out (torch.Tensor): Feature map of Feature Fusion Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.conv_atten = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg), nn.Sigmoid())
|
||||
|
||||
def forward(self, x_sp, x_cp):
|
||||
x_concat = torch.cat([x_sp, x_cp], dim=1)
|
||||
x_fuse = self.conv1(x_concat)
|
||||
x_atten = self.gap(x_fuse)
|
||||
# Note: No BN and more 1x1 conv in paper.
|
||||
x_atten = self.conv_atten(x_atten)
|
||||
x_atten = x_fuse * x_atten
|
||||
x_out = x_atten + x_fuse
|
||||
return x_out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BiSeNetV1(BaseModule):
|
||||
"""BiSeNetV1 backbone.
|
||||
|
||||
This backbone is the implementation of `BiSeNet: Bilateral
|
||||
Segmentation Network for Real-time Semantic
|
||||
Segmentation <https://arxiv.org/abs/1808.00897>`_.
|
||||
|
||||
Args:
|
||||
backbone_cfg:(dict): Config of backbone of
|
||||
Context Path.
|
||||
in_channels (int): The number of channels of input
|
||||
image. Default: 3.
|
||||
spatial_channels (Tuple[int]): Size of channel numbers of
|
||||
various layers in Spatial Path.
|
||||
Default: (64, 64, 64, 128).
|
||||
context_channels (Tuple[int]): Size of channel numbers of
|
||||
various modules in Context Path.
|
||||
Default: (128, 256, 512).
|
||||
out_indices (Tuple[int] | int, optional): Output from which stages.
|
||||
Default: (0, 1, 2).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation in Bilateral Guided Aggregation Layer.
|
||||
Default: False.
|
||||
out_channels(int): The number of channels of output.
|
||||
It must be the same with `in_channels` of decode_head.
|
||||
Default: 256.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
in_channels=3,
|
||||
spatial_channels=(64, 64, 64, 128),
|
||||
context_channels=(128, 256, 512),
|
||||
out_indices=(0, 1, 2),
|
||||
align_corners=False,
|
||||
out_channels=256,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(spatial_channels) == 4, 'Length of input channels \
|
||||
of Spatial Path must be 4!'
|
||||
|
||||
assert len(context_channels) == 3, 'Length of input channels \
|
||||
of Context Path must be 3!'
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.align_corners = align_corners
|
||||
self.context_path = ContextPath(backbone_cfg, context_channels,
|
||||
self.align_corners)
|
||||
self.spatial_path = SpatialPath(in_channels, spatial_channels)
|
||||
self.ffm = FeatureFusionModule(context_channels[1], out_channels)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
def forward(self, x):
|
||||
# stole refactoring code from Coin Cheung, thanks
|
||||
x_context8, x_context16 = self.context_path(x)
|
||||
x_spatial = self.spatial_path(x)
|
||||
x_fuse = self.ffm(x_spatial, x_context8)
|
||||
|
||||
outs = [x_fuse, x_context8, x_context16]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
622
Seg_All_In_One_MMSeg/mmseg/models/backbones/bisenetv2.py
Normal file
622
Seg_All_In_One_MMSeg/mmseg/models/backbones/bisenetv2.py
Normal file
@@ -0,0 +1,622 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
||||
build_activation_layer, build_norm_layer)
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class DetailBranch(BaseModule):
|
||||
"""Detail Branch with wide channels and shallow layers to capture low-level
|
||||
details and generate high-resolution feature representation.
|
||||
|
||||
Args:
|
||||
detail_channels (Tuple[int]): Size of channel numbers of each stage
|
||||
in Detail Branch, in paper it has 3 stages.
|
||||
Default: (64, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Feature map of Detail Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
detail_channels=(64, 64, 128),
|
||||
in_channels=3,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
detail_branch = []
|
||||
for i in range(len(detail_channels)):
|
||||
if i == 0:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
else:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i - 1],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
self.detail_branch = nn.ModuleList(detail_branch)
|
||||
|
||||
def forward(self, x):
|
||||
for stage in self.detail_branch:
|
||||
x = stage(x)
|
||||
return x
|
||||
|
||||
|
||||
class StemBlock(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): First feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_first = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.convs = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=out_channels // 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.pool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=False)
|
||||
self.fuse_last = ConvModule(
|
||||
in_channels=out_channels * 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_first(x)
|
||||
x_left = self.convs(x)
|
||||
x_right = self.pool(x)
|
||||
x = self.fuse_last(torch.cat([x_left, x_right], dim=1))
|
||||
return x
|
||||
|
||||
|
||||
class GELayer(BaseModule):
|
||||
"""Gather-and-Expansion Layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
stride (int): Stride of GELayer. Default: 1
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Intermediate feature map in
|
||||
Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
exp_ratio=6,
|
||||
stride=1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
mid_channel = in_channels * exp_ratio
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if stride == 1:
|
||||
self.dwconv = nn.Sequential(
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.shortcut = None
|
||||
else:
|
||||
self.dwconv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=mid_channel,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
)
|
||||
self.shortcut = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=norm_cfg,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None,
|
||||
))
|
||||
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.dwconv(x)
|
||||
x = self.conv2(x)
|
||||
if self.shortcut is not None:
|
||||
shortcut = self.shortcut(identity)
|
||||
x = x + shortcut
|
||||
else:
|
||||
x = x + identity
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class CEBlock(BaseModule):
|
||||
"""Context Embedding Block for large receptive filed in Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Last feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.gap = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
build_norm_layer(norm_cfg, self.in_channels)[1])
|
||||
self.conv_gap = ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# Note: in paper here is naive conv2d, no bn-relu
|
||||
self.conv_last = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.gap(x)
|
||||
x = self.conv_gap(x)
|
||||
x = identity + x
|
||||
x = self.conv_last(x)
|
||||
return x
|
||||
|
||||
|
||||
class SemanticBranch(BaseModule):
|
||||
"""Semantic Branch which is lightweight with narrow channels and deep
|
||||
layers to obtain high-level semantic context.
|
||||
|
||||
Args:
|
||||
semantic_channels(Tuple[int]): Size of channel numbers of
|
||||
various stages in Semantic Branch.
|
||||
Default: (16, 32, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
semantic_outs (List[torch.Tensor]): List of several feature maps
|
||||
for auxiliary heads (Booster) and Bilateral
|
||||
Guided Aggregation Layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
in_channels=3,
|
||||
exp_ratio=6,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_stages = []
|
||||
for i in range(len(semantic_channels)):
|
||||
stage_name = f'stage{i + 1}'
|
||||
self.semantic_stages.append(stage_name)
|
||||
if i == 0:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
StemBlock(self.in_channels, semantic_channels[i]))
|
||||
elif i == (len(semantic_channels) - 1):
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
else:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
|
||||
self.add_module(f'stage{len(semantic_channels)}_CEBlock',
|
||||
CEBlock(semantic_channels[-1], semantic_channels[-1]))
|
||||
self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock')
|
||||
|
||||
def forward(self, x):
|
||||
semantic_outs = []
|
||||
for stage_name in self.semantic_stages:
|
||||
semantic_stage = getattr(self, stage_name)
|
||||
x = semantic_stage(x)
|
||||
semantic_outs.append(x)
|
||||
return semantic_outs
|
||||
|
||||
|
||||
class BGALayer(BaseModule):
|
||||
"""Bilateral Guided Aggregation Layer to fuse the complementary information
|
||||
from both Detail Branch and Semantic Branch.
|
||||
|
||||
Args:
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
output (torch.Tensor): Output feature map for Segment heads.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_channels=128,
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.out_channels = out_channels
|
||||
self.align_corners = align_corners
|
||||
self.detail_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.detail_down = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
|
||||
self.semantic_conv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
self.semantic_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.conv = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
inplace=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
|
||||
def forward(self, x_d, x_s):
|
||||
detail_dwconv = self.detail_dwconv(x_d)
|
||||
detail_down = self.detail_down(x_d)
|
||||
semantic_conv = self.semantic_conv(x_s)
|
||||
semantic_dwconv = self.semantic_dwconv(x_s)
|
||||
semantic_conv = resize(
|
||||
input=semantic_conv,
|
||||
size=detail_dwconv.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
|
||||
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
|
||||
fuse_2 = resize(
|
||||
input=fuse_2,
|
||||
size=fuse_1.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
output = self.conv(fuse_1 + fuse_2)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BiSeNetV2(BaseModule):
|
||||
"""BiSeNetV2: Bilateral Network with Guided Aggregation for
|
||||
Real-time Semantic Segmentation.
|
||||
|
||||
This backbone is the implementation of
|
||||
`BiSeNetV2 <https://arxiv.org/abs/2004.02147>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channel of input image. Default: 3.
|
||||
detail_channels (Tuple[int], optional): Channels of each stage
|
||||
in Detail Branch. Default: (64, 64, 128).
|
||||
semantic_channels (Tuple[int], optional): Channels of each stage
|
||||
in Semantic Branch. Default: (16, 32, 64, 128).
|
||||
See Table 1 and Figure 3 of paper for more details.
|
||||
semantic_expansion_ratio (int, optional): The expansion factor
|
||||
expanding channel number of middle channels in Semantic Branch.
|
||||
Default: 6.
|
||||
bga_channels (int, optional): Number of middle channels in
|
||||
Bilateral Guided Aggregation Layer. Default: 128.
|
||||
out_indices (Tuple[int] | int, optional): Output from which stages.
|
||||
Default: (0, 1, 2, 3, 4).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation in Bilateral Guided Aggregation Layer.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
detail_channels=(64, 64, 128),
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
semantic_expansion_ratio=6,
|
||||
bga_channels=128,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_indices = out_indices
|
||||
self.detail_channels = detail_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_expansion_ratio = semantic_expansion_ratio
|
||||
self.bga_channels = bga_channels
|
||||
self.align_corners = align_corners
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.detail = DetailBranch(self.detail_channels, self.in_channels)
|
||||
self.semantic = SemanticBranch(self.semantic_channels,
|
||||
self.in_channels,
|
||||
self.semantic_expansion_ratio)
|
||||
self.bga = BGALayer(self.bga_channels, self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
# stole refactoring code from Coin Cheung, thanks
|
||||
x_detail = self.detail(x)
|
||||
x_semantic_lst = self.semantic(x)
|
||||
x_head = self.bga(x_detail, x_semantic_lst[-1])
|
||||
outs = [x_head] + x_semantic_lst[:-1]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
372
Seg_All_In_One_MMSeg/mmseg/models/backbones/cgnet.py
Normal file
372
Seg_All_In_One_MMSeg/mmseg/models/backbones/cgnet.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class GlobalContextExtractor(nn.Module):
|
||||
"""Global Context Extractor for CGNet.
|
||||
|
||||
This class is employed to refine the joint feature of both local feature
|
||||
and surrounding context.
|
||||
|
||||
Args:
|
||||
channel (int): Number of input feature channels.
|
||||
reduction (int): Reductions for global context extractor. Default: 16.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, channel, reduction=16, with_cp=False):
|
||||
super().__init__()
|
||||
self.channel = channel
|
||||
self.reduction = reduction
|
||||
assert reduction >= 1 and channel >= reduction
|
||||
self.with_cp = with_cp
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
num_batch, num_channel = x.size()[:2]
|
||||
y = self.avg_pool(x).view(num_batch, num_channel)
|
||||
y = self.fc(y).view(num_batch, num_channel, 1, 1)
|
||||
return x * y
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ContextGuidedBlock(nn.Module):
|
||||
"""Context Guided Block for CGNet.
|
||||
|
||||
This class consists of four components: local feature extractor,
|
||||
surrounding feature extractor, joint feature extractor and global
|
||||
context extractor.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input feature channels.
|
||||
out_channels (int): Number of output feature channels.
|
||||
dilation (int): Dilation rate for surrounding context extractor.
|
||||
Default: 2.
|
||||
reduction (int): Reduction for global context extractor. Default: 16.
|
||||
skip_connect (bool): Add input to output or not. Default: True.
|
||||
downsample (bool): Downsample the input to 1/2 or not. Default: False.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='PReLU').
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dilation=2,
|
||||
reduction=16,
|
||||
skip_connect=True,
|
||||
downsample=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='PReLU'),
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
self.with_cp = with_cp
|
||||
self.downsample = downsample
|
||||
|
||||
channels = out_channels if downsample else out_channels // 2
|
||||
if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
|
||||
act_cfg['num_parameters'] = channels
|
||||
kernel_size = 3 if downsample else 1
|
||||
stride = 2 if downsample else 1
|
||||
padding = (kernel_size - 1) // 2
|
||||
|
||||
self.conv1x1 = ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.f_loc = build_conv_layer(
|
||||
conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
groups=channels,
|
||||
bias=False)
|
||||
self.f_sur = build_conv_layer(
|
||||
conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
groups=channels,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
|
||||
self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
|
||||
self.activate = nn.PReLU(2 * channels)
|
||||
|
||||
if downsample:
|
||||
self.bottleneck = build_conv_layer(
|
||||
conv_cfg,
|
||||
2 * channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
|
||||
self.skip_connect = skip_connect and not downsample
|
||||
self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
out = self.conv1x1(x)
|
||||
loc = self.f_loc(out)
|
||||
sur = self.f_sur(out)
|
||||
|
||||
joi_feat = torch.cat([loc, sur], 1) # the joint feature
|
||||
joi_feat = self.bn(joi_feat)
|
||||
joi_feat = self.activate(joi_feat)
|
||||
if self.downsample:
|
||||
joi_feat = self.bottleneck(joi_feat) # channel = out_channels
|
||||
# f_glo is employed to refine the joint feature
|
||||
out = self.f_glo(joi_feat)
|
||||
|
||||
if self.skip_connect:
|
||||
return x + out
|
||||
else:
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class InputInjection(nn.Module):
|
||||
"""Downsampling module for CGNet."""
|
||||
|
||||
def __init__(self, num_downsampling):
|
||||
super().__init__()
|
||||
self.pool = nn.ModuleList()
|
||||
for i in range(num_downsampling):
|
||||
self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
for pool in self.pool:
|
||||
x = pool(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CGNet(BaseModule):
|
||||
"""CGNet backbone.
|
||||
|
||||
This backbone is the implementation of `A Light-weight Context Guided
|
||||
Network for Semantic Segmentation <https://arxiv.org/abs/1811.08201>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
num_channels (tuple[int]): Numbers of feature channels at each stages.
|
||||
Default: (32, 64, 128).
|
||||
num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
|
||||
Default: (3, 21).
|
||||
dilations (tuple[int]): Dilation rate for surrounding context
|
||||
extractors at stage 1 and stage 2. Default: (2, 4).
|
||||
reductions (tuple[int]): Reductions for global context extractors at
|
||||
stage 1 and stage 2. Default: (8, 16).
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='PReLU').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
num_channels=(32, 64, 128),
|
||||
num_blocks=(3, 21),
|
||||
dilations=(2, 4),
|
||||
reductions=(8, 16),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='PReLU'),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer=['Conv2d', 'Linear']),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm']),
|
||||
dict(type='Constant', val=0, layer='PReLU')
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_channels = num_channels
|
||||
assert isinstance(self.num_channels, tuple) and len(
|
||||
self.num_channels) == 3
|
||||
self.num_blocks = num_blocks
|
||||
assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
|
||||
self.dilations = dilations
|
||||
assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
|
||||
self.reductions = reductions
|
||||
assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
|
||||
self.act_cfg['num_parameters'] = num_channels[0]
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
cur_channels = in_channels
|
||||
self.stem = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.stem.append(
|
||||
ConvModule(
|
||||
cur_channels,
|
||||
num_channels[0],
|
||||
3,
|
||||
2 if i == 0 else 1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
cur_channels = num_channels[0]
|
||||
|
||||
self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
|
||||
self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
|
||||
|
||||
cur_channels += in_channels
|
||||
self.norm_prelu_0 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
# stage 1
|
||||
self.level1 = nn.ModuleList()
|
||||
for i in range(num_blocks[0]):
|
||||
self.level1.append(
|
||||
ContextGuidedBlock(
|
||||
cur_channels if i == 0 else num_channels[1],
|
||||
num_channels[1],
|
||||
dilations[0],
|
||||
reductions[0],
|
||||
downsample=(i == 0),
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
with_cp=with_cp)) # CG block
|
||||
|
||||
cur_channels = 2 * num_channels[1] + in_channels
|
||||
self.norm_prelu_1 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
# stage 2
|
||||
self.level2 = nn.ModuleList()
|
||||
for i in range(num_blocks[1]):
|
||||
self.level2.append(
|
||||
ContextGuidedBlock(
|
||||
cur_channels if i == 0 else num_channels[2],
|
||||
num_channels[2],
|
||||
dilations[1],
|
||||
reductions[1],
|
||||
downsample=(i == 0),
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
with_cp=with_cp)) # CG block
|
||||
|
||||
cur_channels = 2 * num_channels[2]
|
||||
self.norm_prelu_2 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
def forward(self, x):
|
||||
output = []
|
||||
|
||||
# stage 0
|
||||
inp_2x = self.inject_2x(x)
|
||||
inp_4x = self.inject_4x(x)
|
||||
for layer in self.stem:
|
||||
x = layer(x)
|
||||
x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
|
||||
output.append(x)
|
||||
|
||||
# stage 1
|
||||
for i, layer in enumerate(self.level1):
|
||||
x = layer(x)
|
||||
if i == 0:
|
||||
down1 = x
|
||||
x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
|
||||
output.append(x)
|
||||
|
||||
# stage 2
|
||||
for i, layer in enumerate(self.level2):
|
||||
x = layer(x)
|
||||
if i == 0:
|
||||
down2 = x
|
||||
x = self.norm_prelu_2(torch.cat([down2, x], 1))
|
||||
output.append(x)
|
||||
|
||||
return output
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode will keeping the normalization
|
||||
layer freezed."""
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
222
Seg_All_In_One_MMSeg/mmseg/models/backbones/ddrnet.py
Normal file
222
Seg_All_In_One_MMSeg/mmseg/models/backbones/ddrnet.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DDRNet(BaseModule):
|
||||
"""DDRNet backbone.
|
||||
|
||||
This backbone is the implementation of `Deep Dual-resolution Networks for
|
||||
Real-time and Accurate Semantic Segmentation of Road Scenes
|
||||
<http://arxiv.org/abs/2101.06085>`_.
|
||||
Modified from https://github.com/ydhongHIT/DDRNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
channels: (int): The base channels of DDRNet. Default: 32.
|
||||
ppm_channels (int): The channels of PPM module. Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
norm_cfg (dict): Config dict to build norm layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int = 3,
|
||||
channels: int = 32,
|
||||
ppm_channels: int = 128,
|
||||
align_corners: bool = False,
|
||||
norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.ppm_channels = ppm_channels
|
||||
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
|
||||
# stage 0-2
|
||||
self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# low resolution(context) branch
|
||||
self.context_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.context_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
inplanes=channels * 2**(i + 1),
|
||||
planes=channels * 8 if i > 0 else channels * 4,
|
||||
num_blocks=2 if i < 2 else 1,
|
||||
stride=2))
|
||||
|
||||
# bilateral fusion
|
||||
self.compression_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.down_1 = ConvModule(
|
||||
channels * 2,
|
||||
channels * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
self.compression_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.down_2 = nn.Sequential(
|
||||
ConvModule(
|
||||
channels * 2,
|
||||
channels * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels * 4,
|
||||
channels * 8,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None))
|
||||
|
||||
# high resolution(spatial) branch
|
||||
self.spatial_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.spatial_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
inplanes=channels * 2,
|
||||
planes=channels * 2,
|
||||
num_blocks=2 if i < 2 else 1,
|
||||
))
|
||||
|
||||
self.spp = DAPPM(
|
||||
channels * 16, ppm_channels, channels * 4, num_scales=5)
|
||||
|
||||
def _make_stem_layer(self, in_channels, channels, num_blocks):
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
|
||||
layers.extend([
|
||||
self._make_layer(BasicBlock, channels, channels, num_blocks),
|
||||
nn.ReLU(),
|
||||
self._make_layer(
|
||||
BasicBlock, channels, channels * 2, num_blocks, stride=2),
|
||||
nn.ReLU(),
|
||||
])
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, num_blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
||||
|
||||
layers = [
|
||||
block(
|
||||
in_channels=inplanes,
|
||||
channels=planes,
|
||||
stride=stride,
|
||||
downsample=downsample)
|
||||
]
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
in_channels=inplanes,
|
||||
channels=planes,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
out_size = (x.shape[-2] // 8, x.shape[-1] // 8)
|
||||
|
||||
# stage 0-2
|
||||
x = self.stem(x)
|
||||
|
||||
# stage3
|
||||
x_c = self.context_branch_layers[0](x)
|
||||
x_s = self.spatial_branch_layers[0](x)
|
||||
comp_c = self.compression_1(self.relu(x_c))
|
||||
x_c += self.down_1(self.relu(x_s))
|
||||
x_s += resize(
|
||||
comp_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_context = x_s.clone()
|
||||
|
||||
# stage4
|
||||
x_c = self.context_branch_layers[1](self.relu(x_c))
|
||||
x_s = self.spatial_branch_layers[1](self.relu(x_s))
|
||||
comp_c = self.compression_2(self.relu(x_c))
|
||||
x_c += self.down_2(self.relu(x_s))
|
||||
x_s += resize(
|
||||
comp_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
# stage5
|
||||
x_s = self.spatial_branch_layers[2](self.relu(x_s))
|
||||
x_c = self.context_branch_layers[2](self.relu(x_c))
|
||||
x_c = self.spp(x_c)
|
||||
x_c = resize(
|
||||
x_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
return (temp_context, x_s + x_c) if self.training else x_s + x_c
|
||||
418
Seg_All_In_One_MMSeg/mmseg/models/backbones/en_bisenetv2.py
Normal file
418
Seg_All_In_One_MMSeg/mmseg/models/backbones/en_bisenetv2.py
Normal file
@@ -0,0 +1,418 @@
|
||||
# === Begin: Final corrected code to be appended for EnBiSeNetV2 ===
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule, Sequential
|
||||
from mmseg.registry import MODELS
|
||||
# Import existing BiSeNetV2 components to inherit from them
|
||||
from.bisenetv2 import BiSeNetV2, GELayer, StemBlock
|
||||
|
||||
# --- Helper Modules for EnBiSeNetV2 ---
|
||||
|
||||
class DepthwiseSeparableConvModule(BaseModule):
|
||||
"""Depthwise Separable Convolutional Module."""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.depthwise_conv = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=in_channels,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.pointwise_conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.pointwise_conv(x)
|
||||
return x
|
||||
|
||||
class SpatialAttentionModule(BaseModule):
|
||||
"""Spatial Attention Module."""
|
||||
def __init__(self, init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv = ConvModule(
|
||||
2, 1, kernel_size=7, padding=3, act_cfg=None, norm_cfg=None)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
attention_map = torch.cat([avg_out, max_out], dim=1)
|
||||
attention_map = self.conv(attention_map)
|
||||
attention_map = self.sigmoid(attention_map)
|
||||
return x * attention_map
|
||||
|
||||
class ChannelAttentionModule(BaseModule):
|
||||
"""Channel Attention Module."""
|
||||
def __init__(self, in_channels, reduction=16, init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_channels, in_channels // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(in_channels // reduction, in_channels, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
class SEModule(BaseModule):
|
||||
"""Squeeze-and-Excitation Module."""
|
||||
def __init__(self, channels, reduction=16, init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channels, channels // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channels // reduction, channels, bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y
|
||||
|
||||
class CEBlockSimAspp(BaseModule):
|
||||
"""Simplified ASPP for Context Embedding."""
|
||||
def __init__(self, in_channels, out_channels, rates=(1, 6, 12), norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.global_avg_pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
)
|
||||
|
||||
self.aspp_branches = nn.ModuleList()
|
||||
for rate in rates:
|
||||
reduced_channels = in_channels // 2
|
||||
if reduced_channels == 0:
|
||||
reduced_channels = 1
|
||||
self.aspp_branches.append(
|
||||
nn.Sequential(
|
||||
ConvModule(in_channels, reduced_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg),
|
||||
ConvModule(reduced_channels, out_channels, 3, padding=rate, dilation=rate, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
)
|
||||
)
|
||||
|
||||
self.project = ConvModule(out_channels * (len(rates) + 1), out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
res = []# <-- FIX
|
||||
res.append(nn.functional.interpolate(self.global_avg_pool(x), size=x.shape[2:], mode='bilinear', align_corners=False))
|
||||
for branch in self.aspp_branches:
|
||||
res.append(branch(x))
|
||||
|
||||
res = torch.cat(res, dim=1)
|
||||
return self.project(res)
|
||||
|
||||
# In en_bisenetv2.py, replace the whole GELayerWithSE class with this:
|
||||
|
||||
class GELayerWithSE(BaseModule):
|
||||
"""Gather-and-Expansion Layer with optional SE Module."""
|
||||
def __init__(self, in_channels, out_channels, mid_channels, stride, add_se_module=False,
|
||||
norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
# 1. Pointwise Conv for channel expansion
|
||||
self.pw_conv1 = ConvModule(
|
||||
in_channels, mid_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
|
||||
# 2. Depthwise Conv
|
||||
self.dw_conv = ConvModule(
|
||||
mid_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=mid_channels,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
# 3. Squeeze-and-Excitation module
|
||||
self.add_se_module = add_se_module
|
||||
if self.add_se_module:
|
||||
self.se = SEModule(mid_channels)
|
||||
|
||||
# 4. Pointwise Conv for channel projection
|
||||
self.pw_conv2 = ConvModule(
|
||||
mid_channels, out_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
|
||||
# 5. Shortcut connection
|
||||
if stride == 1 and in_channels == out_channels:
|
||||
# Use forward_v1 for identity shortcut
|
||||
self.shortcut = nn.Identity()
|
||||
self.forward = self.forward_v1 if not add_se_module else self.forward_v1_se
|
||||
else:
|
||||
# Use forward_v2 for projection shortcut
|
||||
self.shortcut = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.forward = self.forward_v2 if not add_se_module else self.forward_v2_se
|
||||
|
||||
def forward_v1(self, x):
|
||||
out = self.pw_conv1(x)
|
||||
out = self.dw_conv(out)
|
||||
out = self.pw_conv2(out)
|
||||
out = out + self.shortcut(x)
|
||||
return out
|
||||
|
||||
def forward_v2(self, x):
|
||||
out = self.pw_conv1(x)
|
||||
out = self.dw_conv(out)
|
||||
out = self.pw_conv2(out)
|
||||
out = out + self.shortcut(x)
|
||||
return out
|
||||
|
||||
def forward_v1_se(self, x):
|
||||
out = self.pw_conv1(x)
|
||||
out = self.se(out)
|
||||
out = self.dw_conv(out)
|
||||
out = self.pw_conv2(out)
|
||||
out = out + self.shortcut(x)
|
||||
return out
|
||||
|
||||
def forward_v2_se(self, x):
|
||||
out = self.pw_conv1(x)
|
||||
out = self.se(out)
|
||||
out = self.dw_conv(out)
|
||||
out = self.pw_conv2(out)
|
||||
out = out + self.shortcut(x)
|
||||
return out
|
||||
|
||||
class FuseModule(BaseModule):
|
||||
"""Fuse features from different levels."""
|
||||
def __init__(self, in_channels_list, out_channels, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv = ConvModule(sum(in_channels_list), out_channels, 3, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
|
||||
def forward(self, features):
|
||||
target_size = features[-1].shape[2:]
|
||||
|
||||
upsampled_features = []# <-- FIX
|
||||
for feat in features:
|
||||
upsampled_features.append(
|
||||
nn.functional.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
|
||||
)
|
||||
|
||||
fused_feature = torch.cat(upsampled_features, dim=1)
|
||||
return self.conv(fused_feature)
|
||||
|
||||
class SimBGABlock(BaseModule):
|
||||
"""Simplified Bilateral Guided Aggregation Block."""
|
||||
def __init__(self, detail_channels, semantic_channels, out_channels, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv_detail = ConvModule(detail_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.conv_semantic = ConvModule(semantic_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.conv_out = ConvModule(out_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
|
||||
def forward(self, detail_feat, semantic_feat):
|
||||
semantic_feat_upsampled = nn.functional.interpolate(
|
||||
semantic_feat, size=detail_feat.shape[2:], mode='bilinear', align_corners=False
|
||||
)
|
||||
|
||||
detail_feat_processed = self.conv_detail(detail_feat)
|
||||
semantic_feat_processed = self.conv_semantic(semantic_feat_upsampled)
|
||||
|
||||
fused_feat = detail_feat_processed + semantic_feat_processed
|
||||
return self.conv_out(fused_feat)
|
||||
|
||||
|
||||
# --- Main EnBiSeNetV2 Backbone ---
|
||||
|
||||
@MODELS.register_module()
|
||||
class EnBiSeNetV2(BiSeNetV2):
|
||||
"""
|
||||
This class is the implementation of En_bisenetv2.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
detail_channels_stages=(64, 64, 128),
|
||||
semantic_channels_stages=(16, 32, 64, 128),
|
||||
semantic_expansion_ratio=6,
|
||||
bga_channels=128,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
align_corners=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super(BiSeNetV2, self).__init__(init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_indices = out_indices
|
||||
self.align_corners = align_corners
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.detail_branch = self._make_detail_branch(detail_channels_stages)
|
||||
self.spatial_attention = SpatialAttentionModule()
|
||||
|
||||
self.semantic_branch = self._make_semantic_branch(
|
||||
semantic_channels_stages, semantic_expansion_ratio)
|
||||
self.channel_attention = ChannelAttentionModule(semantic_channels_stages[-1])
|
||||
|
||||
self.fuse_module = FuseModule(
|
||||
in_channels_list=[semantic_channels_stages[3], semantic_channels_stages[3]],
|
||||
out_channels=semantic_channels_stages[-1]
|
||||
)
|
||||
|
||||
self.bga_layer = SimBGABlock(
|
||||
detail_channels=detail_channels_stages[-1],
|
||||
semantic_channels=semantic_channels_stages[-1],
|
||||
out_channels=bga_channels
|
||||
)
|
||||
|
||||
def _make_detail_branch(self, channels_stages):
|
||||
layers = []
|
||||
# stage 1
|
||||
layers.append(
|
||||
DepthwiseSeparableConvModule(
|
||||
self.in_channels, channels_stages[0], 3, stride=2, padding=1,
|
||||
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
layers.append(
|
||||
DepthwiseSeparableConvModule(
|
||||
channels_stages[0], channels_stages[0], 3, stride=1, padding=1,
|
||||
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
# stage 2
|
||||
layers.append(
|
||||
DepthwiseSeparableConvModule(
|
||||
channels_stages[0], channels_stages[1], 3, stride=2, padding=1,
|
||||
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
layers.append(
|
||||
DepthwiseSeparableConvModule(
|
||||
channels_stages[1], channels_stages[1], 3, stride=1, padding=1,
|
||||
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
# stage 3
|
||||
layers.append(
|
||||
DepthwiseSeparableConvModule(
|
||||
channels_stages[1], channels_stages[2], 3, stride=2, padding=1,
|
||||
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
layers.append(
|
||||
DepthwiseSeparableConvModule(
|
||||
channels_stages[2], channels_stages[2], 3, stride=1, padding=1,
|
||||
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
return Sequential(*layers)
|
||||
|
||||
# def _make_detail_branch(self, channels_stages):
|
||||
# layers = [] # <-- FIX
|
||||
# layers.append(
|
||||
# DepthwiseSeparableConvModule(
|
||||
# self.in_channels, channels_stages, 3, stride=2, padding=1,
|
||||
# norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
# layers.append(
|
||||
# DepthwiseSeparableConvModule(
|
||||
# channels_stages, channels_stages, 3, stride=1, padding=1,
|
||||
# norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
# layers.append(
|
||||
# DepthwiseSeparableConvModule(
|
||||
# channels_stages, channels_stages[3], 3, stride=2, padding=1,
|
||||
# norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
# layers.append(
|
||||
# DepthwiseSeparableConvModule(
|
||||
# channels_stages[3], channels_stages[3], 3, stride=1, padding=1,
|
||||
# norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
# layers.append(
|
||||
# DepthwiseSeparableConvModule(
|
||||
# channels_stages[3], channels_stages[1], 3, stride=2, padding=1,
|
||||
# norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
# layers.append(
|
||||
# DepthwiseSeparableConvModule(
|
||||
# channels_stages[1], channels_stages[1], 3, stride=1, padding=1,
|
||||
# norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
# return Sequential(*layers)
|
||||
|
||||
def _make_semantic_branch(self, channels_stages, expansion_ratio):
|
||||
# stage 1
|
||||
stage1 = StemBlock(self.in_channels, channels_stages[0],
|
||||
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)
|
||||
# stage 2
|
||||
stage2 = self._make_ge_layer(
|
||||
channels_stages[0], channels_stages[1], 2, expansion_ratio, add_se=False)
|
||||
# stage 3
|
||||
stage3 = self._make_ge_layer(
|
||||
channels_stages[1], channels_stages[2], 2, expansion_ratio, add_se=True)
|
||||
# stage 4
|
||||
stage4 = self._make_ge_layer(
|
||||
channels_stages[2], channels_stages[3], 2, expansion_ratio, add_se=True)
|
||||
# stage 5: Simplified ASPP instead of CEBlock
|
||||
stage5 = CEBlockSimAspp(channels_stages[3], channels_stages[3],
|
||||
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)
|
||||
return nn.ModuleDict(
|
||||
{'1': stage1, '2': stage2, '3': stage3, '4': stage4, '5': stage5}
|
||||
)
|
||||
|
||||
# def _make_semantic_branch(self, channels_stages, expansion_ratio):
|
||||
# stage1 = StemBlock(self.in_channels, channels_stages,
|
||||
# norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)
|
||||
# stage2 = self._make_ge_layer(
|
||||
# channels_stages, channels_stages[3], 2, expansion_ratio, add_se=False)
|
||||
# stage3 = self._make_ge_layer(
|
||||
# channels_stages[3], channels_stages[1], 2, expansion_ratio, add_se=True)
|
||||
# stage4 = self._make_ge_layer(
|
||||
# channels_stages[1], channels_stages[2], 2, expansion_ratio, add_se=True)
|
||||
# stage5 = CEBlockSimAspp(channels_stages[2], channels_stages[2],
|
||||
# norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)
|
||||
# return nn.ModuleDict(
|
||||
# {'1': stage1, '2': stage2, '3': stage3, '4': stage4, '5': stage5}
|
||||
# )
|
||||
|
||||
def _make_ge_layer(self, in_channels, out_channels, stride, expansion_ratio, add_se=False):
|
||||
mid_channels = in_channels * expansion_ratio
|
||||
layers = []# <-- FIX
|
||||
layers.append(
|
||||
GELayerWithSE(in_channels, out_channels, mid_channels, stride, add_se_module=add_se,
|
||||
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
for _ in range(1, 3):
|
||||
layers.append(
|
||||
GELayerWithSE(out_channels, out_channels, mid_channels, 1, add_se_module=add_se,
|
||||
norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
||||
return Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
detail_out = self.detail_branch(x)
|
||||
detail_out = self.spatial_attention(detail_out)
|
||||
|
||||
semantic_out_s1 = self.semantic_branch['1'](x)
|
||||
semantic_out_s2 = self.semantic_branch['2'](semantic_out_s1)
|
||||
semantic_out_s3 = self.semantic_branch['3'](semantic_out_s2)
|
||||
semantic_out_s4 = self.semantic_branch['4'](semantic_out_s3)
|
||||
semantic_out_s5 = self.semantic_branch['5'](semantic_out_s4)
|
||||
|
||||
semantic_fused = self.fuse_module([semantic_out_s4, semantic_out_s5])
|
||||
semantic_fused = self.channel_attention(semantic_fused)
|
||||
|
||||
out = self.bga_layer(detail_out, semantic_fused)
|
||||
|
||||
outs = [detail_out, semantic_out_s3, semantic_out_s4, semantic_fused, out]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
|
||||
# === End: Final corrected code to be appended for EnBiSeNetV2 ===
|
||||
329
Seg_All_In_One_MMSeg/mmseg/models/backbones/erfnet.py
Normal file
329
Seg_All_In_One_MMSeg/mmseg/models/backbones/erfnet.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class DownsamplerBlock(BaseModule):
|
||||
"""Downsampler block of ERFNet.
|
||||
|
||||
This module is a little different from basical ConvModule.
|
||||
The features from Conv and MaxPool layers are
|
||||
concatenated before BatchNorm.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.conv = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
out_channels - in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
def forward(self, input):
|
||||
conv_out = self.conv(input)
|
||||
pool_out = self.pool(input)
|
||||
pool_out = resize(
|
||||
input=pool_out,
|
||||
size=conv_out.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
output = torch.cat([conv_out, pool_out], 1)
|
||||
output = self.bn(output)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
class NonBottleneck1d(BaseModule):
|
||||
"""Non-bottleneck block of ERFNet.
|
||||
|
||||
Args:
|
||||
channels (int): Number of channels in Non-bottleneck block.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.
|
||||
dilation (int): Dilation rate for last two conv layers.
|
||||
Default 1.
|
||||
num_conv_layer (int): Number of 3x1 and 1x3 convolution layers.
|
||||
Default 2.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
drop_rate=0,
|
||||
dilation=1,
|
||||
num_conv_layer=2,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
self.convs_layers = nn.ModuleList()
|
||||
for conv_layer in range(num_conv_layer):
|
||||
first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0)
|
||||
first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1)
|
||||
second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation)
|
||||
second_conv_dilation = 1 if conv_layer == 0 else (1, dilation)
|
||||
|
||||
self.convs_layers.append(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=(3, 1),
|
||||
stride=1,
|
||||
padding=first_conv_padding,
|
||||
bias=True,
|
||||
dilation=first_conv_dilation))
|
||||
self.convs_layers.append(self.act)
|
||||
self.convs_layers.append(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=(1, 3),
|
||||
stride=1,
|
||||
padding=second_conv_padding,
|
||||
bias=True,
|
||||
dilation=second_conv_dilation))
|
||||
self.convs_layers.append(
|
||||
build_norm_layer(self.norm_cfg, channels)[1])
|
||||
if conv_layer == 0:
|
||||
self.convs_layers.append(self.act)
|
||||
else:
|
||||
self.convs_layers.append(nn.Dropout(p=drop_rate))
|
||||
|
||||
def forward(self, input):
|
||||
output = input
|
||||
for conv in self.convs_layers:
|
||||
output = conv(output)
|
||||
output = self.act(output + input)
|
||||
return output
|
||||
|
||||
|
||||
class UpsamplerBlock(BaseModule):
|
||||
"""Upsampler block of ERFNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.conv = nn.ConvTranspose2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
bias=True)
|
||||
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv(input)
|
||||
output = self.bn(output)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ERFNet(BaseModule):
|
||||
"""ERFNet backbone.
|
||||
|
||||
This backbone is the implementation of `ERFNet: Efficient Residual
|
||||
Factorized ConvNet for Real-time SemanticSegmentation
|
||||
<https://ieeexplore.ieee.org/document/8063438>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels of input
|
||||
image. Default: 3.
|
||||
enc_downsample_channels (Tuple[int]): Size of channel
|
||||
numbers of various Downsampler block in encoder.
|
||||
Default: (16, 64, 128).
|
||||
enc_stage_non_bottlenecks (Tuple[int]): Number of stages of
|
||||
Non-bottleneck block in encoder.
|
||||
Default: (5, 8).
|
||||
enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each
|
||||
stage of Non-bottleneck block of encoder.
|
||||
Default: (2, 4, 8, 16).
|
||||
enc_non_bottleneck_channels (Tuple[int]): Size of channel
|
||||
numbers of various Non-bottleneck block in encoder.
|
||||
Default: (64, 128).
|
||||
dec_upsample_channels (Tuple[int]): Size of channel numbers of
|
||||
various Deconvolution block in decoder.
|
||||
Default: (64, 16).
|
||||
dec_stages_non_bottleneck (Tuple[int]): Number of stages of
|
||||
Non-bottleneck block in decoder.
|
||||
Default: (2, 2).
|
||||
dec_non_bottleneck_channels (Tuple[int]): Size of channel
|
||||
numbers of various Non-bottleneck block in decoder.
|
||||
Default: (64, 16).
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(dec_upsample_channels)+1, 'Number of downsample\
|
||||
block of encoder does not \
|
||||
match number of upsample block of decoder!'
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(enc_stage_non_bottlenecks)+1, 'Number of \
|
||||
downsample block of encoder does not match \
|
||||
number of Non-bottleneck block of encoder!'
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(enc_non_bottleneck_channels)+1, 'Number of \
|
||||
downsample block of encoder does not match \
|
||||
number of channels of Non-bottleneck block of encoder!'
|
||||
assert enc_stage_non_bottlenecks[-1] \
|
||||
% len(enc_non_bottleneck_dilations) == 0, 'Number of \
|
||||
Non-bottleneck block of encoder does not match \
|
||||
number of Non-bottleneck block of encoder!'
|
||||
assert len(dec_upsample_channels) \
|
||||
== len(dec_stages_non_bottleneck), 'Number of \
|
||||
upsample block of decoder does not match \
|
||||
number of Non-bottleneck block of decoder!'
|
||||
assert len(dec_stages_non_bottleneck) \
|
||||
== len(dec_non_bottleneck_channels), 'Number of \
|
||||
Non-bottleneck block of decoder does not match \
|
||||
number of channels of Non-bottleneck block of decoder!'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.enc_downsample_channels = enc_downsample_channels
|
||||
self.enc_stage_non_bottlenecks = enc_stage_non_bottlenecks
|
||||
self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations
|
||||
self.enc_non_bottleneck_channels = enc_non_bottleneck_channels
|
||||
self.dec_upsample_channels = dec_upsample_channels
|
||||
self.dec_stages_non_bottleneck = dec_stages_non_bottleneck
|
||||
self.dec_non_bottleneck_channels = dec_non_bottleneck_channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.encoder.append(
|
||||
DownsamplerBlock(self.in_channels, enc_downsample_channels[0]))
|
||||
|
||||
for i in range(len(enc_downsample_channels) - 1):
|
||||
self.encoder.append(
|
||||
DownsamplerBlock(enc_downsample_channels[i],
|
||||
enc_downsample_channels[i + 1]))
|
||||
# Last part of encoder is some dilated NonBottleneck1d blocks.
|
||||
if i == len(enc_downsample_channels) - 2:
|
||||
iteration_times = int(enc_stage_non_bottlenecks[-1] /
|
||||
len(enc_non_bottleneck_dilations))
|
||||
for j in range(iteration_times):
|
||||
for k in range(len(enc_non_bottleneck_dilations)):
|
||||
self.encoder.append(
|
||||
NonBottleneck1d(enc_downsample_channels[-1],
|
||||
self.dropout_ratio,
|
||||
enc_non_bottleneck_dilations[k]))
|
||||
else:
|
||||
for j in range(enc_stage_non_bottlenecks[i]):
|
||||
self.encoder.append(
|
||||
NonBottleneck1d(enc_downsample_channels[i + 1],
|
||||
self.dropout_ratio))
|
||||
|
||||
for i in range(len(dec_upsample_channels)):
|
||||
if i == 0:
|
||||
self.decoder.append(
|
||||
UpsamplerBlock(enc_downsample_channels[-1],
|
||||
dec_non_bottleneck_channels[i]))
|
||||
else:
|
||||
self.decoder.append(
|
||||
UpsamplerBlock(dec_non_bottleneck_channels[i - 1],
|
||||
dec_non_bottleneck_channels[i]))
|
||||
for j in range(dec_stages_non_bottleneck[i]):
|
||||
self.decoder.append(
|
||||
NonBottleneck1d(dec_non_bottleneck_channels[i]))
|
||||
|
||||
def forward(self, x):
|
||||
for enc in self.encoder:
|
||||
x = enc(x)
|
||||
for dec in self.decoder:
|
||||
x = dec(x)
|
||||
return [x]
|
||||
408
Seg_All_In_One_MMSeg/mmseg/models/backbones/fast_scnn.py
Normal file
408
Seg_All_In_One_MMSeg/mmseg/models/backbones/fast_scnn.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.models.decode_heads.psp_head import PPM
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual, resize
|
||||
|
||||
|
||||
class LearningToDownsample(nn.Module):
|
||||
"""Learning to downsample module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
dw_channels (tuple[int]): Number of output channels of the first and
|
||||
the second depthwise conv (dwconv) layers.
|
||||
out_channels (int): Number of output channels of the whole
|
||||
'learning to downsample' module.
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
|
||||
of depthwise ConvModule. If it is 'default', it will be the same
|
||||
as `act_cfg`. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
dw_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dw_act_cfg=None):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.dw_act_cfg = dw_act_cfg
|
||||
dw_channels1 = dw_channels[0]
|
||||
dw_channels2 = dw_channels[1]
|
||||
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
dw_channels1,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.dsconv1 = DepthwiseSeparableConvModule(
|
||||
dw_channels1,
|
||||
dw_channels2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=self.dw_act_cfg)
|
||||
|
||||
self.dsconv2 = DepthwiseSeparableConvModule(
|
||||
dw_channels2,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=self.dw_act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.dsconv1(x)
|
||||
x = self.dsconv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class GlobalFeatureExtractor(nn.Module):
|
||||
"""Global feature extractor module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels of the GFE module.
|
||||
Default: 64
|
||||
block_channels (tuple[int]): Tuple of ints. Each int specifies the
|
||||
number of output channels of each Inverted Residual module.
|
||||
Default: (64, 96, 128)
|
||||
out_channels(int): Number of output channels of the GFE module.
|
||||
Default: 128
|
||||
expand_ratio (int): Adjusts number of channels of the hidden layer
|
||||
in InvertedResidual by this amount.
|
||||
Default: 6
|
||||
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
|
||||
number of times each Inverted Residual module is repeated.
|
||||
The repeated Inverted Residual modules are called a 'group'.
|
||||
Default: (3, 3, 3)
|
||||
strides (tuple[int]): Tuple of ints. Each int specifies
|
||||
the downsampling factor of each 'group'.
|
||||
Default: (2, 2, 1)
|
||||
pool_scales (tuple[int]): Tuple of ints. Each int specifies
|
||||
the parameter required in 'global average pooling' within PPM.
|
||||
Default: (1, 2, 3, 6)
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=64,
|
||||
block_channels=(64, 96, 128),
|
||||
out_channels=128,
|
||||
expand_ratio=6,
|
||||
num_blocks=(3, 3, 3),
|
||||
strides=(2, 2, 1),
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
assert len(block_channels) == len(num_blocks) == 3
|
||||
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
|
||||
num_blocks[0], strides[0],
|
||||
expand_ratio)
|
||||
self.bottleneck2 = self._make_layer(block_channels[0],
|
||||
block_channels[1], num_blocks[1],
|
||||
strides[1], expand_ratio)
|
||||
self.bottleneck3 = self._make_layer(block_channels[1],
|
||||
block_channels[2], num_blocks[2],
|
||||
strides[2], expand_ratio)
|
||||
self.ppm = PPM(
|
||||
pool_scales,
|
||||
block_channels[2],
|
||||
block_channels[2] // 4,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=align_corners)
|
||||
|
||||
self.out = ConvModule(
|
||||
block_channels[2] * 2,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _make_layer(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
blocks,
|
||||
stride=1,
|
||||
expand_ratio=6):
|
||||
layers = [
|
||||
InvertedResidual(
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
expand_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
InvertedResidual(
|
||||
out_channels,
|
||||
out_channels,
|
||||
1,
|
||||
expand_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bottleneck1(x)
|
||||
x = self.bottleneck2(x)
|
||||
x = self.bottleneck3(x)
|
||||
x = torch.cat([x, *self.ppm(x)], dim=1)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class FeatureFusionModule(nn.Module):
|
||||
"""Feature fusion module.
|
||||
|
||||
Args:
|
||||
higher_in_channels (int): Number of input channels of the
|
||||
higher-resolution branch.
|
||||
lower_in_channels (int): Number of input channels of the
|
||||
lower-resolution branch.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
dwconv_act_cfg (dict): Config of activation layers in 3x3 conv.
|
||||
Default: dict(type='ReLU').
|
||||
conv_act_cfg (dict): Config of activation layers in the two 1x1 conv.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
higher_in_channels,
|
||||
lower_in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dwconv_act_cfg=dict(type='ReLU'),
|
||||
conv_act_cfg=None,
|
||||
align_corners=False):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.dwconv_act_cfg = dwconv_act_cfg
|
||||
self.conv_act_cfg = conv_act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.dwconv = ConvModule(
|
||||
lower_in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
groups=out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.dwconv_act_cfg)
|
||||
self.conv_lower_res = ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.conv_act_cfg)
|
||||
|
||||
self.conv_higher_res = ConvModule(
|
||||
higher_in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.conv_act_cfg)
|
||||
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
def forward(self, higher_res_feature, lower_res_feature):
|
||||
lower_res_feature = resize(
|
||||
lower_res_feature,
|
||||
size=higher_res_feature.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
lower_res_feature = self.dwconv(lower_res_feature)
|
||||
lower_res_feature = self.conv_lower_res(lower_res_feature)
|
||||
|
||||
higher_res_feature = self.conv_higher_res(higher_res_feature)
|
||||
out = higher_res_feature + lower_res_feature
|
||||
return self.relu(out)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FastSCNN(BaseModule):
|
||||
"""Fast-SCNN Backbone.
|
||||
|
||||
This backbone is the implementation of `Fast-SCNN: Fast Semantic
|
||||
Segmentation Network <https://arxiv.org/abs/1902.04502>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
downsample_dw_channels (tuple[int]): Number of output channels after
|
||||
the first conv layer & the second conv layer in
|
||||
Learning-To-Downsample (LTD) module.
|
||||
Default: (32, 48).
|
||||
global_in_channels (int): Number of input channels of
|
||||
Global Feature Extractor(GFE).
|
||||
Equal to number of output channels of LTD.
|
||||
Default: 64.
|
||||
global_block_channels (tuple[int]): Tuple of integers that describe
|
||||
the output channels for each of the MobileNet-v2 bottleneck
|
||||
residual blocks in GFE.
|
||||
Default: (64, 96, 128).
|
||||
global_block_strides (tuple[int]): Tuple of integers
|
||||
that describe the strides (downsampling factors) for each of the
|
||||
MobileNet-v2 bottleneck residual blocks in GFE.
|
||||
Default: (2, 2, 1).
|
||||
global_out_channels (int): Number of output channels of GFE.
|
||||
Default: 128.
|
||||
higher_in_channels (int): Number of input channels of the higher
|
||||
resolution branch in FFM.
|
||||
Equal to global_in_channels.
|
||||
Default: 64.
|
||||
lower_in_channels (int): Number of input channels of the lower
|
||||
resolution branch in FFM.
|
||||
Equal to global_out_channels.
|
||||
Default: 128.
|
||||
fusion_out_channels (int): Number of output channels of FFM.
|
||||
Default: 128.
|
||||
out_indices (tuple): Tuple of indices of list
|
||||
[higher_res_features, lower_res_features, fusion_output].
|
||||
Often set to (0,1,2) to enable aux. heads.
|
||||
Default: (0, 1, 2).
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False
|
||||
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
|
||||
of depthwise ConvModule. If it is 'default', it will be the same
|
||||
as `act_cfg`. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
downsample_dw_channels=(32, 48),
|
||||
global_in_channels=64,
|
||||
global_block_channels=(64, 96, 128),
|
||||
global_block_strides=(2, 2, 1),
|
||||
global_out_channels=128,
|
||||
higher_in_channels=64,
|
||||
lower_in_channels=128,
|
||||
fusion_out_channels=128,
|
||||
out_indices=(0, 1, 2),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
dw_act_cfg=None,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
|
||||
if global_in_channels != higher_in_channels:
|
||||
raise AssertionError('Global Input Channels must be the same \
|
||||
with Higher Input Channels!')
|
||||
elif global_out_channels != lower_in_channels:
|
||||
raise AssertionError('Global Output Channels must be the same \
|
||||
with Lower Input Channels!')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
||||
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
||||
self.global_in_channels = global_in_channels
|
||||
self.global_block_channels = global_block_channels
|
||||
self.global_block_strides = global_block_strides
|
||||
self.global_out_channels = global_out_channels
|
||||
self.higher_in_channels = higher_in_channels
|
||||
self.lower_in_channels = lower_in_channels
|
||||
self.fusion_out_channels = fusion_out_channels
|
||||
self.out_indices = out_indices
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.learning_to_downsample = LearningToDownsample(
|
||||
in_channels,
|
||||
downsample_dw_channels,
|
||||
global_in_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
dw_act_cfg=dw_act_cfg)
|
||||
self.global_feature_extractor = GlobalFeatureExtractor(
|
||||
global_in_channels,
|
||||
global_block_channels,
|
||||
global_out_channels,
|
||||
strides=self.global_block_strides,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.feature_fusion = FeatureFusionModule(
|
||||
higher_in_channels,
|
||||
lower_in_channels,
|
||||
fusion_out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dwconv_act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
higher_res_features = self.learning_to_downsample(x)
|
||||
lower_res_features = self.global_feature_extractor(higher_res_features)
|
||||
fusion_output = self.feature_fusion(higher_res_features,
|
||||
lower_res_features)
|
||||
|
||||
outs = [higher_res_features, lower_res_features, fusion_output]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
642
Seg_All_In_One_MMSeg/mmseg/models/backbones/hrnet.py
Normal file
642
Seg_All_In_One_MMSeg/mmseg/models/backbones/hrnet.py
Normal file
@@ -0,0 +1,642 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample, resize
|
||||
from .resnet import BasicBlock, Bottleneck
|
||||
|
||||
|
||||
class HRModule(BaseModule):
|
||||
"""High-Resolution Module for HRNet.
|
||||
|
||||
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
|
||||
is in this module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_branches,
|
||||
blocks,
|
||||
num_blocks,
|
||||
in_channels,
|
||||
num_channels,
|
||||
multiscale_output=True,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
block_init_cfg=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.block_init_cfg = block_init_cfg
|
||||
self._check_branches(num_branches, num_blocks, in_channels,
|
||||
num_channels)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_branches = num_branches
|
||||
|
||||
self.multiscale_output = multiscale_output
|
||||
self.norm_cfg = norm_cfg
|
||||
self.conv_cfg = conv_cfg
|
||||
self.with_cp = with_cp
|
||||
self.branches = self._make_branches(num_branches, blocks, num_blocks,
|
||||
num_channels)
|
||||
self.fuse_layers = self._make_fuse_layers()
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def _check_branches(self, num_branches, num_blocks, in_channels,
|
||||
num_channels):
|
||||
"""Check branches configuration."""
|
||||
if num_branches != len(num_blocks):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
|
||||
f'{len(num_blocks)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(num_channels):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
|
||||
f'{len(num_channels)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(in_channels):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
|
||||
f'{len(in_channels)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def _make_one_branch(self,
|
||||
branch_index,
|
||||
block,
|
||||
num_blocks,
|
||||
num_channels,
|
||||
stride=1):
|
||||
"""Build one branch."""
|
||||
downsample = None
|
||||
if stride != 1 or \
|
||||
self.in_channels[branch_index] != \
|
||||
num_channels[branch_index] * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index] * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, num_channels[branch_index] *
|
||||
block.expansion)[1])
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index],
|
||||
stride,
|
||||
downsample=downsample,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=self.block_init_cfg))
|
||||
self.in_channels[branch_index] = \
|
||||
num_channels[branch_index] * block.expansion
|
||||
for i in range(1, num_blocks[branch_index]):
|
||||
layers.append(
|
||||
block(
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index],
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=self.block_init_cfg))
|
||||
|
||||
return Sequential(*layers)
|
||||
|
||||
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
||||
"""Build multiple branch."""
|
||||
branches = []
|
||||
|
||||
for i in range(num_branches):
|
||||
branches.append(
|
||||
self._make_one_branch(i, block, num_blocks, num_channels))
|
||||
|
||||
return ModuleList(branches)
|
||||
|
||||
def _make_fuse_layers(self):
|
||||
"""Build fuse layer."""
|
||||
if self.num_branches == 1:
|
||||
return None
|
||||
|
||||
num_branches = self.num_branches
|
||||
in_channels = self.in_channels
|
||||
fuse_layers = []
|
||||
num_out_branches = num_branches if self.multiscale_output else 1
|
||||
for i in range(num_out_branches):
|
||||
fuse_layer = []
|
||||
for j in range(num_branches):
|
||||
if j > i:
|
||||
fuse_layer.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[i],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
||||
# we set align_corners=False for HRNet
|
||||
Upsample(
|
||||
scale_factor=2**(j - i),
|
||||
mode='bilinear',
|
||||
align_corners=False)))
|
||||
elif j == i:
|
||||
fuse_layer.append(None)
|
||||
else:
|
||||
conv_downsamples = []
|
||||
for k in range(i - j):
|
||||
if k == i - j - 1:
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
in_channels[i])[1]))
|
||||
else:
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[j],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
in_channels[j])[1],
|
||||
nn.ReLU(inplace=False)))
|
||||
fuse_layer.append(nn.Sequential(*conv_downsamples))
|
||||
fuse_layers.append(nn.ModuleList(fuse_layer))
|
||||
|
||||
return nn.ModuleList(fuse_layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
if self.num_branches == 1:
|
||||
return [self.branches[0](x[0])]
|
||||
|
||||
for i in range(self.num_branches):
|
||||
x[i] = self.branches[i](x[i])
|
||||
|
||||
x_fuse = []
|
||||
for i in range(len(self.fuse_layers)):
|
||||
y = 0
|
||||
for j in range(self.num_branches):
|
||||
if i == j:
|
||||
y += x[j]
|
||||
elif j > i:
|
||||
y = y + resize(
|
||||
self.fuse_layers[i][j](x[j]),
|
||||
size=x[i].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
else:
|
||||
y += self.fuse_layers[i][j](x[j])
|
||||
x_fuse.append(self.relu(y))
|
||||
return x_fuse
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class HRNet(BaseModule):
|
||||
"""HRNet backbone.
|
||||
|
||||
This backbone is the implementation of `High-Resolution Representations
|
||||
for Labeling Pixels and Regions <https://arxiv.org/abs/1904.04514>`_.
|
||||
|
||||
Args:
|
||||
extra (dict): Detailed configuration for each stage of HRNet.
|
||||
There must be 4 stages, the configuration for each stage must have
|
||||
5 keys:
|
||||
|
||||
- num_modules (int): The number of HRModule in this stage.
|
||||
- num_branches (int): The number of branches in the HRModule.
|
||||
- block (str): The type of convolution block.
|
||||
- num_blocks (tuple): The number of blocks in each branch.
|
||||
The length must be equal to num_branches.
|
||||
- num_channels (tuple): The number of channels in each branch.
|
||||
The length must be equal to num_branches.
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Use `BN` by default.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity. Default: False.
|
||||
multiscale_output (bool): Whether to output multi-level features
|
||||
produced by multiple branches. If False, only the first level
|
||||
feature will be output. Default: True.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import HRNet
|
||||
>>> import torch
|
||||
>>> extra = dict(
|
||||
>>> stage1=dict(
|
||||
>>> num_modules=1,
|
||||
>>> num_branches=1,
|
||||
>>> block='BOTTLENECK',
|
||||
>>> num_blocks=(4, ),
|
||||
>>> num_channels=(64, )),
|
||||
>>> stage2=dict(
|
||||
>>> num_modules=1,
|
||||
>>> num_branches=2,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4),
|
||||
>>> num_channels=(32, 64)),
|
||||
>>> stage3=dict(
|
||||
>>> num_modules=4,
|
||||
>>> num_branches=3,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4, 4),
|
||||
>>> num_channels=(32, 64, 128)),
|
||||
>>> stage4=dict(
|
||||
>>> num_modules=3,
|
||||
>>> num_branches=4,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4, 4, 4),
|
||||
>>> num_channels=(32, 64, 128, 256)))
|
||||
>>> self = HRNet(extra, in_channels=1)
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 1, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 32, 8, 8)
|
||||
(1, 64, 4, 4)
|
||||
(1, 128, 2, 2)
|
||||
(1, 256, 1, 1)
|
||||
"""
|
||||
|
||||
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
|
||||
|
||||
def __init__(self,
|
||||
extra,
|
||||
in_channels=3,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
frozen_stages=-1,
|
||||
zero_init_residual=False,
|
||||
multiscale_output=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
self.zero_init_residual = zero_init_residual
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
# Assert configurations of 4 stages are in extra
|
||||
assert 'stage1' in extra and 'stage2' in extra \
|
||||
and 'stage3' in extra and 'stage4' in extra
|
||||
# Assert whether the length of `num_blocks` and `num_channels` are
|
||||
# equal to `num_branches`
|
||||
for i in range(4):
|
||||
cfg = extra[f'stage{i + 1}']
|
||||
assert len(cfg['num_blocks']) == cfg['num_branches'] and \
|
||||
len(cfg['num_channels']) == cfg['num_branches']
|
||||
|
||||
self.extra = extra
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
# stem net
|
||||
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
64,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.conv2 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
64,
|
||||
64,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
# stage 1
|
||||
self.stage1_cfg = self.extra['stage1']
|
||||
num_channels = self.stage1_cfg['num_channels'][0]
|
||||
block_type = self.stage1_cfg['block']
|
||||
num_blocks = self.stage1_cfg['num_blocks'][0]
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
stage1_out_channels = num_channels * block.expansion
|
||||
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
|
||||
|
||||
# stage 2
|
||||
self.stage2_cfg = self.extra['stage2']
|
||||
num_channels = self.stage2_cfg['num_channels']
|
||||
block_type = self.stage2_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition1 = self._make_transition_layer([stage1_out_channels],
|
||||
num_channels)
|
||||
self.stage2, pre_stage_channels = self._make_stage(
|
||||
self.stage2_cfg, num_channels)
|
||||
|
||||
# stage 3
|
||||
self.stage3_cfg = self.extra['stage3']
|
||||
num_channels = self.stage3_cfg['num_channels']
|
||||
block_type = self.stage3_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition2 = self._make_transition_layer(pre_stage_channels,
|
||||
num_channels)
|
||||
self.stage3, pre_stage_channels = self._make_stage(
|
||||
self.stage3_cfg, num_channels)
|
||||
|
||||
# stage 4
|
||||
self.stage4_cfg = self.extra['stage4']
|
||||
num_channels = self.stage4_cfg['num_channels']
|
||||
block_type = self.stage4_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition3 = self._make_transition_layer(pre_stage_channels,
|
||||
num_channels)
|
||||
self.stage4, pre_stage_channels = self._make_stage(
|
||||
self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: the normalization layer named "norm2" """
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def _make_transition_layer(self, num_channels_pre_layer,
|
||||
num_channels_cur_layer):
|
||||
"""Make transition layer."""
|
||||
num_branches_cur = len(num_channels_cur_layer)
|
||||
num_branches_pre = len(num_channels_pre_layer)
|
||||
|
||||
transition_layers = []
|
||||
for i in range(num_branches_cur):
|
||||
if i < num_branches_pre:
|
||||
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
||||
transition_layers.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
num_channels_pre_layer[i],
|
||||
num_channels_cur_layer[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
num_channels_cur_layer[i])[1],
|
||||
nn.ReLU(inplace=True)))
|
||||
else:
|
||||
transition_layers.append(None)
|
||||
else:
|
||||
conv_downsamples = []
|
||||
for j in range(i + 1 - num_branches_pre):
|
||||
in_channels = num_channels_pre_layer[-1]
|
||||
out_channels = num_channels_cur_layer[i] \
|
||||
if j == i - num_branches_pre else in_channels
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, out_channels)[1],
|
||||
nn.ReLU(inplace=True)))
|
||||
transition_layers.append(nn.Sequential(*conv_downsamples))
|
||||
|
||||
return nn.ModuleList(transition_layers)
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
||||
"""Make each layer."""
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
||||
|
||||
layers = []
|
||||
block_init_cfg = None
|
||||
if self.pretrained is None and not hasattr(
|
||||
self, 'init_cfg') and self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm3'))
|
||||
|
||||
layers.append(
|
||||
block(
|
||||
inplanes,
|
||||
planes,
|
||||
stride,
|
||||
downsample=downsample,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=block_init_cfg))
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
block(
|
||||
inplanes,
|
||||
planes,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=block_init_cfg))
|
||||
|
||||
return Sequential(*layers)
|
||||
|
||||
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
|
||||
"""Make each stage."""
|
||||
num_modules = layer_config['num_modules']
|
||||
num_branches = layer_config['num_branches']
|
||||
num_blocks = layer_config['num_blocks']
|
||||
num_channels = layer_config['num_channels']
|
||||
block = self.blocks_dict[layer_config['block']]
|
||||
|
||||
hr_modules = []
|
||||
block_init_cfg = None
|
||||
if self.pretrained is None and not hasattr(
|
||||
self, 'init_cfg') and self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm3'))
|
||||
|
||||
for i in range(num_modules):
|
||||
# multi_scale_output is only used for the last module
|
||||
if not multiscale_output and i == num_modules - 1:
|
||||
reset_multiscale_output = False
|
||||
else:
|
||||
reset_multiscale_output = True
|
||||
|
||||
hr_modules.append(
|
||||
HRModule(
|
||||
num_branches,
|
||||
block,
|
||||
num_blocks,
|
||||
in_channels,
|
||||
num_channels,
|
||||
reset_multiscale_output,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
block_init_cfg=block_init_cfg))
|
||||
|
||||
return Sequential(*hr_modules), in_channels
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze stages param and norm stats."""
|
||||
if self.frozen_stages >= 0:
|
||||
|
||||
self.norm1.eval()
|
||||
self.norm2.eval()
|
||||
for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
if i == 1:
|
||||
m = getattr(self, f'layer{i}')
|
||||
t = getattr(self, f'transition{i}')
|
||||
elif i == 4:
|
||||
m = getattr(self, f'stage{i}')
|
||||
else:
|
||||
m = getattr(self, f'stage{i}')
|
||||
t = getattr(self, f'transition{i}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
t.eval()
|
||||
for param in t.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer1(x)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage2_cfg['num_branches']):
|
||||
if self.transition1[i] is not None:
|
||||
x_list.append(self.transition1[i](x))
|
||||
else:
|
||||
x_list.append(x)
|
||||
y_list = self.stage2(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage3_cfg['num_branches']):
|
||||
if self.transition2[i] is not None:
|
||||
x_list.append(self.transition2[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage3(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage4_cfg['num_branches']):
|
||||
if self.transition3[i] is not None:
|
||||
x_list.append(self.transition3[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage4(x_list)
|
||||
|
||||
return y_list
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode will keeping the normalization
|
||||
layer freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
166
Seg_All_In_One_MMSeg/mmseg/models/backbones/icnet.py
Normal file
166
Seg_All_In_One_MMSeg/mmseg/models/backbones/icnet.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..decode_heads.psp_head import PPM
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ICNet(BaseModule):
|
||||
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
|
||||
|
||||
This backbone is the implementation of
|
||||
`ICNet <https://arxiv.org/abs/1704.08545>`_.
|
||||
|
||||
Args:
|
||||
backbone_cfg (dict): Config dict to build backbone. Usually it is
|
||||
ResNet but it can also be other backbones.
|
||||
in_channels (int): The number of input image channels. Default: 3.
|
||||
layer_channels (Sequence[int]): The numbers of feature channels at
|
||||
layer 2 and layer 4 in ResNet. It can also be other backbones.
|
||||
Default: (512, 2048).
|
||||
light_branch_middle_channels (int): The number of channels of the
|
||||
middle layer in light branch. Default: 32.
|
||||
psp_out_channels (int): The number of channels of the output of PSP
|
||||
module. Default: 512.
|
||||
out_channels (Sequence[int]): The numbers of output feature channels
|
||||
at each branches. Default: (64, 256, 256).
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Dictionary to construct and config act layer.
|
||||
Default: dict(type='ReLU').
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
in_channels=3,
|
||||
layer_channels=(512, 2048),
|
||||
light_branch_middle_channels=32,
|
||||
psp_out_channels=512,
|
||||
out_channels=(64, 256, 256),
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
init_cfg=None):
|
||||
if backbone_cfg is None:
|
||||
raise TypeError('backbone_cfg must be passed from config file!')
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', mode='fan_out', layer='Conv2d'),
|
||||
dict(type='Constant', val=1, layer='_BatchNorm'),
|
||||
dict(type='Normal', mean=0.01, layer='Linear')
|
||||
]
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.align_corners = align_corners
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
|
||||
# Note: Default `ceil_mode` is false in nn.MaxPool2d, set
|
||||
# `ceil_mode=True` to keep information in the corner of feature map.
|
||||
self.backbone.maxpool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=True)
|
||||
|
||||
self.psp_modules = PPM(
|
||||
pool_scales=pool_scales,
|
||||
in_channels=layer_channels[1],
|
||||
channels=psp_out_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
align_corners=align_corners)
|
||||
|
||||
self.psp_bottleneck = ConvModule(
|
||||
layer_channels[1] + len(pool_scales) * psp_out_channels,
|
||||
psp_out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.conv_sub1 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=light_branch_middle_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg),
|
||||
ConvModule(
|
||||
in_channels=light_branch_middle_channels,
|
||||
out_channels=light_branch_middle_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg),
|
||||
ConvModule(
|
||||
in_channels=light_branch_middle_channels,
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg))
|
||||
|
||||
self.conv_sub2 = ConvModule(
|
||||
layer_channels[0],
|
||||
out_channels[1],
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
self.conv_sub4 = ConvModule(
|
||||
psp_out_channels,
|
||||
out_channels[2],
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
output = []
|
||||
|
||||
# sub 1
|
||||
output.append(self.conv_sub1(x))
|
||||
|
||||
# sub 2
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=0.5,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.backbone.stem(x)
|
||||
x = self.backbone.maxpool(x)
|
||||
x = self.backbone.layer1(x)
|
||||
x = self.backbone.layer2(x)
|
||||
output.append(self.conv_sub2(x))
|
||||
|
||||
# sub 4
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=0.5,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.backbone.layer3(x)
|
||||
x = self.backbone.layer4(x)
|
||||
psp_outs = self.psp_modules(x) + [x]
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
x = self.psp_bottleneck(psp_outs)
|
||||
|
||||
output.append(self.conv_sub4(x))
|
||||
|
||||
return output
|
||||
260
Seg_All_In_One_MMSeg/mmseg/models/backbones/mae.py
Normal file
260
Seg_All_In_One_MMSeg/mmseg/models/backbones/mae.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.import math
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import _load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
|
||||
|
||||
|
||||
class MAEAttention(BEiTAttention):
|
||||
"""Multi-head self-attention with relative position bias used in MAE.
|
||||
|
||||
This module is different from ``BEiTAttention`` by initializing the
|
||||
relative bias table with zeros.
|
||||
"""
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize relative position bias with zeros."""
|
||||
|
||||
# As MAE initializes relative position bias as zeros and this class
|
||||
# inherited from BEiT which initializes relative position bias
|
||||
# with `trunc_normal`, `init_weights` here does
|
||||
# nothing and just passes directly
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
|
||||
This module is different from ``BEiTTransformerEncoderLayer`` by replacing
|
||||
``BEiTAttention`` with ``MAEAttention``.
|
||||
"""
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = MAEAttention(**attn_cfg)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MAE(BEiT):
|
||||
"""VisionTransformer with support for patch.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): embedding dimension. Default: 768.
|
||||
num_layers (int): depth of transformer. Default: 12.
|
||||
num_heads (int): number of attention heads. Default: 12.
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_values (float): Initialize the values of Attention and FFN
|
||||
with learnable scaling. Defaults to 0.1.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_indices=-1,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
final_norm=False,
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
pretrained=None,
|
||||
init_values=0.1,
|
||||
init_cfg=None):
|
||||
super().__init__(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
out_indices=out_indices,
|
||||
qv_bias=False,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
patch_norm=patch_norm,
|
||||
final_norm=final_norm,
|
||||
num_fcs=num_fcs,
|
||||
norm_eval=norm_eval,
|
||||
pretrained=pretrained,
|
||||
init_values=init_values,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
|
||||
self.num_patches = self.patch_shape[0] * self.patch_shape[1]
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, self.num_patches + 1, embed_dims))
|
||||
|
||||
def _build_layers(self):
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
|
||||
]
|
||||
self.layers = ModuleList()
|
||||
for i in range(self.num_layers):
|
||||
self.layers.append(
|
||||
MAETransformerEncoderLayer(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.num_heads,
|
||||
feedforward_channels=self.mlp_ratio * self.embed_dims,
|
||||
attn_drop_rate=self.attn_drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=self.num_fcs,
|
||||
bias=True,
|
||||
act_cfg=self.act_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
window_size=self.patch_shape,
|
||||
init_values=self.init_values))
|
||||
|
||||
def fix_init_weight(self):
|
||||
"""Rescale the initialization according to layer id.
|
||||
|
||||
This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT License
|
||||
"""
|
||||
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.layers):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
self.apply(_init_weights)
|
||||
self.fix_init_weight()
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
checkpoint = _load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||
state_dict = self.resize_abs_pos_embed(state_dict)
|
||||
self.load_state_dict(state_dict, False)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
# Copyright 2019 Ross Wightman
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def resize_abs_pos_embed(self, state_dict):
|
||||
if 'pos_embed' in state_dict:
|
||||
pos_embed_checkpoint = state_dict['pos_embed']
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int(
|
||||
(pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(self.num_patches**0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
|
||||
embedding_size).permute(
|
||||
0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode='bicubic',
|
||||
align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
state_dict['pos_embed'] = new_pos_embed
|
||||
return state_dict
|
||||
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
out = x[:, 1:]
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
450
Seg_All_In_One_MMSeg/mmseg/models/backbones/mit.py
Normal file
450
Seg_All_In_One_MMSeg/mmseg/models/backbones/mit.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
|
||||
|
||||
|
||||
class MixFFN(BaseModule):
|
||||
"""An implementation of MixFFN of Segformer.
|
||||
|
||||
The differences between MixFFN & FFN:
|
||||
1. Use 1X1 Conv to replace Linear layer.
|
||||
2. Introduce 3X3 Conv to encode positional information.
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Same as
|
||||
`MultiheadAttention`. Defaults: 256.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 1024.
|
||||
act_cfg (dict, optional): The activation config for FFNs.
|
||||
Default: dict(type='ReLU')
|
||||
ffn_drop (float, optional): Probability of an element to be
|
||||
zeroed in FFN. Default 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
feedforward_channels,
|
||||
act_cfg=dict(type='GELU'),
|
||||
ffn_drop=0.,
|
||||
dropout_layer=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.feedforward_channels = feedforward_channels
|
||||
self.act_cfg = act_cfg
|
||||
self.activate = build_activation_layer(act_cfg)
|
||||
|
||||
in_channels = embed_dims
|
||||
fc1 = Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True)
|
||||
# 3x3 depth wise conv to provide positional encode information
|
||||
pe_conv = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=(3 - 1) // 2,
|
||||
bias=True,
|
||||
groups=feedforward_channels)
|
||||
fc2 = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True)
|
||||
drop = nn.Dropout(ffn_drop)
|
||||
layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
|
||||
self.layers = Sequential(*layers)
|
||||
self.dropout_layer = build_dropout(
|
||||
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||||
|
||||
def forward(self, x, hw_shape, identity=None):
|
||||
out = nlc_to_nchw(x, hw_shape)
|
||||
out = self.layers(out)
|
||||
out = nchw_to_nlc(out)
|
||||
if identity is None:
|
||||
identity = x
|
||||
return identity + self.dropout_layer(out)
|
||||
|
||||
|
||||
class EfficientMultiheadAttention(MultiheadAttention):
|
||||
"""An implementation of Efficient Multi-head Attention of Segformer.
|
||||
|
||||
This module is modified from MultiheadAttention which is a module from
|
||||
mmcv.cnn.bricks.transformer.
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||||
Default: 0.0.
|
||||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||||
Default: 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut. Default: None.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: False.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
dropout_layer=None,
|
||||
init_cfg=None,
|
||||
batch_first=True,
|
||||
qkv_bias=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=1):
|
||||
super().__init__(
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop,
|
||||
proj_drop,
|
||||
dropout_layer=dropout_layer,
|
||||
init_cfg=init_cfg,
|
||||
batch_first=batch_first,
|
||||
bias=qkv_bias)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1:
|
||||
self.sr = Conv2d(
|
||||
in_channels=embed_dims,
|
||||
out_channels=embed_dims,
|
||||
kernel_size=sr_ratio,
|
||||
stride=sr_ratio)
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
# handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
|
||||
from mmseg import digit_version, mmcv_version
|
||||
if mmcv_version < digit_version('1.3.17'):
|
||||
warnings.warn('The legacy version of forward function in'
|
||||
'EfficientMultiheadAttention is deprecated in'
|
||||
'mmcv>=1.3.17 and will no longer support in the'
|
||||
'future. Please upgrade your mmcv.')
|
||||
self.forward = self.legacy_forward
|
||||
|
||||
def forward(self, x, hw_shape, identity=None):
|
||||
|
||||
x_q = x
|
||||
if self.sr_ratio > 1:
|
||||
x_kv = nlc_to_nchw(x, hw_shape)
|
||||
x_kv = self.sr(x_kv)
|
||||
x_kv = nchw_to_nlc(x_kv)
|
||||
x_kv = self.norm(x_kv)
|
||||
else:
|
||||
x_kv = x
|
||||
|
||||
if identity is None:
|
||||
identity = x_q
|
||||
|
||||
# Because the dataflow('key', 'query', 'value') of
|
||||
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
|
||||
# embed_dims), We should adjust the shape of dataflow from
|
||||
# batch_first (batch, num_query, embed_dims) to num_query_first
|
||||
# (num_query ,batch, embed_dims), and recover ``attn_output``
|
||||
# from num_query_first to batch_first.
|
||||
if self.batch_first:
|
||||
x_q = x_q.transpose(0, 1)
|
||||
x_kv = x_kv.transpose(0, 1)
|
||||
|
||||
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
|
||||
|
||||
if self.batch_first:
|
||||
out = out.transpose(0, 1)
|
||||
|
||||
return identity + self.dropout_layer(self.proj_drop(out))
|
||||
|
||||
def legacy_forward(self, x, hw_shape, identity=None):
|
||||
"""multi head attention forward in mmcv version < 1.3.17."""
|
||||
|
||||
x_q = x
|
||||
if self.sr_ratio > 1:
|
||||
x_kv = nlc_to_nchw(x, hw_shape)
|
||||
x_kv = self.sr(x_kv)
|
||||
x_kv = nchw_to_nlc(x_kv)
|
||||
x_kv = self.norm(x_kv)
|
||||
else:
|
||||
x_kv = x
|
||||
|
||||
if identity is None:
|
||||
identity = x_q
|
||||
|
||||
# `need_weights=True` will let nn.MultiHeadAttention
|
||||
# `return attn_output, attn_output_weights.sum(dim=1) / num_heads`
|
||||
# The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set
|
||||
# `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`.
|
||||
# This issue - `https://github.com/pytorch/pytorch/issues/37583` report
|
||||
# the error that large scale tensor sum operation may cause cuda error.
|
||||
out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0]
|
||||
|
||||
return identity + self.dropout_layer(self.proj_drop(out))
|
||||
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Segformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
after the feed forward layer. Default 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
qkv_bias (bool): enable bias for qkv if True.
|
||||
Default: True.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: False.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Default:None.
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
sr_ratio=1,
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
self.attn = EfficientMultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
batch_first=batch_first,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratio)
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
self.ffn = MixFFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.with_cp = with_cp
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
||||
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixVisionTransformer(BaseModule):
|
||||
"""The backbone of Segformer.
|
||||
|
||||
This backbone is the implementation of `SegFormer: Simple and
|
||||
Efficient Design for Semantic Segmentation with
|
||||
Transformers <https://arxiv.org/abs/2105.15203>`_.
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): Embedding dimension. Default: 768.
|
||||
num_stags (int): The num of stages. Default: 4.
|
||||
num_layers (Sequence[int]): The layer number of each transformer encode
|
||||
layer. Default: [3, 4, 6, 3].
|
||||
num_heads (Sequence[int]): The attention heads of each transformer
|
||||
encode layer. Default: [1, 2, 4, 8].
|
||||
patch_sizes (Sequence[int]): The patch_size of each overlapped patch
|
||||
embedding. Default: [7, 3, 3, 3].
|
||||
strides (Sequence[int]): The stride of each overlapped patch embedding.
|
||||
Default: [4, 2, 2, 2].
|
||||
sr_ratios (Sequence[int]): The spatial reduction rate of each
|
||||
transformer encode layer. Default: [8, 4, 2, 1].
|
||||
out_indices (Sequence[int] | int): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=64,
|
||||
num_stages=4,
|
||||
num_layers=[3, 4, 6, 3],
|
||||
num_heads=[1, 2, 4, 8],
|
||||
patch_sizes=[7, 3, 3, 3],
|
||||
strides=[4, 2, 2, 2],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
pretrained=None,
|
||||
init_cfg=None,
|
||||
with_cp=False):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.num_stages = num_stages
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.patch_sizes = patch_sizes
|
||||
self.strides = strides
|
||||
self.sr_ratios = sr_ratios
|
||||
self.with_cp = with_cp
|
||||
assert num_stages == len(num_layers) == len(num_heads) \
|
||||
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
||||
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < self.num_stages
|
||||
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
|
||||
] # stochastic num_layer decay rule
|
||||
|
||||
cur = 0
|
||||
self.layers = ModuleList()
|
||||
for i, num_layer in enumerate(num_layers):
|
||||
embed_dims_i = embed_dims * num_heads[i]
|
||||
patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims_i,
|
||||
kernel_size=patch_sizes[i],
|
||||
stride=strides[i],
|
||||
padding=patch_sizes[i] // 2,
|
||||
norm_cfg=norm_cfg)
|
||||
layer = ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims_i,
|
||||
num_heads=num_heads[i],
|
||||
feedforward_channels=mlp_ratio * embed_dims_i,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[cur + idx],
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
|
||||
])
|
||||
in_channels = embed_dims_i
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
|
||||
self.layers.append(ModuleList([patch_embed, layer, norm]))
|
||||
cur += num_layer
|
||||
|
||||
def init_weights(self):
|
||||
if self.init_cfg is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x, hw_shape = layer[0](x)
|
||||
for block in layer[1]:
|
||||
x = block(x, hw_shape)
|
||||
x = layer[2](x)
|
||||
x = nlc_to_nchw(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
||||
197
Seg_All_In_One_MMSeg/mmseg/models/backbones/mobilenet_v2.py
Normal file
197
Seg_All_In_One_MMSeg/mmseg/models/backbones/mobilenet_v2.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual, make_divisible
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MobileNetV2(BaseModule):
|
||||
"""MobileNetV2 backbone.
|
||||
|
||||
This backbone is the implementation of
|
||||
`MobileNetV2: Inverted Residuals and Linear Bottlenecks
|
||||
<https://arxiv.org/abs/1801.04381>`_.
|
||||
|
||||
Args:
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
strides (Sequence[int], optional): Strides of the first block of each
|
||||
layer. If not specified, default config in ``arch_setting`` will
|
||||
be used.
|
||||
dilations (Sequence[int]): Dilation of each layer.
|
||||
out_indices (None or Sequence[int]): Output from which stages.
|
||||
Default: (7, ).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU6').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
# Parameters to build layers. 3 parameters are needed to construct a
|
||||
# layer, from left to right: expand_ratio, channel, num_blocks.
|
||||
arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
|
||||
[6, 96, 3], [6, 160, 3], [6, 320, 1]]
|
||||
|
||||
def __init__(self,
|
||||
widen_factor=1.,
|
||||
strides=(1, 2, 2, 2, 1, 2, 1),
|
||||
dilations=(1, 1, 1, 1, 1, 1, 1),
|
||||
out_indices=(1, 2, 4, 6),
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.widen_factor = widen_factor
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
assert len(strides) == len(dilations) == len(self.arch_settings)
|
||||
self.out_indices = out_indices
|
||||
for index in out_indices:
|
||||
if index not in range(0, 7):
|
||||
raise ValueError('the item in out_indices must in '
|
||||
f'range(0, 7). But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, 7):
|
||||
raise ValueError('frozen_stages must be in range(-1, 7). '
|
||||
f'But received {frozen_stages}')
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.in_channels = make_divisible(32 * widen_factor, 8)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.layers = []
|
||||
|
||||
for i, layer_cfg in enumerate(self.arch_settings):
|
||||
expand_ratio, channel, num_blocks = layer_cfg
|
||||
stride = self.strides[i]
|
||||
dilation = self.dilations[i]
|
||||
out_channels = make_divisible(channel * widen_factor, 8)
|
||||
inverted_res_layer = self.make_layer(
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_blocks,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
expand_ratio=expand_ratio)
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, inverted_res_layer)
|
||||
self.layers.append(layer_name)
|
||||
|
||||
def make_layer(self, out_channels, num_blocks, stride, dilation,
|
||||
expand_ratio):
|
||||
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
|
||||
|
||||
Args:
|
||||
out_channels (int): out_channels of block.
|
||||
num_blocks (int): Number of blocks.
|
||||
stride (int): Stride of the first block.
|
||||
dilation (int): Dilation of the first block.
|
||||
expand_ratio (int): Expand the number of channels of the
|
||||
hidden layer in InvertedResidual by this ratio.
|
||||
"""
|
||||
layers = []
|
||||
for i in range(num_blocks):
|
||||
layers.append(
|
||||
InvertedResidual(
|
||||
self.in_channels,
|
||||
out_channels,
|
||||
stride if i == 0 else 1,
|
||||
expand_ratio=expand_ratio,
|
||||
dilation=dilation if i == 0 else 1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
with_cp=self.with_cp))
|
||||
self.in_channels = out_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
if len(outs) == 1:
|
||||
return outs[0]
|
||||
else:
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
267
Seg_All_In_One_MMSeg/mmseg/models/backbones/mobilenet_v3.py
Normal file
267
Seg_All_In_One_MMSeg/mmseg/models/backbones/mobilenet_v3.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.cnn.bricks import Conv2dAdaptivePadding
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils import is_tuple_of
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidualV3 as InvertedResidual
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MobileNetV3(BaseModule):
|
||||
"""MobileNetV3 backbone.
|
||||
|
||||
This backbone is the improved implementation of `Searching for MobileNetV3
|
||||
<https://ieeexplore.ieee.org/document/9008835>`_.
|
||||
|
||||
Args:
|
||||
arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
|
||||
Default: 'small'.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
out_indices (tuple[int]): Output from which layer.
|
||||
Default: (0, 1, 12).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
# Parameters to build each block:
|
||||
# [kernel size, mid channels, out channels, with_se, act type, stride]
|
||||
arch_settings = {
|
||||
'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
|
||||
[3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
|
||||
[3, 88, 24, False, 'ReLU', 1],
|
||||
[5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
|
||||
[5, 144, 48, True, 'HSwish', 1],
|
||||
[5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
|
||||
[5, 576, 96, True, 'HSwish', 1],
|
||||
[5, 576, 96, True, 'HSwish', 1]],
|
||||
'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
|
||||
[3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
|
||||
[3, 72, 24, False, 'ReLU', 1],
|
||||
[5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
|
||||
[3, 200, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
|
||||
[3, 672, 112, True, 'HSwish', 1],
|
||||
[5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
|
||||
[5, 960, 160, True, 'HSwish', 1],
|
||||
[5, 960, 160, True, 'HSwish', 1]]
|
||||
} # yapf: disable
|
||||
|
||||
def __init__(self,
|
||||
arch='small',
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
out_indices=(0, 1, 12),
|
||||
frozen_stages=-1,
|
||||
reduction_factor=1,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
assert arch in self.arch_settings
|
||||
assert isinstance(reduction_factor, int) and reduction_factor > 0
|
||||
assert is_tuple_of(out_indices, int)
|
||||
for index in out_indices:
|
||||
if index not in range(0, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError(
|
||||
'the item in out_indices must in '
|
||||
f'range(0, {len(self.arch_settings[arch])+2}). '
|
||||
f'But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError('frozen_stages must be in range(-1, '
|
||||
f'{len(self.arch_settings[arch])+2}). '
|
||||
f'But received {frozen_stages}')
|
||||
self.arch = arch
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.reduction_factor = reduction_factor
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.layers = self._make_layer()
|
||||
|
||||
def _make_layer(self):
|
||||
layers = []
|
||||
|
||||
# build the first layer (layer0)
|
||||
in_channels = 16
|
||||
layer = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=dict(type='Conv2dAdaptivePadding'),
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
self.add_module('layer0', layer)
|
||||
layers.append('layer0')
|
||||
|
||||
layer_setting = self.arch_settings[self.arch]
|
||||
for i, params in enumerate(layer_setting):
|
||||
(kernel_size, mid_channels, out_channels, with_se, act,
|
||||
stride) = params
|
||||
|
||||
if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
|
||||
i >= 8:
|
||||
mid_channels = mid_channels // self.reduction_factor
|
||||
out_channels = out_channels // self.reduction_factor
|
||||
|
||||
if with_se:
|
||||
se_cfg = dict(
|
||||
channels=mid_channels,
|
||||
ratio=4,
|
||||
act_cfg=(dict(type='ReLU'),
|
||||
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
|
||||
else:
|
||||
se_cfg = None
|
||||
|
||||
layer = InvertedResidual(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
mid_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
se_cfg=se_cfg,
|
||||
with_expand_conv=(in_channels != mid_channels),
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type=act),
|
||||
with_cp=self.with_cp)
|
||||
in_channels = out_channels
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
# build the last layer
|
||||
# block5 layer12 os=32 for small model
|
||||
# block6 layer16 os=32 for large model
|
||||
layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=576 if self.arch == 'small' else 960,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
dilation=4,
|
||||
padding=0,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
layer_name = f'layer{len(layer_setting) + 1}'
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
# next, convert backbone MobileNetV3 to a semantic segmentation version
|
||||
if self.arch == 'small':
|
||||
self.layer4.depthwise_conv.conv.stride = (1, 1)
|
||||
self.layer9.depthwise_conv.conv.stride = (1, 1)
|
||||
for i in range(4, len(layers)):
|
||||
layer = getattr(self, layers[i])
|
||||
if isinstance(layer, InvertedResidual):
|
||||
modified_module = layer.depthwise_conv.conv
|
||||
else:
|
||||
modified_module = layer.conv
|
||||
|
||||
if i < 9:
|
||||
modified_module.dilation = (2, 2)
|
||||
pad = 2
|
||||
else:
|
||||
modified_module.dilation = (4, 4)
|
||||
pad = 4
|
||||
|
||||
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
||||
# Adjust padding
|
||||
pad *= (modified_module.kernel_size[0] - 1) // 2
|
||||
modified_module.padding = (pad, pad)
|
||||
else:
|
||||
self.layer7.depthwise_conv.conv.stride = (1, 1)
|
||||
self.layer13.depthwise_conv.conv.stride = (1, 1)
|
||||
for i in range(7, len(layers)):
|
||||
layer = getattr(self, layers[i])
|
||||
if isinstance(layer, InvertedResidual):
|
||||
modified_module = layer.depthwise_conv.conv
|
||||
else:
|
||||
modified_module = layer.conv
|
||||
|
||||
if i < 13:
|
||||
modified_module.dilation = (2, 2)
|
||||
pad = 2
|
||||
else:
|
||||
modified_module.dilation = (4, 4)
|
||||
pad = 4
|
||||
|
||||
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
||||
# Adjust padding
|
||||
pad *= (modified_module.kernel_size[0] - 1) // 2
|
||||
modified_module.padding = (pad, pad)
|
||||
|
||||
return layers
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
return outs
|
||||
|
||||
def _freeze_stages(self):
|
||||
for i in range(self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
467
Seg_All_In_One_MMSeg/mmseg/models/backbones/mscan.py
Normal file
467
Seg_All_In_One_MMSeg/mmseg/models/backbones/mscan.py
Normal file
@@ -0,0 +1,467 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Originally from https://github.com/visual-attention-network/segnext
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class Mlp(BaseModule):
|
||||
"""Multi Layer Perceptron (MLP) Module.
|
||||
|
||||
Args:
|
||||
in_features (int): The dimension of input features.
|
||||
hidden_features (int): The dimension of hidden features.
|
||||
Defaults: None.
|
||||
out_features (int): The dimension of output features.
|
||||
Defaults: None.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
drop (float): The number of dropout rate in MLP block.
|
||||
Defaults: 0.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
||||
self.dwconv = nn.Conv2d(
|
||||
hidden_features,
|
||||
hidden_features,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
bias=True,
|
||||
groups=hidden_features)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.fc1(x)
|
||||
|
||||
x = self.dwconv(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StemConv(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): The dimension of input channels.
|
||||
out_channels (int): The dimension of output channels.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels // 2,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)),
|
||||
build_norm_layer(norm_cfg, out_channels // 2)[1],
|
||||
build_activation_layer(act_cfg),
|
||||
nn.Conv2d(
|
||||
out_channels // 2,
|
||||
out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)),
|
||||
build_norm_layer(norm_cfg, out_channels)[1],
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.size()
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x, H, W
|
||||
|
||||
|
||||
class MSCAAttention(BaseModule):
|
||||
"""Attention Module in Multi-Scale Convolutional Attention Module (MSCA).
|
||||
|
||||
Args:
|
||||
channels (int): The dimension of channels.
|
||||
kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
paddings=[2, [0, 3], [0, 5], [0, 10]]):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=kernel_sizes[0],
|
||||
padding=paddings[0],
|
||||
groups=channels)
|
||||
for i, (kernel_size,
|
||||
padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])):
|
||||
kernel_size_ = [kernel_size, kernel_size[::-1]]
|
||||
padding_ = [padding, padding[::-1]]
|
||||
conv_name = [f'conv{i}_1', f'conv{i}_2']
|
||||
for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_,
|
||||
conv_name):
|
||||
self.add_module(
|
||||
i_conv,
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
tuple(i_kernel),
|
||||
padding=i_pad,
|
||||
groups=channels))
|
||||
self.conv3 = nn.Conv2d(channels, channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
u = x.clone()
|
||||
|
||||
attn = self.conv0(x)
|
||||
|
||||
# Multi-Scale Feature extraction
|
||||
attn_0 = self.conv0_1(attn)
|
||||
attn_0 = self.conv0_2(attn_0)
|
||||
|
||||
attn_1 = self.conv1_1(attn)
|
||||
attn_1 = self.conv1_2(attn_1)
|
||||
|
||||
attn_2 = self.conv2_1(attn)
|
||||
attn_2 = self.conv2_2(attn_2)
|
||||
|
||||
attn = attn + attn_0 + attn_1 + attn_2
|
||||
# Channel Mixing
|
||||
attn = self.conv3(attn)
|
||||
|
||||
# Convolutional Attention
|
||||
x = attn * u
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MSCASpatialAttention(BaseModule):
|
||||
"""Spatial Attention Module in Multi-Scale Convolutional Attention Module
|
||||
(MSCA).
|
||||
|
||||
Args:
|
||||
in_channels (int): The dimension of channels.
|
||||
attention_kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
act_cfg=dict(type='GELU')):
|
||||
super().__init__()
|
||||
self.proj_1 = nn.Conv2d(in_channels, in_channels, 1)
|
||||
self.activation = build_activation_layer(act_cfg)
|
||||
self.spatial_gating_unit = MSCAAttention(in_channels,
|
||||
attention_kernel_sizes,
|
||||
attention_kernel_paddings)
|
||||
self.proj_2 = nn.Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
shorcut = x.clone()
|
||||
x = self.proj_1(x)
|
||||
x = self.activation(x)
|
||||
x = self.spatial_gating_unit(x)
|
||||
x = self.proj_2(x)
|
||||
x = x + shorcut
|
||||
return x
|
||||
|
||||
|
||||
class MSCABlock(BaseModule):
|
||||
"""Basic Multi-Scale Convolutional Attention Block. It leverage the large-
|
||||
kernel attention (LKA) mechanism to build both channel and spatial
|
||||
attention. In each branch, it uses two depth-wise strip convolutions to
|
||||
approximate standard depth-wise convolutions with large kernels. The kernel
|
||||
size for each branch is set to 7, 11, and 21, respectively.
|
||||
|
||||
Args:
|
||||
channels (int): The dimension of channels.
|
||||
attention_kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
mlp_ratio (float): The ratio of multiple input dimension to
|
||||
calculate hidden feature in MLP layer. Defaults: 4.0.
|
||||
drop (float): The number of dropout rate in MLP block.
|
||||
Defaults: 0.0.
|
||||
drop_path (float): The ratio of drop paths.
|
||||
Defaults: 0.0.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
self.norm1 = build_norm_layer(norm_cfg, channels)[1]
|
||||
self.attn = MSCASpatialAttention(channels, attention_kernel_sizes,
|
||||
attention_kernel_paddings, act_cfg)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = build_norm_layer(norm_cfg, channels)[1]
|
||||
mlp_hidden_channels = int(channels * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=channels,
|
||||
hidden_features=mlp_hidden_channels,
|
||||
act_cfg=act_cfg,
|
||||
drop=drop)
|
||||
layer_scale_init_value = 1e-2
|
||||
self.layer_scale_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(channels), requires_grad=True)
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(channels), requires_grad=True)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
"""Forward function."""
|
||||
|
||||
B, N, C = x.shape
|
||||
x = x.permute(0, 2, 1).view(B, C, H, W)
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
|
||||
self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
|
||||
self.mlp(self.norm2(x)))
|
||||
x = x.view(B, C, N).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class OverlapPatchEmbed(BaseModule):
|
||||
"""Image to Patch Embedding.
|
||||
|
||||
Args:
|
||||
patch_size (int): The patch size.
|
||||
Defaults: 7.
|
||||
stride (int): Stride of the convolutional layer.
|
||||
Default: 4.
|
||||
in_channels (int): The number of input channels.
|
||||
Defaults: 3.
|
||||
embed_dims (int): The dimensions of embedding.
|
||||
Defaults: 768.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
patch_size=7,
|
||||
stride=4,
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=stride,
|
||||
padding=patch_size // 2)
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = self.norm(x)
|
||||
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MSCAN(BaseModule):
|
||||
"""SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone.
|
||||
|
||||
This backbone is the implementation of `SegNeXt: Rethinking
|
||||
Convolutional Attention Design for Semantic
|
||||
Segmentation <https://arxiv.org/abs/2209.08575>`_.
|
||||
Inspiration from https://github.com/visual-attention-network/segnext.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels. Defaults: 3.
|
||||
embed_dims (list[int]): Embedding dimension.
|
||||
Defaults: [64, 128, 256, 512].
|
||||
mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim.
|
||||
Defaults: [4, 4, 4, 4].
|
||||
drop_rate (float): Dropout rate. Defaults: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults: 0.
|
||||
depths (list[int]): Depths of each Swin Transformer stage.
|
||||
Default: [3, 4, 6, 3].
|
||||
num_stages (int): MSCAN stages. Default: 4.
|
||||
attention_kernel_sizes (list): Size of attention kernel in
|
||||
Attention Module (Figure 2(b) of original paper).
|
||||
Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): Size of attention paddings
|
||||
in Attention Module (Figure 2(b) of original paper).
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
norm_cfg (dict): Config of norm layers.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
pretrained (str, optional): model pretrained path.
|
||||
Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[64, 128, 256, 512],
|
||||
mlp_ratios=[4, 4, 4, 4],
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
depths=[3, 4, 6, 3],
|
||||
num_stages=4,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.depths = depths
|
||||
self.num_stages = num_stages
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
cur = 0
|
||||
|
||||
for i in range(num_stages):
|
||||
if i == 0:
|
||||
patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg)
|
||||
else:
|
||||
patch_embed = OverlapPatchEmbed(
|
||||
patch_size=7 if i == 0 else 3,
|
||||
stride=4 if i == 0 else 2,
|
||||
in_channels=in_channels if i == 0 else embed_dims[i - 1],
|
||||
embed_dim=embed_dims[i],
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
block = nn.ModuleList([
|
||||
MSCABlock(
|
||||
channels=embed_dims[i],
|
||||
attention_kernel_sizes=attention_kernel_sizes,
|
||||
attention_kernel_paddings=attention_kernel_paddings,
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
drop=drop_rate,
|
||||
drop_path=dpr[cur + j],
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg) for j in range(depths[i])
|
||||
])
|
||||
norm = nn.LayerNorm(embed_dims[i])
|
||||
cur += depths[i]
|
||||
|
||||
setattr(self, f'patch_embed{i + 1}', patch_embed)
|
||||
setattr(self, f'block{i + 1}', block)
|
||||
setattr(self, f'norm{i + 1}', norm)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize modules of MSCAN."""
|
||||
|
||||
print('init cfg', self.init_cfg)
|
||||
if self.init_cfg is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
B = x.shape[0]
|
||||
outs = []
|
||||
|
||||
for i in range(self.num_stages):
|
||||
patch_embed = getattr(self, f'patch_embed{i + 1}')
|
||||
block = getattr(self, f'block{i + 1}')
|
||||
norm = getattr(self, f'norm{i + 1}')
|
||||
x, H, W = patch_embed(x)
|
||||
for blk in block:
|
||||
x = blk(x, H, W)
|
||||
x = norm(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
||||
629
Seg_All_In_One_MMSeg/mmseg/models/backbones/my_bisnetv2_A1.py
Normal file
629
Seg_All_In_One_MMSeg/mmseg/models/backbones/my_bisnetv2_A1.py
Normal file
@@ -0,0 +1,629 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
||||
build_activation_layer, build_norm_layer)
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class DetailBranch(BaseModule):
|
||||
"""Detail Branch with wide channels and shallow layers to capture low-level
|
||||
details and generate high-resolution feature representation.
|
||||
|
||||
Args:
|
||||
detail_channels (Tuple[int]): Size of channel numbers of each stage
|
||||
in Detail Branch, in paper it has 3 stages.
|
||||
Default: (64, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Feature map of Detail Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
detail_channels=(64, 64, 128),
|
||||
in_channels=3,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
detail_branch = []
|
||||
for i in range(len(detail_channels)):
|
||||
if i == 0:
|
||||
in_ch = in_channels
|
||||
else:
|
||||
in_ch = detail_channels[i - 1]
|
||||
out_ch = detail_channels[i]
|
||||
|
||||
# 使用 DepthwiseSeparableConvModule 替换 ConvModule
|
||||
if i == 0:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
# TODO V1. 修改为深度可分离卷积
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
# TODO V1. 修改为深度可分离卷积
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=out_ch,
|
||||
out_channels=out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
else:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
# TODO V1. 修改为深度可分离卷积
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
# TODO V1. 修改为深度可分离卷积
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=out_ch,
|
||||
out_channels=out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
# TODO V1. 修改为深度可分离卷积
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=out_ch,
|
||||
out_channels=out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
self.detail_branch = nn.ModuleList(detail_branch)
|
||||
|
||||
def forward(self, x):
|
||||
for stage in self.detail_branch:
|
||||
x = stage(x)
|
||||
return x
|
||||
|
||||
|
||||
class StemBlock(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): First feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_first = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.convs = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=out_channels // 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.pool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=False)
|
||||
self.fuse_last = ConvModule(
|
||||
in_channels=out_channels * 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_first(x)
|
||||
x_left = self.convs(x)
|
||||
x_right = self.pool(x)
|
||||
x = self.fuse_last(torch.cat([x_left, x_right], dim=1))
|
||||
return x
|
||||
|
||||
|
||||
class GELayer(BaseModule):
|
||||
"""Gather-and-Expansion Layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
stride (int): Stride of GELayer. Default: 1
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Intermediate feature map in
|
||||
Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
exp_ratio=6,
|
||||
stride=1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
mid_channel = in_channels * exp_ratio
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if stride == 1:
|
||||
self.dwconv = nn.Sequential(
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.shortcut = None
|
||||
else:
|
||||
self.dwconv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=mid_channel,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
)
|
||||
self.shortcut = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=norm_cfg,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None,
|
||||
))
|
||||
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.dwconv(x)
|
||||
x = self.conv2(x)
|
||||
if self.shortcut is not None:
|
||||
shortcut = self.shortcut(identity)
|
||||
x = x + shortcut
|
||||
else:
|
||||
x = x + identity
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class CEBlock(BaseModule):
|
||||
"""Context Embedding Block for large receptive filed in Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Last feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.gap = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
build_norm_layer(norm_cfg, self.in_channels)[1])
|
||||
self.conv_gap = ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# Note: in paper here is naive conv2d, no bn-relu
|
||||
self.conv_last = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.gap(x)
|
||||
x = self.conv_gap(x)
|
||||
x = identity + x
|
||||
x = self.conv_last(x)
|
||||
return x
|
||||
|
||||
|
||||
class SemanticBranch(BaseModule):
|
||||
"""Semantic Branch which is lightweight with narrow channels and deep
|
||||
layers to obtain high-level semantic context.
|
||||
|
||||
Args:
|
||||
semantic_channels(Tuple[int]): Size of channel numbers of
|
||||
various stages in Semantic Branch.
|
||||
Default: (16, 32, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
semantic_outs (List[torch.Tensor]): List of several feature maps
|
||||
for auxiliary heads (Booster) and Bilateral
|
||||
Guided Aggregation Layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
in_channels=3,
|
||||
exp_ratio=6,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_stages = []
|
||||
for i in range(len(semantic_channels)):
|
||||
stage_name = f'stage{i + 1}'
|
||||
self.semantic_stages.append(stage_name)
|
||||
if i == 0:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
StemBlock(self.in_channels, semantic_channels[i]))
|
||||
elif i == (len(semantic_channels) - 1):
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
else:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
|
||||
self.add_module(f'stage{len(semantic_channels)}_CEBlock',
|
||||
CEBlock(semantic_channels[-1], semantic_channels[-1]))
|
||||
self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock')
|
||||
|
||||
def forward(self, x):
|
||||
semantic_outs = []
|
||||
for stage_name in self.semantic_stages:
|
||||
semantic_stage = getattr(self, stage_name)
|
||||
x = semantic_stage(x)
|
||||
semantic_outs.append(x)
|
||||
return semantic_outs
|
||||
|
||||
|
||||
class BGALayer(BaseModule):
|
||||
"""Bilateral Guided Aggregation Layer to fuse the complementary information
|
||||
from both Detail Branch and Semantic Branch.
|
||||
|
||||
Args:
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
output (torch.Tensor): Output feature map for Segment heads.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_channels=128,
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.out_channels = out_channels
|
||||
self.align_corners = align_corners
|
||||
self.detail_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.detail_down = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
|
||||
self.semantic_conv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
self.semantic_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.conv = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
inplace=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
|
||||
def forward(self, x_d, x_s):
|
||||
detail_dwconv = self.detail_dwconv(x_d)
|
||||
detail_down = self.detail_down(x_d)
|
||||
semantic_conv = self.semantic_conv(x_s)
|
||||
semantic_dwconv = self.semantic_dwconv(x_s)
|
||||
semantic_conv = resize(
|
||||
input=semantic_conv,
|
||||
size=detail_dwconv.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
|
||||
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
|
||||
fuse_2 = resize(
|
||||
input=fuse_2,
|
||||
size=fuse_1.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
output = self.conv(fuse_1 + fuse_2)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class My_BiSeNetV2_A1(BaseModule): # A1:深度可分离卷积
|
||||
"""BiSeNetV2: Bilateral Network with Guided Aggregation for
|
||||
Real-time Semantic Segmentation.
|
||||
|
||||
This backbone is the implementation of
|
||||
`BiSeNetV2 <https://arxiv.org/abs/2004.02147>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channel of input image. Default: 3.
|
||||
detail_channels (Tuple[int], optional): Channels of each stage
|
||||
in Detail Branch. Default: (64, 64, 128).
|
||||
semantic_channels (Tuple[int], optional): Channels of each stage
|
||||
in Semantic Branch. Default: (16, 32, 64, 128).
|
||||
See Table 1 and Figure 3 of paper for more details.
|
||||
semantic_expansion_ratio (int, optional): The expansion factor
|
||||
expanding channel number of middle channels in Semantic Branch.
|
||||
Default: 6.
|
||||
bga_channels (int, optional): Number of middle channels in
|
||||
Bilateral Guided Aggregation Layer. Default: 128.
|
||||
out_indices (Tuple[int] | int, optional): Output from which stages.
|
||||
Default: (0, 1, 2, 3, 4).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation in Bilateral Guided Aggregation Layer.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
detail_channels=(64, 64, 128),
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
semantic_expansion_ratio=6,
|
||||
bga_channels=128,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_indices = out_indices
|
||||
self.detail_channels = detail_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_expansion_ratio = semantic_expansion_ratio
|
||||
self.bga_channels = bga_channels
|
||||
self.align_corners = align_corners
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.detail = DetailBranch(self.detail_channels, self.in_channels)
|
||||
self.semantic = SemanticBranch(self.semantic_channels,
|
||||
self.in_channels,
|
||||
self.semantic_expansion_ratio)
|
||||
self.bga = BGALayer(self.bga_channels, self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
# stole refactoring code from Coin Cheung, thanks
|
||||
x_detail = self.detail(x)
|
||||
x_semantic_lst = self.semantic(x)
|
||||
x_head = self.bga(x_detail, x_semantic_lst[-1])
|
||||
outs = [x_head] + x_semantic_lst[:-1]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
@@ -0,0 +1,591 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
||||
build_activation_layer, build_norm_layer)
|
||||
from mmengine.model import BaseModule
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
# ===================================================================
|
||||
# Part 1: CAFM 模块 (稍作修改以简化输出)
|
||||
# ===================================================================
|
||||
class CAFM(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(CAFM, self).__init__()
|
||||
self.conv1_spatial = nn.Conv2d(2, 1, 3, stride=1, padding=1, groups=1)
|
||||
self.conv2_spatial = nn.Conv2d(1, 1, 3, stride=1, padding=1, groups=1)
|
||||
self.avg1 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
|
||||
self.avg2 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
|
||||
self.max1 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
|
||||
self.max2 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
|
||||
self.avg11 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
|
||||
self.avg22 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
|
||||
self.max11 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
|
||||
self.max22 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
|
||||
|
||||
def forward(self, f1, f2):
|
||||
b, c, h, w = f1.size()
|
||||
f1_reshaped = f1.reshape([b, c, -1])
|
||||
f2_reshaped = f2.reshape([b, c, -1])
|
||||
|
||||
avg_1 = torch.mean(f1_reshaped, dim=-1, keepdim=True).unsqueeze(-1)
|
||||
max_1, _ = torch.max(f1_reshaped, dim=-1, keepdim=True)
|
||||
max_1 = max_1.unsqueeze(-1)
|
||||
avg_1 = F.relu(self.avg1(avg_1))
|
||||
max_1 = F.relu(self.max1(max_1))
|
||||
avg_1 = self.avg11(avg_1).squeeze(-1)
|
||||
max_1 = self.max11(max_1).squeeze(-1)
|
||||
a1 = avg_1 + max_1
|
||||
|
||||
avg_2 = torch.mean(f2_reshaped, dim=-1, keepdim=True).unsqueeze(-1)
|
||||
max_2, _ = torch.max(f2_reshaped, dim=-1, keepdim=True)
|
||||
max_2 = max_2.unsqueeze(-1)
|
||||
avg_2 = F.relu(self.avg2(avg_2))
|
||||
max_2 = F.relu(self.max2(max_2))
|
||||
avg_2 = self.avg22(avg_2).squeeze(-1)
|
||||
max_2 = self.max22(max_2).squeeze(-1)
|
||||
a2 = avg_2 + max_2
|
||||
|
||||
cross = torch.matmul(a1, a2.transpose(1, 2))
|
||||
a1 = torch.matmul(F.softmax(cross, dim=-1), f1_reshaped)
|
||||
a2 = torch.matmul(F.softmax(cross.transpose(1, 2), dim=-1), f2_reshaped)
|
||||
|
||||
# Spatial attention for f1
|
||||
a1_spatial = a1.reshape([b, c, h, w])
|
||||
avg_out = torch.mean(a1_spatial, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(a1_spatial, dim=1, keepdim=True)
|
||||
a1_spatial = torch.cat([avg_out, max_out], dim=1)
|
||||
a1_spatial = F.relu(self.conv1_spatial(a1_spatial))
|
||||
a1_spatial = self.conv2_spatial(a1_spatial)
|
||||
a1_spatial = a1_spatial.reshape([b, 1, -1])
|
||||
a1_spatial = F.softmax(a1_spatial, dim=-1)
|
||||
|
||||
# Spatial attention for f2
|
||||
a2_spatial = a2.reshape([b, c, h, w])
|
||||
avg_out = torch.mean(a2_spatial, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(a2_spatial, dim=1, keepdim=True)
|
||||
a2_spatial = torch.cat([avg_out, max_out], dim=1)
|
||||
a2_spatial = F.relu(self.conv1_spatial(a2_spatial))
|
||||
a2_spatial = self.conv2_spatial(a2_spatial)
|
||||
a2_spatial = a2_spatial.reshape([b, 1, -1])
|
||||
a2_spatial = F.softmax(a2_spatial, dim=-1)
|
||||
|
||||
f1_out = f1_reshaped * a1_spatial + f1_reshaped
|
||||
f2_out = f2_reshaped * a2_spatial + f2_reshaped
|
||||
|
||||
# 修改点: 直接返回 [B, C, H, W] 格式
|
||||
return f1_out.reshape(b, c, h, w), f2_out.reshape(b, c, h, w)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Part 2: 全新的 CrossFusedBranches 模块
|
||||
# ===================================================================
|
||||
class CrossFusedBranches(BaseModule):
|
||||
def __init__(self,
|
||||
detail_channels,
|
||||
in_channels,
|
||||
semantic_channels,
|
||||
semantic_expansion_ratio,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
# 1. --- Porting Detail Branch Layers (with Depthwise Separable Conv) ---
|
||||
# 将这里的 ConvModule 全部替换为 DepthwiseSeparableConvModule
|
||||
self.d_stage1 = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(in_channels, detail_channels[0], 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg),
|
||||
DepthwiseSeparableConvModule(detail_channels[0], detail_channels[0], 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
)
|
||||
self.d_stage2 = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(detail_channels[0], detail_channels[1], 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg),
|
||||
DepthwiseSeparableConvModule(detail_channels[1], detail_channels[1], 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg),
|
||||
DepthwiseSeparableConvModule(detail_channels[1], detail_channels[1], 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
)
|
||||
self.d_stage3 = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(detail_channels[1], detail_channels[2], 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg),
|
||||
DepthwiseSeparableConvModule(detail_channels[2], detail_channels[2], 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg),
|
||||
DepthwiseSeparableConvModule(detail_channels[2], detail_channels[2], 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
)
|
||||
|
||||
# 2. --- Porting Semantic Branch Layers ---
|
||||
# (这部分保持不变)
|
||||
self.s_stage1 = StemBlock(in_channels, semantic_channels[0])
|
||||
self.s_stage2 = nn.Sequential(
|
||||
GELayer(semantic_channels[0], semantic_channels[1], semantic_expansion_ratio, 2),
|
||||
GELayer(semantic_channels[1], semantic_channels[1], semantic_expansion_ratio, 1)
|
||||
)
|
||||
self.s_stage3 = nn.Sequential(
|
||||
GELayer(semantic_channels[1], semantic_channels[2], semantic_expansion_ratio, 2),
|
||||
GELayer(semantic_channels[2], semantic_channels[2], semantic_expansion_ratio, 1)
|
||||
)
|
||||
self.s_stage4 = nn.Sequential(
|
||||
GELayer(semantic_channels[2], semantic_channels[3], semantic_expansion_ratio, 2),
|
||||
GELayer(semantic_channels[3], semantic_channels[3], semantic_expansion_ratio, 1),
|
||||
GELayer(semantic_channels[3], semantic_channels[3], semantic_expansion_ratio, 1),
|
||||
GELayer(semantic_channels[3], semantic_channels[3], semantic_expansion_ratio, 1)
|
||||
)
|
||||
self.s_ce_block = CEBlock(semantic_channels[3], semantic_channels[3])
|
||||
|
||||
# 3. --- Cross Fusion Modules ---
|
||||
# (这部分保持不变)
|
||||
fusion_channels = detail_channels[1]
|
||||
self.adapter_d = ConvModule(detail_channels[1], fusion_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.adapter_s = ConvModule(semantic_channels[2], fusion_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.cafm = CAFM(channels=fusion_channels)
|
||||
self.injection_conv = ConvModule(fusion_channels, detail_channels[1], 3, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
# Forward 逻辑完全保持不变
|
||||
semantic_outs = []
|
||||
|
||||
d1 = self.d_stage1(x)
|
||||
d2 = self.d_stage2(d1)
|
||||
s1 = self.s_stage1(x)
|
||||
semantic_outs.append(s1)
|
||||
s2 = self.s_stage2(s1)
|
||||
semantic_outs.append(s2)
|
||||
s3 = self.s_stage3(s2)
|
||||
semantic_outs.append(s3)
|
||||
|
||||
s_for_fusion = resize(s3, size=d2.shape[2:], mode='bilinear', align_corners=False)
|
||||
d_adapted = self.adapter_d(d2)
|
||||
s_adapted = self.adapter_s(s_for_fusion)
|
||||
fused_d, _ = self.cafm(d_adapted, s_adapted)
|
||||
fused_processed = self.injection_conv(fused_d)
|
||||
d2_enhanced = d2 + fused_processed
|
||||
|
||||
detail_final = self.d_stage3(d2_enhanced)
|
||||
s4 = self.s_stage4(s3)
|
||||
semantic_outs.append(s4)
|
||||
s_final = self.s_ce_block(s4)
|
||||
semantic_outs.append(s_final)
|
||||
|
||||
return detail_final, semantic_outs
|
||||
|
||||
|
||||
class StemBlock(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): First feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_first = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.convs = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=out_channels // 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.pool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=False)
|
||||
self.fuse_last = ConvModule(
|
||||
in_channels=out_channels * 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_first(x)
|
||||
x_left = self.convs(x)
|
||||
x_right = self.pool(x)
|
||||
x = self.fuse_last(torch.cat([x_left, x_right], dim=1))
|
||||
return x
|
||||
|
||||
|
||||
class GELayer(BaseModule):
|
||||
"""Gather-and-Expansion Layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
stride (int): Stride of GELayer. Default: 1
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Intermediate feature map in
|
||||
Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
exp_ratio=6,
|
||||
stride=1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
mid_channel = in_channels * exp_ratio
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if stride == 1:
|
||||
self.dwconv = nn.Sequential(
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.shortcut = None
|
||||
else:
|
||||
self.dwconv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=mid_channel,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
)
|
||||
self.shortcut = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=norm_cfg,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None,
|
||||
))
|
||||
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.dwconv(x)
|
||||
x = self.conv2(x)
|
||||
if self.shortcut is not None:
|
||||
shortcut = self.shortcut(identity)
|
||||
x = x + shortcut
|
||||
else:
|
||||
x = x + identity
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class CEBlock(BaseModule):
|
||||
"""Context Embedding Block for large receptive filed in Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Last feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.gap = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
build_norm_layer(norm_cfg, self.in_channels)[1])
|
||||
self.conv_gap = ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# Note: in paper here is naive conv2d, no bn-relu
|
||||
self.conv_last = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.gap(x)
|
||||
x = self.conv_gap(x)
|
||||
x = identity + x
|
||||
x = self.conv_last(x)
|
||||
return x
|
||||
|
||||
|
||||
class BGALayer(BaseModule):
|
||||
"""Bilateral Guided Aggregation Layer to fuse the complementary information
|
||||
from both Detail Branch and Semantic Branch.
|
||||
|
||||
Args:
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
output (torch.Tensor): Output feature map for Segment heads.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_channels=128,
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.out_channels = out_channels
|
||||
self.align_corners = align_corners
|
||||
self.detail_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.detail_down = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
|
||||
self.semantic_conv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
self.semantic_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.conv = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
inplace=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
|
||||
def forward(self, x_d, x_s):
|
||||
detail_dwconv = self.detail_dwconv(x_d)
|
||||
detail_down = self.detail_down(x_d)
|
||||
semantic_conv = self.semantic_conv(x_s)
|
||||
semantic_dwconv = self.semantic_dwconv(x_s)
|
||||
semantic_conv = resize(
|
||||
input=semantic_conv,
|
||||
size=detail_dwconv.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
|
||||
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
|
||||
fuse_2 = resize(
|
||||
input=fuse_2,
|
||||
size=fuse_1.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
output = self.conv(fuse_1 + fuse_2)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class My_BiSeNetV2_A1_add_A2(BaseModule):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
detail_channels=(64, 64, 128),
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
semantic_expansion_ratio=6,
|
||||
bga_channels=128,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_indices = out_indices
|
||||
self.align_corners = align_corners
|
||||
|
||||
# --- 核心修改点 ---
|
||||
# 用一个 CrossFusedBranches 模块替换掉旧的 self.detail 和 self.semantic
|
||||
self.branches = CrossFusedBranches(
|
||||
detail_channels=detail_channels,
|
||||
in_channels=in_channels,
|
||||
semantic_channels=semantic_channels,
|
||||
semantic_expansion_ratio=semantic_expansion_ratio,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg
|
||||
)
|
||||
|
||||
self.bga = BGALayer(bga_channels, self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
# --- 核心修改点 ---
|
||||
# 一次调用即可获得两个分支的最终输出
|
||||
x_detail, x_semantic_lst = self.branches(x)
|
||||
|
||||
# 后续代码保持不变
|
||||
x_head = self.bga(x_detail, x_semantic_lst[-1])
|
||||
outs = [x_head] + x_semantic_lst[:-1]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
597
Seg_All_In_One_MMSeg/mmseg/models/backbones/my_bisnetv2_A2.py
Normal file
597
Seg_All_In_One_MMSeg/mmseg/models/backbones/my_bisnetv2_A2.py
Normal file
@@ -0,0 +1,597 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
||||
build_activation_layer, build_norm_layer)
|
||||
from mmengine.model import BaseModule
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
# ===================================================================
|
||||
# Part 1: CAFM 模块 (稍作修改以简化输出)
|
||||
# ===================================================================
|
||||
class CAFM(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(CAFM, self).__init__()
|
||||
self.conv1_spatial = nn.Conv2d(2, 1, 3, stride=1, padding=1, groups=1)
|
||||
self.conv2_spatial = nn.Conv2d(1, 1, 3, stride=1, padding=1, groups=1)
|
||||
self.avg1 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
|
||||
self.avg2 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
|
||||
self.max1 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
|
||||
self.max2 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
|
||||
self.avg11 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
|
||||
self.avg22 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
|
||||
self.max11 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
|
||||
self.max22 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
|
||||
|
||||
def forward(self, f1, f2):
|
||||
b, c, h, w = f1.size()
|
||||
f1_reshaped = f1.reshape([b, c, -1])
|
||||
f2_reshaped = f2.reshape([b, c, -1])
|
||||
|
||||
avg_1 = torch.mean(f1_reshaped, dim=-1, keepdim=True).unsqueeze(-1)
|
||||
max_1, _ = torch.max(f1_reshaped, dim=-1, keepdim=True)
|
||||
max_1 = max_1.unsqueeze(-1)
|
||||
avg_1 = F.relu(self.avg1(avg_1))
|
||||
max_1 = F.relu(self.max1(max_1))
|
||||
avg_1 = self.avg11(avg_1).squeeze(-1)
|
||||
max_1 = self.max11(max_1).squeeze(-1)
|
||||
a1 = avg_1 + max_1
|
||||
|
||||
avg_2 = torch.mean(f2_reshaped, dim=-1, keepdim=True).unsqueeze(-1)
|
||||
max_2, _ = torch.max(f2_reshaped, dim=-1, keepdim=True)
|
||||
max_2 = max_2.unsqueeze(-1)
|
||||
avg_2 = F.relu(self.avg2(avg_2))
|
||||
max_2 = F.relu(self.max2(max_2))
|
||||
avg_2 = self.avg22(avg_2).squeeze(-1)
|
||||
max_2 = self.max22(max_2).squeeze(-1)
|
||||
a2 = avg_2 + max_2
|
||||
|
||||
cross = torch.matmul(a1, a2.transpose(1, 2))
|
||||
a1 = torch.matmul(F.softmax(cross, dim=-1), f1_reshaped)
|
||||
a2 = torch.matmul(F.softmax(cross.transpose(1, 2), dim=-1), f2_reshaped)
|
||||
|
||||
# Spatial attention for f1
|
||||
a1_spatial = a1.reshape([b, c, h, w])
|
||||
avg_out = torch.mean(a1_spatial, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(a1_spatial, dim=1, keepdim=True)
|
||||
a1_spatial = torch.cat([avg_out, max_out], dim=1)
|
||||
a1_spatial = F.relu(self.conv1_spatial(a1_spatial))
|
||||
a1_spatial = self.conv2_spatial(a1_spatial)
|
||||
a1_spatial = a1_spatial.reshape([b, 1, -1])
|
||||
a1_spatial = F.softmax(a1_spatial, dim=-1)
|
||||
|
||||
# Spatial attention for f2
|
||||
a2_spatial = a2.reshape([b, c, h, w])
|
||||
avg_out = torch.mean(a2_spatial, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(a2_spatial, dim=1, keepdim=True)
|
||||
a2_spatial = torch.cat([avg_out, max_out], dim=1)
|
||||
a2_spatial = F.relu(self.conv1_spatial(a2_spatial))
|
||||
a2_spatial = self.conv2_spatial(a2_spatial)
|
||||
a2_spatial = a2_spatial.reshape([b, 1, -1])
|
||||
a2_spatial = F.softmax(a2_spatial, dim=-1)
|
||||
|
||||
f1_out = f1_reshaped * a1_spatial + f1_reshaped
|
||||
f2_out = f2_reshaped * a2_spatial + f2_reshaped
|
||||
|
||||
# 修改点: 直接返回 [B, C, H, W] 格式
|
||||
return f1_out.reshape(b, c, h, w), f2_out.reshape(b, c, h, w)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Part 2: 全新的 CrossFusedBranches 模块
|
||||
# ===================================================================
|
||||
class CrossFusedBranches(BaseModule):
|
||||
def __init__(self,
|
||||
detail_channels,
|
||||
in_channels,
|
||||
semantic_channels,
|
||||
semantic_expansion_ratio,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
# 1. --- Porting Detail Branch Layers ---
|
||||
# 复制自原始 DetailBranch 的代码,但拆分为独立的 stage 模块
|
||||
self.d_stage1 = nn.Sequential(
|
||||
ConvModule(in_channels, detail_channels[0], 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg),
|
||||
ConvModule(detail_channels[0], detail_channels[0], 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
)
|
||||
self.d_stage2 = nn.Sequential(
|
||||
ConvModule(detail_channels[0], detail_channels[1], 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg),
|
||||
ConvModule(detail_channels[1], detail_channels[1], 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg),
|
||||
ConvModule(detail_channels[1], detail_channels[1], 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
)
|
||||
self.d_stage3 = nn.Sequential(
|
||||
ConvModule(detail_channels[1], detail_channels[2], 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg),
|
||||
ConvModule(detail_channels[2], detail_channels[2], 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg),
|
||||
ConvModule(detail_channels[2], detail_channels[2], 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
)
|
||||
|
||||
# 2. --- Porting Semantic Branch Layers ---
|
||||
# 复制自原始 SemanticBranch 的代码
|
||||
self.s_stage1 = StemBlock(in_channels, semantic_channels[0])
|
||||
self.s_stage2 = nn.Sequential(
|
||||
GELayer(semantic_channels[0], semantic_channels[1], semantic_expansion_ratio, 2),
|
||||
GELayer(semantic_channels[1], semantic_channels[1], semantic_expansion_ratio, 1)
|
||||
)
|
||||
self.s_stage3 = nn.Sequential(
|
||||
GELayer(semantic_channels[1], semantic_channels[2], semantic_expansion_ratio, 2),
|
||||
GELayer(semantic_channels[2], semantic_channels[2], semantic_expansion_ratio, 1)
|
||||
)
|
||||
self.s_stage4 = nn.Sequential(
|
||||
GELayer(semantic_channels[2], semantic_channels[3], semantic_expansion_ratio, 2),
|
||||
GELayer(semantic_channels[3], semantic_channels[3], semantic_expansion_ratio, 1),
|
||||
GELayer(semantic_channels[3], semantic_channels[3], semantic_expansion_ratio, 1),
|
||||
GELayer(semantic_channels[3], semantic_channels[3], semantic_expansion_ratio, 1)
|
||||
)
|
||||
self.s_ce_block = CEBlock(semantic_channels[3], semantic_channels[3])
|
||||
|
||||
# 3. --- Cross Fusion Modules ---
|
||||
# 定义交叉融合点:Detail Stage 2 (d2) 和 Semantic Stage 3 (s3)
|
||||
fusion_channels = detail_channels[1] # e.g., 64
|
||||
self.adapter_d = ConvModule(detail_channels[1], fusion_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.adapter_s = ConvModule(semantic_channels[2], fusion_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.cafm = CAFM(channels=fusion_channels)
|
||||
self.injection_conv = ConvModule(fusion_channels, detail_channels[1], 3, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
semantic_outs = []
|
||||
|
||||
# --- Stage 1 & 2 (Detail and Semantic) ---
|
||||
d1 = self.d_stage1(x)
|
||||
d2 = self.d_stage2(d1)
|
||||
s1 = self.s_stage1(x)
|
||||
semantic_outs.append(s1)
|
||||
s2 = self.s_stage2(s1)
|
||||
semantic_outs.append(s2)
|
||||
s3 = self.s_stage3(s2)
|
||||
semantic_outs.append(s3)
|
||||
|
||||
# --- Cross Fusion Step ---
|
||||
# 1. 对齐空间维度
|
||||
s_for_fusion = resize(s3, size=d2.shape[2:], mode='bilinear', align_corners=False)
|
||||
# 2. 对齐通道维度
|
||||
d_adapted = self.adapter_d(d2)
|
||||
s_adapted = self.adapter_s(s_for_fusion)
|
||||
# 3. 执行 CAFM 融合
|
||||
fused_d, _ = self.cafm(d_adapted, s_adapted) # 我们只使用被语义增强后的细节特征
|
||||
# 4. 将融合特征注入回细节分支
|
||||
fused_processed = self.injection_conv(fused_d)
|
||||
d2_enhanced = d2 + fused_processed
|
||||
|
||||
# --- Final Stages ---
|
||||
detail_final = self.d_stage3(d2_enhanced)
|
||||
s4 = self.s_stage4(s3)
|
||||
semantic_outs.append(s4)
|
||||
s_final = self.s_ce_block(s4)
|
||||
semantic_outs.append(s_final)
|
||||
|
||||
return detail_final, semantic_outs
|
||||
|
||||
|
||||
class StemBlock(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): First feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_first = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.convs = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=out_channels // 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.pool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=False)
|
||||
self.fuse_last = ConvModule(
|
||||
in_channels=out_channels * 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_first(x)
|
||||
x_left = self.convs(x)
|
||||
x_right = self.pool(x)
|
||||
x = self.fuse_last(torch.cat([x_left, x_right], dim=1))
|
||||
return x
|
||||
|
||||
|
||||
class GELayer(BaseModule):
|
||||
"""Gather-and-Expansion Layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
stride (int): Stride of GELayer. Default: 1
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Intermediate feature map in
|
||||
Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
exp_ratio=6,
|
||||
stride=1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
mid_channel = in_channels * exp_ratio
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if stride == 1:
|
||||
self.dwconv = nn.Sequential(
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.shortcut = None
|
||||
else:
|
||||
self.dwconv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=mid_channel,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
)
|
||||
self.shortcut = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=norm_cfg,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None,
|
||||
))
|
||||
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.dwconv(x)
|
||||
x = self.conv2(x)
|
||||
if self.shortcut is not None:
|
||||
shortcut = self.shortcut(identity)
|
||||
x = x + shortcut
|
||||
else:
|
||||
x = x + identity
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class CEBlock(BaseModule):
|
||||
"""Context Embedding Block for large receptive filed in Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Last feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.gap = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
build_norm_layer(norm_cfg, self.in_channels)[1])
|
||||
self.conv_gap = ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# Note: in paper here is naive conv2d, no bn-relu
|
||||
self.conv_last = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.gap(x)
|
||||
x = self.conv_gap(x)
|
||||
x = identity + x
|
||||
x = self.conv_last(x)
|
||||
return x
|
||||
|
||||
|
||||
class BGALayer(BaseModule):
|
||||
"""Bilateral Guided Aggregation Layer to fuse the complementary information
|
||||
from both Detail Branch and Semantic Branch.
|
||||
|
||||
Args:
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
output (torch.Tensor): Output feature map for Segment heads.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_channels=128,
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.out_channels = out_channels
|
||||
self.align_corners = align_corners
|
||||
self.detail_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.detail_down = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
|
||||
self.semantic_conv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
self.semantic_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.conv = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
inplace=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
|
||||
def forward(self, x_d, x_s):
|
||||
detail_dwconv = self.detail_dwconv(x_d)
|
||||
detail_down = self.detail_down(x_d)
|
||||
semantic_conv = self.semantic_conv(x_s)
|
||||
semantic_dwconv = self.semantic_dwconv(x_s)
|
||||
semantic_conv = resize(
|
||||
input=semantic_conv,
|
||||
size=detail_dwconv.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
|
||||
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
|
||||
fuse_2 = resize(
|
||||
input=fuse_2,
|
||||
size=fuse_1.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
output = self.conv(fuse_1 + fuse_2)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class My_BiSeNetV2_A2(BaseModule):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
detail_channels=(64, 64, 128),
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
semantic_expansion_ratio=6,
|
||||
bga_channels=128,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_indices = out_indices
|
||||
self.align_corners = align_corners
|
||||
|
||||
# --- 核心修改点 ---
|
||||
# 用一个 CrossFusedBranches 模块替换掉旧的 self.detail 和 self.semantic
|
||||
self.branches = CrossFusedBranches(
|
||||
detail_channels=detail_channels,
|
||||
in_channels=in_channels,
|
||||
semantic_channels=semantic_channels,
|
||||
semantic_expansion_ratio=semantic_expansion_ratio,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg
|
||||
)
|
||||
|
||||
self.bga = BGALayer(bga_channels, self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
# --- 核心修改点 ---
|
||||
# 一次调用即可获得两个分支的最终输出
|
||||
x_detail, x_semantic_lst = self.branches(x)
|
||||
|
||||
# 后续代码保持不变
|
||||
x_head = self.bga(x_detail, x_semantic_lst[-1])
|
||||
outs = [x_head] + x_semantic_lst[:-1]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
522
Seg_All_In_One_MMSeg/mmseg/models/backbones/pidnet.py
Normal file
522
Seg_All_In_One_MMSeg/mmseg/models/backbones/pidnet.py
Normal file
@@ -0,0 +1,522 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.runner import CheckpointLoader
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType
|
||||
from ..utils import DAPPM, PAPPM, BasicBlock, Bottleneck
|
||||
|
||||
|
||||
class PagFM(BaseModule):
|
||||
"""Pixel-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
channels (int): The number of channels.
|
||||
after_relu (bool): Whether to use ReLU before attention.
|
||||
Default: False.
|
||||
with_channel (bool): Whether to use channel attention.
|
||||
Default: False.
|
||||
upsample_mode (str): The mode of upsample. Default: 'bilinear'.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(typ='ReLU', inplace=True).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
after_relu: bool = False,
|
||||
with_channel: bool = False,
|
||||
upsample_mode: str = 'bilinear',
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(typ='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.after_relu = after_relu
|
||||
self.with_channel = with_channel
|
||||
self.upsample_mode = upsample_mode
|
||||
self.f_i = ConvModule(
|
||||
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
self.f_p = ConvModule(
|
||||
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
if with_channel:
|
||||
self.up = ConvModule(
|
||||
channels, in_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
if after_relu:
|
||||
self.relu = MODELS.build(act_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with pixel-attention-guided fusion.
|
||||
"""
|
||||
if self.after_relu:
|
||||
x_p = self.relu(x_p)
|
||||
x_i = self.relu(x_i)
|
||||
|
||||
f_i = self.f_i(x_i)
|
||||
f_i = F.interpolate(
|
||||
f_i,
|
||||
size=x_p.shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=False)
|
||||
|
||||
f_p = self.f_p(x_p)
|
||||
|
||||
if self.with_channel:
|
||||
sigma = torch.sigmoid(self.up(f_p * f_i))
|
||||
else:
|
||||
sigma = torch.sigmoid(torch.sum(f_p * f_i, dim=1).unsqueeze(1))
|
||||
|
||||
x_i = F.interpolate(
|
||||
x_i,
|
||||
size=x_p.shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=False)
|
||||
|
||||
out = sigma * x_i + (1 - sigma) * x_p
|
||||
return out
|
||||
|
||||
|
||||
class Bag(BaseModule):
|
||||
"""Boundary-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int): The kernel size of the convolution. Default: 3.
|
||||
padding (int): The padding of the convolution. Default: 1.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: dict(order=('norm', 'act', 'conv')).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
padding: int = 1,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
conv_cfg: OptConfigType = dict(order=('norm', 'act', 'conv')),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=padding,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**conv_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
x_d (Tensor): The featrue map from D branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with boundary-attention-guided fusion.
|
||||
"""
|
||||
sigma = torch.sigmoid(x_d)
|
||||
return self.conv(sigma * x_p + (1 - sigma) * x_i)
|
||||
|
||||
|
||||
class LightBag(BaseModule):
|
||||
"""Light Boundary-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer. Default: None.
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = None,
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.f_p = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.f_i = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
x_d (Tensor): The featrue map from D branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with light boundary-attention-guided
|
||||
fusion.
|
||||
"""
|
||||
sigma = torch.sigmoid(x_d)
|
||||
|
||||
f_p = self.f_p((1 - sigma) * x_i + x_p)
|
||||
f_i = self.f_i(x_i + sigma * x_p)
|
||||
|
||||
return f_p + f_i
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PIDNet(BaseModule):
|
||||
"""PIDNet backbone.
|
||||
|
||||
This backbone is the implementation of `PIDNet: A Real-time Semantic
|
||||
Segmentation Network Inspired from PID Controller
|
||||
<https://arxiv.org/abs/2206.02066>`_.
|
||||
Modified from https://github.com/XuJiacong/PIDNet.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels. Default: 3.
|
||||
channels (int): The number of channels in the stem layer. Default: 64.
|
||||
ppm_channels (int): The number of channels in the PPM layer.
|
||||
Default: 96.
|
||||
num_stem_blocks (int): The number of blocks in the stem layer.
|
||||
Default: 2.
|
||||
num_branch_blocks (int): The number of blocks in the branch layer.
|
||||
Default: 3.
|
||||
align_corners (bool): The align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int = 3,
|
||||
channels: int = 64,
|
||||
ppm_channels: int = 96,
|
||||
num_stem_blocks: int = 2,
|
||||
num_branch_blocks: int = 3,
|
||||
align_corners: bool = False,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg)
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
|
||||
# stem layer
|
||||
self.stem = self._make_stem_layer(in_channels, channels,
|
||||
num_stem_blocks)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# I Branch
|
||||
self.i_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.i_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
in_channels=channels * 2**(i + 1),
|
||||
channels=channels * 8 if i > 0 else channels * 4,
|
||||
num_blocks=num_branch_blocks if i < 2 else 2,
|
||||
stride=2))
|
||||
|
||||
# P Branch
|
||||
self.p_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.p_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
in_channels=channels * 2,
|
||||
channels=channels * 2,
|
||||
num_blocks=num_stem_blocks if i < 2 else 1))
|
||||
self.compression_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.compression_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.pag_1 = PagFM(channels * 2, channels)
|
||||
self.pag_2 = PagFM(channels * 2, channels)
|
||||
|
||||
# D Branch
|
||||
if num_stem_blocks == 2:
|
||||
self.d_branch_layers = nn.ModuleList([
|
||||
self._make_single_layer(BasicBlock, channels * 2, channels),
|
||||
self._make_layer(Bottleneck, channels, channels, 1)
|
||||
])
|
||||
channel_expand = 1
|
||||
spp_module = PAPPM
|
||||
dfm_module = LightBag
|
||||
act_cfg_dfm = None
|
||||
else:
|
||||
self.d_branch_layers = nn.ModuleList([
|
||||
self._make_single_layer(BasicBlock, channels * 2,
|
||||
channels * 2),
|
||||
self._make_single_layer(BasicBlock, channels * 2, channels * 2)
|
||||
])
|
||||
channel_expand = 2
|
||||
spp_module = DAPPM
|
||||
dfm_module = Bag
|
||||
act_cfg_dfm = act_cfg
|
||||
|
||||
self.diff_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * channel_expand,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.diff_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
self.spp = spp_module(
|
||||
channels * 16, ppm_channels, channels * 4, num_scales=5)
|
||||
self.dfm = dfm_module(
|
||||
channels * 4, channels * 4, norm_cfg=norm_cfg, act_cfg=act_cfg_dfm)
|
||||
|
||||
self.d_branch_layers.append(
|
||||
self._make_layer(Bottleneck, channels * 2, channels * 2, 1))
|
||||
|
||||
def _make_stem_layer(self, in_channels: int, channels: int,
|
||||
num_blocks: int) -> nn.Sequential:
|
||||
"""Make stem layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_blocks (int): Number of blocks.
|
||||
|
||||
Returns:
|
||||
nn.Sequential: The stem layer.
|
||||
"""
|
||||
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
|
||||
layers.append(
|
||||
self._make_layer(BasicBlock, channels, channels, num_blocks))
|
||||
layers.append(nn.ReLU())
|
||||
layers.append(
|
||||
self._make_layer(
|
||||
BasicBlock, channels, channels * 2, num_blocks, stride=2))
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_layer(self,
|
||||
block: BasicBlock,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_blocks: int,
|
||||
stride: int = 1) -> nn.Sequential:
|
||||
"""Make layer for PIDNet backbone.
|
||||
Args:
|
||||
block (BasicBlock): Basic block.
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_blocks (int): Number of blocks.
|
||||
stride (int): Stride of the first block. Default: 1.
|
||||
|
||||
Returns:
|
||||
nn.Sequential: The Branch Layer.
|
||||
"""
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != channels * block.expansion:
|
||||
downsample = ConvModule(
|
||||
in_channels,
|
||||
channels * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
layers = [block(in_channels, channels, stride, downsample)]
|
||||
in_channels = channels * block.expansion
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
in_channels,
|
||||
channels,
|
||||
stride=1,
|
||||
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_single_layer(self,
|
||||
block: Union[BasicBlock, Bottleneck],
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
stride: int = 1) -> nn.Module:
|
||||
"""Make single layer for PIDNet backbone.
|
||||
Args:
|
||||
block (BasicBlock or Bottleneck): Basic block or Bottleneck.
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
stride (int): Stride of the first block. Default: 1.
|
||||
|
||||
Returns:
|
||||
nn.Module
|
||||
"""
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != channels * block.expansion:
|
||||
downsample = ConvModule(
|
||||
in_channels,
|
||||
channels * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
return block(
|
||||
in_channels, channels, stride, downsample, act_cfg_out=None)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Since the D branch is not initialized by the pre-trained model, we
|
||||
initialize it with the same method as the ResNet.
|
||||
"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
if self.init_cfg is not None:
|
||||
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||
f'specify `Pretrained` in ' \
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], map_location='cpu')
|
||||
self.load_state_dict(ckpt, strict=False)
|
||||
|
||||
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (B, C, H, W).
|
||||
|
||||
Returns:
|
||||
Tensor or tuple[Tensor]: If self.training is True, return
|
||||
tuple[Tensor], else return Tensor.
|
||||
"""
|
||||
w_out = x.shape[-1] // 8
|
||||
h_out = x.shape[-2] // 8
|
||||
|
||||
# stage 0-2
|
||||
x = self.stem(x)
|
||||
|
||||
# stage 3
|
||||
x_i = self.relu(self.i_branch_layers[0](x))
|
||||
x_p = self.p_branch_layers[0](x)
|
||||
x_d = self.d_branch_layers[0](x)
|
||||
|
||||
comp_i = self.compression_1(x_i)
|
||||
x_p = self.pag_1(x_p, comp_i)
|
||||
diff_i = self.diff_1(x_i)
|
||||
x_d += F.interpolate(
|
||||
diff_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_p = x_p.clone()
|
||||
|
||||
# stage 4
|
||||
x_i = self.relu(self.i_branch_layers[1](x_i))
|
||||
x_p = self.p_branch_layers[1](self.relu(x_p))
|
||||
x_d = self.d_branch_layers[1](self.relu(x_d))
|
||||
|
||||
comp_i = self.compression_2(x_i)
|
||||
x_p = self.pag_2(x_p, comp_i)
|
||||
diff_i = self.diff_2(x_i)
|
||||
x_d += F.interpolate(
|
||||
diff_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_d = x_d.clone()
|
||||
|
||||
# stage 5
|
||||
x_i = self.i_branch_layers[2](x_i)
|
||||
x_p = self.p_branch_layers[2](self.relu(x_p))
|
||||
x_d = self.d_branch_layers[2](self.relu(x_d))
|
||||
|
||||
x_i = self.spp(x_i)
|
||||
x_i = F.interpolate(
|
||||
x_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
out = self.dfm(x_p, x_i, x_d)
|
||||
return (temp_p, out, temp_d) if self.training else out
|
||||
318
Seg_All_In_One_MMSeg/mmseg/models/backbones/resnest.py
Normal file
318
Seg_All_In_One_MMSeg/mmseg/models/backbones/resnest.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResNetV1d
|
||||
|
||||
|
||||
class RSoftmax(nn.Module):
|
||||
"""Radix Softmax module in ``SplitAttentionConv2d``.
|
||||
|
||||
Args:
|
||||
radix (int): Radix of input.
|
||||
groups (int): Groups of input.
|
||||
"""
|
||||
|
||||
def __init__(self, radix, groups):
|
||||
super().__init__()
|
||||
self.radix = radix
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
batch = x.size(0)
|
||||
if self.radix > 1:
|
||||
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
|
||||
x = F.softmax(x, dim=1)
|
||||
x = x.reshape(batch, -1)
|
||||
else:
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class SplitAttentionConv2d(nn.Module):
|
||||
"""Split-Attention Conv2d in ResNeSt.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int | tuple[int]): Same as nn.Conv2d.
|
||||
stride (int | tuple[int]): Same as nn.Conv2d.
|
||||
padding (int | tuple[int]): Same as nn.Conv2d.
|
||||
dilation (int | tuple[int]): Same as nn.Conv2d.
|
||||
groups (int): Same as nn.Conv2d.
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels. Default: 4.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
dcn (dict): Config dict for DCN. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None):
|
||||
super().__init__()
|
||||
inter_channels = max(in_channels * radix // reduction_factor, 32)
|
||||
self.radix = radix
|
||||
self.groups = groups
|
||||
self.channels = channels
|
||||
self.with_dcn = dcn is not None
|
||||
self.dcn = dcn
|
||||
fallback_on_stride = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
|
||||
if self.with_dcn and not fallback_on_stride:
|
||||
assert conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
conv_cfg = dcn
|
||||
self.conv = build_conv_layer(
|
||||
conv_cfg,
|
||||
in_channels,
|
||||
channels * radix,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups * radix,
|
||||
bias=False)
|
||||
self.norm0_name, norm0 = build_norm_layer(
|
||||
norm_cfg, channels * radix, postfix=0)
|
||||
self.add_module(self.norm0_name, norm0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc1 = build_conv_layer(
|
||||
None, channels, inter_channels, 1, groups=self.groups)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, inter_channels, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.fc2 = build_conv_layer(
|
||||
None, inter_channels, channels * radix, 1, groups=self.groups)
|
||||
self.rsoftmax = RSoftmax(radix, groups)
|
||||
|
||||
@property
|
||||
def norm0(self):
|
||||
"""nn.Module: the normalization layer named "norm0" """
|
||||
return getattr(self, self.norm0_name)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm0(x)
|
||||
x = self.relu(x)
|
||||
|
||||
batch, rchannel = x.shape[:2]
|
||||
batch = x.size(0)
|
||||
if self.radix > 1:
|
||||
splits = x.view(batch, self.radix, -1, *x.shape[2:])
|
||||
gap = splits.sum(dim=1)
|
||||
else:
|
||||
gap = x
|
||||
gap = F.adaptive_avg_pool2d(gap, 1)
|
||||
gap = self.fc1(gap)
|
||||
|
||||
gap = self.norm1(gap)
|
||||
gap = self.relu(gap)
|
||||
|
||||
atten = self.fc2(gap)
|
||||
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
|
||||
|
||||
if self.radix > 1:
|
||||
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
|
||||
out = torch.sum(attens * splits, dim=1)
|
||||
else:
|
||||
out = atten * x
|
||||
return out.contiguous()
|
||||
|
||||
|
||||
class Bottleneck(_Bottleneck):
|
||||
"""Bottleneck block for ResNeSt.
|
||||
|
||||
Args:
|
||||
inplane (int): Input planes of this block.
|
||||
planes (int): Middle planes of this block.
|
||||
groups (int): Groups of conv2.
|
||||
width_per_group (int): Width per group of conv2. 64x4d indicates
|
||||
``groups=64, width_per_group=4`` and 32x8d indicates
|
||||
``groups=32, width_per_group=8``.
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels in
|
||||
SplitAttentionConv2d. Default: 4.
|
||||
avg_down_stride (bool): Whether to use average pool for stride in
|
||||
Bottleneck. Default: True.
|
||||
kwargs (dict): Key word arguments for base class.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
base_channels=64,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True,
|
||||
**kwargs):
|
||||
"""Bottleneck block for ResNeSt."""
|
||||
super().__init__(inplanes, planes, **kwargs)
|
||||
|
||||
if groups == 1:
|
||||
width = self.planes
|
||||
else:
|
||||
width = math.floor(self.planes *
|
||||
(base_width / base_channels)) * groups
|
||||
|
||||
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=1)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.inplanes,
|
||||
width,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.with_modulated_dcn = False
|
||||
self.conv2 = SplitAttentionConv2d(
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
stride=1 if self.avg_down_stride else self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
dilation=self.dilation,
|
||||
groups=groups,
|
||||
radix=radix,
|
||||
reduction_factor=reduction_factor,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dcn=self.dcn)
|
||||
delattr(self, self.norm2_name)
|
||||
|
||||
if self.avg_down_stride:
|
||||
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
|
||||
|
||||
self.conv3 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
self.planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
||||
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.avg_down_stride:
|
||||
out = self.avd_layer(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.norm3(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNeSt(ResNetV1d):
|
||||
"""ResNeSt backbone.
|
||||
|
||||
This backbone is the implementation of `ResNeSt:
|
||||
Split-Attention Networks <https://arxiv.org/abs/2004.08955>`_.
|
||||
|
||||
Args:
|
||||
groups (int): Number of groups of Bottleneck. Default: 1
|
||||
base_width (int): Base width of Bottleneck. Default: 4
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels in
|
||||
SplitAttentionConv2d. Default: 4.
|
||||
avg_down_stride (bool): Whether to use average pool for stride in
|
||||
Bottleneck. Default: True.
|
||||
kwargs (dict): Keyword arguments for ResNet.
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3)),
|
||||
200: (Bottleneck, (3, 24, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True,
|
||||
**kwargs):
|
||||
self.groups = groups
|
||||
self.base_width = base_width
|
||||
self.radix = radix
|
||||
self.reduction_factor = reduction_factor
|
||||
self.avg_down_stride = avg_down_stride
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``."""
|
||||
return ResLayer(
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
base_channels=self.base_channels,
|
||||
radix=self.radix,
|
||||
reduction_factor=self.reduction_factor,
|
||||
avg_down_stride=self.avg_down_stride,
|
||||
**kwargs)
|
||||
712
Seg_All_In_One_MMSeg/mmseg/models/backbones/resnet.py
Normal file
712
Seg_All_In_One_MMSeg/mmseg/models/backbones/resnet.py
Normal file
@@ -0,0 +1,712 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
|
||||
|
||||
class BasicBlock(BaseModule):
|
||||
"""Basic block for ResNet."""
|
||||
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
style='pytorch',
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None,
|
||||
plugins=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.conv2 = build_conv_layer(
|
||||
conv_cfg, planes, planes, 3, padding=1, bias=False)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.with_cp = with_cp
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: normalization layer after the first convolution layer"""
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: normalization layer after the second convolution layer"""
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.norm2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(BaseModule):
|
||||
"""Bottleneck block for ResNet.
|
||||
|
||||
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
||||
"caffe", the stride-two layer is the first 1x1 conv layer.
|
||||
"""
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
style='pytorch',
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None,
|
||||
plugins=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
assert style in ['pytorch', 'caffe']
|
||||
assert dcn is None or isinstance(dcn, dict)
|
||||
assert plugins is None or isinstance(plugins, list)
|
||||
if plugins is not None:
|
||||
allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
|
||||
assert all(p['position'] in allowed_position for p in plugins)
|
||||
|
||||
self.inplanes = inplanes
|
||||
self.planes = planes
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.style = style
|
||||
self.with_cp = with_cp
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.dcn = dcn
|
||||
self.with_dcn = dcn is not None
|
||||
self.plugins = plugins
|
||||
self.with_plugins = plugins is not None
|
||||
|
||||
if self.with_plugins:
|
||||
# collect plugins for conv1/conv2/conv3
|
||||
self.after_conv1_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv1'
|
||||
]
|
||||
self.after_conv2_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv2'
|
||||
]
|
||||
self.after_conv3_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv3'
|
||||
]
|
||||
|
||||
if self.style == 'pytorch':
|
||||
self.conv1_stride = 1
|
||||
self.conv2_stride = stride
|
||||
else:
|
||||
self.conv1_stride = stride
|
||||
self.conv2_stride = 1
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
norm_cfg, planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
fallback_on_stride = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = dcn.pop('fallback_on_stride', False)
|
||||
if not self.with_dcn or fallback_on_stride:
|
||||
self.conv2 = build_conv_layer(
|
||||
conv_cfg,
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
else:
|
||||
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
self.conv2 = build_conv_layer(
|
||||
dcn,
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.conv3 = build_conv_layer(
|
||||
conv_cfg,
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
if self.with_plugins:
|
||||
self.after_conv1_plugin_names = self.make_block_plugins(
|
||||
planes, self.after_conv1_plugins)
|
||||
self.after_conv2_plugin_names = self.make_block_plugins(
|
||||
planes, self.after_conv2_plugins)
|
||||
self.after_conv3_plugin_names = self.make_block_plugins(
|
||||
planes * self.expansion, self.after_conv3_plugins)
|
||||
|
||||
def make_block_plugins(self, in_channels, plugins):
|
||||
"""make plugins for block.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of plugin.
|
||||
plugins (list[dict]): List of plugins cfg to build.
|
||||
|
||||
Returns:
|
||||
list[str]: List of the names of plugin.
|
||||
"""
|
||||
assert isinstance(plugins, list)
|
||||
plugin_names = []
|
||||
for plugin in plugins:
|
||||
plugin = plugin.copy()
|
||||
name, layer = build_plugin_layer(
|
||||
plugin,
|
||||
in_channels=in_channels,
|
||||
postfix=plugin.pop('postfix', ''))
|
||||
assert not hasattr(self, name), f'duplicate plugin {name}'
|
||||
self.add_module(name, layer)
|
||||
plugin_names.append(name)
|
||||
return plugin_names
|
||||
|
||||
def forward_plugin(self, x, plugin_names):
|
||||
"""Forward function for plugins."""
|
||||
out = x
|
||||
for name in plugin_names:
|
||||
out = getattr(self, name)(x)
|
||||
return out
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: normalization layer after the first convolution layer"""
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: normalization layer after the second convolution layer"""
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
@property
|
||||
def norm3(self):
|
||||
"""nn.Module: normalization layer after the third convolution layer"""
|
||||
return getattr(self, self.norm3_name)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.norm2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.norm3(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNet(BaseModule):
|
||||
"""ResNet backbone.
|
||||
|
||||
This backbone is the improved implementation of `Deep Residual Learning
|
||||
for Image Recognition <https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
stem_channels (int): Number of stem channels. Default: 64.
|
||||
base_channels (int): Number of base channels of res layer. Default: 64.
|
||||
num_stages (int): Resnet stages, normally 4. Default: 4.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
Default: (1, 2, 2, 2).
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
Default: (1, 1, 1, 1).
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer. Default: 'pytorch'.
|
||||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
||||
Default: False.
|
||||
avg_down (bool): Use AvgPool instead of stride conv when
|
||||
downsampling in the bottleneck. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
conv_cfg (dict | None): Dictionary to construct and config conv layer.
|
||||
When conv_cfg is None, cfg will be set to dict(type='Conv2d').
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
dcn (dict | None): Dictionary to construct and config DCN conv layer.
|
||||
When dcn is not None, conv_cfg must be None. Default: None.
|
||||
stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each
|
||||
stage. The length of stage_with_dcn is equal to num_stages.
|
||||
Default: (False, False, False, False).
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
|
||||
- position (str, required): Position inside block to insert plugin,
|
||||
options: 'after_conv1', 'after_conv2', 'after_conv3'.
|
||||
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
Default: None.
|
||||
multi_grid (Sequence[int]|None): Multi grid dilation rates of last
|
||||
stage. Default: None.
|
||||
contract_dilation (bool): Whether contract first dilation of each layer
|
||||
Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity. Default: True.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import ResNet
|
||||
>>> import torch
|
||||
>>> self = ResNet(depth=18)
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 3, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 64, 8, 8)
|
||||
(1, 128, 4, 4)
|
||||
(1, 256, 2, 2)
|
||||
(1, 512, 1, 1)
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
18: (BasicBlock, (2, 2, 2, 2)),
|
||||
34: (BasicBlock, (3, 4, 6, 3)),
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
depth,
|
||||
in_channels=3,
|
||||
stem_channels=64,
|
||||
base_channels=64,
|
||||
num_stages=4,
|
||||
strides=(1, 2, 2, 2),
|
||||
dilations=(1, 1, 1, 1),
|
||||
out_indices=(0, 1, 2, 3),
|
||||
style='pytorch',
|
||||
deep_stem=False,
|
||||
avg_down=False,
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
dcn=None,
|
||||
stage_with_dcn=(False, False, False, False),
|
||||
plugins=None,
|
||||
multi_grid=None,
|
||||
contract_dilation=False,
|
||||
with_cp=False,
|
||||
zero_init_residual=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for resnet')
|
||||
|
||||
self.pretrained = pretrained
|
||||
self.zero_init_residual = zero_init_residual
|
||||
block_init_cfg = None
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
block = self.arch_settings[depth][0]
|
||||
if self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant',
|
||||
val=0,
|
||||
override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant',
|
||||
val=0,
|
||||
override=dict(name='norm3'))
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.depth = depth
|
||||
self.stem_channels = stem_channels
|
||||
self.base_channels = base_channels
|
||||
self.num_stages = num_stages
|
||||
assert num_stages >= 1 and num_stages <= 4
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
assert len(strides) == len(dilations) == num_stages
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < num_stages
|
||||
self.style = style
|
||||
self.deep_stem = deep_stem
|
||||
self.avg_down = avg_down
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.with_cp = with_cp
|
||||
self.norm_eval = norm_eval
|
||||
self.dcn = dcn
|
||||
self.stage_with_dcn = stage_with_dcn
|
||||
if dcn is not None:
|
||||
assert len(stage_with_dcn) == num_stages
|
||||
self.plugins = plugins
|
||||
self.multi_grid = multi_grid
|
||||
self.contract_dilation = contract_dilation
|
||||
self.block, stage_blocks = self.arch_settings[depth]
|
||||
self.stage_blocks = stage_blocks[:num_stages]
|
||||
self.inplanes = stem_channels
|
||||
|
||||
self._make_stem_layer(in_channels, stem_channels)
|
||||
|
||||
self.res_layers = []
|
||||
for i, num_blocks in enumerate(self.stage_blocks):
|
||||
stride = strides[i]
|
||||
dilation = dilations[i]
|
||||
dcn = self.dcn if self.stage_with_dcn[i] else None
|
||||
if plugins is not None:
|
||||
stage_plugins = self.make_stage_plugins(plugins, i)
|
||||
else:
|
||||
stage_plugins = None
|
||||
# multi grid is applied to last layer only
|
||||
stage_multi_grid = multi_grid if i == len(
|
||||
self.stage_blocks) - 1 else None
|
||||
planes = base_channels * 2**i
|
||||
res_layer = self.make_res_layer(
|
||||
block=self.block,
|
||||
inplanes=self.inplanes,
|
||||
planes=planes,
|
||||
num_blocks=num_blocks,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
style=self.style,
|
||||
avg_down=self.avg_down,
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
dcn=dcn,
|
||||
plugins=stage_plugins,
|
||||
multi_grid=stage_multi_grid,
|
||||
contract_dilation=contract_dilation,
|
||||
init_cfg=block_init_cfg)
|
||||
self.inplanes = planes * self.block.expansion
|
||||
layer_name = f'layer{i+1}'
|
||||
self.add_module(layer_name, res_layer)
|
||||
self.res_layers.append(layer_name)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
self.feat_dim = self.block.expansion * base_channels * 2**(
|
||||
len(self.stage_blocks) - 1)
|
||||
|
||||
def make_stage_plugins(self, plugins, stage_idx):
|
||||
"""make plugins for ResNet 'stage_idx'th stage .
|
||||
|
||||
Currently we support to insert 'context_block',
|
||||
'empirical_attention_block', 'nonlocal_block' into the backbone like
|
||||
ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
|
||||
Bottleneck.
|
||||
|
||||
An example of plugins format could be :
|
||||
>>> plugins=[
|
||||
... dict(cfg=dict(type='xxx', arg1='xxx'),
|
||||
... stages=(False, True, True, True),
|
||||
... position='after_conv2'),
|
||||
... dict(cfg=dict(type='yyy'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3'),
|
||||
... dict(cfg=dict(type='zzz', postfix='1'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3'),
|
||||
... dict(cfg=dict(type='zzz', postfix='2'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3')
|
||||
... ]
|
||||
>>> self = ResNet(depth=18)
|
||||
>>> stage_plugins = self.make_stage_plugins(plugins, 0)
|
||||
>>> assert len(stage_plugins) == 3
|
||||
|
||||
Suppose 'stage_idx=0', the structure of blocks in the stage would be:
|
||||
conv1-> conv2->conv3->yyy->zzz1->zzz2
|
||||
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
|
||||
conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
|
||||
|
||||
If stages is missing, the plugin would be applied to all stages.
|
||||
|
||||
Args:
|
||||
plugins (list[dict]): List of plugins cfg to build. The postfix is
|
||||
required if multiple same type plugins are inserted.
|
||||
stage_idx (int): Index of stage to build
|
||||
|
||||
Returns:
|
||||
list[dict]: Plugins for current stage
|
||||
"""
|
||||
stage_plugins = []
|
||||
for plugin in plugins:
|
||||
plugin = plugin.copy()
|
||||
stages = plugin.pop('stages', None)
|
||||
assert stages is None or len(stages) == self.num_stages
|
||||
# whether to insert plugin into current stage
|
||||
if stages is None or stages[stage_idx]:
|
||||
stage_plugins.append(plugin)
|
||||
|
||||
return stage_plugins
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``."""
|
||||
return ResLayer(**kwargs)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def _make_stem_layer(self, in_channels, stem_channels):
|
||||
"""Make stem layer for ResNet."""
|
||||
if self.deep_stem:
|
||||
self.stem = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
stem_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
||||
nn.ReLU(inplace=True),
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
stem_channels // 2,
|
||||
stem_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
||||
nn.ReLU(inplace=True),
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
stem_channels // 2,
|
||||
stem_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels)[1],
|
||||
nn.ReLU(inplace=True))
|
||||
else:
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
stem_channels,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, stem_channels, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze stages param and norm stats."""
|
||||
if self.frozen_stages >= 0:
|
||||
if self.deep_stem:
|
||||
self.stem.eval()
|
||||
for param in self.stem.parameters():
|
||||
param.requires_grad = False
|
||||
else:
|
||||
self.norm1.eval()
|
||||
for m in [self.conv1, self.norm1]:
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = getattr(self, f'layer{i}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
if self.deep_stem:
|
||||
x = self.stem(x)
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.res_layers):
|
||||
res_layer = getattr(self, layer_name)
|
||||
x = res_layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep normalization layer
|
||||
freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNetV1c(ResNet):
|
||||
"""ResNetV1c variant described in [1]_.
|
||||
|
||||
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in
|
||||
the input stem with three 3x3 convs. For more details please refer to `Bag
|
||||
of Tricks for Image Classification with Convolutional Neural Networks
|
||||
<https://arxiv.org/abs/1812.01187>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(deep_stem=True, avg_down=False, **kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNetV1d(ResNet):
|
||||
"""ResNetV1d variant described in [1]_.
|
||||
|
||||
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
|
||||
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
|
||||
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(deep_stem=True, avg_down=True, **kwargs)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user