first commit
This commit is contained in:
1
Seg_All_In_One_MMSeg/tests/__init__.py
Normal file
1
Seg_All_In_One_MMSeg/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
60
Seg_All_In_One_MMSeg/tests/test_apis/test_inferencer.py
Normal file
60
Seg_All_In_One_MMSeg/tests/test_apis/test_inferencer.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import ConfigDict
|
||||
from utils import * # noqa: F401, F403
|
||||
|
||||
from mmseg.apis import MMSegInferencer
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def test_inferencer():
|
||||
register_all_modules()
|
||||
|
||||
visualizer = dict(
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
name='visualizer')
|
||||
|
||||
cfg_dict = dict(
|
||||
model=dict(
|
||||
type='InferExampleModel',
|
||||
data_preprocessor=dict(type='SegDataPreProcessor'),
|
||||
backbone=dict(type='InferExampleBackbone'),
|
||||
decode_head=dict(type='InferExampleHead'),
|
||||
test_cfg=dict(mode='whole')),
|
||||
visualizer=visualizer,
|
||||
test_dataloader=dict(
|
||||
dataset=dict(
|
||||
type='ExampleDataset',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]), ))
|
||||
cfg = ConfigDict(cfg_dict)
|
||||
model = MODELS.build(cfg.model)
|
||||
|
||||
ckpt = model.state_dict()
|
||||
ckpt_filename = tempfile.mktemp()
|
||||
torch.save(ckpt, ckpt_filename)
|
||||
|
||||
# test initialization
|
||||
infer = MMSegInferencer(cfg, ckpt_filename)
|
||||
|
||||
# test forward
|
||||
img = np.random.randint(0, 256, (4, 4, 3))
|
||||
infer(img)
|
||||
|
||||
imgs = [img, img]
|
||||
infer(imgs)
|
||||
results = infer(imgs, out_dir=tempfile.gettempdir())
|
||||
|
||||
# test results
|
||||
assert 'predictions' in results
|
||||
assert 'visualization' in results
|
||||
assert len(results['predictions']) == 2
|
||||
assert results['predictions'][0].shape == (4, 4)
|
||||
73
Seg_All_In_One_MMSeg/tests/test_apis/test_rs_inferencer.py
Normal file
73
Seg_All_In_One_MMSeg/tests/test_apis/test_rs_inferencer.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
from mmengine import ConfigDict, init_default_scope
|
||||
from utils import * # noqa: F401, F403
|
||||
|
||||
from mmseg.apis import RSImage, RSInferencer
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class TestRSImage(TestCase):
|
||||
|
||||
def test_read_whole_image(self):
|
||||
init_default_scope('mmseg')
|
||||
img_path = osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_loveda_dataset/img_dir/0.png')
|
||||
rs_image = RSImage(img_path)
|
||||
window_size = (16, 16)
|
||||
rs_image.create_grids(window_size)
|
||||
image_data = rs_image.read(rs_image.grids[0])
|
||||
self.assertIsNotNone(image_data)
|
||||
|
||||
def test_write_image_data(self):
|
||||
init_default_scope('mmseg')
|
||||
img_path = osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_loveda_dataset/img_dir/0.png')
|
||||
rs_image = RSImage(img_path)
|
||||
window_size = (16, 16)
|
||||
rs_image.create_grids(window_size)
|
||||
data = np.random.random((16, 16)).astype(np.int8)
|
||||
rs_image.write(data, rs_image.grids[0])
|
||||
|
||||
|
||||
class TestRSInferencer(TestCase):
|
||||
|
||||
def test_read_and_inference(self):
|
||||
init_default_scope('mmseg')
|
||||
cfg_dict = dict(
|
||||
model=dict(
|
||||
type='InferExampleModel',
|
||||
data_preprocessor=dict(type='SegDataPreProcessor'),
|
||||
backbone=dict(type='InferExampleBackbone'),
|
||||
decode_head=dict(type='InferExampleHead'),
|
||||
test_cfg=dict(mode='whole')),
|
||||
test_dataloader=dict(
|
||||
dataset=dict(
|
||||
type='ExampleDataset',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
])),
|
||||
test_pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
])
|
||||
cfg = ConfigDict(cfg_dict)
|
||||
model = MODELS.build(cfg.model)
|
||||
model.cfg = cfg
|
||||
inferencer = RSInferencer.from_model(model)
|
||||
|
||||
img_path = osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_loveda_dataset/img_dir/0.png')
|
||||
rs_image = RSImage(img_path)
|
||||
window_size = (16, 16)
|
||||
stride = (16, 16)
|
||||
inferencer.run(rs_image, window_size, stride)
|
||||
38
Seg_All_In_One_MMSeg/tests/test_apis/utils.py
Normal file
38
Seg_All_In_One_MMSeg/tests/test_apis/utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
|
||||
from mmseg.models import EncoderDecoder
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleHead')
|
||||
class ExampleDecodeHead(BaseDecodeHead):
|
||||
|
||||
def __init__(self, num_classes=19, out_channels=None):
|
||||
super().__init__(
|
||||
3, 3, num_classes=num_classes, out_channels=out_channels)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.cls_seg(inputs[0])
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleBackbone')
|
||||
class ExampleBackbone(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
return [self.conv(x)]
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleModel')
|
||||
class ExampleModel(EncoderDecoder):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
175
Seg_All_In_One_MMSeg/tests/test_config.py
Normal file
175
Seg_All_In_One_MMSeg/tests/test_config.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import glob
|
||||
import os
|
||||
from os.path import dirname, exists, isdir, join, relpath
|
||||
|
||||
import numpy as np
|
||||
from mmengine import Config
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.registry import init_default_scope
|
||||
from torch import nn
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
|
||||
|
||||
def _get_config_directory():
|
||||
"""Find the predefined segmentor config directory."""
|
||||
try:
|
||||
# Assume we are running in the source mmsegmentation repo
|
||||
repo_dpath = dirname(dirname(__file__))
|
||||
except NameError:
|
||||
# For IPython development when this __file__ is not defined
|
||||
import mmseg
|
||||
repo_dpath = dirname(dirname(mmseg.__file__))
|
||||
config_dpath = join(repo_dpath, 'configs')
|
||||
if not exists(config_dpath):
|
||||
raise Exception('Cannot find config path')
|
||||
return config_dpath
|
||||
|
||||
|
||||
def test_config_build_segmentor():
|
||||
"""Test that all segmentation models defined in the configs can be
|
||||
initialized."""
|
||||
init_default_scope('mmseg')
|
||||
config_dpath = _get_config_directory()
|
||||
print(f'Found config_dpath = {config_dpath!r}')
|
||||
|
||||
config_fpaths = []
|
||||
# one config each sub folder
|
||||
for sub_folder in os.listdir(config_dpath):
|
||||
if isdir(sub_folder):
|
||||
config_fpaths.append(
|
||||
list(glob.glob(join(config_dpath, sub_folder, '*.py')))[0])
|
||||
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
||||
config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
||||
|
||||
print(f'Using {len(config_names)} config files')
|
||||
|
||||
for config_fname in config_names:
|
||||
config_fpath = join(config_dpath, config_fname)
|
||||
config_mod = Config.fromfile(config_fpath)
|
||||
|
||||
config_mod.model
|
||||
print(f'Building segmentor, config_fpath = {config_fpath!r}')
|
||||
|
||||
# Remove pretrained keys to allow for testing in an offline environment
|
||||
if 'pretrained' in config_mod.model:
|
||||
config_mod.model['pretrained'] = None
|
||||
|
||||
print(f'building {config_fname}')
|
||||
segmentor = build_segmentor(config_mod.model)
|
||||
assert segmentor is not None
|
||||
|
||||
head_config = config_mod.model['decode_head']
|
||||
_check_decode_head(head_config, segmentor.decode_head)
|
||||
|
||||
|
||||
def test_config_data_pipeline():
|
||||
"""Test whether the data pipeline is valid and can process corner cases.
|
||||
|
||||
CommandLine:
|
||||
xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
||||
"""
|
||||
|
||||
init_default_scope('mmseg')
|
||||
config_dpath = _get_config_directory()
|
||||
print(f'Found config_dpath = {config_dpath!r}')
|
||||
|
||||
import glob
|
||||
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
|
||||
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
||||
config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
||||
|
||||
print(f'Using {len(config_names)} config files')
|
||||
|
||||
for config_fname in config_names:
|
||||
config_fpath = join(config_dpath, config_fname)
|
||||
print(f'Building data pipeline, config_fpath = {config_fpath!r}')
|
||||
config_mod = Config.fromfile(config_fpath)
|
||||
|
||||
# remove loading pipeline
|
||||
load_img_pipeline = config_mod.train_pipeline.pop(0)
|
||||
to_float32 = load_img_pipeline.get('to_float32', False)
|
||||
del config_mod.train_pipeline[0]
|
||||
del config_mod.test_pipeline[0]
|
||||
# remove loading annotation in test pipeline
|
||||
load_anno_idx = -1
|
||||
for i in range(len(config_mod.test_pipeline)):
|
||||
if config_mod.test_pipeline[i].type in ('LoadAnnotations',
|
||||
'LoadDepthAnnotation'):
|
||||
load_anno_idx = i
|
||||
del config_mod.test_pipeline[load_anno_idx]
|
||||
|
||||
train_pipeline = Compose(config_mod.train_pipeline)
|
||||
test_pipeline = Compose(config_mod.test_pipeline)
|
||||
|
||||
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
|
||||
if to_float32:
|
||||
img = img.astype(np.float32)
|
||||
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
|
||||
depth = np.random.rand(1024, 2048).astype(np.float32)
|
||||
|
||||
results = dict(
|
||||
filename='test_img.png',
|
||||
ori_filename='test_img.png',
|
||||
img=img,
|
||||
img_shape=img.shape,
|
||||
ori_shape=img.shape,
|
||||
gt_seg_map=seg,
|
||||
gt_depth_map=depth)
|
||||
results['seg_fields'] = ['gt_seg_map']
|
||||
_check_concat_cd_input(config_mod, results)
|
||||
print(f'Test training data pipeline: \n{train_pipeline!r}')
|
||||
output_results = train_pipeline(results)
|
||||
assert output_results is not None
|
||||
|
||||
_check_concat_cd_input(config_mod, results)
|
||||
print(f'Test testing data pipeline: \n{test_pipeline!r}')
|
||||
output_results = test_pipeline(results)
|
||||
assert output_results is not None
|
||||
|
||||
|
||||
def _check_concat_cd_input(config_mod: Config, results: dict):
|
||||
keys = []
|
||||
pipeline = config_mod.train_pipeline.copy()
|
||||
pipeline.extend(config_mod.test_pipeline)
|
||||
for t in pipeline:
|
||||
keys.append(t.type)
|
||||
if 'ConcatCDInput' in keys:
|
||||
results.update({'img2': results['img']})
|
||||
|
||||
|
||||
def _check_decode_head(decode_head_cfg, decode_head):
|
||||
if isinstance(decode_head_cfg, list):
|
||||
assert isinstance(decode_head, nn.ModuleList)
|
||||
assert len(decode_head_cfg) == len(decode_head)
|
||||
num_heads = len(decode_head)
|
||||
for i in range(num_heads):
|
||||
_check_decode_head(decode_head_cfg[i], decode_head[i])
|
||||
return
|
||||
# check consistency between head_config and roi_head
|
||||
assert decode_head_cfg['type'] == decode_head.__class__.__name__
|
||||
|
||||
assert decode_head_cfg['type'] == decode_head.__class__.__name__
|
||||
|
||||
in_channels = decode_head_cfg.in_channels
|
||||
input_transform = decode_head.input_transform
|
||||
assert input_transform in ['resize_concat', 'multiple_select', None]
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(decode_head.in_index, (list, tuple))
|
||||
assert len(in_channels) == len(decode_head.in_index)
|
||||
elif input_transform == 'resize_concat':
|
||||
assert sum(in_channels) == decode_head.in_channels
|
||||
else:
|
||||
assert in_channels == decode_head.in_channels
|
||||
|
||||
if decode_head_cfg['type'] == 'PointHead':
|
||||
assert decode_head_cfg.channels+decode_head_cfg.num_classes == \
|
||||
decode_head.fc_seg.in_channels
|
||||
assert decode_head.fc_seg.out_channels == decode_head_cfg.num_classes
|
||||
elif decode_head_cfg['type'] == 'VPDDepthHead':
|
||||
assert decode_head.out_channels == 1
|
||||
else:
|
||||
assert decode_head_cfg.channels == decode_head.conv_seg.in_channels
|
||||
assert decode_head.conv_seg.out_channels == decode_head_cfg.num_classes
|
||||
475
Seg_All_In_One_MMSeg/tests/test_datasets/test_dataset.py
Normal file
475
Seg_All_In_One_MMSeg/tests/test_datasets/test_dataset.py
Normal file
@@ -0,0 +1,475 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from mmseg.datasets import (ADE20KDataset, BaseSegDataset, BDD100KDataset,
|
||||
CityscapesDataset, COCOStuffDataset,
|
||||
DecathlonDataset, DSDLSegDataset, ISPRSDataset,
|
||||
LIPDataset, LoveDADataset, MapillaryDataset_v1,
|
||||
MapillaryDataset_v2, NYUDataset, PascalVOCDataset,
|
||||
PotsdamDataset, REFUGEDataset, SynapseDataset,
|
||||
iSAIDDataset)
|
||||
from mmseg.registry import DATASETS
|
||||
from mmseg.utils import get_classes, get_palette
|
||||
|
||||
try:
|
||||
from dsdl.dataset import DSDLDataset
|
||||
except ImportError:
|
||||
DSDLDataset = None
|
||||
|
||||
|
||||
def test_classes():
|
||||
assert list(
|
||||
CityscapesDataset.METAINFO['classes']) == get_classes('cityscapes')
|
||||
assert list(PascalVOCDataset.METAINFO['classes']) == get_classes(
|
||||
'voc') == get_classes('pascal_voc')
|
||||
assert list(ADE20KDataset.METAINFO['classes']) == get_classes(
|
||||
'ade') == get_classes('ade20k')
|
||||
assert list(
|
||||
COCOStuffDataset.METAINFO['classes']) == get_classes('cocostuff')
|
||||
assert list(LoveDADataset.METAINFO['classes']) == get_classes('loveda')
|
||||
assert list(PotsdamDataset.METAINFO['classes']) == get_classes('potsdam')
|
||||
assert list(ISPRSDataset.METAINFO['classes']) == get_classes('vaihingen')
|
||||
assert list(iSAIDDataset.METAINFO['classes']) == get_classes('isaid')
|
||||
assert list(
|
||||
MapillaryDataset_v1.METAINFO['classes']) == get_classes('mapillary_v1')
|
||||
assert list(
|
||||
MapillaryDataset_v2.METAINFO['classes']) == get_classes('mapillary_v2')
|
||||
assert list(BDD100KDataset.METAINFO['classes']) == get_classes('bdd100k')
|
||||
with pytest.raises(ValueError):
|
||||
get_classes('unsupported')
|
||||
|
||||
|
||||
def test_classes_file_path():
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
classes_path = f'{tmp_file.name}.txt'
|
||||
train_pipeline = []
|
||||
kwargs = dict(
|
||||
pipeline=train_pipeline,
|
||||
data_prefix=dict(img_path='./', seg_map_path='./'),
|
||||
metainfo=dict(classes=classes_path))
|
||||
|
||||
# classes.txt with full categories
|
||||
categories = get_classes('cityscapes')
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
dataset = CityscapesDataset(**kwargs)
|
||||
assert list(dataset.metainfo['classes']) == categories
|
||||
assert dataset.label_map is None
|
||||
|
||||
# classes.txt with sub categories
|
||||
categories = ['road', 'sidewalk', 'building']
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
dataset = CityscapesDataset(**kwargs)
|
||||
assert list(dataset.metainfo['classes']) == categories
|
||||
assert dataset.label_map is not None
|
||||
|
||||
# classes.txt with unknown categories
|
||||
categories = ['road', 'sidewalk', 'unknown']
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
CityscapesDataset(**kwargs)
|
||||
|
||||
tmp_file.close()
|
||||
os.remove(classes_path)
|
||||
assert not osp.exists(classes_path)
|
||||
|
||||
|
||||
def test_palette():
|
||||
assert CityscapesDataset.METAINFO['palette'] == get_palette('cityscapes')
|
||||
assert PascalVOCDataset.METAINFO['palette'] == get_palette(
|
||||
'voc') == get_palette('pascal_voc')
|
||||
assert ADE20KDataset.METAINFO['palette'] == get_palette(
|
||||
'ade') == get_palette('ade20k')
|
||||
assert LoveDADataset.METAINFO['palette'] == get_palette('loveda')
|
||||
assert PotsdamDataset.METAINFO['palette'] == get_palette('potsdam')
|
||||
assert COCOStuffDataset.METAINFO['palette'] == get_palette('cocostuff')
|
||||
assert iSAIDDataset.METAINFO['palette'] == get_palette('isaid')
|
||||
assert list(
|
||||
MapillaryDataset_v1.METAINFO['palette']) == get_palette('mapillary_v1')
|
||||
assert list(
|
||||
MapillaryDataset_v2.METAINFO['palette']) == get_palette('mapillary_v2')
|
||||
assert list(BDD100KDataset.METAINFO['palette']) == get_palette('bdd100k')
|
||||
with pytest.raises(ValueError):
|
||||
get_palette('unsupported')
|
||||
|
||||
|
||||
def test_custom_dataset():
|
||||
|
||||
# with 'img_path' and 'seg_map_path' in data_prefix
|
||||
train_dataset = BaseSegDataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||
data_prefix=dict(
|
||||
img_path='imgs/',
|
||||
seg_map_path='gts/',
|
||||
),
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png')
|
||||
assert len(train_dataset) == 5
|
||||
|
||||
# with 'img_path' and 'seg_map_path' in data_prefix and ann_file
|
||||
train_dataset = BaseSegDataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||
data_prefix=dict(
|
||||
img_path='imgs/',
|
||||
seg_map_path='gts/',
|
||||
),
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png',
|
||||
ann_file='splits/train.txt')
|
||||
assert len(train_dataset) == 4
|
||||
|
||||
# no data_root
|
||||
train_dataset = BaseSegDataset(
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png')
|
||||
assert len(train_dataset) == 5
|
||||
|
||||
# with data_root but 'img_path' and 'seg_map_path' in data_prefix are
|
||||
# abs path
|
||||
train_dataset = BaseSegDataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png')
|
||||
assert len(train_dataset) == 5
|
||||
|
||||
# test_mode=True
|
||||
test_dataset = BaseSegDataset(
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
|
||||
img_suffix='img.jpg',
|
||||
test_mode=True,
|
||||
metainfo=dict(classes=('pseudo_class', )))
|
||||
assert len(test_dataset) == 5
|
||||
|
||||
# training data get
|
||||
train_data = train_dataset[0]
|
||||
assert isinstance(train_data, dict)
|
||||
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
|
||||
assert 'seg_map_path' in train_data and osp.isfile(
|
||||
train_data['seg_map_path'])
|
||||
|
||||
# test data get
|
||||
test_data = test_dataset[0]
|
||||
assert isinstance(test_data, dict)
|
||||
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
|
||||
assert 'seg_map_path' in train_data and osp.isfile(
|
||||
train_data['seg_map_path'])
|
||||
|
||||
|
||||
def test_ade():
|
||||
test_dataset = ADE20KDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/imgs')))
|
||||
assert len(test_dataset) == 5
|
||||
|
||||
|
||||
def test_cityscapes():
|
||||
test_dataset = CityscapesDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_cityscapes_dataset/leftImg8bit/val'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_cityscapes_dataset/gtFine/val')))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_loveda():
|
||||
test_dataset = LoveDADataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_loveda_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_loveda_dataset/ann_dir')))
|
||||
assert len(test_dataset) == 3
|
||||
|
||||
|
||||
def test_potsdam():
|
||||
test_dataset = PotsdamDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_potsdam_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_potsdam_dataset/ann_dir')))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_vaihingen():
|
||||
test_dataset = ISPRSDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_vaihingen_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_vaihingen_dataset/ann_dir')))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_synapse():
|
||||
test_dataset = SynapseDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_synapse_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_synapse_dataset/ann_dir')))
|
||||
assert len(test_dataset) == 2
|
||||
|
||||
|
||||
def test_refuge():
|
||||
test_dataset = REFUGEDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_refuge_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_refuge_dataset/ann_dir')))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_isaid():
|
||||
test_dataset = iSAIDDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_isaid_dataset/ann_dir')))
|
||||
assert len(test_dataset) == 2
|
||||
test_dataset = iSAIDDataset(
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_isaid_dataset/ann_dir')),
|
||||
ann_file=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_isaid_dataset/splits/train.txt'))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_decathlon():
|
||||
data_root = osp.join(osp.dirname(__file__), '../data')
|
||||
# test load training dataset
|
||||
test_dataset = DecathlonDataset(
|
||||
pipeline=[], data_root=data_root, ann_file='dataset.json')
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
# test load test dataset
|
||||
test_dataset = DecathlonDataset(
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
ann_file='dataset.json',
|
||||
test_mode=True)
|
||||
assert len(test_dataset) == 3
|
||||
|
||||
|
||||
def test_lip():
|
||||
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_lip_dataset')
|
||||
# train load training dataset
|
||||
train_dataset = LIPDataset(
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='train_images', seg_map_path='train_segmentations'))
|
||||
assert len(train_dataset) == 1
|
||||
|
||||
# test load training dataset
|
||||
test_dataset = LIPDataset(
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='val_images', seg_map_path='val_segmentations'))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_mapillary():
|
||||
test_dataset = MapillaryDataset_v1(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_mapillary_dataset/images'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_mapillary_dataset/v1.2')))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_bdd100k():
|
||||
test_dataset = BDD100KDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_bdd100k_dataset/images/10k/val'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val')))
|
||||
assert len(test_dataset) == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dataset, classes', [
|
||||
('ADE20KDataset', ('wall', 'building')),
|
||||
('CityscapesDataset', ('road', 'sidewalk')),
|
||||
('BaseSegDataset', ('bus', 'car')),
|
||||
('PascalVOCDataset', ('aeroplane', 'bicycle')),
|
||||
])
|
||||
def test_custom_classes_override_default(dataset, classes):
|
||||
|
||||
dataset_class = DATASETS.get(dataset)
|
||||
original_classes = dataset_class.METAINFO.get('classes', None)
|
||||
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
ann_file = tmp_file.name
|
||||
img_path = tempfile.mkdtemp()
|
||||
|
||||
# Test setting classes as a tuple
|
||||
custom_dataset = dataset_class(
|
||||
data_prefix=dict(img_path=img_path),
|
||||
ann_file=ann_file,
|
||||
metainfo=dict(classes=classes),
|
||||
test_mode=True,
|
||||
lazy_init=True)
|
||||
|
||||
assert custom_dataset.metainfo['classes'] != original_classes
|
||||
assert custom_dataset.metainfo['classes'] == classes
|
||||
if not isinstance(custom_dataset, BaseSegDataset):
|
||||
assert isinstance(custom_dataset.label_map, dict)
|
||||
|
||||
# Test setting classes as a list
|
||||
custom_dataset = dataset_class(
|
||||
data_prefix=dict(img_path=img_path),
|
||||
ann_file=ann_file,
|
||||
metainfo=dict(classes=list(classes)),
|
||||
test_mode=True,
|
||||
lazy_init=True)
|
||||
|
||||
assert custom_dataset.metainfo['classes'] != original_classes
|
||||
assert custom_dataset.metainfo['classes'] == list(classes)
|
||||
if not isinstance(custom_dataset, BaseSegDataset):
|
||||
assert isinstance(custom_dataset.label_map, dict)
|
||||
|
||||
# Test overriding not a subset
|
||||
custom_dataset = dataset_class(
|
||||
ann_file=ann_file,
|
||||
data_prefix=dict(img_path=img_path),
|
||||
metainfo=dict(classes=[classes[0]]),
|
||||
test_mode=True,
|
||||
lazy_init=True)
|
||||
|
||||
assert custom_dataset.metainfo['classes'] != original_classes
|
||||
assert custom_dataset.metainfo['classes'] == [classes[0]]
|
||||
if not isinstance(custom_dataset, BaseSegDataset):
|
||||
assert isinstance(custom_dataset.label_map, dict)
|
||||
|
||||
# Test default behavior
|
||||
if dataset_class is BaseSegDataset:
|
||||
with pytest.raises(AssertionError):
|
||||
custom_dataset = dataset_class(
|
||||
ann_file=ann_file,
|
||||
data_prefix=dict(img_path=img_path),
|
||||
metainfo=None,
|
||||
test_mode=True,
|
||||
lazy_init=True)
|
||||
else:
|
||||
custom_dataset = dataset_class(
|
||||
data_prefix=dict(img_path=img_path),
|
||||
ann_file=ann_file,
|
||||
metainfo=None,
|
||||
test_mode=True,
|
||||
lazy_init=True)
|
||||
|
||||
assert custom_dataset.METAINFO['classes'] == original_classes
|
||||
assert custom_dataset.label_map is None
|
||||
|
||||
|
||||
def test_custom_dataset_random_palette_is_generated():
|
||||
dataset = BaseSegDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(img_path=tempfile.mkdtemp()),
|
||||
ann_file=tempfile.mkdtemp(),
|
||||
metainfo=dict(classes=('bus', 'car')),
|
||||
lazy_init=True,
|
||||
test_mode=True)
|
||||
assert len(dataset.metainfo['palette']) == 2
|
||||
for class_color in dataset.metainfo['palette']:
|
||||
assert len(class_color) == 3
|
||||
assert all(x >= 0 and x <= 255 for x in class_color)
|
||||
|
||||
|
||||
def test_custom_dataset_custom_palette():
|
||||
dataset = BaseSegDataset(
|
||||
data_prefix=dict(img_path=tempfile.mkdtemp()),
|
||||
ann_file=tempfile.mkdtemp(),
|
||||
metainfo=dict(
|
||||
classes=('bus', 'car'), palette=[[100, 100, 100], [200, 200,
|
||||
200]]),
|
||||
lazy_init=True,
|
||||
test_mode=True)
|
||||
assert tuple(dataset.metainfo['palette']) == tuple([[100, 100, 100],
|
||||
[200, 200, 200]])
|
||||
# test custom class and palette don't match
|
||||
with pytest.raises(ValueError):
|
||||
dataset = BaseSegDataset(
|
||||
data_prefix=dict(img_path=tempfile.mkdtemp()),
|
||||
ann_file=tempfile.mkdtemp(),
|
||||
metainfo=dict(classes=('bus', 'car'), palette=[[200, 200, 200]]),
|
||||
lazy_init=True)
|
||||
|
||||
|
||||
def test_dsdlseg_dataset():
|
||||
if DSDLDataset is not None:
|
||||
dataset = DSDLSegDataset(
|
||||
data_root='tests/data/dsdl_seg', ann_file='set-train/train.yaml')
|
||||
assert len(dataset) == 3
|
||||
assert len(dataset.metainfo['classes']) == 21
|
||||
else:
|
||||
ImportWarning('Package `dsdl` is not installed.')
|
||||
|
||||
|
||||
def test_nyu_dataset():
|
||||
dataset = NYUDataset(
|
||||
data_root='tests/data/pseudo_nyu_dataset',
|
||||
data_prefix=dict(img_path='images', depth_map_path='annotations'),
|
||||
)
|
||||
assert len(dataset) == 1
|
||||
data = dataset[0]
|
||||
assert data.get('depth_map_path', None) is not None
|
||||
assert data.get('category_id', -1) == 26
|
||||
152
Seg_All_In_One_MMSeg/tests/test_datasets/test_dataset_builder.py
Normal file
152
Seg_All_In_One_MMSeg/tests/test_datasets/test_dataset_builder.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
from mmengine.dataset import ConcatDataset, RepeatDataset
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmseg.datasets import MultiImageMixDataset
|
||||
from mmseg.registry import DATASETS
|
||||
|
||||
init_default_scope('mmseg')
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ToyDataset:
|
||||
|
||||
def __init__(self, cnt=0):
|
||||
self.cnt = cnt
|
||||
|
||||
def __item__(self, idx):
|
||||
return idx
|
||||
|
||||
def __len__(self):
|
||||
return 100
|
||||
|
||||
|
||||
def test_build_dataset():
|
||||
cfg = dict(type='ToyDataset')
|
||||
dataset = DATASETS.build(cfg)
|
||||
assert isinstance(dataset, ToyDataset)
|
||||
assert dataset.cnt == 0
|
||||
dataset = DATASETS.build(cfg, default_args=dict(cnt=1))
|
||||
assert isinstance(dataset, ToyDataset)
|
||||
assert dataset.cnt == 1
|
||||
|
||||
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
|
||||
data_prefix = dict(img_path='imgs/', seg_map_path='gts/')
|
||||
|
||||
# test RepeatDataset
|
||||
cfg = dict(
|
||||
type='BaseSegDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
serialize_data=False)
|
||||
dataset = DATASETS.build(cfg)
|
||||
dataset_repeat = RepeatDataset(dataset=dataset, times=5)
|
||||
assert isinstance(dataset_repeat, RepeatDataset)
|
||||
assert len(dataset_repeat) == 25
|
||||
|
||||
# test ConcatDataset
|
||||
# We use same dir twice for simplicity
|
||||
# with data_prefix.seg_map_path
|
||||
cfg1 = dict(
|
||||
type='BaseSegDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
serialize_data=False)
|
||||
cfg2 = dict(
|
||||
type='BaseSegDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
serialize_data=False)
|
||||
dataset1 = DATASETS.build(cfg1)
|
||||
dataset2 = DATASETS.build(cfg2)
|
||||
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||
assert isinstance(dataset_concat, ConcatDataset)
|
||||
assert len(dataset_concat) == 10
|
||||
|
||||
# test MultiImageMixDataset
|
||||
dataset = MultiImageMixDataset(dataset=dataset_concat, pipeline=[])
|
||||
assert isinstance(dataset, MultiImageMixDataset)
|
||||
assert len(dataset) == 10
|
||||
|
||||
cfg = dict(type='ConcatDataset', datasets=[cfg1, cfg2])
|
||||
|
||||
dataset = MultiImageMixDataset(dataset=cfg, pipeline=[])
|
||||
assert isinstance(dataset, MultiImageMixDataset)
|
||||
assert len(dataset) == 10
|
||||
|
||||
# with data_prefix.seg_map_path, ann_file
|
||||
cfg1 = dict(
|
||||
type='BaseSegDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
ann_file='splits/train.txt',
|
||||
serialize_data=False)
|
||||
cfg2 = dict(
|
||||
type='BaseSegDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
ann_file='splits/val.txt',
|
||||
serialize_data=False)
|
||||
|
||||
dataset1 = DATASETS.build(cfg1)
|
||||
dataset2 = DATASETS.build(cfg2)
|
||||
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||
assert isinstance(dataset_concat, ConcatDataset)
|
||||
assert len(dataset_concat) == 5
|
||||
|
||||
# test mode
|
||||
cfg1 = dict(
|
||||
type='BaseSegDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='imgs/'),
|
||||
test_mode=True,
|
||||
metainfo=dict(classes=('pseudo_class', )),
|
||||
serialize_data=False)
|
||||
cfg2 = dict(
|
||||
type='BaseSegDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='imgs/'),
|
||||
test_mode=True,
|
||||
metainfo=dict(classes=('pseudo_class', )),
|
||||
serialize_data=False)
|
||||
|
||||
dataset1 = DATASETS.build(cfg1)
|
||||
dataset2 = DATASETS.build(cfg2)
|
||||
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||
assert isinstance(dataset_concat, ConcatDataset)
|
||||
assert len(dataset_concat) == 10
|
||||
|
||||
# test mode with ann_files
|
||||
cfg1 = dict(
|
||||
type='BaseSegDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='imgs/'),
|
||||
ann_file='splits/val.txt',
|
||||
test_mode=True,
|
||||
metainfo=dict(classes=('pseudo_class', )),
|
||||
serialize_data=False)
|
||||
cfg2 = dict(
|
||||
type='BaseSegDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='imgs/'),
|
||||
ann_file='splits/val.txt',
|
||||
test_mode=True,
|
||||
metainfo=dict(classes=('pseudo_class', )),
|
||||
serialize_data=False)
|
||||
|
||||
dataset1 = DATASETS.build(cfg1)
|
||||
dataset2 = DATASETS.build(cfg2)
|
||||
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||
assert isinstance(dataset_concat, ConcatDataset)
|
||||
assert len(dataset_concat) == 2
|
||||
61
Seg_All_In_One_MMSeg/tests/test_datasets/test_formatting.py
Normal file
61
Seg_All_In_One_MMSeg/tests/test_datasets/test_formatting.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
from mmseg.datasets.transforms import PackSegInputs
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
class TestPackSegInputs(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Setup the model and optimizer which are used in every test method.
|
||||
|
||||
TestCase calls functions in this order: setUp() -> testMethod() ->
|
||||
tearDown() -> cleanUp()
|
||||
"""
|
||||
data_prefix = osp.join(osp.dirname(__file__), '../../data')
|
||||
img_path = osp.join(data_prefix, 'color.jpg')
|
||||
rng = np.random.RandomState(0)
|
||||
self.results = {
|
||||
'img_path': img_path,
|
||||
'ori_shape': (300, 400),
|
||||
'pad_shape': (600, 800),
|
||||
'img_shape': (600, 800),
|
||||
'scale_factor': 2.0,
|
||||
'flip': False,
|
||||
'flip_direction': 'horizontal',
|
||||
'img_norm_cfg': None,
|
||||
'img': rng.rand(300, 400),
|
||||
'gt_seg_map': rng.rand(300, 400),
|
||||
}
|
||||
self.meta_keys = ('img_path', 'ori_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction')
|
||||
|
||||
def test_transform(self):
|
||||
transform = PackSegInputs(meta_keys=self.meta_keys)
|
||||
results = transform(copy.deepcopy(self.results))
|
||||
self.assertIn('data_samples', results)
|
||||
self.assertIsInstance(results['data_samples'], SegDataSample)
|
||||
self.assertIsInstance(results['data_samples'].gt_sem_seg,
|
||||
BaseDataElement)
|
||||
self.assertEqual(results['data_samples'].ori_shape,
|
||||
results['data_samples'].gt_sem_seg.shape)
|
||||
results = copy.deepcopy(self.results)
|
||||
# test dataset shape is not 2D
|
||||
results['gt_seg_map'] = np.random.rand(3, 300, 400)
|
||||
msg = 'the segmentation map is 2D'
|
||||
with pytest.warns(UserWarning, match=msg):
|
||||
results = transform(results)
|
||||
self.assertEqual(results['data_samples'].ori_shape,
|
||||
results['data_samples'].gt_sem_seg.shape)
|
||||
|
||||
def test_repr(self):
|
||||
transform = PackSegInputs(meta_keys=self.meta_keys)
|
||||
self.assertEqual(
|
||||
repr(transform), f'PackSegInputs(meta_keys={self.meta_keys})')
|
||||
295
Seg_All_In_One_MMSeg/tests/test_datasets/test_loading.py
Normal file
295
Seg_All_In_One_MMSeg/tests/test_datasets/test_loading.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
|
||||
from mmseg.datasets.transforms import LoadAnnotations # noqa
|
||||
from mmseg.datasets.transforms import (LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData,
|
||||
LoadBiomedicalImageFromFile,
|
||||
LoadDepthAnnotation,
|
||||
LoadImageFromNDArray)
|
||||
|
||||
|
||||
class TestLoading:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.data_prefix = osp.join(osp.dirname(__file__), '../data')
|
||||
|
||||
def test_load_img(self):
|
||||
results = dict(img_path=osp.join(self.data_prefix, 'color.jpg'))
|
||||
transform = LoadImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img_path'] == osp.join(self.data_prefix, 'color.jpg')
|
||||
assert results['img'].shape == (288, 512, 3)
|
||||
assert results['img'].dtype == np.uint8
|
||||
assert results['ori_shape'] == results['img'].shape[:2]
|
||||
assert repr(transform) == transform.__class__.__name__ + \
|
||||
"(ignore_empty=False, to_float32=False, color_type='color'," + \
|
||||
" imdecode_backend='cv2', backend_args=None)"
|
||||
|
||||
# to_float32
|
||||
transform = LoadImageFromFile(to_float32=True)
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].dtype == np.float32
|
||||
|
||||
# gray image
|
||||
results = dict(img_path=osp.join(self.data_prefix, 'gray.jpg'))
|
||||
transform = LoadImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].shape == (288, 512, 3)
|
||||
assert results['img'].dtype == np.uint8
|
||||
|
||||
transform = LoadImageFromFile(color_type='unchanged')
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].shape == (288, 512)
|
||||
assert results['img'].dtype == np.uint8
|
||||
|
||||
def test_load_seg(self):
|
||||
seg_path = osp.join(self.data_prefix, 'seg.png')
|
||||
results = dict(
|
||||
seg_map_path=seg_path, reduce_zero_label=True, seg_fields=[])
|
||||
transform = LoadAnnotations()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['gt_seg_map'].shape == (288, 512)
|
||||
assert results['gt_seg_map'].dtype == np.uint8
|
||||
assert repr(transform) == transform.__class__.__name__ + \
|
||||
"(reduce_zero_label=True, imdecode_backend='pillow', " + \
|
||||
'backend_args=None)'
|
||||
|
||||
# reduce_zero_label
|
||||
transform = LoadAnnotations(reduce_zero_label=True)
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['gt_seg_map'].shape == (288, 512)
|
||||
assert results['gt_seg_map'].dtype == np.uint8
|
||||
|
||||
def test_load_seg_custom_classes(self):
|
||||
|
||||
test_img = np.random.rand(10, 10)
|
||||
test_gt = np.zeros_like(test_img)
|
||||
test_gt[2:4, 2:4] = 1
|
||||
test_gt[2:4, 6:8] = 2
|
||||
test_gt[6:8, 2:4] = 3
|
||||
test_gt[6:8, 6:8] = 4
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
img_path = osp.join(tmp_dir.name, 'img.jpg')
|
||||
gt_path = osp.join(tmp_dir.name, 'gt.png')
|
||||
|
||||
mmcv.imwrite(test_img, img_path)
|
||||
mmcv.imwrite(test_gt, gt_path)
|
||||
|
||||
# test only train with label with id 3
|
||||
results = dict(
|
||||
img_path=img_path,
|
||||
seg_map_path=gt_path,
|
||||
label_map={
|
||||
0: 0,
|
||||
1: 0,
|
||||
2: 0,
|
||||
3: 1,
|
||||
4: 0
|
||||
},
|
||||
reduce_zero_label=False,
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_seg_map']
|
||||
|
||||
true_mask = np.zeros_like(gt_array)
|
||||
true_mask[6:8, 2:4] = 1
|
||||
|
||||
assert results['seg_fields'] == ['gt_seg_map']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test only train with label with id 4 and 3
|
||||
results = dict(
|
||||
img_path=osp.join(self.data_prefix, 'color.jpg'),
|
||||
seg_map_path=gt_path,
|
||||
label_map={
|
||||
0: 0,
|
||||
1: 0,
|
||||
2: 0,
|
||||
3: 2,
|
||||
4: 1
|
||||
},
|
||||
reduce_zero_label=False,
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_seg_map']
|
||||
|
||||
true_mask = np.zeros_like(gt_array)
|
||||
true_mask[6:8, 2:4] = 2
|
||||
true_mask[6:8, 6:8] = 1
|
||||
|
||||
assert results['seg_fields'] == ['gt_seg_map']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test with removing a class and reducing zero label simultaneously
|
||||
results = dict(
|
||||
img_path=img_path,
|
||||
seg_map_path=gt_path,
|
||||
# since reduce_zero_label is True, there are only 4 real classes.
|
||||
# if the full set of classes is ["A", "B", "C", "D"], the
|
||||
# following label map simulates the dataset option
|
||||
# classes=["A", "C", "D"] which removes class "B".
|
||||
label_map={
|
||||
0: 0,
|
||||
1: 255, # simulate removing class 1
|
||||
2: 1,
|
||||
3: 2
|
||||
},
|
||||
reduce_zero_label=True, # reduce zero label
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
# reduce zero label
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_seg_map']
|
||||
|
||||
true_mask = np.ones_like(gt_array) * 255 # all zeros get mapped to 255
|
||||
true_mask[2:4, 2:4] = 0 # 1s are reduced to class 0 mapped to class 0
|
||||
true_mask[2:4, 6:8] = 255 # 2s are reduced to class 1 which is removed
|
||||
true_mask[6:8, 2:4] = 1 # 3s are reduced to class 2 mapped to class 1
|
||||
true_mask[6:8, 6:8] = 2 # 4s are reduced to class 3 mapped to class 2
|
||||
|
||||
assert results['seg_fields'] == ['gt_seg_map']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test no custom classes
|
||||
results = dict(
|
||||
img_path=img_path,
|
||||
seg_map_path=gt_path,
|
||||
reduce_zero_label=False,
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_seg_map']
|
||||
|
||||
assert results['seg_fields'] == ['gt_seg_map']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, test_gt)
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
def test_load_image_from_ndarray(self):
|
||||
results = {'img': np.zeros((256, 256, 3), dtype=np.uint8)}
|
||||
transform = LoadImageFromNDArray()
|
||||
results = transform(results)
|
||||
|
||||
assert results['img'].shape == (256, 256, 3)
|
||||
assert results['img'].dtype == np.uint8
|
||||
assert results['img_shape'] == (256, 256)
|
||||
assert results['ori_shape'] == (256, 256)
|
||||
|
||||
# to_float32
|
||||
transform = LoadImageFromNDArray(to_float32=True)
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].dtype == np.float32
|
||||
|
||||
# test repr
|
||||
transform = LoadImageFromNDArray()
|
||||
assert repr(transform) == ('LoadImageFromNDArray('
|
||||
'ignore_empty=False, '
|
||||
'to_float32=False, '
|
||||
"color_type='color', "
|
||||
"imdecode_backend='cv2', "
|
||||
'backend_args=None)')
|
||||
|
||||
def test_load_biomedical_img(self):
|
||||
results = dict(
|
||||
img_path=osp.join(self.data_prefix, 'biomedical.nii.gz'))
|
||||
transform = LoadBiomedicalImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img_path'] == osp.join(self.data_prefix,
|
||||
'biomedical.nii.gz')
|
||||
assert len(results['img'].shape) == 4
|
||||
assert results['img'].dtype == np.float32
|
||||
assert results['ori_shape'] == results['img'].shape[1:]
|
||||
assert repr(transform) == ('LoadBiomedicalImageFromFile('
|
||||
"decode_backend='nifti', "
|
||||
'to_xyz=False, '
|
||||
'to_float32=True, '
|
||||
'backend_args=None)')
|
||||
|
||||
def test_load_biomedical_annotation(self):
|
||||
results = dict(
|
||||
seg_map_path=osp.join(self.data_prefix, 'biomedical_ann.nii.gz'))
|
||||
transform = LoadBiomedicalAnnotation()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert len(results['gt_seg_map'].shape) == 3
|
||||
assert results['gt_seg_map'].dtype == np.float32
|
||||
|
||||
def test_load_biomedical_data(self):
|
||||
input_results = dict(
|
||||
img_path=osp.join(self.data_prefix, 'biomedical.npy'))
|
||||
transform = LoadBiomedicalData(with_seg=True)
|
||||
results = transform(copy.deepcopy(input_results))
|
||||
assert results['img_path'] == osp.join(self.data_prefix,
|
||||
'biomedical.npy')
|
||||
assert results['img'][0].shape == results['gt_seg_map'].shape
|
||||
assert results['img'].dtype == np.float32
|
||||
assert results['ori_shape'] == results['img'].shape[1:]
|
||||
assert repr(transform) == ('LoadBiomedicalData('
|
||||
'with_seg=True, '
|
||||
"decode_backend='numpy', "
|
||||
'to_xyz=False, '
|
||||
'backend_args=None)')
|
||||
|
||||
transform = LoadBiomedicalData(with_seg=False)
|
||||
results = transform(copy.deepcopy(input_results))
|
||||
assert len(results['img'].shape) == 4
|
||||
assert results.get('gt_seg_map') is None
|
||||
assert repr(transform) == ('LoadBiomedicalData('
|
||||
'with_seg=False, '
|
||||
"decode_backend='numpy', "
|
||||
'to_xyz=False, '
|
||||
'backend_args=None)')
|
||||
|
||||
def test_load_depth_annotation(self):
|
||||
input_results = dict(
|
||||
img_path='tests/data/pseudo_nyu_dataset/images/'
|
||||
'bookstore_0001d_00001.jpg',
|
||||
depth_map_path='tests/data/pseudo_nyu_dataset/'
|
||||
'annotations/bookstore_0001d_00001.png',
|
||||
category_id=-1,
|
||||
seg_fields=[])
|
||||
transform = LoadDepthAnnotation(depth_rescale_factor=0.001)
|
||||
results = transform(input_results)
|
||||
assert 'gt_depth_map' in results
|
||||
assert results['gt_depth_map'].shape[:2] == mmcv.imread(
|
||||
input_results['depth_map_path']).shape[:2]
|
||||
assert results['gt_depth_map'].dtype == np.float32
|
||||
assert 'gt_depth_map' in results['seg_fields']
|
||||
1273
Seg_All_In_One_MMSeg/tests/test_datasets/test_transform.py
Normal file
1273
Seg_All_In_One_MMSeg/tests/test_datasets/test_transform.py
Normal file
File diff suppressed because it is too large
Load Diff
131
Seg_All_In_One_MMSeg/tests/test_datasets/test_tta.py
Normal file
131
Seg_All_In_One_MMSeg/tests/test_datasets/test_tta.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import pytest
|
||||
|
||||
from mmseg.datasets.transforms import * # noqa
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
def test_multi_scale_flip_aug():
|
||||
# test exception
|
||||
with pytest.raises(TypeError):
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
TRANSFORMS.build(tta_transform)
|
||||
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[[
|
||||
dict(type='Resize', scale=scale, keep_ratio=False)
|
||||
for scale in [(256, 256), (512, 512), (1024, 1024)]
|
||||
], [dict(type='mmseg.PackSegInputs')]])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
|
||||
results = dict()
|
||||
# (288, 512, 3)
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
results['img'] = img
|
||||
results['ori_shape'] = img.shape
|
||||
results['ori_height'] = img.shape[0]
|
||||
results['ori_width'] = img.shape[1]
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256),
|
||||
(3, 512, 512),
|
||||
(3, 1024, 1024)]
|
||||
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale=scale, keep_ratio=False)
|
||||
for scale in [(256, 256), (512, 512), (1024, 1024)]
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='mmseg.PackSegInputs')]
|
||||
])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results: dict = tta_module(results.copy())
|
||||
assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256),
|
||||
(3, 256, 256),
|
||||
(3, 512, 512),
|
||||
(3, 512, 512),
|
||||
(3, 1024, 1024),
|
||||
(3, 1024, 1024)]
|
||||
assert [
|
||||
data_sample.metainfo['flip']
|
||||
for data_sample in tta_results['data_samples']
|
||||
] == [False, True, False, True, False, True]
|
||||
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[[dict(type='Resize', scale=(512, 512), keep_ratio=False)],
|
||||
[dict(type='mmseg.PackSegInputs')]])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [tta_results['inputs'][0].shape] == [(3, 512, 512)]
|
||||
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[dict(type='Resize', scale=(512, 512), keep_ratio=False)],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='mmseg.PackSegInputs')]
|
||||
])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [img.shape for img in tta_results['inputs']] == [(3, 512, 512),
|
||||
(3, 512, 512)]
|
||||
assert [
|
||||
data_sample.metainfo['flip']
|
||||
for data_sample in tta_results['data_samples']
|
||||
] == [False, True]
|
||||
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=False)
|
||||
for r in [0.5, 1.0, 2.0]
|
||||
], [dict(type='mmseg.PackSegInputs')]])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
|
||||
(3, 288, 512),
|
||||
(3, 576, 1024)]
|
||||
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in [0.5, 1.0, 2.0]
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='mmseg.PackSegInputs')]
|
||||
])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
|
||||
(3, 144, 256),
|
||||
(3, 288, 512),
|
||||
(3, 288, 512),
|
||||
(3, 576, 1024),
|
||||
(3, 576, 1024)]
|
||||
assert [
|
||||
data_sample.metainfo['flip']
|
||||
for data_sample in tta_results['data_samples']
|
||||
] == [False, True, False, True, False, True]
|
||||
21
Seg_All_In_One_MMSeg/tests/test_digit_version.py
Normal file
21
Seg_All_In_One_MMSeg/tests/test_digit_version.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg import digit_version
|
||||
|
||||
|
||||
def test_digit_version():
|
||||
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
|
||||
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
|
||||
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
|
||||
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
|
||||
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
|
||||
assert digit_version('1.0') == digit_version('1.0.0')
|
||||
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
|
||||
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
|
||||
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
|
||||
assert digit_version('1.0.0a') < digit_version('1.0.0b')
|
||||
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
|
||||
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
|
||||
assert digit_version('1.0.0') < digit_version('1.0.0post')
|
||||
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
|
||||
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
|
||||
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)
|
||||
@@ -0,0 +1,300 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# from copyreg import constructor
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.optim.optimizer import build_optim_wrapper
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmseg.engine.optimizers.layer_decay_optimizer_constructor import \
|
||||
LearningRateDecayOptimizerConstructor
|
||||
|
||||
init_default_scope('mmseg')
|
||||
|
||||
base_lr = 1
|
||||
decay_rate = 2
|
||||
base_wd = 0.05
|
||||
weight_decay = 0.05
|
||||
|
||||
expected_stage_wise_lr_wd_convnext = [{
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 128
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 1
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 64
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 64
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 32
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 32
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 16
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 16
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 8
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 8
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 128
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 1
|
||||
}]
|
||||
|
||||
expected_layer_wise_lr_wd_convnext = [{
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 128
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 1
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 64
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 64
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 32
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 32
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 16
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 16
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 2
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 2
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 128
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 1
|
||||
}]
|
||||
|
||||
expected_layer_wise_wd_lr_beit = [{
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 16
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 8
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 8
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 4
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 4
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 2
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 2
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 1
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 1
|
||||
}]
|
||||
|
||||
|
||||
class ToyConvNeXt(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList()
|
||||
for i in range(4):
|
||||
stage = nn.Sequential(ConvModule(3, 4, kernel_size=1, bias=True))
|
||||
self.stages.append(stage)
|
||||
self.norm0 = nn.BatchNorm2d(2)
|
||||
|
||||
# add some variables to meet unit test coverate rate
|
||||
self.cls_token = nn.Parameter(torch.ones(1))
|
||||
self.mask_token = nn.Parameter(torch.ones(1))
|
||||
self.pos_embed = nn.Parameter(torch.ones(1))
|
||||
self.stem_norm = nn.Parameter(torch.ones(1))
|
||||
self.downsample_norm0 = nn.BatchNorm2d(2)
|
||||
self.downsample_norm1 = nn.BatchNorm2d(2)
|
||||
self.downsample_norm2 = nn.BatchNorm2d(2)
|
||||
self.lin = nn.Parameter(torch.ones(1))
|
||||
self.lin.requires_grad = False
|
||||
self.downsample_layers = nn.ModuleList()
|
||||
for _ in range(4):
|
||||
stage = nn.Sequential(nn.Conv2d(3, 4, kernel_size=1, bias=True))
|
||||
self.downsample_layers.append(stage)
|
||||
|
||||
|
||||
class ToyBEiT(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# add some variables to meet unit test coverate rate
|
||||
self.cls_token = nn.Parameter(torch.ones(1))
|
||||
self.patch_embed = nn.Parameter(torch.ones(1))
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(3):
|
||||
layer = nn.Conv2d(3, 3, 1)
|
||||
self.layers.append(layer)
|
||||
|
||||
|
||||
class ToyMAE(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# add some variables to meet unit test coverate rate
|
||||
self.cls_token = nn.Parameter(torch.ones(1))
|
||||
self.patch_embed = nn.Parameter(torch.ones(1))
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(3):
|
||||
layer = nn.Conv2d(3, 3, 1)
|
||||
self.layers.append(layer)
|
||||
|
||||
|
||||
class ToySegmentor(nn.Module):
|
||||
|
||||
def __init__(self, backbone):
|
||||
super().__init__()
|
||||
self.backbone = backbone
|
||||
self.decode_head = nn.Conv2d(2, 2, kernel_size=1, groups=2)
|
||||
|
||||
|
||||
class PseudoDataParallel(nn.Module):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.module = model
|
||||
|
||||
|
||||
class ToyViT(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def check_optimizer_lr_wd(optimizer, gt_lr_wd):
|
||||
assert isinstance(optimizer, torch.optim.AdamW)
|
||||
assert optimizer.defaults['lr'] == base_lr
|
||||
assert optimizer.defaults['weight_decay'] == base_wd
|
||||
param_groups = optimizer.param_groups
|
||||
print(param_groups)
|
||||
assert len(param_groups) == len(gt_lr_wd)
|
||||
for i, param_dict in enumerate(param_groups):
|
||||
assert param_dict['weight_decay'] == gt_lr_wd[i]['weight_decay']
|
||||
assert param_dict['lr_scale'] == gt_lr_wd[i]['lr_scale']
|
||||
assert param_dict['lr_scale'] == param_dict['lr']
|
||||
|
||||
|
||||
def test_learning_rate_decay_optimizer_constructor():
|
||||
|
||||
# Test lr wd for ConvNeXT
|
||||
backbone = ToyConvNeXt()
|
||||
model = PseudoDataParallel(ToySegmentor(backbone))
|
||||
# stagewise decay
|
||||
stagewise_paramwise_cfg = dict(
|
||||
decay_rate=decay_rate, decay_type='stage_wise', num_layers=6)
|
||||
optimizer_cfg = dict(
|
||||
type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=stagewise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_stage_wise_lr_wd_convnext)
|
||||
# layerwise decay
|
||||
layerwise_paramwise_cfg = dict(
|
||||
decay_rate=decay_rate, decay_type='layer_wise', num_layers=6)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=layerwise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_lr_wd_convnext)
|
||||
|
||||
# Test lr wd for BEiT
|
||||
backbone = ToyBEiT()
|
||||
model = PseudoDataParallel(ToySegmentor(backbone))
|
||||
|
||||
layerwise_paramwise_cfg = dict(
|
||||
decay_rate=decay_rate, decay_type='layer_wise', num_layers=3)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=layerwise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_wd_lr_beit)
|
||||
|
||||
# Test invalidation of lr wd for Vit
|
||||
backbone = ToyViT()
|
||||
model = PseudoDataParallel(ToySegmentor(backbone))
|
||||
with pytest.raises(NotImplementedError):
|
||||
optim_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optim_wrapper_cfg, layerwise_paramwise_cfg)
|
||||
optim_constructor(model)
|
||||
with pytest.raises(NotImplementedError):
|
||||
optim_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optim_wrapper_cfg, stagewise_paramwise_cfg)
|
||||
optim_constructor(model)
|
||||
|
||||
# Test lr wd for MAE
|
||||
backbone = ToyMAE()
|
||||
model = PseudoDataParallel(ToySegmentor(backbone))
|
||||
|
||||
layerwise_paramwise_cfg = dict(
|
||||
decay_rate=decay_rate, decay_type='layer_wise', num_layers=3)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=layerwise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_wd_lr_beit)
|
||||
|
||||
|
||||
def test_beit_layer_decay_optimizer_constructor():
|
||||
|
||||
# paramwise_cfg with BEiTExampleModel
|
||||
backbone = ToyBEiT()
|
||||
model = PseudoDataParallel(ToySegmentor(backbone))
|
||||
paramwise_cfg = dict(layer_decay_rate=2, num_layers=3)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
constructor='LayerDecayOptimizerConstructor',
|
||||
paramwise_cfg=paramwise_cfg,
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05))
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
# optimizer = optim_wrapper_builder(model)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_wd_lr_beit)
|
||||
33
Seg_All_In_One_MMSeg/tests/test_engine/test_optimizer.py
Normal file
33
Seg_All_In_One_MMSeg/tests/test_engine/test_optimizer.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.optim import build_optim_wrapper
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param1 = nn.Parameter(torch.ones(1))
|
||||
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
|
||||
self.bn = nn.BatchNorm2d(2)
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
base_lr = 0.01
|
||||
base_wd = 0.0001
|
||||
momentum = 0.9
|
||||
|
||||
|
||||
def test_build_optimizer():
|
||||
model = ExampleModel()
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum))
|
||||
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
|
||||
# test whether optimizer is successfully built from parent.
|
||||
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
|
||||
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.engine.hooks import SegVisualizationHook
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
|
||||
class TestVisualizationHook(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
||||
h = 288
|
||||
w = 512
|
||||
num_class = 2
|
||||
|
||||
SegLocalVisualizer.get_instance('visualizer')
|
||||
SegLocalVisualizer.dataset_meta = dict(
|
||||
classes=('background', 'foreground'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
|
||||
data_sample = SegDataSample()
|
||||
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
|
||||
self.data_batch = [{'data_sample': data_sample}] * 2
|
||||
|
||||
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||
pred_seg_data_sample = SegDataSample()
|
||||
pred_seg_data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
|
||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||
self.outputs = [pred_seg_data_sample] * 2
|
||||
|
||||
def test_after_iter(self):
|
||||
runner = Mock()
|
||||
runner.iter = 1
|
||||
hook = SegVisualizationHook(draw=True, interval=1)
|
||||
hook._after_iter(
|
||||
runner, 1, self.data_batch, self.outputs, mode='train')
|
||||
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='val')
|
||||
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='test')
|
||||
|
||||
def test_after_val_iter(self):
|
||||
runner = Mock()
|
||||
runner.iter = 2
|
||||
hook = SegVisualizationHook(interval=1)
|
||||
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
|
||||
|
||||
hook = SegVisualizationHook(draw=True, interval=1)
|
||||
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
|
||||
|
||||
hook = SegVisualizationHook(
|
||||
draw=True, interval=1, show=True, wait_time=1)
|
||||
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
|
||||
|
||||
def test_after_test_iter(self):
|
||||
runner = Mock()
|
||||
hook = SegVisualizationHook(draw=True, interval=1)
|
||||
assert hook._test_index == 0
|
||||
hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
|
||||
assert hook._test_index == len(self.outputs)
|
||||
@@ -0,0 +1,119 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import shutil
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.evaluation import CityscapesMetric
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
class TestCityscapesMetric(TestCase):
|
||||
|
||||
def _demo_mm_inputs(self,
|
||||
batch_size=1,
|
||||
image_shapes=(3, 128, 256),
|
||||
num_classes=5):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size. Default to 2.
|
||||
image_shapes (List[tuple], Optional): image shape.
|
||||
Default to (3, 64, 64)
|
||||
num_classes (int): number of different classes.
|
||||
Default to 5.
|
||||
"""
|
||||
if isinstance(image_shapes, list):
|
||||
assert len(image_shapes) == batch_size
|
||||
else:
|
||||
image_shapes = [image_shapes] * batch_size
|
||||
|
||||
packed_inputs = []
|
||||
for idx in range(batch_size):
|
||||
image_shape = image_shapes[idx]
|
||||
_, h, w = image_shape
|
||||
|
||||
data_sample = SegDataSample()
|
||||
gt_semantic_seg = np.random.randint(
|
||||
0, num_classes, (1, h, w), dtype=np.uint8)
|
||||
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
|
||||
gt_sem_seg_data = dict(data=gt_semantic_seg)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
data_sample = data_sample.to_dict()
|
||||
data_sample[
|
||||
'seg_map_path'] = 'tests/data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png' # noqa
|
||||
packed_inputs.append(data_sample)
|
||||
|
||||
return packed_inputs
|
||||
|
||||
def _demo_mm_model_output(self,
|
||||
batch_size=1,
|
||||
image_shapes=(3, 128, 256),
|
||||
num_classes=5):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size. Default to 2.
|
||||
image_shapes (List[tuple], Optional): image shape.
|
||||
Default to (3, 64, 64)
|
||||
num_classes (int): number of different classes.
|
||||
Default to 5.
|
||||
"""
|
||||
results_dict = dict()
|
||||
_, h, w = image_shapes
|
||||
seg_logit = torch.randn(batch_size, num_classes, h, w)
|
||||
results_dict['seg_logits'] = seg_logit
|
||||
seg_pred = np.random.randint(
|
||||
0, num_classes, (batch_size, h, w), dtype=np.uint8)
|
||||
seg_pred = torch.LongTensor(seg_pred)
|
||||
results_dict['pred_sem_seg'] = seg_pred
|
||||
|
||||
batch_datasampes = [
|
||||
SegDataSample()
|
||||
for _ in range(results_dict['pred_sem_seg'].shape[0])
|
||||
]
|
||||
for key, value in results_dict.items():
|
||||
for i in range(value.shape[0]):
|
||||
setattr(batch_datasampes[i], key, PixelData(data=value[i]))
|
||||
|
||||
_predictions = []
|
||||
for pred in batch_datasampes:
|
||||
test_data = pred.to_dict()
|
||||
test_data[
|
||||
'img_path'] = 'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png' # noqa
|
||||
_predictions.append(test_data)
|
||||
|
||||
return _predictions
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
|
||||
data_batch = self._demo_mm_inputs(2)
|
||||
predictions = self._demo_mm_model_output(2)
|
||||
data_samples = [
|
||||
dict(**data, **result)
|
||||
for data, result in zip(data_batch, predictions)
|
||||
]
|
||||
# test keep_results should be True when format_only is True
|
||||
with pytest.raises(AssertionError):
|
||||
CityscapesMetric(
|
||||
output_dir='tmp', format_only=True, keep_results=False)
|
||||
|
||||
# test evaluate with cityscape metric
|
||||
metric = CityscapesMetric(output_dir='tmp')
|
||||
metric.process(data_batch, data_samples)
|
||||
res = metric.evaluate(2)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
||||
# test format_only
|
||||
metric = CityscapesMetric(
|
||||
output_dir='tmp', format_only=True, keep_results=True)
|
||||
metric.process(data_batch, data_samples)
|
||||
metric.evaluate(2)
|
||||
assert osp.exists('tmp')
|
||||
assert osp.isfile('tmp/frankfurt_000000_000294_leftImg8bit.png')
|
||||
shutil.rmtree('tmp')
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import shutil
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.evaluation import DepthMetric
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
class TestDepthMetric(TestCase):
|
||||
|
||||
def _demo_mm_inputs(self,
|
||||
batch_size=2,
|
||||
image_shapes=(3, 64, 64),
|
||||
num_classes=5):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size. Default to 2.
|
||||
image_shapes (List[tuple], Optional): image shape.
|
||||
Default to (3, 64, 64)
|
||||
num_classes (int): number of different classes.
|
||||
Default to 5.
|
||||
"""
|
||||
if isinstance(image_shapes, list):
|
||||
assert len(image_shapes) == batch_size
|
||||
else:
|
||||
image_shapes = [image_shapes] * batch_size
|
||||
|
||||
data_samples = []
|
||||
for idx in range(batch_size):
|
||||
image_shape = image_shapes[idx]
|
||||
_, h, w = image_shape
|
||||
|
||||
data_sample = SegDataSample()
|
||||
gt_depth_map = torch.rand((1, h, w)) * 10
|
||||
data_sample.gt_depth_map = PixelData(data=gt_depth_map)
|
||||
|
||||
data_samples.append(data_sample.to_dict())
|
||||
|
||||
return data_samples
|
||||
|
||||
def _demo_mm_model_output(self,
|
||||
data_samples,
|
||||
batch_size=2,
|
||||
image_shapes=(3, 64, 64),
|
||||
num_classes=5):
|
||||
|
||||
_, h, w = image_shapes
|
||||
|
||||
for data_sample in data_samples:
|
||||
data_sample['pred_depth_map'] = dict(data=torch.randn(1, h, w))
|
||||
|
||||
data_sample[
|
||||
'img_path'] = 'tests/data/pseudo_dataset/imgs/00000_img.jpg'
|
||||
return data_samples
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
|
||||
data_samples = self._demo_mm_inputs()
|
||||
data_samples = self._demo_mm_model_output(data_samples)
|
||||
|
||||
depth_metric = DepthMetric()
|
||||
depth_metric.process([0] * len(data_samples), data_samples)
|
||||
res = depth_metric.compute_metrics(depth_metric.results)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
||||
# test save depth map file in output_dir
|
||||
depth_metric = DepthMetric(output_dir='tmp')
|
||||
depth_metric.process([0] * len(data_samples), data_samples)
|
||||
assert osp.exists('tmp')
|
||||
assert osp.isfile('tmp/00000_img.png')
|
||||
shutil.rmtree('tmp')
|
||||
|
||||
# test format_only
|
||||
depth_metric = DepthMetric(output_dir='tmp', format_only=True)
|
||||
depth_metric.process([0] * len(data_samples), data_samples)
|
||||
assert depth_metric.results == []
|
||||
assert osp.exists('tmp')
|
||||
assert osp.isfile('tmp/00000_img.png')
|
||||
shutil.rmtree('tmp')
|
||||
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import shutil
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.evaluation import IoUMetric
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
class TestIoUMetric(TestCase):
|
||||
|
||||
def _demo_mm_inputs(self,
|
||||
batch_size=2,
|
||||
image_shapes=(3, 64, 64),
|
||||
num_classes=5):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size. Default to 2.
|
||||
image_shapes (List[tuple], Optional): image shape.
|
||||
Default to (3, 64, 64)
|
||||
num_classes (int): number of different classes.
|
||||
Default to 5.
|
||||
"""
|
||||
if isinstance(image_shapes, list):
|
||||
assert len(image_shapes) == batch_size
|
||||
else:
|
||||
image_shapes = [image_shapes] * batch_size
|
||||
|
||||
data_samples = []
|
||||
for idx in range(batch_size):
|
||||
image_shape = image_shapes[idx]
|
||||
_, h, w = image_shape
|
||||
|
||||
data_sample = SegDataSample()
|
||||
gt_semantic_seg = np.random.randint(
|
||||
0, num_classes, (1, h, w), dtype=np.uint8)
|
||||
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
|
||||
gt_sem_seg_data = dict(data=gt_semantic_seg)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
data_samples.append(data_sample.to_dict())
|
||||
|
||||
return data_samples
|
||||
|
||||
def _demo_mm_model_output(self,
|
||||
data_samples,
|
||||
batch_size=2,
|
||||
image_shapes=(3, 64, 64),
|
||||
num_classes=5):
|
||||
|
||||
_, h, w = image_shapes
|
||||
|
||||
for data_sample in data_samples:
|
||||
data_sample['seg_logits'] = dict(
|
||||
data=torch.randn(num_classes, h, w))
|
||||
data_sample['pred_sem_seg'] = dict(
|
||||
data=torch.randint(0, num_classes, (1, h, w)))
|
||||
data_sample[
|
||||
'img_path'] = 'tests/data/pseudo_dataset/imgs/00000_img.jpg'
|
||||
return data_samples
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
|
||||
data_samples = self._demo_mm_inputs()
|
||||
data_samples = self._demo_mm_model_output(data_samples)
|
||||
|
||||
iou_metric = IoUMetric(iou_metrics=['mIoU'])
|
||||
iou_metric.dataset_meta = dict(
|
||||
classes=['wall', 'building', 'sky', 'floor', 'tree'],
|
||||
label_map=dict(),
|
||||
reduce_zero_label=False)
|
||||
iou_metric.process([0] * len(data_samples), data_samples)
|
||||
res = iou_metric.evaluate(2)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
||||
# test save segment file in output_dir
|
||||
iou_metric = IoUMetric(iou_metrics=['mIoU'], output_dir='tmp')
|
||||
iou_metric.dataset_meta = dict(
|
||||
classes=['wall', 'building', 'sky', 'floor', 'tree'],
|
||||
label_map=dict(),
|
||||
reduce_zero_label=False)
|
||||
iou_metric.process([0] * len(data_samples), data_samples)
|
||||
assert osp.exists('tmp')
|
||||
assert osp.isfile('tmp/00000_img.png')
|
||||
shutil.rmtree('tmp')
|
||||
|
||||
# test format_only
|
||||
iou_metric = IoUMetric(
|
||||
iou_metrics=['mIoU'], output_dir='tmp', format_only=True)
|
||||
iou_metric.dataset_meta = dict(
|
||||
classes=['wall', 'building', 'sky', 'floor', 'tree'],
|
||||
label_map=dict(),
|
||||
reduce_zero_label=False)
|
||||
iou_metric.process([0] * len(data_samples), data_samples)
|
||||
assert iou_metric.results == []
|
||||
assert osp.exists('tmp')
|
||||
assert osp.isfile('tmp/00000_img.png')
|
||||
shutil.rmtree('tmp')
|
||||
1
Seg_All_In_One_MMSeg/tests/test_models/__init__.py
Normal file
1
Seg_All_In_One_MMSeg/tests/test_models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
from mmseg.models.assigners import HungarianAssigner
|
||||
|
||||
|
||||
class TestHungarianAssigner(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
HungarianAssigner([])
|
||||
|
||||
def test_hungarian_match_assigner(self):
|
||||
assigner = HungarianAssigner([
|
||||
dict(type='ClassificationCost', weight=2.0),
|
||||
dict(type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True),
|
||||
dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0)
|
||||
])
|
||||
num_classes = 3
|
||||
num_masks = 10
|
||||
num_points = 20
|
||||
gt_instances = InstanceData()
|
||||
gt_instances.labels = torch.randint(0, num_classes, (num_classes, ))
|
||||
gt_instances.masks = torch.randint(0, 2, (num_classes, num_points))
|
||||
pred_instances = InstanceData()
|
||||
pred_instances.scores = torch.rand((num_masks, num_classes))
|
||||
pred_instances.masks = torch.rand((num_masks, num_points))
|
||||
|
||||
matched_quiery_inds, matched_label_inds = \
|
||||
assigner.assign(pred_instances, gt_instances)
|
||||
unique_quiery_inds = torch.unique(matched_quiery_inds)
|
||||
unique_label_inds = torch.unique(matched_label_inds)
|
||||
self.assertTrue(len(unique_quiery_inds) == len(matched_quiery_inds))
|
||||
self.assertTrue(
|
||||
torch.equal(unique_label_inds, torch.arange(0, num_classes)))
|
||||
|
||||
def test_cls_match_cost(self):
|
||||
num_classes = 3
|
||||
num_masks = 10
|
||||
gt_instances = InstanceData()
|
||||
gt_instances.labels = torch.randint(0, num_classes, (num_classes, ))
|
||||
pred_instances = InstanceData()
|
||||
pred_instances.scores = torch.rand((num_masks, num_classes))
|
||||
|
||||
# test ClassificationCost
|
||||
assigner = HungarianAssigner(dict(type='ClassificationCost'))
|
||||
matched_quiery_inds, matched_label_inds = \
|
||||
assigner.assign(pred_instances, gt_instances)
|
||||
unique_quiery_inds = torch.unique(matched_quiery_inds)
|
||||
unique_label_inds = torch.unique(matched_label_inds)
|
||||
self.assertTrue(len(unique_quiery_inds) == len(matched_quiery_inds))
|
||||
self.assertTrue(
|
||||
torch.equal(unique_label_inds, torch.arange(0, num_classes)))
|
||||
|
||||
def test_mask_match_cost(self):
|
||||
num_classes = 3
|
||||
num_masks = 10
|
||||
num_points = 20
|
||||
gt_instances = InstanceData()
|
||||
gt_instances.masks = torch.randint(0, 2, (num_classes, num_points))
|
||||
pred_instances = InstanceData()
|
||||
pred_instances.masks = torch.rand((num_masks, num_points))
|
||||
|
||||
# test DiceCost
|
||||
assigner = HungarianAssigner(
|
||||
dict(type='DiceCost', pred_act=True, eps=1.0))
|
||||
assign_result = assigner.assign(pred_instances, gt_instances)
|
||||
self.assertTrue(len(assign_result[0]) == len(assign_result[1]))
|
||||
|
||||
# test CrossEntropyLossCost
|
||||
assigner = HungarianAssigner(
|
||||
dict(type='CrossEntropyLossCost', use_sigmoid=True))
|
||||
assign_result = assigner.assign(pred_instances, gt_instances)
|
||||
self.assertTrue(len(assign_result[0]) == len(assign_result[1]))
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
@@ -0,0 +1,185 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones.beit import BEiT
|
||||
from .utils import check_norm_state
|
||||
|
||||
|
||||
def test_beit_backbone():
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be a string path
|
||||
model = BEiT()
|
||||
model.init_weights(pretrained=0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# img_size must be int or tuple
|
||||
model = BEiT(img_size=512.0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# out_indices must be int ,list or tuple
|
||||
model = BEiT(out_indices=1.)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# The length of img_size tuple must be lower than 3.
|
||||
BEiT(img_size=(224, 224, 224))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Pretrained must be None or Str.
|
||||
BEiT(pretrained=123)
|
||||
|
||||
# Test img_size isinstance tuple
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = BEiT(img_size=(224, ))
|
||||
model.init_weights()
|
||||
model(imgs)
|
||||
|
||||
# Test img_size isinstance tuple
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = BEiT(img_size=(224, 224))
|
||||
model(imgs)
|
||||
|
||||
# Test norm_eval = True
|
||||
model = BEiT(norm_eval=True)
|
||||
model.train()
|
||||
|
||||
# Test BEiT backbone with input size of 224 and patch size of 16
|
||||
model = BEiT()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
# Test qv_bias
|
||||
model = BEiT(qv_bias=False)
|
||||
model.train()
|
||||
|
||||
# Test out_indices = list
|
||||
model = BEiT(out_indices=[2, 4, 8, 12])
|
||||
model.train()
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
# Test image size = (224, 224)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test BEiT backbone with input size of 256 and patch size of 16
|
||||
model = BEiT(img_size=(256, 256))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 256, 256)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 16, 16)
|
||||
|
||||
# Test BEiT backbone with input size of 32 and patch size of 16
|
||||
model = BEiT(img_size=(32, 32))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 32, 32)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 2, 2)
|
||||
|
||||
# Test unbalanced size input image
|
||||
model = BEiT(img_size=(112, 224))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 112, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 7, 14)
|
||||
|
||||
# Test irregular input image
|
||||
model = BEiT(img_size=(234, 345))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 234, 345)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 21)
|
||||
|
||||
# Test init_values=0
|
||||
model = BEiT(init_values=0)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test final norm
|
||||
model = BEiT(final_norm=True)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test patch norm
|
||||
model = BEiT(patch_norm=True)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
|
||||
def test_beit_init():
|
||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||
# Test all combinations of pretrained and init_cfg
|
||||
# pretrained=None, init_cfg=None
|
||||
model = BEiT(pretrained=None, init_cfg=None)
|
||||
assert model.init_cfg is None
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
model = BEiT(
|
||||
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# test resize_rel_pos_embed
|
||||
value = torch.randn(732, 16)
|
||||
ckpt = {
|
||||
'state_dict': {
|
||||
'layers.0.attn.relative_position_index': 0,
|
||||
'layers.0.attn.relative_position_bias_table': value
|
||||
}
|
||||
}
|
||||
model = BEiT(img_size=(512, 512))
|
||||
# If scipy is installed, this AttributeError would not be raised.
|
||||
from mmengine.utils import is_installed
|
||||
if not is_installed('scipy'):
|
||||
with pytest.raises(AttributeError):
|
||||
model.resize_rel_pos_embed(ckpt)
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg=123, whose type is unsupported
|
||||
model = BEiT(pretrained=None, init_cfg=123)
|
||||
with pytest.raises(TypeError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg=None
|
||||
model = BEiT(pretrained=path, init_cfg=None)
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = BEiT(
|
||||
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
with pytest.raises(AssertionError):
|
||||
model = BEiT(pretrained=path, init_cfg=123)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=None
|
||||
with pytest.raises(TypeError):
|
||||
model = BEiT(pretrained=123, init_cfg=None)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = BEiT(
|
||||
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=123, whose type is unsupported
|
||||
with pytest.raises(AssertionError):
|
||||
model = BEiT(pretrained=123, init_cfg=123)
|
||||
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import BiSeNetV1
|
||||
from mmseg.models.backbones.bisenetv1 import (AttentionRefinementModule,
|
||||
ContextPath, FeatureFusionModule,
|
||||
SpatialPath)
|
||||
|
||||
|
||||
def test_bisenetv1_backbone():
|
||||
# Test BiSeNetV1 Standard Forward
|
||||
backbone_cfg = dict(
|
||||
type='ResNet',
|
||||
in_channels=3,
|
||||
depth=18,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 1, 1),
|
||||
strides=(1, 2, 2, 2),
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True)
|
||||
model = BiSeNetV1(in_channels=3, backbone_cfg=backbone_cfg)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 64, 128)
|
||||
feat = model(imgs)
|
||||
|
||||
assert len(feat) == 3
|
||||
# output for segment Head
|
||||
assert feat[0].shape == torch.Size([batch_size, 256, 8, 16])
|
||||
# for auxiliary head 1
|
||||
assert feat[1].shape == torch.Size([batch_size, 128, 8, 16])
|
||||
# for auxiliary head 2
|
||||
assert feat[2].shape == torch.Size([batch_size, 128, 4, 8])
|
||||
|
||||
# Test input with rare shape
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 95, 27)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# BiSeNetV1 spatial path channel constraints.
|
||||
BiSeNetV1(
|
||||
backbone_cfg=backbone_cfg,
|
||||
in_channels=3,
|
||||
spatial_channels=(16, 16, 16))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# BiSeNetV1 context path constraints.
|
||||
BiSeNetV1(
|
||||
backbone_cfg=backbone_cfg,
|
||||
in_channels=3,
|
||||
context_channels=(16, 32, 64, 128))
|
||||
|
||||
|
||||
def test_bisenetv1_spatial_path():
|
||||
with pytest.raises(AssertionError):
|
||||
# BiSeNetV1 spatial path channel constraints.
|
||||
SpatialPath(num_channels=(16, 16, 16), in_channels=3)
|
||||
|
||||
|
||||
def test_bisenetv1_context_path():
|
||||
backbone_cfg = dict(
|
||||
type='ResNet',
|
||||
in_channels=3,
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 1, 1),
|
||||
strides=(1, 2, 2, 2),
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# BiSeNetV1 context path constraints.
|
||||
ContextPath(
|
||||
backbone_cfg=backbone_cfg, context_channels=(16, 32, 64, 128))
|
||||
|
||||
|
||||
def test_bisenetv1_attention_refinement_module():
|
||||
x_arm = AttentionRefinementModule(32, 8)
|
||||
assert x_arm.conv_layer.in_channels == 32
|
||||
assert x_arm.conv_layer.out_channels == 8
|
||||
assert x_arm.conv_layer.kernel_size == (3, 3)
|
||||
x = torch.randn(2, 32, 8, 16)
|
||||
x_out = x_arm(x)
|
||||
assert x_out.shape == torch.Size([2, 8, 8, 16])
|
||||
|
||||
|
||||
def test_bisenetv1_feature_fusion_module():
|
||||
ffm = FeatureFusionModule(16, 32)
|
||||
assert ffm.conv1.in_channels == 16
|
||||
assert ffm.conv1.out_channels == 32
|
||||
assert ffm.conv1.kernel_size == (1, 1)
|
||||
assert ffm.gap.output_size == (1, 1)
|
||||
assert ffm.conv_atten[0].in_channels == 32
|
||||
assert ffm.conv_atten[0].out_channels == 32
|
||||
assert ffm.conv_atten[0].kernel_size == (1, 1)
|
||||
|
||||
ffm = FeatureFusionModule(16, 16)
|
||||
x1 = torch.randn(2, 8, 8, 16)
|
||||
x2 = torch.randn(2, 8, 8, 16)
|
||||
x_out = ffm(x1, x2)
|
||||
assert x_out.shape == torch.Size([2, 16, 8, 16])
|
||||
@@ -0,0 +1,57 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.models.backbones import BiSeNetV2
|
||||
from mmseg.models.backbones.bisenetv2 import (BGALayer, DetailBranch,
|
||||
SemanticBranch)
|
||||
|
||||
|
||||
def test_bisenetv2_backbone():
|
||||
# Test BiSeNetV2 Standard Forward
|
||||
model = BiSeNetV2()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 128, 256)
|
||||
feat = model(imgs)
|
||||
|
||||
assert len(feat) == 5
|
||||
# output for segment Head
|
||||
assert feat[0].shape == torch.Size([batch_size, 128, 16, 32])
|
||||
# for auxiliary head 1
|
||||
assert feat[1].shape == torch.Size([batch_size, 16, 32, 64])
|
||||
# for auxiliary head 2
|
||||
assert feat[2].shape == torch.Size([batch_size, 32, 16, 32])
|
||||
# for auxiliary head 3
|
||||
assert feat[3].shape == torch.Size([batch_size, 64, 8, 16])
|
||||
# for auxiliary head 4
|
||||
assert feat[4].shape == torch.Size([batch_size, 128, 4, 8])
|
||||
|
||||
# Test input with rare shape
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 95, 27)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
|
||||
|
||||
def test_bisenetv2_DetailBranch():
|
||||
x = torch.randn(1, 3, 32, 64)
|
||||
detail_branch = DetailBranch(detail_channels=(64, 16, 32))
|
||||
assert isinstance(detail_branch.detail_branch[0][0], ConvModule)
|
||||
x_out = detail_branch(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 4, 8])
|
||||
|
||||
|
||||
def test_bisenetv2_SemanticBranch():
|
||||
semantic_branch = SemanticBranch(semantic_channels=(16, 32, 64, 128))
|
||||
assert semantic_branch.stage1.pool.stride == 2
|
||||
|
||||
|
||||
def test_bisenetv2_BGALayer():
|
||||
x_a = torch.randn(1, 8, 8, 16)
|
||||
x_b = torch.randn(1, 8, 2, 4)
|
||||
bga = BGALayer(out_channels=8)
|
||||
assert isinstance(bga.conv, ConvModule)
|
||||
x_out = bga(x_a, x_b)
|
||||
assert x_out.shape == torch.Size([1, 8, 8, 16])
|
||||
@@ -0,0 +1,187 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmcv
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer,
|
||||
make_divisible)
|
||||
|
||||
|
||||
def test_make_divisible():
|
||||
# test with min_value = None
|
||||
assert make_divisible(10, 4) == 12
|
||||
assert make_divisible(9, 4) == 12
|
||||
assert make_divisible(1, 4) == 4
|
||||
|
||||
# test with min_value = 8
|
||||
assert make_divisible(10, 4, 8) == 12
|
||||
assert make_divisible(9, 4, 8) == 12
|
||||
assert make_divisible(1, 4, 8) == 8
|
||||
|
||||
|
||||
def test_inv_residual():
|
||||
with pytest.raises(AssertionError):
|
||||
# test stride assertion.
|
||||
InvertedResidual(32, 32, 3, 4)
|
||||
|
||||
# test default config with res connection.
|
||||
# set expand_ratio = 4, stride = 1 and inp=oup.
|
||||
inv_module = InvertedResidual(32, 32, 1, 4)
|
||||
assert inv_module.use_res_connect
|
||||
assert inv_module.conv[0].kernel_size == (1, 1)
|
||||
assert inv_module.conv[0].padding == 0
|
||||
assert inv_module.conv[1].kernel_size == (3, 3)
|
||||
assert inv_module.conv[1].padding == 1
|
||||
assert inv_module.conv[0].with_norm
|
||||
assert inv_module.conv[1].with_norm
|
||||
x = torch.rand(1, 32, 64, 64)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (1, 32, 64, 64)
|
||||
|
||||
# test inv_residual module without res connection.
|
||||
# set expand_ratio = 4, stride = 2.
|
||||
inv_module = InvertedResidual(32, 32, 2, 4)
|
||||
assert not inv_module.use_res_connect
|
||||
assert inv_module.conv[0].kernel_size == (1, 1)
|
||||
x = torch.rand(1, 32, 64, 64)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (1, 32, 32, 32)
|
||||
|
||||
# test expand_ratio == 1
|
||||
inv_module = InvertedResidual(32, 32, 1, 1)
|
||||
assert inv_module.conv[0].kernel_size == (3, 3)
|
||||
x = torch.rand(1, 32, 64, 64)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (1, 32, 64, 64)
|
||||
|
||||
# test with checkpoint forward
|
||||
inv_module = InvertedResidual(32, 32, 1, 1, with_cp=True)
|
||||
assert inv_module.with_cp
|
||||
x = torch.rand(1, 32, 64, 64, requires_grad=True)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (1, 32, 64, 64)
|
||||
|
||||
|
||||
def test_inv_residualv3():
|
||||
with pytest.raises(AssertionError):
|
||||
# test stride assertion.
|
||||
InvertedResidualV3(32, 32, 16, stride=3)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# test assertion.
|
||||
InvertedResidualV3(32, 32, 16, with_expand_conv=False)
|
||||
|
||||
# test with se_cfg=None, with_expand_conv=False
|
||||
inv_module = InvertedResidualV3(32, 32, 32, with_expand_conv=False)
|
||||
|
||||
assert inv_module.with_res_shortcut is True
|
||||
assert inv_module.with_se is False
|
||||
assert inv_module.with_expand_conv is False
|
||||
assert not hasattr(inv_module, 'expand_conv')
|
||||
assert isinstance(inv_module.depthwise_conv.conv, torch.nn.Conv2d)
|
||||
assert inv_module.depthwise_conv.conv.kernel_size == (3, 3)
|
||||
assert inv_module.depthwise_conv.conv.stride == (1, 1)
|
||||
assert inv_module.depthwise_conv.conv.padding == (1, 1)
|
||||
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
|
||||
assert isinstance(inv_module.depthwise_conv.activate, torch.nn.ReLU)
|
||||
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
|
||||
assert inv_module.linear_conv.conv.stride == (1, 1)
|
||||
assert inv_module.linear_conv.conv.padding == (0, 0)
|
||||
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
|
||||
|
||||
x = torch.rand(1, 32, 64, 64)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (1, 32, 64, 64)
|
||||
|
||||
# test with se_cfg and with_expand_conv
|
||||
se_cfg = dict(
|
||||
channels=16,
|
||||
ratio=4,
|
||||
act_cfg=(dict(type='ReLU'),
|
||||
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
|
||||
act_cfg = dict(type='HSwish')
|
||||
inv_module = InvertedResidualV3(
|
||||
32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg)
|
||||
assert inv_module.with_res_shortcut is False
|
||||
assert inv_module.with_se is True
|
||||
assert inv_module.with_expand_conv is True
|
||||
assert inv_module.expand_conv.conv.kernel_size == (1, 1)
|
||||
assert inv_module.expand_conv.conv.stride == (1, 1)
|
||||
assert inv_module.expand_conv.conv.padding == (0, 0)
|
||||
|
||||
assert isinstance(inv_module.depthwise_conv.conv,
|
||||
mmcv.cnn.bricks.Conv2dAdaptivePadding)
|
||||
assert inv_module.depthwise_conv.conv.kernel_size == (3, 3)
|
||||
assert inv_module.depthwise_conv.conv.stride == (2, 2)
|
||||
assert inv_module.depthwise_conv.conv.padding == (0, 0)
|
||||
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
|
||||
|
||||
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
|
||||
assert inv_module.linear_conv.conv.stride == (1, 1)
|
||||
assert inv_module.linear_conv.conv.padding == (0, 0)
|
||||
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
|
||||
|
||||
if (TORCH_VERSION == 'parrots'
|
||||
or digit_version(TORCH_VERSION) < digit_version('1.7')):
|
||||
# Note: Use PyTorch official HSwish
|
||||
# when torch>=1.7 after MMCV >= 1.4.5.
|
||||
# Hardswish is not supported when PyTorch version < 1.6.
|
||||
# And Hardswish in PyTorch 1.6 does not support inplace.
|
||||
# More details could be found from:
|
||||
# https://github.com/open-mmlab/mmcv/pull/1709
|
||||
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
|
||||
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
|
||||
else:
|
||||
assert isinstance(inv_module.expand_conv.activate, torch.nn.Hardswish)
|
||||
assert isinstance(inv_module.depthwise_conv.activate,
|
||||
torch.nn.Hardswish)
|
||||
|
||||
x = torch.rand(1, 32, 64, 64)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (1, 40, 32, 32)
|
||||
|
||||
# test with checkpoint forward
|
||||
inv_module = InvertedResidualV3(
|
||||
32, 40, 16, 3, 2, se_cfg=se_cfg, act_cfg=act_cfg, with_cp=True)
|
||||
assert inv_module.with_cp
|
||||
x = torch.randn(2, 32, 64, 64, requires_grad=True)
|
||||
output = inv_module(x)
|
||||
assert output.shape == (2, 40, 32, 32)
|
||||
|
||||
|
||||
def test_se_layer():
|
||||
with pytest.raises(AssertionError):
|
||||
# test act_cfg assertion.
|
||||
SELayer(32, act_cfg=(dict(type='ReLU'), ))
|
||||
|
||||
# test config with channels = 16.
|
||||
se_layer = SELayer(16)
|
||||
assert se_layer.conv1.conv.kernel_size == (1, 1)
|
||||
assert se_layer.conv1.conv.stride == (1, 1)
|
||||
assert se_layer.conv1.conv.padding == (0, 0)
|
||||
assert isinstance(se_layer.conv1.activate, torch.nn.ReLU)
|
||||
assert se_layer.conv2.conv.kernel_size == (1, 1)
|
||||
assert se_layer.conv2.conv.stride == (1, 1)
|
||||
assert se_layer.conv2.conv.padding == (0, 0)
|
||||
assert isinstance(se_layer.conv2.activate, mmcv.cnn.HSigmoid)
|
||||
|
||||
x = torch.rand(1, 16, 64, 64)
|
||||
output = se_layer(x)
|
||||
assert output.shape == (1, 16, 64, 64)
|
||||
|
||||
# test config with channels = 16, act_cfg = dict(type='ReLU').
|
||||
se_layer = SELayer(16, act_cfg=dict(type='ReLU'))
|
||||
assert se_layer.conv1.conv.kernel_size == (1, 1)
|
||||
assert se_layer.conv1.conv.stride == (1, 1)
|
||||
assert se_layer.conv1.conv.padding == (0, 0)
|
||||
assert isinstance(se_layer.conv1.activate, torch.nn.ReLU)
|
||||
assert se_layer.conv2.conv.kernel_size == (1, 1)
|
||||
assert se_layer.conv2.conv.stride == (1, 1)
|
||||
assert se_layer.conv2.conv.padding == (0, 0)
|
||||
assert isinstance(se_layer.conv2.activate, torch.nn.ReLU)
|
||||
|
||||
x = torch.rand(1, 16, 64, 64)
|
||||
output = se_layer(x)
|
||||
assert output.shape == (1, 16, 64, 64)
|
||||
@@ -0,0 +1,151 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import CGNet
|
||||
from mmseg.models.backbones.cgnet import (ContextGuidedBlock,
|
||||
GlobalContextExtractor)
|
||||
|
||||
|
||||
def test_cgnet_GlobalContextExtractor():
|
||||
block = GlobalContextExtractor(16, 16, with_cp=True)
|
||||
x = torch.randn(2, 16, 64, 64, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([2, 16, 64, 64])
|
||||
|
||||
|
||||
def test_cgnet_context_guided_block():
|
||||
with pytest.raises(AssertionError):
|
||||
# cgnet ContextGuidedBlock GlobalContextExtractor channel and reduction
|
||||
# constraints.
|
||||
ContextGuidedBlock(8, 8)
|
||||
|
||||
# test cgnet ContextGuidedBlock with checkpoint forward
|
||||
block = ContextGuidedBlock(
|
||||
16, 16, act_cfg=dict(type='PReLU'), with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(2, 16, 64, 64, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([2, 16, 64, 64])
|
||||
|
||||
# test cgnet ContextGuidedBlock without checkpoint forward
|
||||
block = ContextGuidedBlock(32, 32)
|
||||
assert not block.with_cp
|
||||
x = torch.randn(3, 32, 32, 32)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([3, 32, 32, 32])
|
||||
|
||||
# test cgnet ContextGuidedBlock with down sampling
|
||||
block = ContextGuidedBlock(32, 32, downsample=True)
|
||||
assert block.conv1x1.conv.in_channels == 32
|
||||
assert block.conv1x1.conv.out_channels == 32
|
||||
assert block.conv1x1.conv.kernel_size == (3, 3)
|
||||
assert block.conv1x1.conv.stride == (2, 2)
|
||||
assert block.conv1x1.conv.padding == (1, 1)
|
||||
|
||||
assert block.f_loc.in_channels == 32
|
||||
assert block.f_loc.out_channels == 32
|
||||
assert block.f_loc.kernel_size == (3, 3)
|
||||
assert block.f_loc.stride == (1, 1)
|
||||
assert block.f_loc.padding == (1, 1)
|
||||
assert block.f_loc.groups == 32
|
||||
assert block.f_loc.dilation == (1, 1)
|
||||
assert block.f_loc.bias is None
|
||||
|
||||
assert block.f_sur.in_channels == 32
|
||||
assert block.f_sur.out_channels == 32
|
||||
assert block.f_sur.kernel_size == (3, 3)
|
||||
assert block.f_sur.stride == (1, 1)
|
||||
assert block.f_sur.padding == (2, 2)
|
||||
assert block.f_sur.groups == 32
|
||||
assert block.f_sur.dilation == (2, 2)
|
||||
assert block.f_sur.bias is None
|
||||
|
||||
assert block.bottleneck.in_channels == 64
|
||||
assert block.bottleneck.out_channels == 32
|
||||
assert block.bottleneck.kernel_size == (1, 1)
|
||||
assert block.bottleneck.stride == (1, 1)
|
||||
assert block.bottleneck.bias is None
|
||||
|
||||
x = torch.randn(1, 32, 32, 32)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 16, 16])
|
||||
|
||||
# test cgnet ContextGuidedBlock without down sampling
|
||||
block = ContextGuidedBlock(32, 32, downsample=False)
|
||||
assert block.conv1x1.conv.in_channels == 32
|
||||
assert block.conv1x1.conv.out_channels == 16
|
||||
assert block.conv1x1.conv.kernel_size == (1, 1)
|
||||
assert block.conv1x1.conv.stride == (1, 1)
|
||||
assert block.conv1x1.conv.padding == (0, 0)
|
||||
|
||||
assert block.f_loc.in_channels == 16
|
||||
assert block.f_loc.out_channels == 16
|
||||
assert block.f_loc.kernel_size == (3, 3)
|
||||
assert block.f_loc.stride == (1, 1)
|
||||
assert block.f_loc.padding == (1, 1)
|
||||
assert block.f_loc.groups == 16
|
||||
assert block.f_loc.dilation == (1, 1)
|
||||
assert block.f_loc.bias is None
|
||||
|
||||
assert block.f_sur.in_channels == 16
|
||||
assert block.f_sur.out_channels == 16
|
||||
assert block.f_sur.kernel_size == (3, 3)
|
||||
assert block.f_sur.stride == (1, 1)
|
||||
assert block.f_sur.padding == (2, 2)
|
||||
assert block.f_sur.groups == 16
|
||||
assert block.f_sur.dilation == (2, 2)
|
||||
assert block.f_sur.bias is None
|
||||
|
||||
x = torch.randn(1, 32, 32, 32)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 32, 32])
|
||||
|
||||
|
||||
def test_cgnet_backbone():
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid num_channels
|
||||
CGNet(num_channels=(32, 64, 128, 256))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid num_blocks
|
||||
CGNet(num_blocks=(3, 21, 3))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid dilation
|
||||
CGNet(num_blocks=2)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid reduction
|
||||
CGNet(reductions=16)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid num_channels and reduction
|
||||
CGNet(num_channels=(32, 64, 128), reductions=(64, 129))
|
||||
|
||||
# Test CGNet with default settings
|
||||
model = CGNet()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size([2, 35, 112, 112])
|
||||
assert feat[1].shape == torch.Size([2, 131, 56, 56])
|
||||
assert feat[2].shape == torch.Size([2, 256, 28, 28])
|
||||
|
||||
# Test CGNet with norm_eval True and with_cp True
|
||||
model = CGNet(norm_eval=True, with_cp=True)
|
||||
with pytest.raises(TypeError):
|
||||
# check invalid pretrained
|
||||
model.init_weights(pretrained=8)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size([2, 35, 112, 112])
|
||||
assert feat[1].shape == torch.Size([2, 131, 56, 56])
|
||||
assert feat[2].shape == torch.Size([2, 256, 28, 28])
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmseg.models.text_encoder import CLIPTextEncoder
|
||||
from mmseg.utils import get_classes
|
||||
|
||||
|
||||
def test_clip_text_encoder():
|
||||
init_default_scope('mmseg')
|
||||
# test vocabulary
|
||||
output_dims = 8
|
||||
embed_dims = 32
|
||||
vocabulary = ['cat', 'dog', 'bird', 'car', 'bike']
|
||||
cfg = dict(
|
||||
vocabulary=vocabulary,
|
||||
templates=['a photo of a {}.'],
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims)
|
||||
cfg = Config(cfg)
|
||||
|
||||
text_encoder = CLIPTextEncoder(**cfg)
|
||||
if torch.cuda.is_available():
|
||||
text_encoder = text_encoder.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
class_embeds = text_encoder()
|
||||
assert class_embeds.shape == (len(vocabulary) + 1, output_dims)
|
||||
|
||||
# test dataset name
|
||||
cfg = dict(
|
||||
dataset_name='vaihingen',
|
||||
templates=['a photo of a {}.'],
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims)
|
||||
cfg = Config(cfg)
|
||||
|
||||
text_encoder = CLIPTextEncoder(**cfg)
|
||||
with torch.no_grad():
|
||||
class_embeds = text_encoder()
|
||||
class_nums = len(get_classes('vaihingen'))
|
||||
assert class_embeds.shape == (class_nums + 1, output_dims)
|
||||
@@ -0,0 +1,146 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import ERFNet
|
||||
from mmseg.models.backbones.erfnet import (DownsamplerBlock, NonBottleneck1d,
|
||||
UpsamplerBlock)
|
||||
|
||||
|
||||
def test_erfnet_backbone():
|
||||
# Test ERFNet Standard Forward.
|
||||
model = ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 256, 512)
|
||||
output = model(imgs)
|
||||
|
||||
# output for segment Head
|
||||
assert output[0].shape == torch.Size([batch_size, 16, 128, 256])
|
||||
|
||||
# Test input with rare shape
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 527, 279)
|
||||
output = model(imgs)
|
||||
assert len(output[0]) == batch_size
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of encoder downsample block and decoder upsample block.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(128, 64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of encoder downsample block and encoder Non-bottleneck block.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8, 10),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of encoder downsample block and
|
||||
# channels of encoder Non-bottleneck block.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128, 256),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of encoder Non-bottleneck block and number of its channels.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8, 3),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of decoder upsample block and decoder Non-bottleneck block.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2, 3),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of decoder Non-bottleneck block and number of its channels.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16, 8),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
|
||||
|
||||
def test_erfnet_downsampler_block():
|
||||
x_db = DownsamplerBlock(16, 64)
|
||||
assert x_db.conv.in_channels == 16
|
||||
assert x_db.conv.out_channels == 48
|
||||
assert len(x_db.bn.weight) == 64
|
||||
assert x_db.pool.kernel_size == 2
|
||||
assert x_db.pool.stride == 2
|
||||
|
||||
|
||||
def test_erfnet_non_bottleneck_1d():
|
||||
x_nb1d = NonBottleneck1d(16, 0, 1)
|
||||
assert x_nb1d.convs_layers[0].in_channels == 16
|
||||
assert x_nb1d.convs_layers[0].out_channels == 16
|
||||
assert x_nb1d.convs_layers[2].in_channels == 16
|
||||
assert x_nb1d.convs_layers[2].out_channels == 16
|
||||
assert x_nb1d.convs_layers[5].in_channels == 16
|
||||
assert x_nb1d.convs_layers[5].out_channels == 16
|
||||
assert x_nb1d.convs_layers[7].in_channels == 16
|
||||
assert x_nb1d.convs_layers[7].out_channels == 16
|
||||
assert x_nb1d.convs_layers[9].p == 0
|
||||
|
||||
|
||||
def test_erfnet_upsampler_block():
|
||||
x_ub = UpsamplerBlock(64, 16)
|
||||
assert x_ub.conv.in_channels == 64
|
||||
assert x_ub.conv.out_channels == 16
|
||||
assert len(x_ub.bn.weight) == 16
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import FastSCNN
|
||||
|
||||
|
||||
def test_fastscnn_backbone():
|
||||
with pytest.raises(AssertionError):
|
||||
# Fast-SCNN channel constraints.
|
||||
FastSCNN(
|
||||
3, (32, 48),
|
||||
64, (64, 96, 128), (2, 2, 1),
|
||||
global_out_channels=127,
|
||||
higher_in_channels=64,
|
||||
lower_in_channels=128)
|
||||
|
||||
# Test FastSCNN Standard Forward
|
||||
model = FastSCNN(
|
||||
in_channels=3,
|
||||
downsample_dw_channels=(4, 6),
|
||||
global_in_channels=8,
|
||||
global_block_channels=(8, 12, 16),
|
||||
global_block_strides=(2, 2, 1),
|
||||
global_out_channels=16,
|
||||
higher_in_channels=8,
|
||||
lower_in_channels=16,
|
||||
fusion_out_channels=16,
|
||||
)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
batch_size = 4
|
||||
imgs = torch.randn(batch_size, 3, 64, 128)
|
||||
feat = model(imgs)
|
||||
|
||||
assert len(feat) == 3
|
||||
# higher-res
|
||||
assert feat[0].shape == torch.Size([batch_size, 8, 8, 16])
|
||||
# lower-res
|
||||
assert feat[1].shape == torch.Size([batch_size, 16, 2, 4])
|
||||
# FFM output
|
||||
assert feat[2].shape == torch.Size([batch_size, 16, 8, 16])
|
||||
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.models.backbones.hrnet import HRModule, HRNet
|
||||
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
||||
|
||||
|
||||
@pytest.mark.parametrize('block', [BasicBlock, Bottleneck])
|
||||
def test_hrmodule(block):
|
||||
# Test multiscale forward
|
||||
num_channles = (32, 64)
|
||||
in_channels = [c * block.expansion for c in num_channles]
|
||||
hrmodule = HRModule(
|
||||
num_branches=2,
|
||||
blocks=block,
|
||||
in_channels=in_channels,
|
||||
num_blocks=(4, 4),
|
||||
num_channels=num_channles,
|
||||
)
|
||||
|
||||
feats = [
|
||||
torch.randn(1, in_channels[0], 64, 64),
|
||||
torch.randn(1, in_channels[1], 32, 32)
|
||||
]
|
||||
feats = hrmodule(feats)
|
||||
|
||||
assert len(feats) == 2
|
||||
assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
|
||||
assert feats[1].shape == torch.Size([1, in_channels[1], 32, 32])
|
||||
|
||||
# Test single scale forward
|
||||
num_channles = (32, 64)
|
||||
in_channels = [c * block.expansion for c in num_channles]
|
||||
hrmodule = HRModule(
|
||||
num_branches=2,
|
||||
blocks=block,
|
||||
in_channels=in_channels,
|
||||
num_blocks=(4, 4),
|
||||
num_channels=num_channles,
|
||||
multiscale_output=False,
|
||||
)
|
||||
|
||||
feats = [
|
||||
torch.randn(1, in_channels[0], 64, 64),
|
||||
torch.randn(1, in_channels[1], 32, 32)
|
||||
]
|
||||
feats = hrmodule(feats)
|
||||
|
||||
assert len(feats) == 1
|
||||
assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
|
||||
|
||||
|
||||
def test_hrnet_backbone():
|
||||
# only have 3 stages
|
||||
extra = dict(
|
||||
stage1=dict(
|
||||
num_modules=1,
|
||||
num_branches=1,
|
||||
block='BOTTLENECK',
|
||||
num_blocks=(4, ),
|
||||
num_channels=(64, )),
|
||||
stage2=dict(
|
||||
num_modules=1,
|
||||
num_branches=2,
|
||||
block='BASIC',
|
||||
num_blocks=(4, 4),
|
||||
num_channels=(32, 64)),
|
||||
stage3=dict(
|
||||
num_modules=4,
|
||||
num_branches=3,
|
||||
block='BASIC',
|
||||
num_blocks=(4, 4, 4),
|
||||
num_channels=(32, 64, 128)))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# HRNet now only support 4 stages
|
||||
HRNet(extra=extra)
|
||||
extra['stage4'] = dict(
|
||||
num_modules=3,
|
||||
num_branches=3, # should be 4
|
||||
block='BASIC',
|
||||
num_blocks=(4, 4, 4, 4),
|
||||
num_channels=(32, 64, 128, 256))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# len(num_blocks) should equal num_branches
|
||||
HRNet(extra=extra)
|
||||
|
||||
extra['stage4']['num_branches'] = 4
|
||||
|
||||
# Test hrnetv2p_w32
|
||||
model = HRNet(extra=extra)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feats = model(imgs)
|
||||
assert len(feats) == 4
|
||||
assert feats[0].shape == torch.Size([1, 32, 16, 16])
|
||||
assert feats[3].shape == torch.Size([1, 256, 2, 2])
|
||||
|
||||
# Test single scale output
|
||||
model = HRNet(extra=extra, multiscale_output=False)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feats = model(imgs)
|
||||
assert len(feats) == 1
|
||||
assert feats[0].shape == torch.Size([1, 32, 16, 16])
|
||||
|
||||
# Test HRNET with two stage frozen
|
||||
frozen_stages = 2
|
||||
model = HRNet(extra, frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert model.norm1.training is False
|
||||
|
||||
for layer in [model.conv1, model.norm1]:
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
if i == 1:
|
||||
layer = getattr(model, f'layer{i}')
|
||||
transition = getattr(model, f'transition{i}')
|
||||
elif i == 4:
|
||||
layer = getattr(model, f'stage{i}')
|
||||
else:
|
||||
layer = getattr(model, f'stage{i}')
|
||||
transition = getattr(model, f'transition{i}')
|
||||
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
for mod in transition.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in transition.parameters():
|
||||
assert param.requires_grad is False
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import ICNet
|
||||
|
||||
|
||||
def test_icnet_backbone():
|
||||
with pytest.raises(TypeError):
|
||||
# Must give backbone dict in config file.
|
||||
ICNet(
|
||||
in_channels=3,
|
||||
layer_channels=(128, 512),
|
||||
light_branch_middle_channels=8,
|
||||
psp_out_channels=128,
|
||||
out_channels=(16, 128, 128),
|
||||
backbone_cfg=None)
|
||||
|
||||
# Test ICNet Standard Forward
|
||||
model = ICNet(
|
||||
layer_channels=(128, 512),
|
||||
backbone_cfg=dict(
|
||||
type='ResNetV1c',
|
||||
in_channels=3,
|
||||
depth=18,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
)
|
||||
assert hasattr(model.backbone,
|
||||
'maxpool') and model.backbone.maxpool.ceil_mode is True
|
||||
model.init_weights()
|
||||
model.train()
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 32, 64)
|
||||
feat = model(imgs)
|
||||
|
||||
assert model.psp_modules[0][0].output_size == 1
|
||||
assert model.psp_modules[1][0].output_size == 2
|
||||
assert model.psp_modules[2][0].output_size == 3
|
||||
assert model.psp_bottleneck.padding == 1
|
||||
assert model.conv_sub1[0].padding == 1
|
||||
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size([batch_size, 64, 4, 8])
|
||||
@@ -0,0 +1,186 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones.mae import MAE
|
||||
from .utils import check_norm_state
|
||||
|
||||
|
||||
def test_mae_backbone():
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be a string path
|
||||
model = MAE()
|
||||
model.init_weights(pretrained=0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# img_size must be int or tuple
|
||||
model = MAE(img_size=512.0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# out_indices must be int ,list or tuple
|
||||
model = MAE(out_indices=1.)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# The length of img_size tuple must be lower than 3.
|
||||
MAE(img_size=(224, 224, 224))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Pretrained must be None or Str.
|
||||
MAE(pretrained=123)
|
||||
|
||||
# Test img_size isinstance tuple
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = MAE(img_size=(224, ))
|
||||
model.init_weights()
|
||||
model(imgs)
|
||||
|
||||
# Test img_size isinstance tuple
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = MAE(img_size=(224, 224))
|
||||
model(imgs)
|
||||
|
||||
# Test norm_eval = True
|
||||
model = MAE(norm_eval=True)
|
||||
model.train()
|
||||
|
||||
# Test BEiT backbone with input size of 224 and patch size of 16
|
||||
model = MAE()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
# Test out_indices = list
|
||||
model = MAE(out_indices=[2, 4, 8, 12])
|
||||
model.train()
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
# Test image size = (224, 224)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test MAE backbone with input size of 256 and patch size of 16
|
||||
model = MAE(img_size=(256, 256))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 256, 256)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 16, 16)
|
||||
|
||||
# Test MAE backbone with input size of 32 and patch size of 16
|
||||
model = MAE(img_size=(32, 32))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 32, 32)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 2, 2)
|
||||
|
||||
# Test unbalanced size input image
|
||||
model = MAE(img_size=(112, 224))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 112, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 7, 14)
|
||||
|
||||
# Test irregular input image
|
||||
model = MAE(img_size=(234, 345))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 234, 345)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 21)
|
||||
|
||||
# Test init_values=0
|
||||
model = MAE(init_values=0)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test final norm
|
||||
model = MAE(final_norm=True)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test patch norm
|
||||
model = MAE(patch_norm=True)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
|
||||
def test_mae_init():
|
||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||
# Test all combinations of pretrained and init_cfg
|
||||
# pretrained=None, init_cfg=None
|
||||
model = MAE(pretrained=None, init_cfg=None)
|
||||
assert model.init_cfg is None
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
model = MAE(
|
||||
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# test resize_rel_pos_embed
|
||||
value = torch.randn(732, 16)
|
||||
abs_pos_embed_value = torch.rand(1, 17, 768)
|
||||
ckpt = {
|
||||
'state_dict': {
|
||||
'layers.0.attn.relative_position_index': 0,
|
||||
'layers.0.attn.relative_position_bias_table': value,
|
||||
'pos_embed': abs_pos_embed_value
|
||||
}
|
||||
}
|
||||
model = MAE(img_size=(512, 512))
|
||||
# If scipy is installed, this AttributeError would not be raised.
|
||||
from mmengine.utils import is_installed
|
||||
if not is_installed('scipy'):
|
||||
with pytest.raises(AttributeError):
|
||||
model.resize_rel_pos_embed(ckpt)
|
||||
|
||||
# test resize abs pos embed
|
||||
ckpt = model.resize_abs_pos_embed(ckpt['state_dict'])
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg=123, whose type is unsupported
|
||||
model = MAE(pretrained=None, init_cfg=123)
|
||||
with pytest.raises(TypeError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg=None
|
||||
model = MAE(pretrained=path, init_cfg=None)
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = MAE(
|
||||
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
with pytest.raises(AssertionError):
|
||||
model = MAE(pretrained=path, init_cfg=123)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=None
|
||||
with pytest.raises(TypeError):
|
||||
model = MAE(pretrained=123, init_cfg=None)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = MAE(
|
||||
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=123, whose type is unsupported
|
||||
with pytest.raises(AssertionError):
|
||||
model = MAE(pretrained=123, init_cfg=123)
|
||||
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import MixVisionTransformer
|
||||
from mmseg.models.backbones.mit import (EfficientMultiheadAttention, MixFFN,
|
||||
TransformerEncoderLayer)
|
||||
|
||||
|
||||
def test_mit():
|
||||
with pytest.raises(TypeError):
|
||||
# Pretrained represents pretrain url and must be str or None.
|
||||
MixVisionTransformer(pretrained=123)
|
||||
|
||||
# Test normal input
|
||||
H, W = (224, 224)
|
||||
temp = torch.randn((1, 3, H, W))
|
||||
model = MixVisionTransformer(
|
||||
embed_dims=32, num_heads=[1, 2, 5, 8], out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
outs = model(temp)
|
||||
assert outs[0].shape == (1, 32, H // 4, W // 4)
|
||||
assert outs[1].shape == (1, 64, H // 8, W // 8)
|
||||
assert outs[2].shape == (1, 160, H // 16, W // 16)
|
||||
assert outs[3].shape == (1, 256, H // 32, W // 32)
|
||||
|
||||
# Test non-squared input
|
||||
H, W = (224, 256)
|
||||
temp = torch.randn((1, 3, H, W))
|
||||
outs = model(temp)
|
||||
assert outs[0].shape == (1, 32, H // 4, W // 4)
|
||||
assert outs[1].shape == (1, 64, H // 8, W // 8)
|
||||
assert outs[2].shape == (1, 160, H // 16, W // 16)
|
||||
assert outs[3].shape == (1, 256, H // 32, W // 32)
|
||||
|
||||
# Test MixFFN
|
||||
FFN = MixFFN(64, 128)
|
||||
hw_shape = (32, 32)
|
||||
token_len = 32 * 32
|
||||
temp = torch.randn((1, token_len, 64))
|
||||
# Self identity
|
||||
out = FFN(temp, hw_shape)
|
||||
assert out.shape == (1, token_len, 64)
|
||||
# Out identity
|
||||
outs = FFN(temp, hw_shape, temp)
|
||||
assert out.shape == (1, token_len, 64)
|
||||
|
||||
# Test EfficientMHA
|
||||
MHA = EfficientMultiheadAttention(64, 2)
|
||||
hw_shape = (32, 32)
|
||||
token_len = 32 * 32
|
||||
temp = torch.randn((1, token_len, 64))
|
||||
# Self identity
|
||||
out = MHA(temp, hw_shape)
|
||||
assert out.shape == (1, token_len, 64)
|
||||
# Out identity
|
||||
outs = MHA(temp, hw_shape, temp)
|
||||
assert out.shape == (1, token_len, 64)
|
||||
|
||||
# Test TransformerEncoderLayer with checkpoint forward
|
||||
block = TransformerEncoderLayer(
|
||||
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 56 * 56, 64)
|
||||
x_out = block(x, (56, 56))
|
||||
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
||||
|
||||
|
||||
def test_mit_init():
|
||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||
# Test all combinations of pretrained and init_cfg
|
||||
# pretrained=None, init_cfg=None
|
||||
model = MixVisionTransformer(pretrained=None, init_cfg=None)
|
||||
assert model.init_cfg is None
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
model = MixVisionTransformer(
|
||||
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg=123, whose type is unsupported
|
||||
model = MixVisionTransformer(pretrained=None, init_cfg=123)
|
||||
with pytest.raises(TypeError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg=None
|
||||
model = MixVisionTransformer(pretrained=path, init_cfg=None)
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
MixVisionTransformer(
|
||||
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
with pytest.raises(AssertionError):
|
||||
MixVisionTransformer(pretrained=path, init_cfg=123)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=None
|
||||
with pytest.raises(TypeError):
|
||||
MixVisionTransformer(pretrained=123, init_cfg=None)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
MixVisionTransformer(
|
||||
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=123, whose type is unsupported
|
||||
with pytest.raises(AssertionError):
|
||||
MixVisionTransformer(pretrained=123, init_cfg=123)
|
||||
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import MobileNetV3
|
||||
|
||||
|
||||
def test_mobilenet_v3():
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid arch
|
||||
MobileNetV3('big')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid reduction_factor
|
||||
MobileNetV3(reduction_factor=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# check invalid out_indices
|
||||
MobileNetV3(out_indices=(0, 1, 15))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# check invalid frozen_stages
|
||||
MobileNetV3(frozen_stages=15)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# check invalid pretrained
|
||||
model = MobileNetV3()
|
||||
model.init_weights(pretrained=8)
|
||||
|
||||
# Test MobileNetV3 with default settings
|
||||
model = MobileNetV3()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 56, 56)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == (2, 16, 28, 28)
|
||||
assert feat[1].shape == (2, 16, 14, 14)
|
||||
assert feat[2].shape == (2, 576, 7, 7)
|
||||
|
||||
# Test MobileNetV3 with arch = 'large'
|
||||
model = MobileNetV3(arch='large', out_indices=(1, 3, 16))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 56, 56)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == (2, 16, 28, 28)
|
||||
assert feat[1].shape == (2, 24, 14, 14)
|
||||
assert feat[2].shape == (2, 960, 7, 7)
|
||||
|
||||
# Test MobileNetV3 with norm_eval True, with_cp True and frozen_stages=5
|
||||
model = MobileNetV3(norm_eval=True, with_cp=True, frozen_stages=5)
|
||||
with pytest.raises(TypeError):
|
||||
# check invalid pretrained
|
||||
model.init_weights(pretrained=8)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 56, 56)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == (2, 16, 28, 28)
|
||||
assert feat[1].shape == (2, 16, 14, 14)
|
||||
assert feat[2].shape == (2, 576, 7, 7)
|
||||
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import MSCAN
|
||||
from mmseg.models.backbones.mscan import (MSCAAttention, MSCASpatialAttention,
|
||||
OverlapPatchEmbed, StemConv)
|
||||
|
||||
|
||||
def test_mscan_backbone():
|
||||
# Test MSCAN Standard Forward
|
||||
model = MSCAN(
|
||||
embed_dims=[8, 16, 32, 64],
|
||||
norm_cfg=dict(type='BN', requires_grad=True))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 64, 128)
|
||||
feat = model(imgs)
|
||||
|
||||
assert len(feat) == 4
|
||||
# output for segment Head
|
||||
assert feat[0].shape == torch.Size([batch_size, 8, 16, 32])
|
||||
assert feat[1].shape == torch.Size([batch_size, 16, 8, 16])
|
||||
assert feat[2].shape == torch.Size([batch_size, 32, 4, 8])
|
||||
assert feat[3].shape == torch.Size([batch_size, 64, 2, 4])
|
||||
|
||||
# Test input with rare shape
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 95, 27)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
|
||||
|
||||
def test_mscan_overlap_patch_embed_module():
|
||||
x_overlap_patch_embed = OverlapPatchEmbed(
|
||||
norm_cfg=dict(type='BN', requires_grad=True))
|
||||
assert x_overlap_patch_embed.proj.in_channels == 3
|
||||
assert x_overlap_patch_embed.norm.weight.shape == torch.Size([768])
|
||||
x = torch.randn(2, 3, 16, 32)
|
||||
x_out, H, W = x_overlap_patch_embed(x)
|
||||
assert x_out.shape == torch.Size([2, 32, 768])
|
||||
|
||||
|
||||
def test_mscan_spatial_attention_module():
|
||||
x_spatial_attention = MSCASpatialAttention(8)
|
||||
assert x_spatial_attention.proj_1.kernel_size == (1, 1)
|
||||
assert x_spatial_attention.proj_2.stride == (1, 1)
|
||||
x = torch.randn(2, 8, 16, 32)
|
||||
x_out = x_spatial_attention(x)
|
||||
assert x_out.shape == torch.Size([2, 8, 16, 32])
|
||||
|
||||
|
||||
def test_mscan_attention_module():
|
||||
x_attention = MSCAAttention(8)
|
||||
assert x_attention.conv0.weight.shape[0] == 8
|
||||
assert x_attention.conv3.kernel_size == (1, 1)
|
||||
x = torch.randn(2, 8, 16, 32)
|
||||
x_out = x_attention(x)
|
||||
assert x_out.shape == torch.Size([2, 8, 16, 32])
|
||||
|
||||
|
||||
def test_mscan_stem_module():
|
||||
x_stem = StemConv(8, 8, norm_cfg=dict(type='BN', requires_grad=True))
|
||||
assert x_stem.proj[0].weight.shape[0] == 4
|
||||
assert x_stem.proj[-1].weight.shape[0] == 8
|
||||
x = torch.randn(2, 8, 16, 32)
|
||||
x_out, H, W = x_stem(x)
|
||||
assert x_out.shape == torch.Size([2, 32, 8])
|
||||
assert (H, W) == (4, 8)
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
init_default_scope('mmseg')
|
||||
|
||||
|
||||
def test_pidnet_backbone():
|
||||
# Test PIDNet Standard Forward
|
||||
norm_cfg = dict(type='BN', requires_grad=True)
|
||||
backbone_cfg = dict(
|
||||
type='PIDNet',
|
||||
in_channels=3,
|
||||
channels=32,
|
||||
ppm_channels=96,
|
||||
num_stem_blocks=2,
|
||||
num_branch_blocks=3,
|
||||
align_corners=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU', inplace=True))
|
||||
model = MODELS.build(backbone_cfg)
|
||||
model.init_weights()
|
||||
|
||||
# Test init weights
|
||||
temp_file = tempfile.NamedTemporaryFile()
|
||||
temp_file.close()
|
||||
torch.save(model.state_dict(), temp_file.name)
|
||||
backbone_cfg.update(
|
||||
init_cfg=dict(type='Pretrained', checkpoint=temp_file.name))
|
||||
model = MODELS.build(backbone_cfg)
|
||||
model.init_weights()
|
||||
os.remove(temp_file.name)
|
||||
|
||||
# Test eval mode
|
||||
model.eval()
|
||||
batch_size = 1
|
||||
imgs = torch.randn(batch_size, 3, 64, 128)
|
||||
feats = model(imgs)
|
||||
|
||||
assert type(feats) == torch.Tensor
|
||||
assert feats.shape == torch.Size([batch_size, 128, 8, 16])
|
||||
|
||||
# Test train mode
|
||||
model.train()
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 64, 128)
|
||||
feats = model(imgs)
|
||||
|
||||
assert len(feats) == 3
|
||||
# test output for P branch
|
||||
assert feats[0].shape == torch.Size([batch_size, 64, 8, 16])
|
||||
# test output for I branch
|
||||
assert feats[1].shape == torch.Size([batch_size, 128, 8, 16])
|
||||
# test output for D branch
|
||||
assert feats[2].shape == torch.Size([batch_size, 64, 8, 16])
|
||||
|
||||
# Test pidnet-m
|
||||
backbone_cfg.update(channels=64)
|
||||
model = MODELS.build(backbone_cfg)
|
||||
feats = model(imgs)
|
||||
|
||||
assert len(feats) == 3
|
||||
# test output for P branch
|
||||
assert feats[0].shape == torch.Size([batch_size, 128, 8, 16])
|
||||
# test output for I branch
|
||||
assert feats[1].shape == torch.Size([batch_size, 256, 8, 16])
|
||||
# test output for D branch
|
||||
assert feats[2].shape == torch.Size([batch_size, 128, 8, 16])
|
||||
|
||||
# Test pidnet-l
|
||||
backbone_cfg.update(
|
||||
channels=64, ppm_channesl=112, num_stem_blocks=3, num_branch_blocks=4)
|
||||
model = MODELS.build(backbone_cfg)
|
||||
feats = model(imgs)
|
||||
|
||||
assert len(feats) == 3
|
||||
# test output for P branch
|
||||
assert feats[0].shape == torch.Size([batch_size, 128, 8, 16])
|
||||
# test output for I branch
|
||||
assert feats[1].shape == torch.Size([batch_size, 256, 8, 16])
|
||||
# test output for D branch
|
||||
assert feats[2].shape == torch.Size([batch_size, 128, 8, 16])
|
||||
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import ResNeSt
|
||||
from mmseg.models.backbones.resnest import Bottleneck as BottleneckS
|
||||
|
||||
|
||||
def test_resnest_bottleneck():
|
||||
with pytest.raises(AssertionError):
|
||||
# Style must be in ['pytorch', 'caffe']
|
||||
BottleneckS(64, 64, radix=2, reduction_factor=4, style='tensorflow')
|
||||
|
||||
# Test ResNeSt Bottleneck structure
|
||||
block = BottleneckS(
|
||||
64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch')
|
||||
assert block.avd_layer.stride == 2
|
||||
assert block.conv2.channels == 256
|
||||
|
||||
# Test ResNeSt Bottleneck forward
|
||||
block = BottleneckS(64, 16, radix=2, reduction_factor=4)
|
||||
x = torch.randn(2, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([2, 64, 56, 56])
|
||||
|
||||
|
||||
def test_resnest_backbone():
|
||||
with pytest.raises(KeyError):
|
||||
# ResNeSt depth should be in [50, 101, 152, 200]
|
||||
ResNeSt(depth=18)
|
||||
|
||||
# Test ResNeSt with radix 2, reduction_factor 4
|
||||
model = ResNeSt(
|
||||
depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([2, 256, 56, 56])
|
||||
assert feat[1].shape == torch.Size([2, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([2, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([2, 2048, 7, 7])
|
||||
@@ -0,0 +1,575 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.ops import DeformConv2dPack
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
from torch.nn.modules import AvgPool2d, GroupNorm
|
||||
|
||||
from mmseg.models.backbones import ResNet, ResNetV1d
|
||||
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
||||
from mmseg.models.utils import ResLayer
|
||||
from .utils import all_zeros, check_norm_state, is_block, is_norm
|
||||
|
||||
|
||||
def test_resnet_basic_block():
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||
BasicBlock(64, 64, dcn=dcn)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv3')
|
||||
]
|
||||
BasicBlock(64, 64, plugins=plugins)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0010',
|
||||
kv_stride=2),
|
||||
position='after_conv2')
|
||||
]
|
||||
BasicBlock(64, 64, plugins=plugins)
|
||||
|
||||
# Test BasicBlock with checkpoint forward
|
||||
block = BasicBlock(16, 16, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 16, 28, 28)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 16, 28, 28])
|
||||
|
||||
# test BasicBlock structure and forward
|
||||
block = BasicBlock(32, 32)
|
||||
assert block.conv1.in_channels == 32
|
||||
assert block.conv1.out_channels == 32
|
||||
assert block.conv1.kernel_size == (3, 3)
|
||||
assert block.conv2.in_channels == 32
|
||||
assert block.conv2.out_channels == 32
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
x = torch.randn(1, 32, 28, 28)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 28, 28])
|
||||
|
||||
|
||||
def test_resnet_bottleneck():
|
||||
with pytest.raises(AssertionError):
|
||||
# Style must be in ['pytorch', 'caffe']
|
||||
Bottleneck(64, 64, style='tensorflow')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Allowed positions are 'after_conv1', 'after_conv2', 'after_conv3'
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv4')
|
||||
]
|
||||
Bottleneck(64, 16, plugins=plugins)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Need to specify different postfix to avoid duplicate plugin name
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv3'),
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv3')
|
||||
]
|
||||
Bottleneck(64, 16, plugins=plugins)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# Plugin type is not supported
|
||||
plugins = [dict(cfg=dict(type='WrongPlugin'), position='after_conv3')]
|
||||
Bottleneck(64, 16, plugins=plugins)
|
||||
|
||||
# Test Bottleneck with checkpoint forward
|
||||
block = Bottleneck(64, 16, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# Test Bottleneck style
|
||||
block = Bottleneck(64, 64, stride=2, style='pytorch')
|
||||
assert block.conv1.stride == (1, 1)
|
||||
assert block.conv2.stride == (2, 2)
|
||||
block = Bottleneck(64, 64, stride=2, style='caffe')
|
||||
assert block.conv1.stride == (2, 2)
|
||||
assert block.conv2.stride == (1, 1)
|
||||
|
||||
# Test Bottleneck DCN
|
||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||
with pytest.raises(AssertionError):
|
||||
Bottleneck(64, 64, dcn=dcn, conv_cfg=dict(type='Conv'))
|
||||
block = Bottleneck(64, 64, dcn=dcn)
|
||||
assert isinstance(block.conv2, DeformConv2dPack)
|
||||
|
||||
# Test Bottleneck forward
|
||||
block = Bottleneck(64, 16)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# Test Bottleneck with 1 ContextBlock after conv3
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv3')
|
||||
]
|
||||
block = Bottleneck(64, 16, plugins=plugins)
|
||||
assert block.context_block.in_channels == 64
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# Test Bottleneck with 1 GeneralizedAttention after conv2
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0010',
|
||||
kv_stride=2),
|
||||
position='after_conv2')
|
||||
]
|
||||
block = Bottleneck(64, 16, plugins=plugins)
|
||||
assert block.gen_attention_block.in_channels == 16
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# Test Bottleneck with 1 GeneralizedAttention after conv2, 1 NonLocal2d
|
||||
# after conv2, 1 ContextBlock after conv3
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0010',
|
||||
kv_stride=2),
|
||||
position='after_conv2'),
|
||||
dict(cfg=dict(type='NonLocal2d'), position='after_conv2'),
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv3')
|
||||
]
|
||||
block = Bottleneck(64, 16, plugins=plugins)
|
||||
assert block.gen_attention_block.in_channels == 16
|
||||
assert block.nonlocal_block.in_channels == 16
|
||||
assert block.context_block.in_channels == 64
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# Test Bottleneck with 1 ContextBlock after conv2, 2 ContextBlock after
|
||||
# conv3
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=1),
|
||||
position='after_conv2'),
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=2),
|
||||
position='after_conv3'),
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=3),
|
||||
position='after_conv3')
|
||||
]
|
||||
block = Bottleneck(64, 16, plugins=plugins)
|
||||
assert block.context_block1.in_channels == 16
|
||||
assert block.context_block2.in_channels == 64
|
||||
assert block.context_block3.in_channels == 64
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
|
||||
def test_resnet_res_layer():
|
||||
# Test ResLayer of 3 Bottleneck w\o downsample
|
||||
layer = ResLayer(Bottleneck, 64, 16, 3)
|
||||
assert len(layer) == 3
|
||||
assert layer[0].conv1.in_channels == 64
|
||||
assert layer[0].conv1.out_channels == 16
|
||||
for i in range(1, len(layer)):
|
||||
assert layer[i].conv1.in_channels == 64
|
||||
assert layer[i].conv1.out_channels == 16
|
||||
for i in range(len(layer)):
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# Test ResLayer of 3 Bottleneck with downsample
|
||||
layer = ResLayer(Bottleneck, 64, 64, 3)
|
||||
assert layer[0].downsample[0].out_channels == 256
|
||||
for i in range(1, len(layer)):
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == torch.Size([1, 256, 56, 56])
|
||||
|
||||
# Test ResLayer of 3 Bottleneck with stride=2
|
||||
layer = ResLayer(Bottleneck, 64, 64, 3, stride=2)
|
||||
assert layer[0].downsample[0].out_channels == 256
|
||||
assert layer[0].downsample[0].stride == (2, 2)
|
||||
for i in range(1, len(layer)):
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == torch.Size([1, 256, 28, 28])
|
||||
|
||||
# Test ResLayer of 3 Bottleneck with stride=2 and average downsample
|
||||
layer = ResLayer(Bottleneck, 64, 64, 3, stride=2, avg_down=True)
|
||||
assert isinstance(layer[0].downsample[0], AvgPool2d)
|
||||
assert layer[0].downsample[1].out_channels == 256
|
||||
assert layer[0].downsample[1].stride == (1, 1)
|
||||
for i in range(1, len(layer)):
|
||||
assert layer[i].downsample is None
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == torch.Size([1, 256, 28, 28])
|
||||
|
||||
# Test ResLayer of 3 Bottleneck with dilation=2
|
||||
layer = ResLayer(Bottleneck, 64, 16, 3, dilation=2)
|
||||
for i in range(len(layer)):
|
||||
assert layer[i].conv2.dilation == (2, 2)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# Test ResLayer of 3 Bottleneck with dilation=2, contract_dilation=True
|
||||
layer = ResLayer(Bottleneck, 64, 16, 3, dilation=2, contract_dilation=True)
|
||||
assert layer[0].conv2.dilation == (1, 1)
|
||||
for i in range(1, len(layer)):
|
||||
assert layer[i].conv2.dilation == (2, 2)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# Test ResLayer of 3 Bottleneck with dilation=2, multi_grid
|
||||
layer = ResLayer(Bottleneck, 64, 16, 3, dilation=2, multi_grid=(1, 2, 4))
|
||||
assert layer[0].conv2.dilation == (1, 1)
|
||||
assert layer[1].conv2.dilation == (2, 2)
|
||||
assert layer[2].conv2.dilation == (4, 4)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = layer(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
|
||||
def test_resnet_backbone():
|
||||
"""Test resnet backbone."""
|
||||
with pytest.raises(KeyError):
|
||||
# ResNet depth should be in [18, 34, 50, 101, 152]
|
||||
ResNet(20)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# In ResNet: 1 <= num_stages <= 4
|
||||
ResNet(50, num_stages=0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# len(stage_with_dcn) == num_stages
|
||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||
ResNet(50, dcn=dcn, stage_with_dcn=(True, ))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# len(stage_with_plugin) == num_stages
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
stages=(False, True, True),
|
||||
position='after_conv3')
|
||||
]
|
||||
ResNet(50, plugins=plugins)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# In ResNet: 1 <= num_stages <= 4
|
||||
ResNet(18, num_stages=5)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# len(strides) == len(dilations) == num_stages
|
||||
ResNet(18, strides=(1, ), dilations=(1, 1), num_stages=3)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be a string path
|
||||
model = ResNet(18, pretrained=0)
|
||||
model.init_weights()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Style must be in ['pytorch', 'caffe']
|
||||
ResNet(50, style='tensorflow')
|
||||
|
||||
# Test ResNet18 norm_eval=True
|
||||
model = ResNet(18, norm_eval=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test ResNet18 with torchvision pretrained weight
|
||||
model = ResNet(
|
||||
depth=18, norm_eval=True, pretrained='torchvision://resnet18')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test ResNet18 with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = ResNet(18, frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert model.norm1.training is False
|
||||
for layer in [model.conv1, model.norm1]:
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
layer = getattr(model, f'layer{i}')
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test ResNet18V1d with first stage frozen
|
||||
model = ResNetV1d(depth=18, frozen_stages=frozen_stages)
|
||||
assert len(model.stem) == 9
|
||||
model.init_weights()
|
||||
model.train()
|
||||
check_norm_state(model.stem, False)
|
||||
for param in model.stem.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
layer = getattr(model, f'layer{i}')
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test ResNet18 forward
|
||||
model = ResNet(18)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 64, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 128, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 256, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 512, 7, 7])
|
||||
|
||||
# Test ResNet18 with BatchNorm forward
|
||||
model = ResNet(18)
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 64, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 128, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 256, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 512, 7, 7])
|
||||
|
||||
# Test ResNet18 with layers 1, 2, 3 out forward
|
||||
model = ResNet(18, out_indices=(0, 1, 2))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 112, 112)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size([1, 64, 28, 28])
|
||||
assert feat[1].shape == torch.Size([1, 128, 14, 14])
|
||||
assert feat[2].shape == torch.Size([1, 256, 7, 7])
|
||||
|
||||
# Test ResNet18 with checkpoint forward
|
||||
model = ResNet(18, with_cp=True)
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 64, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 128, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 256, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 512, 7, 7])
|
||||
|
||||
# Test ResNet18 with checkpoint forward
|
||||
model = ResNet(18, with_cp=True)
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 64, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 128, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 256, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 512, 7, 7])
|
||||
|
||||
# Test ResNet18 with GroupNorm forward
|
||||
model = ResNet(
|
||||
18, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True))
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, GroupNorm)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 64, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 128, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 256, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 512, 7, 7])
|
||||
|
||||
# Test ResNet50 with 1 GeneralizedAttention after conv2, 1 NonLocal2d
|
||||
# after conv2, 1 ContextBlock after conv3 in layers 2, 3, 4
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0010',
|
||||
kv_stride=2),
|
||||
stages=(False, True, True, True),
|
||||
position='after_conv2'),
|
||||
dict(cfg=dict(type='NonLocal2d'), position='after_conv2'),
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
stages=(False, True, True, False),
|
||||
position='after_conv3')
|
||||
]
|
||||
model = ResNet(50, plugins=plugins)
|
||||
for m in model.layer1.modules():
|
||||
if is_block(m):
|
||||
assert not hasattr(m, 'context_block')
|
||||
assert not hasattr(m, 'gen_attention_block')
|
||||
assert m.nonlocal_block.in_channels == 64
|
||||
for m in model.layer2.modules():
|
||||
if is_block(m):
|
||||
assert m.nonlocal_block.in_channels == 128
|
||||
assert m.gen_attention_block.in_channels == 128
|
||||
assert m.context_block.in_channels == 512
|
||||
|
||||
for m in model.layer3.modules():
|
||||
if is_block(m):
|
||||
assert m.nonlocal_block.in_channels == 256
|
||||
assert m.gen_attention_block.in_channels == 256
|
||||
assert m.context_block.in_channels == 1024
|
||||
|
||||
for m in model.layer4.modules():
|
||||
if is_block(m):
|
||||
assert m.nonlocal_block.in_channels == 512
|
||||
assert m.gen_attention_block.in_channels == 512
|
||||
assert not hasattr(m, 'context_block')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 256, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
# Test ResNet50 with 1 ContextBlock after conv2, 1 ContextBlock after
|
||||
# conv3 in layers 2, 3, 4
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=1),
|
||||
stages=(False, True, True, False),
|
||||
position='after_conv3'),
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=2),
|
||||
stages=(False, True, True, False),
|
||||
position='after_conv3')
|
||||
]
|
||||
|
||||
model = ResNet(50, plugins=plugins)
|
||||
for m in model.layer1.modules():
|
||||
if is_block(m):
|
||||
assert not hasattr(m, 'context_block')
|
||||
assert not hasattr(m, 'context_block1')
|
||||
assert not hasattr(m, 'context_block2')
|
||||
for m in model.layer2.modules():
|
||||
if is_block(m):
|
||||
assert not hasattr(m, 'context_block')
|
||||
assert m.context_block1.in_channels == 512
|
||||
assert m.context_block2.in_channels == 512
|
||||
|
||||
for m in model.layer3.modules():
|
||||
if is_block(m):
|
||||
assert not hasattr(m, 'context_block')
|
||||
assert m.context_block1.in_channels == 1024
|
||||
assert m.context_block2.in_channels == 1024
|
||||
|
||||
for m in model.layer4.modules():
|
||||
if is_block(m):
|
||||
assert not hasattr(m, 'context_block')
|
||||
assert not hasattr(m, 'context_block1')
|
||||
assert not hasattr(m, 'context_block2')
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 256, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
# Test ResNet18 zero initialization of residual
|
||||
model = ResNet(18, zero_init_residual=True)
|
||||
model.init_weights()
|
||||
for m in model.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
assert all_zeros(m.norm3)
|
||||
elif isinstance(m, BasicBlock):
|
||||
assert all_zeros(m.norm2)
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 64, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 128, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 256, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 512, 7, 7])
|
||||
|
||||
# Test ResNetV1d forward
|
||||
model = ResNetV1d(depth=18)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 64, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 128, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 256, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 512, 7, 7])
|
||||
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import ResNeXt
|
||||
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
|
||||
from .utils import is_block
|
||||
|
||||
|
||||
def test_renext_bottleneck():
|
||||
with pytest.raises(AssertionError):
|
||||
# Style must be in ['pytorch', 'caffe']
|
||||
BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow')
|
||||
|
||||
# Test ResNeXt Bottleneck structure
|
||||
block = BottleneckX(
|
||||
64, 64, groups=32, base_width=4, stride=2, style='pytorch')
|
||||
assert block.conv2.stride == (2, 2)
|
||||
assert block.conv2.groups == 32
|
||||
assert block.conv2.out_channels == 128
|
||||
|
||||
# Test ResNeXt Bottleneck with DCN
|
||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||
with pytest.raises(AssertionError):
|
||||
# conv_cfg must be None if dcn is not None
|
||||
BottleneckX(
|
||||
64,
|
||||
64,
|
||||
groups=32,
|
||||
base_width=4,
|
||||
dcn=dcn,
|
||||
conv_cfg=dict(type='Conv'))
|
||||
BottleneckX(64, 64, dcn=dcn)
|
||||
|
||||
# Test ResNeXt Bottleneck forward
|
||||
block = BottleneckX(64, 16, groups=32, base_width=4)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
|
||||
def test_resnext_backbone():
|
||||
with pytest.raises(KeyError):
|
||||
# ResNeXt depth should be in [50, 101, 152]
|
||||
ResNeXt(depth=18)
|
||||
|
||||
# Test ResNeXt with group 32, base_width 4
|
||||
model = ResNeXt(depth=50, groups=32, base_width=4)
|
||||
print(model)
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
assert m.conv2.groups == 32
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 256, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
|
||||
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import STDCContextPathNet
|
||||
from mmseg.models.backbones.stdc import (AttentionRefinementModule,
|
||||
FeatureFusionModule, STDCModule,
|
||||
STDCNet)
|
||||
|
||||
|
||||
def test_stdc_context_path_net():
|
||||
# Test STDCContextPathNet Standard Forward
|
||||
model = STDCContextPathNet(
|
||||
backbone_cfg=dict(
|
||||
type='STDCNet',
|
||||
stdc_type='STDCNet1',
|
||||
in_channels=3,
|
||||
channels=(32, 64, 256, 512, 1024),
|
||||
bottleneck_type='cat',
|
||||
num_convs=4,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_final_conv=True),
|
||||
last_in_channels=(1024, 512),
|
||||
out_channels=128,
|
||||
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 256, 512)
|
||||
feat = model(imgs)
|
||||
|
||||
assert len(feat) == 4
|
||||
# output for segment Head
|
||||
assert feat[0].shape == torch.Size([batch_size, 256, 32, 64])
|
||||
# for auxiliary head 1
|
||||
assert feat[1].shape == torch.Size([batch_size, 128, 16, 32])
|
||||
# for auxiliary head 2
|
||||
assert feat[2].shape == torch.Size([batch_size, 128, 32, 64])
|
||||
# for auxiliary head 3
|
||||
assert feat[3].shape == torch.Size([batch_size, 256, 32, 64])
|
||||
|
||||
# Test input with rare shape
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 527, 279)
|
||||
model = STDCContextPathNet(
|
||||
backbone_cfg=dict(
|
||||
type='STDCNet',
|
||||
stdc_type='STDCNet1',
|
||||
in_channels=3,
|
||||
channels=(32, 64, 256, 512, 1024),
|
||||
bottleneck_type='add',
|
||||
num_convs=4,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_final_conv=False),
|
||||
last_in_channels=(1024, 512),
|
||||
out_channels=128,
|
||||
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
|
||||
|
||||
def test_stdcnet():
|
||||
with pytest.raises(AssertionError):
|
||||
# STDC backbone constraints.
|
||||
STDCNet(
|
||||
stdc_type='STDCNet3',
|
||||
in_channels=3,
|
||||
channels=(32, 64, 256, 512, 1024),
|
||||
bottleneck_type='cat',
|
||||
num_convs=4,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_final_conv=False)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# STDC bottleneck type constraints.
|
||||
STDCNet(
|
||||
stdc_type='STDCNet1',
|
||||
in_channels=3,
|
||||
channels=(32, 64, 256, 512, 1024),
|
||||
bottleneck_type='dog',
|
||||
num_convs=4,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_final_conv=False)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# STDC channels length constraints.
|
||||
STDCNet(
|
||||
stdc_type='STDCNet1',
|
||||
in_channels=3,
|
||||
channels=(16, 32, 64, 256, 512, 1024),
|
||||
bottleneck_type='cat',
|
||||
num_convs=4,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_final_conv=False)
|
||||
|
||||
|
||||
def test_feature_fusion_module():
|
||||
x_ffm = FeatureFusionModule(in_channels=64, out_channels=32)
|
||||
assert x_ffm.conv0.in_channels == 64
|
||||
assert x_ffm.attention[1].in_channels == 32
|
||||
assert x_ffm.attention[2].in_channels == 8
|
||||
assert x_ffm.attention[2].out_channels == 32
|
||||
x1 = torch.randn(2, 32, 32, 64)
|
||||
x2 = torch.randn(2, 32, 32, 64)
|
||||
x_out = x_ffm(x1, x2)
|
||||
assert x_out.shape == torch.Size([2, 32, 32, 64])
|
||||
|
||||
|
||||
def test_attention_refinement_module():
|
||||
x_arm = AttentionRefinementModule(128, 32)
|
||||
assert x_arm.conv_layer.in_channels == 128
|
||||
assert x_arm.atten_conv_layer[1].conv.out_channels == 32
|
||||
x = torch.randn(2, 128, 32, 64)
|
||||
x_out = x_arm(x)
|
||||
assert x_out.shape == torch.Size([2, 32, 32, 64])
|
||||
|
||||
|
||||
def test_stdc_module():
|
||||
x_stdc = STDCModule(in_channels=32, out_channels=32, stride=4)
|
||||
assert x_stdc.layers[0].conv.in_channels == 32
|
||||
assert x_stdc.layers[3].conv.out_channels == 4
|
||||
x = torch.randn(2, 32, 32, 64)
|
||||
x_out = x_stdc(x)
|
||||
assert x_out.shape == torch.Size([2, 32, 32, 64])
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones.swin import SwinBlock, SwinTransformer
|
||||
|
||||
|
||||
def test_swin_block():
|
||||
# test SwinBlock structure and forward
|
||||
block = SwinBlock(embed_dims=32, num_heads=4, feedforward_channels=128)
|
||||
assert block.ffn.embed_dims == 32
|
||||
assert block.attn.w_msa.num_heads == 4
|
||||
assert block.ffn.feedforward_channels == 128
|
||||
x = torch.randn(1, 56 * 56, 32)
|
||||
x_out = block(x, (56, 56))
|
||||
assert x_out.shape == torch.Size([1, 56 * 56, 32])
|
||||
|
||||
# Test BasicBlock with checkpoint forward
|
||||
block = SwinBlock(
|
||||
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 56 * 56, 64)
|
||||
x_out = block(x, (56, 56))
|
||||
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
||||
|
||||
|
||||
def test_swin_transformer():
|
||||
"""Test Swin Transformer backbone."""
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Pretrained arg must be str or None.
|
||||
SwinTransformer(pretrained=123)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Because swin uses non-overlapping patch embed, so the stride of patch
|
||||
# embed must be equal to patch size.
|
||||
SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)
|
||||
|
||||
# test pretrained image size
|
||||
with pytest.raises(AssertionError):
|
||||
SwinTransformer(pretrain_img_size=(112, 112, 112))
|
||||
|
||||
# Test absolute position embedding
|
||||
temp = torch.randn((1, 3, 112, 112))
|
||||
model = SwinTransformer(pretrain_img_size=112, use_abs_pos_embed=True)
|
||||
model.init_weights()
|
||||
model(temp)
|
||||
|
||||
# Test patch norm
|
||||
model = SwinTransformer(patch_norm=False)
|
||||
model(temp)
|
||||
|
||||
# Test normal inference
|
||||
temp = torch.randn((1, 3, 256, 256))
|
||||
model = SwinTransformer()
|
||||
outs = model(temp)
|
||||
assert outs[0].shape == (1, 96, 64, 64)
|
||||
assert outs[1].shape == (1, 192, 32, 32)
|
||||
assert outs[2].shape == (1, 384, 16, 16)
|
||||
assert outs[3].shape == (1, 768, 8, 8)
|
||||
|
||||
# Test abnormal inference size
|
||||
temp = torch.randn((1, 3, 255, 255))
|
||||
model = SwinTransformer()
|
||||
outs = model(temp)
|
||||
assert outs[0].shape == (1, 96, 64, 64)
|
||||
assert outs[1].shape == (1, 192, 32, 32)
|
||||
assert outs[2].shape == (1, 384, 16, 16)
|
||||
assert outs[3].shape == (1, 768, 8, 8)
|
||||
|
||||
# Test abnormal inference size
|
||||
temp = torch.randn((1, 3, 112, 137))
|
||||
model = SwinTransformer()
|
||||
outs = model(temp)
|
||||
assert outs[0].shape == (1, 96, 28, 35)
|
||||
assert outs[1].shape == (1, 192, 14, 18)
|
||||
assert outs[2].shape == (1, 384, 7, 9)
|
||||
assert outs[3].shape == (1, 768, 4, 5)
|
||||
|
||||
# Test frozen
|
||||
model = SwinTransformer(frozen_stages=4)
|
||||
model.train()
|
||||
for p in model.parameters():
|
||||
assert not p.requires_grad
|
||||
|
||||
# Test absolute position embedding frozen
|
||||
model = SwinTransformer(frozen_stages=4, use_abs_pos_embed=True)
|
||||
model.train()
|
||||
for p in model.parameters():
|
||||
assert not p.requires_grad
|
||||
|
||||
# Test Swin with checkpoint forward
|
||||
temp = torch.randn((1, 3, 56, 56))
|
||||
model = SwinTransformer(with_cp=True)
|
||||
for m in model.modules():
|
||||
if isinstance(m, SwinBlock):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
model(temp)
|
||||
@@ -0,0 +1,133 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import TIMMBackbone
|
||||
from .utils import check_norm_state
|
||||
|
||||
|
||||
def test_timm_backbone():
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be a string path
|
||||
model = TIMMBackbone()
|
||||
model.init_weights(pretrained=0)
|
||||
|
||||
# Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN'
|
||||
# Test resnet18 from timm, norm_layer='BN2d'
|
||||
model = TIMMBackbone(
|
||||
model_name='resnet18',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=32,
|
||||
norm_layer='BN2d')
|
||||
|
||||
# Test resnet18 from timm, norm_layer='SyncBN'
|
||||
model = TIMMBackbone(
|
||||
model_name='resnet18',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=32,
|
||||
norm_layer='SyncBN2d')
|
||||
|
||||
# Test resnet18 from timm, features_only=True, output_stride=32
|
||||
model = TIMMBackbone(
|
||||
model_name='resnet18',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=32)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 64, 112, 112))
|
||||
assert feats[1] == torch.Size((1, 64, 56, 56))
|
||||
assert feats[2] == torch.Size((1, 128, 28, 28))
|
||||
assert feats[3] == torch.Size((1, 256, 14, 14))
|
||||
assert feats[4] == torch.Size((1, 512, 7, 7))
|
||||
|
||||
# Test resnet18 from timm, features_only=True, output_stride=16
|
||||
model = TIMMBackbone(
|
||||
model_name='resnet18',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=16)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 64, 112, 112))
|
||||
assert feats[1] == torch.Size((1, 64, 56, 56))
|
||||
assert feats[2] == torch.Size((1, 128, 28, 28))
|
||||
assert feats[3] == torch.Size((1, 256, 14, 14))
|
||||
assert feats[4] == torch.Size((1, 512, 14, 14))
|
||||
|
||||
# Test resnet18 from timm, features_only=True, output_stride=8
|
||||
model = TIMMBackbone(
|
||||
model_name='resnet18',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=8)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 64, 112, 112))
|
||||
assert feats[1] == torch.Size((1, 64, 56, 56))
|
||||
assert feats[2] == torch.Size((1, 128, 28, 28))
|
||||
assert feats[3] == torch.Size((1, 256, 28, 28))
|
||||
assert feats[4] == torch.Size((1, 512, 28, 28))
|
||||
|
||||
# Test efficientnet_b1 with pretrained weights
|
||||
model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True)
|
||||
|
||||
# Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8
|
||||
model = TIMMBackbone(
|
||||
model_name='resnetv2_50x1_bitm',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=8)
|
||||
imgs = torch.randn(1, 3, 8, 8)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 64, 4, 4))
|
||||
assert feats[1] == torch.Size((1, 256, 2, 2))
|
||||
assert feats[2] == torch.Size((1, 512, 1, 1))
|
||||
assert feats[3] == torch.Size((1, 1024, 1, 1))
|
||||
assert feats[4] == torch.Size((1, 2048, 1, 1))
|
||||
|
||||
# Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8
|
||||
model = TIMMBackbone(
|
||||
model_name='resnetv2_50x3_bitm',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=8)
|
||||
imgs = torch.randn(1, 3, 8, 8)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 192, 4, 4))
|
||||
assert feats[1] == torch.Size((1, 768, 2, 2))
|
||||
assert feats[2] == torch.Size((1, 1536, 1, 1))
|
||||
assert feats[3] == torch.Size((1, 3072, 1, 1))
|
||||
assert feats[4] == torch.Size((1, 6144, 1, 1))
|
||||
|
||||
# Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8
|
||||
model = TIMMBackbone(
|
||||
model_name='resnetv2_101x1_bitm',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=8)
|
||||
imgs = torch.randn(1, 3, 8, 8)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 64, 4, 4))
|
||||
assert feats[1] == torch.Size((1, 256, 2, 2))
|
||||
assert feats[2] == torch.Size((1, 512, 1, 1))
|
||||
assert feats[3] == torch.Size((1, 1024, 1, 1))
|
||||
assert feats[4] == torch.Size((1, 2048, 1, 1))
|
||||
@@ -0,0 +1,171 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones.twins import (PCPVT, SVT,
|
||||
ConditionalPositionEncoding,
|
||||
LocallyGroupedSelfAttention)
|
||||
|
||||
|
||||
def test_pcpvt():
|
||||
# Test normal input
|
||||
H, W = (224, 224)
|
||||
temp = torch.randn((1, 3, H, W))
|
||||
model = PCPVT(
|
||||
embed_dims=[32, 64, 160, 256],
|
||||
num_heads=[1, 2, 5, 8],
|
||||
mlp_ratios=[8, 8, 4, 4],
|
||||
qkv_bias=True,
|
||||
depths=[3, 4, 6, 3],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
norm_after_stage=False)
|
||||
model.init_weights()
|
||||
outs = model(temp)
|
||||
assert outs[0].shape == (1, 32, H // 4, W // 4)
|
||||
assert outs[1].shape == (1, 64, H // 8, W // 8)
|
||||
assert outs[2].shape == (1, 160, H // 16, W // 16)
|
||||
assert outs[3].shape == (1, 256, H // 32, W // 32)
|
||||
|
||||
|
||||
def test_svt():
|
||||
# Test normal input
|
||||
H, W = (224, 224)
|
||||
temp = torch.randn((1, 3, H, W))
|
||||
model = SVT(
|
||||
embed_dims=[32, 64, 128],
|
||||
num_heads=[1, 2, 4],
|
||||
mlp_ratios=[4, 4, 4],
|
||||
qkv_bias=False,
|
||||
depths=[4, 4, 4],
|
||||
windiow_sizes=[7, 7, 7],
|
||||
norm_after_stage=True)
|
||||
|
||||
model.init_weights()
|
||||
outs = model(temp)
|
||||
assert outs[0].shape == (1, 32, H // 4, W // 4)
|
||||
assert outs[1].shape == (1, 64, H // 8, W // 8)
|
||||
assert outs[2].shape == (1, 128, H // 16, W // 16)
|
||||
|
||||
|
||||
def test_svt_init():
|
||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||
# Test all combinations of pretrained and init_cfg
|
||||
# pretrained=None, init_cfg=None
|
||||
model = SVT(pretrained=None, init_cfg=None)
|
||||
assert model.init_cfg is None
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
model = SVT(
|
||||
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg=123, whose type is unsupported
|
||||
model = SVT(pretrained=None, init_cfg=123)
|
||||
with pytest.raises(TypeError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg=None
|
||||
model = SVT(pretrained=path, init_cfg=None)
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = SVT(
|
||||
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
with pytest.raises(AssertionError):
|
||||
model = SVT(pretrained=path, init_cfg=123)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=None
|
||||
with pytest.raises(TypeError):
|
||||
model = SVT(pretrained=123, init_cfg=None)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = SVT(
|
||||
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=123, whose type is unsupported
|
||||
with pytest.raises(AssertionError):
|
||||
model = SVT(pretrained=123, init_cfg=123)
|
||||
|
||||
|
||||
def test_pcpvt_init():
|
||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||
# Test all combinations of pretrained and init_cfg
|
||||
# pretrained=None, init_cfg=None
|
||||
model = PCPVT(pretrained=None, init_cfg=None)
|
||||
assert model.init_cfg is None
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
model = PCPVT(
|
||||
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg=123, whose type is unsupported
|
||||
model = PCPVT(pretrained=None, init_cfg=123)
|
||||
with pytest.raises(TypeError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg=None
|
||||
model = PCPVT(pretrained=path, init_cfg=None)
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = PCPVT(
|
||||
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
with pytest.raises(AssertionError):
|
||||
model = PCPVT(pretrained=path, init_cfg=123)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=None
|
||||
with pytest.raises(TypeError):
|
||||
model = PCPVT(pretrained=123, init_cfg=None)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = PCPVT(
|
||||
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=123, whose type is unsupported
|
||||
with pytest.raises(AssertionError):
|
||||
model = PCPVT(pretrained=123, init_cfg=123)
|
||||
|
||||
|
||||
def test_locallygrouped_self_attention_module():
|
||||
LSA = LocallyGroupedSelfAttention(embed_dims=32, window_size=3)
|
||||
outs = LSA(torch.randn(1, 3136, 32), (56, 56))
|
||||
assert outs.shape == torch.Size([1, 3136, 32])
|
||||
|
||||
|
||||
def test_conditional_position_encoding_module():
|
||||
CPE = ConditionalPositionEncoding(in_channels=32, embed_dims=32, stride=2)
|
||||
outs = CPE(torch.randn(1, 3136, 32), (56, 56))
|
||||
assert outs.shape == torch.Size([1, 784, 32])
|
||||
@@ -0,0 +1,825 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
|
||||
InterpConv, UNet, UpConvBlock)
|
||||
from mmseg.models.utils import Upsample
|
||||
from .utils import check_norm_state
|
||||
|
||||
init_default_scope('mmseg')
|
||||
|
||||
|
||||
def test_unet_basic_conv_block():
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||
BasicConvBlock(64, 64, dcn=dcn)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv3')
|
||||
]
|
||||
BasicConvBlock(64, 64, plugins=plugins)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0010',
|
||||
kv_stride=2),
|
||||
position='after_conv2')
|
||||
]
|
||||
BasicConvBlock(64, 64, plugins=plugins)
|
||||
|
||||
# test BasicConvBlock with checkpoint forward
|
||||
block = BasicConvBlock(16, 16, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 16, 64, 64, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 16, 64, 64])
|
||||
|
||||
block = BasicConvBlock(16, 16, with_cp=False)
|
||||
assert not block.with_cp
|
||||
x = torch.randn(1, 16, 64, 64)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 16, 64, 64])
|
||||
|
||||
# test BasicConvBlock with stride convolution to downsample
|
||||
block = BasicConvBlock(16, 16, stride=2)
|
||||
x = torch.randn(1, 16, 64, 64)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 16, 32, 32])
|
||||
|
||||
# test BasicConvBlock structure and forward
|
||||
block = BasicConvBlock(16, 64, num_convs=3, dilation=3)
|
||||
assert block.convs[0].conv.in_channels == 16
|
||||
assert block.convs[0].conv.out_channels == 64
|
||||
assert block.convs[0].conv.kernel_size == (3, 3)
|
||||
assert block.convs[0].conv.dilation == (1, 1)
|
||||
assert block.convs[0].conv.padding == (1, 1)
|
||||
|
||||
assert block.convs[1].conv.in_channels == 64
|
||||
assert block.convs[1].conv.out_channels == 64
|
||||
assert block.convs[1].conv.kernel_size == (3, 3)
|
||||
assert block.convs[1].conv.dilation == (3, 3)
|
||||
assert block.convs[1].conv.padding == (3, 3)
|
||||
|
||||
assert block.convs[2].conv.in_channels == 64
|
||||
assert block.convs[2].conv.out_channels == 64
|
||||
assert block.convs[2].conv.kernel_size == (3, 3)
|
||||
assert block.convs[2].conv.dilation == (3, 3)
|
||||
assert block.convs[2].conv.padding == (3, 3)
|
||||
|
||||
|
||||
def test_deconv_module():
|
||||
with pytest.raises(AssertionError):
|
||||
# kernel_size should be greater than or equal to scale_factor and
|
||||
# (kernel_size - scale_factor) should be even numbers
|
||||
DeconvModule(64, 32, kernel_size=1, scale_factor=2)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# kernel_size should be greater than or equal to scale_factor and
|
||||
# (kernel_size - scale_factor) should be even numbers
|
||||
DeconvModule(64, 32, kernel_size=3, scale_factor=2)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# kernel_size should be greater than or equal to scale_factor and
|
||||
# (kernel_size - scale_factor) should be even numbers
|
||||
DeconvModule(64, 32, kernel_size=5, scale_factor=4)
|
||||
|
||||
# test DeconvModule with checkpoint forward and upsample 2X.
|
||||
block = DeconvModule(64, 32, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 64, 128, 128, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
block = DeconvModule(64, 32, with_cp=False)
|
||||
assert not block.with_cp
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test DeconvModule with different kernel size for upsample 2X.
|
||||
x = torch.randn(1, 64, 64, 64)
|
||||
block = DeconvModule(64, 32, kernel_size=2, scale_factor=2)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 128, 128])
|
||||
|
||||
block = DeconvModule(64, 32, kernel_size=6, scale_factor=2)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 128, 128])
|
||||
|
||||
# test DeconvModule with different kernel size for upsample 4X.
|
||||
x = torch.randn(1, 64, 64, 64)
|
||||
block = DeconvModule(64, 32, kernel_size=4, scale_factor=4)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
block = DeconvModule(64, 32, kernel_size=6, scale_factor=4)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
|
||||
def test_interp_conv():
|
||||
# test InterpConv with checkpoint forward and upsample 2X.
|
||||
block = InterpConv(64, 32, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 64, 128, 128, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
block = InterpConv(64, 32, with_cp=False)
|
||||
assert not block.with_cp
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test InterpConv with conv_first=False for upsample 2X.
|
||||
block = InterpConv(64, 32, conv_first=False)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], Upsample)
|
||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test InterpConv with conv_first=True for upsample 2X.
|
||||
block = InterpConv(64, 32, conv_first=True)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], ConvModule)
|
||||
assert isinstance(block.interp_upsample[1], Upsample)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test InterpConv with bilinear upsample for upsample 2X.
|
||||
block = InterpConv(
|
||||
64,
|
||||
32,
|
||||
conv_first=False,
|
||||
upsample_cfg=dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False))
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], Upsample)
|
||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
assert block.interp_upsample[0].mode == 'bilinear'
|
||||
|
||||
# test InterpConv with nearest upsample for upsample 2X.
|
||||
block = InterpConv(
|
||||
64,
|
||||
32,
|
||||
conv_first=False,
|
||||
upsample_cfg=dict(scale_factor=2, mode='nearest'))
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], Upsample)
|
||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
assert block.interp_upsample[0].mode == 'nearest'
|
||||
|
||||
|
||||
def test_up_conv_block():
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||
UpConvBlock(BasicConvBlock, 64, 32, 32, dcn=dcn)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv3')
|
||||
]
|
||||
UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0010',
|
||||
kv_stride=2),
|
||||
position='after_conv2')
|
||||
]
|
||||
UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins)
|
||||
|
||||
# test UpConvBlock with checkpoint forward and upsample 2X.
|
||||
block = UpConvBlock(BasicConvBlock, 64, 32, 32, with_cp=True)
|
||||
skip_x = torch.randn(1, 32, 256, 256, requires_grad=True)
|
||||
x = torch.randn(1, 64, 128, 128, requires_grad=True)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test UpConvBlock with upsample=True for upsample 2X. The spatial size of
|
||||
# skip_x is 2X larger than x.
|
||||
block = UpConvBlock(
|
||||
BasicConvBlock, 64, 32, 32, upsample_cfg=dict(type='InterpConv'))
|
||||
skip_x = torch.randn(1, 32, 256, 256)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test UpConvBlock with upsample=False for upsample 2X. The spatial size of
|
||||
# skip_x is the same as that of x.
|
||||
block = UpConvBlock(BasicConvBlock, 64, 32, 32, upsample_cfg=None)
|
||||
skip_x = torch.randn(1, 32, 256, 256)
|
||||
x = torch.randn(1, 64, 256, 256)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test UpConvBlock with different upsample method for upsample 2X.
|
||||
# The upsample method is interpolation upsample (bilinear or nearest).
|
||||
block = UpConvBlock(
|
||||
BasicConvBlock,
|
||||
64,
|
||||
32,
|
||||
32,
|
||||
upsample_cfg=dict(
|
||||
type='InterpConv',
|
||||
upsample_cfg=dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)))
|
||||
skip_x = torch.randn(1, 32, 256, 256)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test UpConvBlock with different upsample method for upsample 2X.
|
||||
# The upsample method is deconvolution upsample.
|
||||
block = UpConvBlock(
|
||||
BasicConvBlock,
|
||||
64,
|
||||
32,
|
||||
32,
|
||||
upsample_cfg=dict(type='DeconvModule', kernel_size=4, scale_factor=2))
|
||||
skip_x = torch.randn(1, 32, 256, 256)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test BasicConvBlock structure and forward
|
||||
block = UpConvBlock(
|
||||
conv_block=BasicConvBlock,
|
||||
in_channels=64,
|
||||
skip_channels=32,
|
||||
out_channels=32,
|
||||
num_convs=3,
|
||||
dilation=3,
|
||||
upsample_cfg=dict(
|
||||
type='InterpConv',
|
||||
upsample_cfg=dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)))
|
||||
skip_x = torch.randn(1, 32, 256, 256)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
assert block.conv_block.convs[0].conv.in_channels == 64
|
||||
assert block.conv_block.convs[0].conv.out_channels == 32
|
||||
assert block.conv_block.convs[0].conv.kernel_size == (3, 3)
|
||||
assert block.conv_block.convs[0].conv.dilation == (1, 1)
|
||||
assert block.conv_block.convs[0].conv.padding == (1, 1)
|
||||
|
||||
assert block.conv_block.convs[1].conv.in_channels == 32
|
||||
assert block.conv_block.convs[1].conv.out_channels == 32
|
||||
assert block.conv_block.convs[1].conv.kernel_size == (3, 3)
|
||||
assert block.conv_block.convs[1].conv.dilation == (3, 3)
|
||||
assert block.conv_block.convs[1].conv.padding == (3, 3)
|
||||
|
||||
assert block.conv_block.convs[2].conv.in_channels == 32
|
||||
assert block.conv_block.convs[2].conv.out_channels == 32
|
||||
assert block.conv_block.convs[2].conv.kernel_size == (3, 3)
|
||||
assert block.conv_block.convs[2].conv.dilation == (3, 3)
|
||||
assert block.conv_block.convs[2].conv.padding == (3, 3)
|
||||
|
||||
assert block.upsample.interp_upsample[1].conv.in_channels == 64
|
||||
assert block.upsample.interp_upsample[1].conv.out_channels == 32
|
||||
assert block.upsample.interp_upsample[1].conv.kernel_size == (1, 1)
|
||||
assert block.upsample.interp_upsample[1].conv.dilation == (1, 1)
|
||||
assert block.upsample.interp_upsample[1].conv.padding == (0, 0)
|
||||
|
||||
|
||||
def test_unet():
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||
UNet(3, 64, 5, dcn=dcn)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv3')
|
||||
]
|
||||
UNet(3, 64, 5, plugins=plugins)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0010',
|
||||
kv_stride=2),
|
||||
position='after_conv2')
|
||||
]
|
||||
UNet(3, 64, 5, plugins=plugins)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check whether the input image size can be divisible by the whole
|
||||
# downsample rate of the encoder. The whole downsample rate of this
|
||||
# case is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=4,
|
||||
strides=(1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2),
|
||||
downsamples=(True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1))
|
||||
x = torch.randn(2, 3, 65, 65)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check whether the input image size can be divisible by the whole
|
||||
# downsample rate of the encoder. The whole downsample rate of this
|
||||
# case is 16.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 65, 65)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check whether the input image size can be divisible by the whole
|
||||
# downsample rate of the encoder. The whole downsample rate of this
|
||||
# case is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 65, 65)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check whether the input image size can be divisible by the whole
|
||||
# downsample rate of the encoder. The whole downsample rate of this
|
||||
# case is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 2, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 65, 65)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check whether the input image size can be divisible by the whole
|
||||
# downsample rate of the encoder. The whole downsample rate of this
|
||||
# case is 32.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=6,
|
||||
strides=(1, 1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 65, 65)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matches strides, len(strides)=num_stages
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matches strides, len(enc_num_convs)=num_stages
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matches strides, len(dec_num_convs)=num_stages-1
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matches strides, len(downsamples)=num_stages-1
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matches strides, len(enc_dilations)=num_stages
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matches strides, len(dec_dilations)=num_stages-1
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
# test UNet norm_eval=True
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
norm_eval=True)
|
||||
unet.train()
|
||||
assert check_norm_state(unet.modules(), False)
|
||||
|
||||
# test UNet norm_eval=False
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
norm_eval=False)
|
||||
unet.train()
|
||||
assert check_norm_state(unet.modules(), True)
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 16.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 8, 8])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 16, 16])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 2, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 16, 16])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 4.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 32, 32])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 32, 32])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 4.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 32, 32])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 32, 32])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 16, 16])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 4.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 32, 32])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 32, 32])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 2.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, False, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 64, 64])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 64, 64])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 64, 64])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 1.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(False, False, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 128, 128])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 128, 128])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 128, 128])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 128, 128])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 16.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 8, 8])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 16, 16])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 2, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 16, 16])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 4.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 32, 32])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 32, 32])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
|
||||
# test UNet init_weights method.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=4,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
pretrained=None)
|
||||
unet.init_weights()
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 64, 32, 32])
|
||||
assert x_outs[1].shape == torch.Size([2, 32, 32, 32])
|
||||
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
||||
@@ -0,0 +1,185 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones.vit import (TransformerEncoderLayer,
|
||||
VisionTransformer)
|
||||
from .utils import check_norm_state
|
||||
|
||||
|
||||
def test_vit_backbone():
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be a string path
|
||||
model = VisionTransformer()
|
||||
model.init_weights(pretrained=0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# img_size must be int or tuple
|
||||
model = VisionTransformer(img_size=512.0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# out_indices must be int ,list or tuple
|
||||
model = VisionTransformer(out_indices=1.)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# test upsample_pos_embed function
|
||||
x = torch.randn(1, 196)
|
||||
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# The length of img_size tuple must be lower than 3.
|
||||
VisionTransformer(img_size=(224, 224, 224))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Pretrained must be None or Str.
|
||||
VisionTransformer(pretrained=123)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# with_cls_token must be True when output_cls_token == True
|
||||
VisionTransformer(with_cls_token=False, output_cls_token=True)
|
||||
|
||||
# Test img_size isinstance tuple
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = VisionTransformer(img_size=(224, ))
|
||||
model.init_weights()
|
||||
model(imgs)
|
||||
|
||||
# Test img_size isinstance tuple
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
model = VisionTransformer(img_size=(224, 224))
|
||||
model(imgs)
|
||||
|
||||
# Test norm_eval = True
|
||||
model = VisionTransformer(norm_eval=True)
|
||||
model.train()
|
||||
|
||||
# Test ViT backbone with input size of 224 and patch size of 16
|
||||
model = VisionTransformer()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
# Test normal size input image
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test large size input image
|
||||
imgs = torch.randn(1, 3, 256, 256)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 16, 16)
|
||||
|
||||
# Test small size input image
|
||||
imgs = torch.randn(1, 3, 32, 32)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 2, 2)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test unbalanced size input image
|
||||
imgs = torch.randn(1, 3, 112, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 7, 14)
|
||||
|
||||
# Test irregular input image
|
||||
imgs = torch.randn(1, 3, 234, 345)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 15, 22)
|
||||
|
||||
# Test with_cp=True
|
||||
model = VisionTransformer(with_cp=True)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test with_cls_token=False
|
||||
model = VisionTransformer(with_cls_token=False)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test final norm
|
||||
model = VisionTransformer(final_norm=True)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test patch norm
|
||||
model = VisionTransformer(patch_norm=True)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[-1].shape == (1, 768, 14, 14)
|
||||
|
||||
# Test output_cls_token
|
||||
model = VisionTransformer(with_cls_token=True, output_cls_token=True)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[0][0].shape == (1, 768, 14, 14)
|
||||
assert feat[0][1].shape == (1, 768)
|
||||
|
||||
# Test TransformerEncoderLayer with checkpoint forward
|
||||
block = TransformerEncoderLayer(
|
||||
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 56 * 56, 64)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
||||
|
||||
|
||||
def test_vit_init():
|
||||
path = 'PATH_THAT_DO_NOT_EXIST'
|
||||
# Test all combinations of pretrained and init_cfg
|
||||
# pretrained=None, init_cfg=None
|
||||
model = VisionTransformer(pretrained=None, init_cfg=None)
|
||||
assert model.init_cfg is None
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
model = VisionTransformer(
|
||||
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg=123, whose type is unsupported
|
||||
model = VisionTransformer(pretrained=None, init_cfg=123)
|
||||
with pytest.raises(TypeError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg=None
|
||||
model = VisionTransformer(pretrained=path, init_cfg=None)
|
||||
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
||||
# Test loading a checkpoint from an non-existent file
|
||||
with pytest.raises(OSError):
|
||||
model.init_weights()
|
||||
|
||||
# pretrained loads pretrain from an non-existent file
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = VisionTransformer(
|
||||
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
with pytest.raises(AssertionError):
|
||||
model = VisionTransformer(pretrained=path, init_cfg=123)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=None
|
||||
with pytest.raises(TypeError):
|
||||
model = VisionTransformer(pretrained=123, init_cfg=None)
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg loads pretrain from an non-existent file
|
||||
with pytest.raises(AssertionError):
|
||||
model = VisionTransformer(
|
||||
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
||||
|
||||
# pretrain=123, whose type is unsupported
|
||||
# init_cfg=123, whose type is unsupported
|
||||
with pytest.raises(AssertionError):
|
||||
model = VisionTransformer(pretrained=123, init_cfg=123)
|
||||
@@ -0,0 +1,51 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from os.path import dirname, join
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine import Config
|
||||
|
||||
import mmseg
|
||||
from mmseg.models.backbones import VPD
|
||||
|
||||
|
||||
class TestVPD(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
||||
repo_dpath = dirname(dirname(mmseg.__file__))
|
||||
config_dpath = join(repo_dpath, 'configs/_base_/models/vpd_sd.py')
|
||||
vpd_cfg = Config.fromfile(config_dpath).stable_diffusion_cfg
|
||||
vpd_cfg.pop('checkpoint')
|
||||
|
||||
self.vpd_model = VPD(
|
||||
diffusion_cfg=vpd_cfg,
|
||||
class_embed_path='https://download.openmmlab.com/mmsegmentation/'
|
||||
'v0.5/vpd/nyu_class_embeddings.pth',
|
||||
class_embed_select=True,
|
||||
pad_shape=64,
|
||||
unet_cfg=dict(use_attn=False),
|
||||
)
|
||||
|
||||
def test_forward(self):
|
||||
# test forward without class_id
|
||||
x = torch.randn(1, 3, 60, 60)
|
||||
with torch.no_grad():
|
||||
out = self.vpd_model(x)
|
||||
|
||||
self.assertEqual(len(out), 4)
|
||||
self.assertListEqual(list(out[0].shape), [1, 320, 8, 8])
|
||||
self.assertListEqual(list(out[1].shape), [1, 640, 4, 4])
|
||||
self.assertListEqual(list(out[2].shape), [1, 1280, 2, 2])
|
||||
self.assertListEqual(list(out[3].shape), [1, 1280, 1, 1])
|
||||
|
||||
# test forward with class_id
|
||||
x = torch.randn(1, 3, 60, 60)
|
||||
with torch.no_grad():
|
||||
out = self.vpd_model((x, torch.tensor([2])))
|
||||
|
||||
self.assertEqual(len(out), 4)
|
||||
self.assertListEqual(list(out[0].shape), [1, 320, 8, 8])
|
||||
self.assertListEqual(list(out[1].shape), [1, 640, 4, 4])
|
||||
self.assertListEqual(list(out[2].shape), [1, 1280, 2, 2])
|
||||
self.assertListEqual(list(out[3].shape), [1, 1280, 1, 1])
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
||||
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
|
||||
|
||||
|
||||
def is_block(modules):
|
||||
"""Check if is ResNet building block."""
|
||||
if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def all_zeros(modules):
|
||||
"""Check if the weight(and bias) is all zero."""
|
||||
weight_zero = torch.allclose(modules.weight.data,
|
||||
torch.zeros_like(modules.weight.data))
|
||||
if hasattr(modules, 'bias'):
|
||||
bias_zero = torch.allclose(modules.bias.data,
|
||||
torch.zeros_like(modules.bias.data))
|
||||
else:
|
||||
bias_zero = True
|
||||
|
||||
return weight_zero and bias_zero
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.models import SegDataPreProcessor
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
class TestSegDataPreProcessor(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
# test mean is None
|
||||
processor = SegDataPreProcessor()
|
||||
self.assertTrue(not hasattr(processor, 'mean'))
|
||||
self.assertTrue(processor._enable_normalize is False)
|
||||
|
||||
# test mean is not None
|
||||
processor = SegDataPreProcessor(mean=[0, 0, 0], std=[1, 1, 1])
|
||||
self.assertTrue(hasattr(processor, 'mean'))
|
||||
self.assertTrue(hasattr(processor, 'std'))
|
||||
self.assertTrue(processor._enable_normalize)
|
||||
|
||||
# please specify both mean and std
|
||||
with self.assertRaises(AssertionError):
|
||||
SegDataPreProcessor(mean=[0, 0, 0])
|
||||
|
||||
# bgr2rgb and rgb2bgr cannot be set to True at the same time
|
||||
with self.assertRaises(AssertionError):
|
||||
SegDataPreProcessor(bgr_to_rgb=True, rgb_to_bgr=True)
|
||||
|
||||
def test_forward(self):
|
||||
data_sample = SegDataSample()
|
||||
data_sample.gt_sem_seg = PixelData(
|
||||
**{'data': torch.randint(0, 10, (1, 11, 10))})
|
||||
processor = SegDataPreProcessor(
|
||||
mean=[0, 0, 0], std=[1, 1, 1], size=(20, 20))
|
||||
data = {
|
||||
'inputs': [
|
||||
torch.randint(0, 256, (3, 11, 10)),
|
||||
torch.randint(0, 256, (3, 11, 10))
|
||||
],
|
||||
'data_samples': [data_sample, data_sample]
|
||||
}
|
||||
out = processor(data, training=True)
|
||||
self.assertEqual(out['inputs'].shape, (2, 3, 20, 20))
|
||||
self.assertEqual(len(out['data_samples']), 2)
|
||||
|
||||
# test predict with padding
|
||||
processor = SegDataPreProcessor(
|
||||
mean=[0, 0, 0],
|
||||
std=[1, 1, 1],
|
||||
size=(20, 20),
|
||||
test_cfg=dict(size_divisor=15))
|
||||
data = {
|
||||
'inputs': [
|
||||
torch.randint(0, 256, (3, 11, 10)),
|
||||
],
|
||||
'data_samples': [data_sample]
|
||||
}
|
||||
out = processor(data, training=False)
|
||||
self.assertEqual(out['inputs'].shape[2] % 15, 0)
|
||||
self.assertEqual(out['inputs'].shape[3] % 15, 0)
|
||||
229
Seg_All_In_One_MMSeg/tests/test_models/test_forward.py
Normal file
229
Seg_All_In_One_MMSeg/tests/test_models/test_forward.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""pytest tests/test_forward.py."""
|
||||
import copy
|
||||
from os.path import dirname, exists, join
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model.utils import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.structures import PixelData
|
||||
from mmengine.utils import is_list_of, is_tuple_of
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
init_default_scope('mmseg')
|
||||
|
||||
|
||||
def _demo_mm_inputs(batch_size=2, image_shapes=(3, 32, 32), num_classes=5):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size. Default to 2.
|
||||
image_shapes (List[tuple], Optional): image shape.
|
||||
Default to (3, 128, 128)
|
||||
num_classes (int): number of different labels a
|
||||
box might have. Default to 10.
|
||||
"""
|
||||
if isinstance(image_shapes, list):
|
||||
assert len(image_shapes) == batch_size
|
||||
else:
|
||||
image_shapes = [image_shapes] * batch_size
|
||||
|
||||
inputs = []
|
||||
data_samples = []
|
||||
for idx in range(batch_size):
|
||||
image_shape = image_shapes[idx]
|
||||
c, h, w = image_shape
|
||||
image = np.random.randint(0, 255, size=image_shape, dtype=np.uint8)
|
||||
|
||||
mm_input = torch.from_numpy(image)
|
||||
|
||||
img_meta = {
|
||||
'img_id': idx,
|
||||
'img_shape': image_shape[1:],
|
||||
'ori_shape': image_shape[1:],
|
||||
'pad_shape': image_shape[1:],
|
||||
'filename': '<demo>.png',
|
||||
'scale_factor': 1.0,
|
||||
'flip': False,
|
||||
'flip_direction': None,
|
||||
}
|
||||
|
||||
data_sample = SegDataSample()
|
||||
data_sample.set_metainfo(img_meta)
|
||||
|
||||
gt_semantic_seg = np.random.randint(
|
||||
0, num_classes, (1, h, w), dtype=np.uint8)
|
||||
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
|
||||
gt_sem_seg_data = dict(data=gt_semantic_seg)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
inputs.append(mm_input)
|
||||
data_samples.append(data_sample)
|
||||
return dict(inputs=inputs, data_samples=data_samples)
|
||||
|
||||
|
||||
def _get_config_directory():
|
||||
"""Find the predefined segmentor config directory."""
|
||||
try:
|
||||
# Assume we are running in the source mmsegmentation repo
|
||||
repo_dpath = dirname(dirname(dirname(__file__)))
|
||||
except NameError:
|
||||
# For IPython development when this __file__ is not defined
|
||||
import mmseg
|
||||
repo_dpath = dirname(dirname(dirname(mmseg.__file__)))
|
||||
config_dpath = join(repo_dpath, 'configs')
|
||||
if not exists(config_dpath):
|
||||
raise Exception('Cannot find config path')
|
||||
return config_dpath
|
||||
|
||||
|
||||
def _get_config_module(fname):
|
||||
"""Load a configuration as a python module."""
|
||||
from mmengine import Config
|
||||
config_dpath = _get_config_directory()
|
||||
config_fpath = join(config_dpath, fname)
|
||||
config_mod = Config.fromfile(config_fpath)
|
||||
return config_mod
|
||||
|
||||
|
||||
def _get_segmentor_cfg(fname):
|
||||
"""Grab configs necessary to create a segmentor.
|
||||
|
||||
These are deep copied to allow for safe modification of parameters without
|
||||
influencing other tests.
|
||||
"""
|
||||
config = _get_config_module(fname)
|
||||
model = copy.deepcopy(config.model)
|
||||
return model
|
||||
|
||||
|
||||
def test_pspnet_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'pspnet/pspnet_r18-d8_4xb2-80k_cityscapes-512x1024.py')
|
||||
|
||||
|
||||
def test_fcn_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'fcn/fcn_r18-d8_4xb2-80k_cityscapes-512x1024.py')
|
||||
|
||||
|
||||
def test_deeplabv3_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'deeplabv3/deeplabv3_r18-d8_4xb2-80k_cityscapes-512x1024.py')
|
||||
|
||||
|
||||
def test_deeplabv3plus_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'deeplabv3plus/deeplabv3plus_r18-d8_4xb2-80k_cityscapes-512x1024.py')
|
||||
|
||||
|
||||
def test_gcnet_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'gcnet/gcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py')
|
||||
|
||||
|
||||
def test_ccnet_forward():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip('CCNet requires CUDA')
|
||||
_test_encoder_decoder_forward(
|
||||
'ccnet/ccnet_r50-d8_4xb2-40k_cityscapes-512x1024.py')
|
||||
|
||||
|
||||
def test_upernet_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'upernet/upernet_r50_4xb2-40k_cityscapes-512x1024.py')
|
||||
|
||||
|
||||
def test_hrnet_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'hrnet/fcn_hr18s_4xb2-40k_cityscapes-512x1024.py')
|
||||
|
||||
|
||||
def test_ocrnet_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'ocrnet/ocrnet_hr18s_4xb2-40k_cityscapes-512x1024.py')
|
||||
|
||||
|
||||
def test_sem_fpn_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'sem_fpn/fpn_r50_4xb2-80k_cityscapes-512x1024.py')
|
||||
|
||||
|
||||
def test_mobilenet_v2_forward():
|
||||
_test_encoder_decoder_forward(
|
||||
'mobilenet_v2/mobilenet-v2-d8_pspnet_4xb2-80k_cityscapes-512x1024.py')
|
||||
|
||||
|
||||
def get_world_size(process_group):
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
def _check_input_dim(self, inputs):
|
||||
pass
|
||||
|
||||
|
||||
@patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim',
|
||||
_check_input_dim)
|
||||
@patch('torch.distributed.get_world_size', get_world_size)
|
||||
def _test_encoder_decoder_forward(cfg_file):
|
||||
model = _get_segmentor_cfg(cfg_file)
|
||||
model['pretrained'] = None
|
||||
model['test_cfg']['mode'] = 'whole'
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
segmentor = build_segmentor(model)
|
||||
segmentor.init_weights()
|
||||
|
||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||
num_classes = segmentor.decode_head[-1].num_classes
|
||||
else:
|
||||
num_classes = segmentor.decode_head.num_classes
|
||||
# batch_size=2 for BatchNorm
|
||||
packed_inputs = _demo_mm_inputs(
|
||||
batch_size=2, image_shapes=(3, 4, 4), num_classes=num_classes)
|
||||
# convert to cuda Tensor if applicable
|
||||
if torch.cuda.is_available():
|
||||
segmentor = segmentor.cuda()
|
||||
else:
|
||||
segmentor = revert_sync_batchnorm(segmentor)
|
||||
|
||||
# Test forward train
|
||||
data = segmentor.data_preprocessor(packed_inputs, True)
|
||||
losses = segmentor.forward(**data, mode='loss')
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
packed_inputs = _demo_mm_inputs(
|
||||
batch_size=1, image_shapes=(3, 32, 32), num_classes=num_classes)
|
||||
data = segmentor.data_preprocessor(packed_inputs, False)
|
||||
with torch.no_grad():
|
||||
segmentor.eval()
|
||||
# Test forward predict
|
||||
batch_results = segmentor.forward(**data, mode='predict')
|
||||
assert len(batch_results) == 1
|
||||
assert is_list_of(batch_results, SegDataSample)
|
||||
assert batch_results[0].pred_sem_seg.shape == (32, 32)
|
||||
assert batch_results[0].seg_logits.data.shape == (num_classes, 32, 32)
|
||||
assert batch_results[0].gt_sem_seg.shape == (32, 32)
|
||||
|
||||
# Test forward tensor
|
||||
batch_results = segmentor.forward(**data, mode='tensor')
|
||||
assert isinstance(batch_results, Tensor) or is_tuple_of(
|
||||
batch_results, Tensor)
|
||||
|
||||
# Test forward predict without ground truth
|
||||
data.pop('data_samples')
|
||||
batch_results = segmentor.forward(**data, mode='predict')
|
||||
assert len(batch_results) == 1
|
||||
assert is_list_of(batch_results, SegDataSample)
|
||||
assert batch_results[0].pred_sem_seg.shape == (32, 32)
|
||||
|
||||
# Test forward tensor without ground truth
|
||||
batch_results = segmentor.forward(**data, mode='tensor')
|
||||
assert isinstance(batch_results, Tensor) or is_tuple_of(
|
||||
batch_results, Tensor)
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import ANNHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_ann_head():
|
||||
|
||||
inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 8, 21, 21)]
|
||||
head = ANNHead(
|
||||
in_channels=[4, 8],
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
in_index=[-2, -1],
|
||||
project_channels=8)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 21, 21)
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import APCHead
|
||||
from .utils import _conv_has_norm, to_cuda
|
||||
|
||||
|
||||
def test_apc_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# pool_scales must be list|tuple
|
||||
APCHead(in_channels=8, channels=2, num_classes=19, pool_scales=1)
|
||||
|
||||
# test no norm_cfg
|
||||
head = APCHead(in_channels=8, channels=2, num_classes=19)
|
||||
assert not _conv_has_norm(head, sync_bn=False)
|
||||
|
||||
# test with norm_cfg
|
||||
head = APCHead(
|
||||
in_channels=8,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='SyncBN'))
|
||||
assert _conv_has_norm(head, sync_bn=True)
|
||||
|
||||
# fusion=True
|
||||
inputs = [torch.randn(1, 8, 45, 45)]
|
||||
head = APCHead(
|
||||
in_channels=8,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
pool_scales=(1, 2, 3),
|
||||
fusion=True)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.fusion is True
|
||||
assert head.acm_modules[0].pool_scale == 1
|
||||
assert head.acm_modules[1].pool_scale == 2
|
||||
assert head.acm_modules[2].pool_scale == 3
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
# fusion=False
|
||||
inputs = [torch.randn(1, 8, 45, 45)]
|
||||
head = APCHead(
|
||||
in_channels=8,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
pool_scales=(1, 2, 3),
|
||||
fusion=False)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.fusion is False
|
||||
assert head.acm_modules[0].pool_scale == 1
|
||||
assert head.acm_modules[1].pool_scale == 2
|
||||
assert head.acm_modules[2].pool_scale == 3
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import ASPPHead, DepthwiseSeparableASPPHead
|
||||
from .utils import _conv_has_norm, to_cuda
|
||||
|
||||
|
||||
def test_aspp_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# pool_scales must be list|tuple
|
||||
ASPPHead(in_channels=8, channels=4, num_classes=19, dilations=1)
|
||||
|
||||
# test no norm_cfg
|
||||
head = ASPPHead(in_channels=8, channels=4, num_classes=19)
|
||||
assert not _conv_has_norm(head, sync_bn=False)
|
||||
|
||||
# test with norm_cfg
|
||||
head = ASPPHead(
|
||||
in_channels=8,
|
||||
channels=4,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='SyncBN'))
|
||||
assert _conv_has_norm(head, sync_bn=True)
|
||||
|
||||
inputs = [torch.randn(1, 8, 45, 45)]
|
||||
head = ASPPHead(
|
||||
in_channels=8, channels=4, num_classes=19, dilations=(1, 12, 24))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.aspp_modules[0].conv.dilation == (1, 1)
|
||||
assert head.aspp_modules[1].conv.dilation == (12, 12)
|
||||
assert head.aspp_modules[2].conv.dilation == (24, 24)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
|
||||
def test_dw_aspp_head():
|
||||
|
||||
# test w.o. c1
|
||||
inputs = [torch.randn(1, 8, 45, 45)]
|
||||
head = DepthwiseSeparableASPPHead(
|
||||
c1_in_channels=0,
|
||||
c1_channels=0,
|
||||
in_channels=8,
|
||||
channels=4,
|
||||
num_classes=19,
|
||||
dilations=(1, 12, 24))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.c1_bottleneck is None
|
||||
assert head.aspp_modules[0].conv.dilation == (1, 1)
|
||||
assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12)
|
||||
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
# test with c1
|
||||
inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 16, 21, 21)]
|
||||
head = DepthwiseSeparableASPPHead(
|
||||
c1_in_channels=4,
|
||||
c1_channels=2,
|
||||
in_channels=16,
|
||||
channels=8,
|
||||
num_classes=19,
|
||||
dilations=(1, 12, 24))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.c1_bottleneck.in_channels == 4
|
||||
assert head.c1_bottleneck.out_channels == 2
|
||||
assert head.aspp_modules[0].conv.dilation == (1, 1)
|
||||
assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12)
|
||||
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import CCHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_cc_head():
|
||||
head = CCHead(in_channels=16, channels=8, num_classes=19)
|
||||
assert len(head.convs) == 2
|
||||
assert hasattr(head, 'cca')
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip('CCHead requires CUDA')
|
||||
inputs = [torch.randn(1, 16, 23, 23)]
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
@@ -0,0 +1,193 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.structures import SegDataSample
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
@patch.multiple(BaseDecodeHead, __abstractmethods__=set())
|
||||
def test_decode_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# default input_transform doesn't accept multiple inputs
|
||||
BaseDecodeHead([32, 16], 16, num_classes=19)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# default input_transform doesn't accept multiple inputs
|
||||
BaseDecodeHead(32, 16, num_classes=19, in_index=[-1, -2])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# supported mode is resize_concat only
|
||||
BaseDecodeHead(32, 16, num_classes=19, input_transform='concat')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# in_channels should be list|tuple
|
||||
BaseDecodeHead(32, 16, num_classes=19, input_transform='resize_concat')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# in_index should be list|tuple
|
||||
BaseDecodeHead([32],
|
||||
16,
|
||||
in_index=-1,
|
||||
num_classes=19,
|
||||
input_transform='resize_concat')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# len(in_index) should equal len(in_channels)
|
||||
BaseDecodeHead([32, 16],
|
||||
16,
|
||||
num_classes=19,
|
||||
in_index=[-1],
|
||||
input_transform='resize_concat')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# out_channels should be equal to num_classes
|
||||
BaseDecodeHead(32, 16, num_classes=19, out_channels=18)
|
||||
|
||||
# test out_channels
|
||||
head = BaseDecodeHead(32, 16, num_classes=2)
|
||||
assert head.out_channels == 2
|
||||
|
||||
# test out_channels == 1 and num_classes == 2
|
||||
head = BaseDecodeHead(32, 16, num_classes=2, out_channels=1)
|
||||
assert head.out_channels == 1 and head.num_classes == 2
|
||||
|
||||
# test default dropout
|
||||
head = BaseDecodeHead(32, 16, num_classes=19)
|
||||
assert hasattr(head, 'dropout') and head.dropout.p == 0.1
|
||||
|
||||
# test set dropout
|
||||
head = BaseDecodeHead(32, 16, num_classes=19, dropout_ratio=0.2)
|
||||
assert hasattr(head, 'dropout') and head.dropout.p == 0.2
|
||||
|
||||
# test no input_transform
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
head = BaseDecodeHead(32, 16, num_classes=19)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.in_channels == 32
|
||||
assert head.input_transform is None
|
||||
transformed_inputs = head._transform_inputs(inputs)
|
||||
assert transformed_inputs.shape == (1, 32, 45, 45)
|
||||
|
||||
# test input_transform = resize_concat
|
||||
inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)]
|
||||
head = BaseDecodeHead([32, 16],
|
||||
16,
|
||||
num_classes=19,
|
||||
in_index=[0, 1],
|
||||
input_transform='resize_concat')
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.in_channels == 48
|
||||
assert head.input_transform == 'resize_concat'
|
||||
transformed_inputs = head._transform_inputs(inputs)
|
||||
assert transformed_inputs.shape == (1, 48, 45, 45)
|
||||
|
||||
# test multi-loss, loss_decode is dict
|
||||
with pytest.raises(TypeError):
|
||||
# loss_decode must be a dict or sequence of dict.
|
||||
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
|
||||
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
data_samples = [
|
||||
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
|
||||
for _ in range(2)
|
||||
]
|
||||
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
loss = head.loss_by_feat(
|
||||
seg_logits=inputs, batch_data_samples=data_samples)
|
||||
assert 'loss_ce' in loss
|
||||
|
||||
# test multi-loss, loss_decode is list of dict
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
data_samples = [
|
||||
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
|
||||
for _ in range(2)
|
||||
]
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=[
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_1'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_2')
|
||||
])
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
|
||||
loss = head.loss_by_feat(
|
||||
seg_logits=inputs, batch_data_samples=data_samples)
|
||||
assert 'loss_1' in loss
|
||||
assert 'loss_2' in loss
|
||||
|
||||
# 'loss_decode' must be a dict or sequence of dict
|
||||
with pytest.raises(TypeError):
|
||||
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
|
||||
with pytest.raises(TypeError):
|
||||
BaseDecodeHead(3, 16, num_classes=19, loss_decode=0)
|
||||
|
||||
# test multi-loss, loss_decode is list of dict
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
data_samples = [
|
||||
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
|
||||
for _ in range(2)
|
||||
]
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_2'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_3')))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
loss = head.loss_by_feat(
|
||||
seg_logits=inputs, batch_data_samples=data_samples)
|
||||
assert 'loss_1' in loss
|
||||
assert 'loss_2' in loss
|
||||
assert 'loss_3' in loss
|
||||
|
||||
# test multi-loss, loss_decode is list of dict, names of them are identical
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
data_samples = [
|
||||
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
|
||||
for _ in range(2)
|
||||
]
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
loss_3 = head.loss_by_feat(
|
||||
seg_logits=inputs, batch_data_samples=data_samples)
|
||||
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
loss = head.loss_by_feat(
|
||||
seg_logits=inputs, batch_data_samples=data_samples)
|
||||
assert 'loss_ce' in loss
|
||||
assert 'loss_ce' in loss_3
|
||||
assert loss_3['loss_ce'] == 3 * loss['loss_ce']
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import DMHead
|
||||
from .utils import _conv_has_norm, to_cuda
|
||||
|
||||
|
||||
def test_dm_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# filter_sizes must be list|tuple
|
||||
DMHead(in_channels=8, channels=4, num_classes=19, filter_sizes=1)
|
||||
|
||||
# test no norm_cfg
|
||||
head = DMHead(in_channels=8, channels=4, num_classes=19)
|
||||
assert not _conv_has_norm(head, sync_bn=False)
|
||||
|
||||
# test with norm_cfg
|
||||
head = DMHead(
|
||||
in_channels=8,
|
||||
channels=4,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='SyncBN'))
|
||||
assert _conv_has_norm(head, sync_bn=True)
|
||||
|
||||
# fusion=True
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
head = DMHead(
|
||||
in_channels=8,
|
||||
channels=4,
|
||||
num_classes=19,
|
||||
filter_sizes=(1, 3, 5),
|
||||
fusion=True)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.fusion is True
|
||||
assert head.dcm_modules[0].filter_size == 1
|
||||
assert head.dcm_modules[1].filter_size == 3
|
||||
assert head.dcm_modules[2].filter_size == 5
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
|
||||
# fusion=False
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
head = DMHead(
|
||||
in_channels=8,
|
||||
channels=4,
|
||||
num_classes=19,
|
||||
filter_sizes=(1, 3, 5),
|
||||
fusion=False)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.fusion is False
|
||||
assert head.dcm_modules[0].filter_size == 1
|
||||
assert head.dcm_modules[1].filter_size == 3
|
||||
assert head.dcm_modules[2].filter_size == 5
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import DNLHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_dnl_head():
|
||||
# DNL with 'embedded_gaussian' mode
|
||||
head = DNLHead(in_channels=8, channels=4, num_classes=19)
|
||||
assert len(head.convs) == 2
|
||||
assert hasattr(head, 'dnl_block')
|
||||
assert head.dnl_block.temperature == 0.05
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
|
||||
# NonLocal2d with 'dot_product' mode
|
||||
head = DNLHead(
|
||||
in_channels=8, channels=4, num_classes=19, mode='dot_product')
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
|
||||
# NonLocal2d with 'gaussian' mode
|
||||
head = DNLHead(in_channels=8, channels=4, num_classes=19, mode='gaussian')
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
|
||||
# NonLocal2d with 'concatenation' mode
|
||||
head = DNLHead(
|
||||
in_channels=8, channels=4, num_classes=19, mode='concatenation')
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import DPTHead
|
||||
|
||||
|
||||
def test_dpt_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# input_transform must be 'multiple_select'
|
||||
head = DPTHead(
|
||||
in_channels=[768, 768, 768, 768],
|
||||
channels=4,
|
||||
num_classes=19,
|
||||
in_index=[0, 1, 2, 3])
|
||||
|
||||
head = DPTHead(
|
||||
in_channels=[768, 768, 768, 768],
|
||||
channels=4,
|
||||
num_classes=19,
|
||||
in_index=[0, 1, 2, 3],
|
||||
input_transform='multiple_select')
|
||||
|
||||
inputs = [[torch.randn(4, 768, 2, 2),
|
||||
torch.randn(4, 768)] for _ in range(4)]
|
||||
output = head(inputs)
|
||||
assert output.shape == torch.Size((4, 19, 16, 16))
|
||||
|
||||
# test readout operation
|
||||
head = DPTHead(
|
||||
in_channels=[768, 768, 768, 768],
|
||||
channels=4,
|
||||
num_classes=19,
|
||||
in_index=[0, 1, 2, 3],
|
||||
input_transform='multiple_select',
|
||||
readout_type='add')
|
||||
output = head(inputs)
|
||||
assert output.shape == torch.Size((4, 19, 16, 16))
|
||||
|
||||
head = DPTHead(
|
||||
in_channels=[768, 768, 768, 768],
|
||||
channels=4,
|
||||
num_classes=19,
|
||||
in_index=[0, 1, 2, 3],
|
||||
input_transform='multiple_select',
|
||||
readout_type='project')
|
||||
output = head(inputs)
|
||||
assert output.shape == torch.Size((4, 19, 16, 16))
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import EMAHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_emanet_head():
|
||||
head = EMAHead(
|
||||
in_channels=4,
|
||||
ema_channels=3,
|
||||
channels=2,
|
||||
num_stages=3,
|
||||
num_bases=2,
|
||||
num_classes=19)
|
||||
for param in head.ema_mid_conv.parameters():
|
||||
assert not param.requires_grad
|
||||
assert hasattr(head, 'ema_module')
|
||||
inputs = [torch.randn(1, 4, 23, 23)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import SyncBatchNorm
|
||||
|
||||
from mmseg.models.decode_heads import DepthwiseSeparableFCNHead, FCNHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_fcn_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# num_convs must be not less than 0
|
||||
FCNHead(num_classes=19, num_convs=-1)
|
||||
|
||||
# test no norm_cfg
|
||||
head = FCNHead(in_channels=8, channels=4, num_classes=19)
|
||||
for m in head.modules():
|
||||
if isinstance(m, ConvModule):
|
||||
assert not m.with_norm
|
||||
|
||||
# test with norm_cfg
|
||||
head = FCNHead(
|
||||
in_channels=8,
|
||||
channels=4,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='SyncBN'))
|
||||
for m in head.modules():
|
||||
if isinstance(m, ConvModule):
|
||||
assert m.with_norm and isinstance(m.bn, SyncBatchNorm)
|
||||
|
||||
# test concat_input=False
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
head = FCNHead(
|
||||
in_channels=8, channels=4, num_classes=19, concat_input=False)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert len(head.convs) == 2
|
||||
assert not head.concat_input and not hasattr(head, 'conv_cat')
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
|
||||
# test concat_input=True
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
head = FCNHead(
|
||||
in_channels=8, channels=4, num_classes=19, concat_input=True)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert len(head.convs) == 2
|
||||
assert head.concat_input
|
||||
assert head.conv_cat.in_channels == 12
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
|
||||
# test kernel_size=3
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
head = FCNHead(in_channels=8, channels=4, num_classes=19)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
for i in range(len(head.convs)):
|
||||
assert head.convs[i].kernel_size == (3, 3)
|
||||
assert head.convs[i].padding == 1
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
|
||||
# test kernel_size=1
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
head = FCNHead(in_channels=8, channels=4, num_classes=19, kernel_size=1)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
for i in range(len(head.convs)):
|
||||
assert head.convs[i].kernel_size == (1, 1)
|
||||
assert head.convs[i].padding == 0
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
|
||||
# test num_conv
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
head = FCNHead(in_channels=8, channels=4, num_classes=19, num_convs=1)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert len(head.convs) == 1
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
|
||||
# test num_conv = 0
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
head = FCNHead(
|
||||
in_channels=8,
|
||||
channels=8,
|
||||
num_classes=19,
|
||||
num_convs=0,
|
||||
concat_input=False)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert isinstance(head.convs, torch.nn.Identity)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
|
||||
|
||||
def test_sep_fcn_head():
|
||||
# test sep_fcn_head with concat_input=False
|
||||
head = DepthwiseSeparableFCNHead(
|
||||
in_channels=128,
|
||||
channels=128,
|
||||
concat_input=False,
|
||||
num_classes=19,
|
||||
in_index=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01))
|
||||
x = [torch.rand(2, 128, 8, 8)]
|
||||
output = head(x)
|
||||
assert output.shape == (2, head.num_classes, 8, 8)
|
||||
assert not head.concat_input
|
||||
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
|
||||
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
|
||||
assert head.conv_seg.kernel_size == (1, 1)
|
||||
|
||||
head = DepthwiseSeparableFCNHead(
|
||||
in_channels=64,
|
||||
channels=64,
|
||||
concat_input=True,
|
||||
num_classes=19,
|
||||
in_index=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01))
|
||||
x = [torch.rand(3, 64, 8, 8)]
|
||||
output = head(x)
|
||||
assert output.shape == (3, head.num_classes, 8, 8)
|
||||
assert head.concat_input
|
||||
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
|
||||
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import GCHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_gc_head():
|
||||
head = GCHead(in_channels=4, channels=4, num_classes=19)
|
||||
assert len(head.convs) == 2
|
||||
assert hasattr(head, 'gc_block')
|
||||
inputs = [torch.randn(1, 4, 23, 23)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import LightHamHead
|
||||
from .utils import _conv_has_norm, to_cuda
|
||||
|
||||
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
|
||||
|
||||
|
||||
def test_ham_head():
|
||||
|
||||
# test without sync_bn
|
||||
head = LightHamHead(
|
||||
in_channels=[16, 32, 64],
|
||||
in_index=[1, 2, 3],
|
||||
channels=64,
|
||||
ham_channels=64,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=ham_norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
||||
ham_kwargs=dict(
|
||||
MD_S=1,
|
||||
MD_R=64,
|
||||
train_steps=6,
|
||||
eval_steps=7,
|
||||
inv_t=100,
|
||||
rand_init=True))
|
||||
assert not _conv_has_norm(head, sync_bn=False)
|
||||
|
||||
inputs = [
|
||||
torch.randn(1, 8, 32, 32),
|
||||
torch.randn(1, 16, 16, 16),
|
||||
torch.randn(1, 32, 8, 8),
|
||||
torch.randn(1, 64, 4, 4)
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.in_channels == [16, 32, 64]
|
||||
assert head.hamburger.ham_in.in_channels == 64
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 16, 16)
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import ISAHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_isa_head():
|
||||
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
isa_head = ISAHead(
|
||||
in_channels=8,
|
||||
channels=4,
|
||||
num_classes=19,
|
||||
isa_channels=4,
|
||||
down_factor=(8, 8))
|
||||
if torch.cuda.is_available():
|
||||
isa_head, inputs = to_cuda(isa_head, inputs)
|
||||
output = isa_head(inputs)
|
||||
assert output.shape == (1, isa_head.num_classes, 23, 23)
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import LRASPPHead
|
||||
|
||||
|
||||
def test_lraspp_head():
|
||||
with pytest.raises(ValueError):
|
||||
# check invalid input_transform
|
||||
LRASPPHead(
|
||||
in_channels=(4, 4, 123),
|
||||
in_index=(0, 1, 2),
|
||||
channels=32,
|
||||
input_transform='resize_concat',
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid branch_channels
|
||||
LRASPPHead(
|
||||
in_channels=(4, 4, 123),
|
||||
in_index=(0, 1, 2),
|
||||
channels=32,
|
||||
branch_channels=64,
|
||||
input_transform='multiple_select',
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||
|
||||
# test with default settings
|
||||
lraspp_head = LRASPPHead(
|
||||
in_channels=(4, 4, 123),
|
||||
in_index=(0, 1, 2),
|
||||
channels=32,
|
||||
input_transform='multiple_select',
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||
inputs = [
|
||||
torch.randn(2, 4, 45, 45),
|
||||
torch.randn(2, 4, 28, 28),
|
||||
torch.randn(2, 123, 14, 14)
|
||||
]
|
||||
with pytest.raises(RuntimeError):
|
||||
# check invalid inputs
|
||||
output = lraspp_head(inputs)
|
||||
|
||||
inputs = [
|
||||
torch.randn(2, 4, 111, 111),
|
||||
torch.randn(2, 4, 77, 77),
|
||||
torch.randn(2, 123, 55, 55)
|
||||
]
|
||||
output = lraspp_head(inputs)
|
||||
assert output.shape == (2, 19, 111, 111)
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.models.decode_heads import Mask2FormerHead
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import SampleList
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_mask2former_head():
|
||||
num_classes = 19
|
||||
cfg = dict(
|
||||
in_channels=[96, 192, 384, 768],
|
||||
strides=[4, 8, 16, 32],
|
||||
feat_channels=256,
|
||||
out_channels=256,
|
||||
num_classes=num_classes,
|
||||
num_queries=100,
|
||||
num_transformer_feat_level=3,
|
||||
align_corners=False,
|
||||
pixel_decoder=dict(
|
||||
type='mmdet.MSDeformAttnPixelDecoder',
|
||||
num_outs=3,
|
||||
norm_cfg=dict(type='GN', num_groups=32),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
encoder=dict( # DeformableDetrTransformerEncoder
|
||||
num_layers=6,
|
||||
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
|
||||
self_attn_cfg=dict( # MultiScaleDeformableAttention
|
||||
embed_dims=256,
|
||||
num_heads=8,
|
||||
num_levels=3,
|
||||
num_points=4,
|
||||
im2col_step=64,
|
||||
dropout=0.0,
|
||||
batch_first=True,
|
||||
norm_cfg=None,
|
||||
init_cfg=None),
|
||||
ffn_cfg=dict(
|
||||
embed_dims=256,
|
||||
feedforward_channels=1024,
|
||||
num_fcs=2,
|
||||
ffn_drop=0.0,
|
||||
act_cfg=dict(type='ReLU', inplace=True))),
|
||||
init_cfg=None),
|
||||
positional_encoding=dict( # SinePositionalEncoding
|
||||
num_feats=128, normalize=True),
|
||||
init_cfg=None),
|
||||
enforce_decoder_input_project=False,
|
||||
positional_encoding=dict( # SinePositionalEncoding
|
||||
num_feats=128, normalize=True),
|
||||
transformer_decoder=dict( # Mask2FormerTransformerDecoder
|
||||
return_intermediate=True,
|
||||
num_layers=9,
|
||||
layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
|
||||
self_attn_cfg=dict( # MultiheadAttention
|
||||
embed_dims=256,
|
||||
num_heads=8,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
dropout_layer=None,
|
||||
batch_first=True),
|
||||
cross_attn_cfg=dict( # MultiheadAttention
|
||||
embed_dims=256,
|
||||
num_heads=8,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
dropout_layer=None,
|
||||
batch_first=True),
|
||||
ffn_cfg=dict(
|
||||
embed_dims=256,
|
||||
feedforward_channels=2048,
|
||||
num_fcs=2,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
ffn_drop=0.0,
|
||||
dropout_layer=None,
|
||||
add_identity=True)),
|
||||
init_cfg=None),
|
||||
loss_cls=dict(
|
||||
type='mmdet.CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=2.0,
|
||||
reduction='mean',
|
||||
class_weight=[1.0] * num_classes + [0.1]),
|
||||
loss_mask=dict(
|
||||
type='mmdet.CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=5.0),
|
||||
loss_dice=dict(
|
||||
type='mmdet.DiceLoss',
|
||||
use_sigmoid=True,
|
||||
activate=True,
|
||||
reduction='mean',
|
||||
naive_dice=True,
|
||||
eps=1.0,
|
||||
loss_weight=5.0),
|
||||
train_cfg=dict(
|
||||
num_points=12544,
|
||||
oversample_ratio=3.0,
|
||||
importance_sample_ratio=0.75,
|
||||
assigner=dict(
|
||||
type='mmdet.HungarianAssigner',
|
||||
match_costs=[
|
||||
dict(type='mmdet.ClassificationCost', weight=2.0),
|
||||
dict(
|
||||
type='mmdet.CrossEntropyLossCost',
|
||||
weight=5.0,
|
||||
use_sigmoid=True),
|
||||
dict(
|
||||
type='mmdet.DiceCost',
|
||||
weight=5.0,
|
||||
pred_act=True,
|
||||
eps=1.0)
|
||||
]),
|
||||
sampler=dict(type='mmdet.MaskPseudoSampler')))
|
||||
cfg = Config(cfg)
|
||||
head = Mask2FormerHead(**cfg)
|
||||
|
||||
inputs = [
|
||||
torch.rand((2, 96, 8, 8)),
|
||||
torch.rand((2, 192, 4, 4)),
|
||||
torch.rand((2, 384, 2, 2)),
|
||||
torch.rand((2, 768, 1, 1))
|
||||
]
|
||||
|
||||
data_samples: SampleList = []
|
||||
for i in range(2):
|
||||
data_sample = SegDataSample()
|
||||
img_meta = {}
|
||||
img_meta['img_shape'] = (32, 32)
|
||||
img_meta['ori_shape'] = (32, 32)
|
||||
data_sample.gt_sem_seg = PixelData(
|
||||
data=torch.randint(0, num_classes, (1, 32, 32)))
|
||||
data_sample.set_metainfo(img_meta)
|
||||
data_samples.append(data_sample)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
for data_sample in data_samples:
|
||||
data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.cuda()
|
||||
|
||||
loss_dict = head.loss(inputs, data_samples, None)
|
||||
assert isinstance(loss_dict, dict)
|
||||
|
||||
batch_img_metas = []
|
||||
for data_sample in data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
|
||||
seg_logits = head.predict(inputs, batch_img_metas, None)
|
||||
assert seg_logits.shape == torch.Size((2, num_classes, 32, 32))
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from os.path import dirname, join
|
||||
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
def test_maskformer_head():
|
||||
init_default_scope('mmseg')
|
||||
repo_dpath = dirname(dirname(__file__))
|
||||
cfg = Config.fromfile(
|
||||
join(
|
||||
repo_dpath,
|
||||
'../../configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py' # noqa
|
||||
))
|
||||
cfg.model.train_cfg = None
|
||||
decode_head = MODELS.build(cfg.model.decode_head)
|
||||
inputs = (torch.randn(1, 256, 32, 32), torch.randn(1, 512, 16, 16),
|
||||
torch.randn(1, 1024, 8, 8), torch.randn(1, 2048, 4, 4))
|
||||
# test inference
|
||||
batch_img_metas = [
|
||||
dict(
|
||||
scale_factor=(1.0, 1.0),
|
||||
img_shape=(512, 683),
|
||||
ori_shape=(512, 683))
|
||||
]
|
||||
test_cfg = dict(mode='whole')
|
||||
output = decode_head.predict(inputs, batch_img_metas, test_cfg)
|
||||
assert output.shape == (1, 150, 512, 683)
|
||||
|
||||
# test training
|
||||
inputs = (torch.randn(2, 256, 32, 32), torch.randn(2, 512, 16, 16),
|
||||
torch.randn(2, 1024, 8, 8), torch.randn(2, 2048, 4, 4))
|
||||
batch_data_samples = []
|
||||
img_meta = {
|
||||
'img_shape': (512, 512),
|
||||
'ori_shape': (480, 640),
|
||||
'pad_shape': (512, 512),
|
||||
'scale_factor': (1.425, 1.425),
|
||||
}
|
||||
for _ in range(2):
|
||||
data_sample = SegDataSample(
|
||||
gt_sem_seg=PixelData(data=torch.ones(512, 512).long()))
|
||||
data_sample.set_metainfo(img_meta)
|
||||
batch_data_samples.append(data_sample)
|
||||
train_cfg = {}
|
||||
losses = decode_head.loss(inputs, batch_data_samples, train_cfg)
|
||||
assert (loss in losses.keys()
|
||||
for loss in ('loss_cls', 'loss_mask', 'loss_dice'))
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import NLHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_nl_head():
|
||||
head = NLHead(in_channels=8, channels=4, num_classes=19)
|
||||
assert len(head.convs) == 2
|
||||
assert hasattr(head, 'nl_block')
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import FCNHead, OCRHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_ocr_head():
|
||||
|
||||
inputs = [torch.randn(1, 8, 23, 23)]
|
||||
ocr_head = OCRHead(
|
||||
in_channels=8, channels=4, num_classes=19, ocr_channels=8)
|
||||
fcn_head = FCNHead(in_channels=8, channels=4, num_classes=19)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(ocr_head, inputs)
|
||||
head, inputs = to_cuda(fcn_head, inputs)
|
||||
prev_output = fcn_head(inputs)
|
||||
output = ocr_head(inputs, prev_output)
|
||||
assert output.shape == (1, ocr_head.num_classes, 23, 23)
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
def test_pidnet_head():
|
||||
init_default_scope('mmseg')
|
||||
|
||||
# Test PIDNet decode head Standard Forward
|
||||
norm_cfg = dict(type='BN', requires_grad=True)
|
||||
backbone_cfg = dict(
|
||||
type='PIDNet',
|
||||
in_channels=3,
|
||||
channels=32,
|
||||
ppm_channels=96,
|
||||
num_stem_blocks=2,
|
||||
num_branch_blocks=3,
|
||||
align_corners=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU', inplace=True))
|
||||
decode_head_cfg = dict(
|
||||
type='PIDHead',
|
||||
in_channels=128,
|
||||
channels=128,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
align_corners=True,
|
||||
loss_decode=[
|
||||
dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
class_weight=[
|
||||
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
|
||||
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
|
||||
1.0865, 1.0955, 1.0865, 1.1529, 1.0507
|
||||
],
|
||||
loss_weight=0.4),
|
||||
dict(
|
||||
type='OhemCrossEntropy',
|
||||
thres=0.9,
|
||||
min_kept=131072,
|
||||
class_weight=[
|
||||
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
|
||||
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
|
||||
1.0865, 1.0955, 1.0865, 1.1529, 1.0507
|
||||
],
|
||||
loss_weight=1.0),
|
||||
dict(type='BoundaryLoss', loss_weight=20.0),
|
||||
dict(
|
||||
type='OhemCrossEntropy',
|
||||
thres=0.9,
|
||||
min_kept=131072,
|
||||
class_weight=[
|
||||
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
|
||||
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
|
||||
1.0865, 1.0955, 1.0865, 1.1529, 1.0507
|
||||
],
|
||||
loss_weight=1.0)
|
||||
])
|
||||
backbone = MODELS.build(backbone_cfg)
|
||||
head = MODELS.build(decode_head_cfg)
|
||||
|
||||
# Test train mode
|
||||
backbone.train()
|
||||
head.train()
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 64, 128)
|
||||
feats = backbone(imgs)
|
||||
seg_logit = head(feats)
|
||||
|
||||
assert isinstance(seg_logit, tuple)
|
||||
assert len(seg_logit) == 3
|
||||
|
||||
p_logits, i_logits, d_logits = seg_logit
|
||||
assert p_logits.shape == (batch_size, 19, 8, 16)
|
||||
assert i_logits.shape == (batch_size, 19, 8, 16)
|
||||
assert d_logits.shape == (batch_size, 1, 8, 16)
|
||||
|
||||
# Test eval mode
|
||||
backbone.eval()
|
||||
head.eval()
|
||||
feats = backbone(imgs)
|
||||
seg_logit = head(feats)
|
||||
|
||||
assert isinstance(seg_logit, torch.Tensor)
|
||||
assert seg_logit.shape == (batch_size, 19, 8, 16)
|
||||
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import PSAHead
|
||||
from .utils import _conv_has_norm, to_cuda
|
||||
|
||||
|
||||
def test_psa_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# psa_type must be in 'bi-direction', 'collect', 'distribute'
|
||||
PSAHead(
|
||||
in_channels=4,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
mask_size=(13, 13),
|
||||
psa_type='gather')
|
||||
|
||||
# test no norm_cfg
|
||||
head = PSAHead(
|
||||
in_channels=4, channels=2, num_classes=19, mask_size=(13, 13))
|
||||
assert not _conv_has_norm(head, sync_bn=False)
|
||||
|
||||
# test with norm_cfg
|
||||
head = PSAHead(
|
||||
in_channels=4,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
mask_size=(13, 13),
|
||||
norm_cfg=dict(type='SyncBN'))
|
||||
assert _conv_has_norm(head, sync_bn=True)
|
||||
|
||||
# test 'bi-direction' psa_type
|
||||
inputs = [torch.randn(1, 4, 13, 13)]
|
||||
head = PSAHead(
|
||||
in_channels=4, channels=2, num_classes=19, mask_size=(13, 13))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 13, 13)
|
||||
|
||||
# test 'bi-direction' psa_type, shrink_factor=1
|
||||
inputs = [torch.randn(1, 4, 13, 13)]
|
||||
head = PSAHead(
|
||||
in_channels=4,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
mask_size=(13, 13),
|
||||
shrink_factor=1)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 13, 13)
|
||||
|
||||
# test 'bi-direction' psa_type with soft_max
|
||||
inputs = [torch.randn(1, 4, 13, 13)]
|
||||
head = PSAHead(
|
||||
in_channels=4,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
mask_size=(13, 13),
|
||||
psa_softmax=True)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 13, 13)
|
||||
|
||||
# test 'collect' psa_type
|
||||
inputs = [torch.randn(1, 4, 13, 13)]
|
||||
head = PSAHead(
|
||||
in_channels=4,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
mask_size=(13, 13),
|
||||
psa_type='collect')
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 13, 13)
|
||||
|
||||
# test 'collect' psa_type, shrink_factor=1
|
||||
inputs = [torch.randn(1, 4, 13, 13)]
|
||||
head = PSAHead(
|
||||
in_channels=4,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
mask_size=(13, 13),
|
||||
shrink_factor=1,
|
||||
psa_type='collect')
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 13, 13)
|
||||
|
||||
# test 'collect' psa_type, shrink_factor=1, compact=True
|
||||
inputs = [torch.randn(1, 4, 13, 13)]
|
||||
head = PSAHead(
|
||||
in_channels=4,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
mask_size=(13, 13),
|
||||
psa_type='collect',
|
||||
shrink_factor=1,
|
||||
compact=True)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 13, 13)
|
||||
|
||||
# test 'distribute' psa_type
|
||||
inputs = [torch.randn(1, 4, 13, 13)]
|
||||
head = PSAHead(
|
||||
in_channels=4,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
mask_size=(13, 13),
|
||||
psa_type='distribute')
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 13, 13)
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import PSPHead
|
||||
from .utils import _conv_has_norm, to_cuda
|
||||
|
||||
|
||||
def test_psp_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# pool_scales must be list|tuple
|
||||
PSPHead(in_channels=4, channels=2, num_classes=19, pool_scales=1)
|
||||
|
||||
# test no norm_cfg
|
||||
head = PSPHead(in_channels=4, channels=2, num_classes=19)
|
||||
assert not _conv_has_norm(head, sync_bn=False)
|
||||
|
||||
# test with norm_cfg
|
||||
head = PSPHead(
|
||||
in_channels=4,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='SyncBN'))
|
||||
assert _conv_has_norm(head, sync_bn=True)
|
||||
|
||||
inputs = [torch.randn(1, 4, 23, 23)]
|
||||
head = PSPHead(
|
||||
in_channels=4, channels=2, num_classes=19, pool_scales=(1, 2, 3))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.psp_modules[0][0].output_size == 1
|
||||
assert head.psp_modules[1][0].output_size == 2
|
||||
assert head.psp_modules[2][0].output_size == 3
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 23, 23)
|
||||
@@ -0,0 +1,126 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.models.decode_heads import SideAdapterCLIPHead
|
||||
from mmseg.structures import SegDataSample
|
||||
from .utils import list_to_cuda
|
||||
|
||||
|
||||
def test_san_head():
|
||||
H, W = (64, 64)
|
||||
clip_channels = 64
|
||||
img_channels = 4
|
||||
num_queries = 40
|
||||
out_dims = 64
|
||||
num_classes = 19
|
||||
cfg = dict(
|
||||
num_classes=num_classes,
|
||||
deep_supervision_idxs=[4],
|
||||
san_cfg=dict(
|
||||
in_channels=img_channels,
|
||||
embed_dims=128,
|
||||
clip_channels=clip_channels,
|
||||
num_queries=num_queries,
|
||||
cfg_encoder=dict(num_encode_layer=4, mlp_ratio=2, num_heads=2),
|
||||
cfg_decoder=dict(
|
||||
num_heads=4,
|
||||
num_layers=1,
|
||||
embed_channels=32,
|
||||
mlp_channels=32,
|
||||
num_mlp=2,
|
||||
rescale=True)),
|
||||
maskgen_cfg=dict(
|
||||
sos_token_num=num_queries,
|
||||
embed_dims=clip_channels,
|
||||
out_dims=out_dims,
|
||||
num_heads=4,
|
||||
mlp_ratio=2),
|
||||
train_cfg=dict(
|
||||
num_points=100,
|
||||
oversample_ratio=3.0,
|
||||
importance_sample_ratio=0.75,
|
||||
assigner=dict(
|
||||
type='HungarianAssigner',
|
||||
match_costs=[
|
||||
dict(type='ClassificationCost', weight=2.0),
|
||||
dict(
|
||||
type='CrossEntropyLossCost',
|
||||
weight=5.0,
|
||||
use_sigmoid=True),
|
||||
dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0)
|
||||
])),
|
||||
loss_decode=[
|
||||
dict(
|
||||
type='CrossEntropyLoss',
|
||||
loss_name='loss_cls_ce',
|
||||
loss_weight=2.0,
|
||||
class_weight=[1.0] * num_classes + [0.1]),
|
||||
dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
loss_name='loss_mask_ce',
|
||||
loss_weight=5.0),
|
||||
dict(
|
||||
type='DiceLoss',
|
||||
ignore_index=None,
|
||||
naive_dice=True,
|
||||
eps=1,
|
||||
loss_name='loss_mask_dice',
|
||||
loss_weight=5.0)
|
||||
])
|
||||
|
||||
cfg = Config(cfg)
|
||||
head = SideAdapterCLIPHead(**cfg)
|
||||
|
||||
inputs = torch.rand((2, img_channels, H, W))
|
||||
clip_feature = [[
|
||||
torch.rand((2, clip_channels, H // 2, W // 2)),
|
||||
torch.rand((2, clip_channels))
|
||||
],
|
||||
[
|
||||
torch.rand((2, clip_channels, H // 2, W // 2)),
|
||||
torch.rand((2, clip_channels))
|
||||
],
|
||||
[
|
||||
torch.rand((2, clip_channels, H // 2, W // 2)),
|
||||
torch.rand((2, clip_channels))
|
||||
],
|
||||
[
|
||||
torch.rand((2, clip_channels, H // 2, W // 2)),
|
||||
torch.rand((2, clip_channels))
|
||||
]]
|
||||
class_embed = torch.rand((num_classes + 1, out_dims))
|
||||
|
||||
data_samples = []
|
||||
for i in range(2):
|
||||
data_sample = SegDataSample()
|
||||
img_meta = {}
|
||||
img_meta['img_shape'] = (H, W)
|
||||
img_meta['ori_shape'] = (H, W)
|
||||
data_sample.gt_sem_seg = PixelData(
|
||||
data=torch.randint(0, num_classes, (1, H, W)))
|
||||
data_sample.set_metainfo(img_meta)
|
||||
data_samples.append(data_sample)
|
||||
|
||||
batch_img_metas = []
|
||||
for data_sample in data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
head = head.cuda()
|
||||
data = list_to_cuda([inputs, clip_feature, class_embed])
|
||||
for data_sample in data_samples:
|
||||
data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.cuda()
|
||||
else:
|
||||
data = [inputs, clip_feature, class_embed]
|
||||
|
||||
# loss test
|
||||
loss_dict = head.loss(data, data_samples, None)
|
||||
assert isinstance(loss_dict, dict)
|
||||
|
||||
# prediction test
|
||||
with torch.no_grad():
|
||||
seg_logits = head.predict(data, batch_img_metas, None)
|
||||
assert seg_logits.shape == torch.Size((2, num_classes, H, W))
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import SegformerHead
|
||||
|
||||
|
||||
def test_segformer_head():
|
||||
with pytest.raises(AssertionError):
|
||||
# `in_channels` must have same length as `in_index`
|
||||
SegformerHead(
|
||||
in_channels=(1, 2, 3), in_index=(0, 1), channels=5, num_classes=2)
|
||||
|
||||
H, W = (64, 64)
|
||||
in_channels = (32, 64, 160, 256)
|
||||
shapes = [(H // 2**(i + 2), W // 2**(i + 2))
|
||||
for i in range(len(in_channels))]
|
||||
model = SegformerHead(
|
||||
in_channels=in_channels,
|
||||
in_index=[0, 1, 2, 3],
|
||||
channels=256,
|
||||
num_classes=19)
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
# in_index must match the input feature maps.
|
||||
inputs = [
|
||||
torch.randn((1, in_channel, *shape))
|
||||
for in_channel, shape in zip(in_channels, shapes)
|
||||
][:3]
|
||||
temp = model(inputs)
|
||||
|
||||
# Normal Input
|
||||
# ((1, 32, 16, 16), (1, 64, 8, 8), (1, 160, 4, 4), (1, 256, 2, 2)
|
||||
inputs = [
|
||||
torch.randn((1, in_channel, *shape))
|
||||
for in_channel, shape in zip(in_channels, shapes)
|
||||
]
|
||||
temp = model(inputs)
|
||||
|
||||
assert temp.shape == (1, 19, H // 4, W // 4)
|
||||
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import SegmenterMaskTransformerHead
|
||||
from .utils import _conv_has_norm, to_cuda
|
||||
|
||||
|
||||
def test_segmenter_mask_transformer_head():
|
||||
head = SegmenterMaskTransformerHead(
|
||||
in_channels=2,
|
||||
channels=2,
|
||||
num_classes=150,
|
||||
num_layers=2,
|
||||
num_heads=3,
|
||||
embed_dims=192,
|
||||
dropout_ratio=0.0)
|
||||
assert _conv_has_norm(head, sync_bn=True)
|
||||
head.init_weights()
|
||||
|
||||
inputs = [torch.randn(1, 2, 32, 32)]
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 32, 32)
|
||||
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import SETRMLAHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_setr_mla_head(capsys):
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# MLA requires input multiple stage feature information.
|
||||
SETRMLAHead(in_channels=8, channels=4, num_classes=19, in_index=1)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# multiple in_indexs requires multiple in_channels.
|
||||
SETRMLAHead(
|
||||
in_channels=8, channels=4, num_classes=19, in_index=(0, 1, 2, 3))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# channels should be len(in_channels) * mla_channels
|
||||
SETRMLAHead(
|
||||
in_channels=(8, 8, 8, 8),
|
||||
channels=8,
|
||||
mla_channels=4,
|
||||
in_index=(0, 1, 2, 3),
|
||||
num_classes=19)
|
||||
|
||||
# test inference of MLA head
|
||||
img_size = (8, 8)
|
||||
patch_size = 4
|
||||
head = SETRMLAHead(
|
||||
in_channels=(8, 8, 8, 8),
|
||||
channels=16,
|
||||
mla_channels=4,
|
||||
in_index=(0, 1, 2, 3),
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
h, w = img_size[0] // patch_size, img_size[1] // patch_size
|
||||
# Input square NCHW format feature information
|
||||
x = [
|
||||
torch.randn(1, 8, h, w),
|
||||
torch.randn(1, 8, h, w),
|
||||
torch.randn(1, 8, h, w),
|
||||
torch.randn(1, 8, h, w)
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
head, x = to_cuda(head, x)
|
||||
out = head(x)
|
||||
assert out.shape == (1, head.num_classes, h * 4, w * 4)
|
||||
|
||||
# Input non-square NCHW format feature information
|
||||
x = [
|
||||
torch.randn(1, 8, h, w * 2),
|
||||
torch.randn(1, 8, h, w * 2),
|
||||
torch.randn(1, 8, h, w * 2),
|
||||
torch.randn(1, 8, h, w * 2)
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
head, x = to_cuda(head, x)
|
||||
out = head(x)
|
||||
assert out.shape == (1, head.num_classes, h * 4, w * 8)
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import SETRUPHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_setr_up_head(capsys):
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# kernel_size must be [1/3]
|
||||
SETRUPHead(num_classes=19, kernel_size=2)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# in_channels must be int type and in_channels must be same
|
||||
# as embed_dim.
|
||||
SETRUPHead(in_channels=(4, 4), channels=2, num_classes=19)
|
||||
|
||||
# test init_cfg of head
|
||||
head = SETRUPHead(
|
||||
in_channels=4,
|
||||
channels=2,
|
||||
norm_cfg=dict(type='SyncBN'),
|
||||
num_classes=19,
|
||||
init_cfg=dict(type='Kaiming'))
|
||||
super(SETRUPHead, head).init_weights()
|
||||
|
||||
# test inference of Naive head
|
||||
# the auxiliary head of Naive head is same as Naive head
|
||||
img_size = (4, 4)
|
||||
patch_size = 2
|
||||
head = SETRUPHead(
|
||||
in_channels=4,
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
num_convs=1,
|
||||
up_scale=4,
|
||||
kernel_size=1,
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
h, w = img_size[0] // patch_size, img_size[1] // patch_size
|
||||
|
||||
# Input square NCHW format feature information
|
||||
x = [torch.randn(1, 4, h, w)]
|
||||
if torch.cuda.is_available():
|
||||
head, x = to_cuda(head, x)
|
||||
out = head(x)
|
||||
assert out.shape == (1, head.num_classes, h * 4, w * 4)
|
||||
|
||||
# Input non-square NCHW format feature information
|
||||
x = [torch.randn(1, 4, h, w * 2)]
|
||||
if torch.cuda.is_available():
|
||||
head, x = to_cuda(head, x)
|
||||
out = head(x)
|
||||
assert out.shape == (1, head.num_classes, h * 4, w * 8)
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import UPerHead
|
||||
from .utils import _conv_has_norm, to_cuda
|
||||
|
||||
|
||||
def test_uper_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# fpn_in_channels must be list|tuple
|
||||
UPerHead(in_channels=4, channels=2, num_classes=19)
|
||||
|
||||
# test no norm_cfg
|
||||
head = UPerHead(
|
||||
in_channels=[4, 2], channels=2, num_classes=19, in_index=[-2, -1])
|
||||
assert not _conv_has_norm(head, sync_bn=False)
|
||||
|
||||
# test with norm_cfg
|
||||
head = UPerHead(
|
||||
in_channels=[4, 2],
|
||||
channels=2,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='SyncBN'),
|
||||
in_index=[-2, -1])
|
||||
assert _conv_has_norm(head, sync_bn=True)
|
||||
|
||||
inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 2, 21, 21)]
|
||||
head = UPerHead(
|
||||
in_channels=[4, 2], channels=2, num_classes=19, in_index=[-2, -1])
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.models.decode_heads import VPDDepthHead
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
class TestVPDDepthHead(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Set up common resources."""
|
||||
self.in_channels = [320, 640, 1280, 1280]
|
||||
self.max_depth = 10.0
|
||||
self.loss_decode = dict(
|
||||
type='SiLogLoss'
|
||||
) # Replace with your actual loss type and parameters
|
||||
self.vpd_depth_head = VPDDepthHead(
|
||||
max_depth=self.max_depth,
|
||||
in_channels=self.in_channels,
|
||||
loss_decode=self.loss_decode)
|
||||
|
||||
def test_forward(self):
|
||||
"""Test the forward method."""
|
||||
# Create a mock input tensor. Replace shape as per your needs.
|
||||
x = [
|
||||
torch.randn(1, 320, 32, 32),
|
||||
torch.randn(1, 640, 16, 16),
|
||||
torch.randn(1, 1280, 8, 8),
|
||||
torch.randn(1, 1280, 4, 4)
|
||||
]
|
||||
|
||||
output = self.vpd_depth_head.forward(x)
|
||||
print(output.shape)
|
||||
|
||||
self.assertEqual(output.shape, (1, 1, 256, 256))
|
||||
|
||||
def test_loss_by_feat(self):
|
||||
"""Test the loss_by_feat method."""
|
||||
# Create mock data for `pred_depth_map` and `batch_data_samples`.
|
||||
pred_depth_map = torch.randn(1, 1, 32, 32)
|
||||
gt_depth_map = PixelData(data=torch.rand(1, 32, 32))
|
||||
batch_data_samples = [SegDataSample(gt_depth_map=gt_depth_map)]
|
||||
|
||||
loss = self.vpd_depth_head.loss_by_feat(pred_depth_map,
|
||||
batch_data_samples)
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
31
Seg_All_In_One_MMSeg/tests/test_models/test_heads/utils.py
Normal file
31
Seg_All_In_One_MMSeg/tests/test_models/test_heads/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import SyncBatchNorm
|
||||
|
||||
|
||||
def _conv_has_norm(module, sync_bn):
|
||||
for m in module.modules():
|
||||
if isinstance(m, ConvModule):
|
||||
if not m.with_norm:
|
||||
return False
|
||||
if sync_bn:
|
||||
if not isinstance(m.bn, SyncBatchNorm):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def to_cuda(module, data):
|
||||
module = module.cuda()
|
||||
if isinstance(data, list):
|
||||
for i in range(len(data)):
|
||||
data[i] = data[i].cuda()
|
||||
return module, data
|
||||
|
||||
|
||||
def list_to_cuda(data):
|
||||
if isinstance(data, list):
|
||||
for i in range(len(data)):
|
||||
data[i] = list_to_cuda(data[i])
|
||||
return data
|
||||
else:
|
||||
return data.cuda()
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmseg.models.losses import CrossEntropyLoss, weight_reduce_loss
|
||||
|
||||
|
||||
def test_cross_entropy_loss_class_weights():
|
||||
loss_class = CrossEntropyLoss
|
||||
pred = torch.rand((1, 10, 4, 4))
|
||||
target = torch.randint(0, 10, (1, 4, 4))
|
||||
class_weight = torch.ones(10)
|
||||
avg_factor = target.numel()
|
||||
|
||||
cross_entropy_loss = F.cross_entropy(
|
||||
pred, target, weight=class_weight, reduction='none', ignore_index=-100)
|
||||
|
||||
expected_loss = weight_reduce_loss(
|
||||
cross_entropy_loss,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=avg_factor)
|
||||
|
||||
# Test loss forward
|
||||
loss = loss_class(class_weight=class_weight.tolist())(pred, target)
|
||||
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert expected_loss == loss
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.losses import DiceLoss
|
||||
|
||||
|
||||
@pytest.mark.parametrize('naive_dice', [True, False])
|
||||
def test_dice_loss(naive_dice):
|
||||
loss_class = DiceLoss
|
||||
pred = torch.rand((1, 10, 4, 4))
|
||||
target = torch.randint(0, 10, (1, 4, 4))
|
||||
weight = torch.rand(1)
|
||||
# Test loss forward
|
||||
loss = loss_class(naive_dice=naive_dice)(pred, target)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test loss forward with weight
|
||||
loss = loss_class(naive_dice=naive_dice)(pred, target, weight)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test loss forward with reduction_override
|
||||
loss = loss_class(naive_dice=naive_dice)(
|
||||
pred, target, reduction_override='mean')
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test loss forward with avg_factor
|
||||
loss = loss_class(naive_dice=naive_dice)(pred, target, avg_factor=10)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# loss can evaluate with avg_factor only if
|
||||
# reduction is None, 'none' or 'mean'.
|
||||
reduction_override = 'sum'
|
||||
loss_class(naive_dice=naive_dice)(
|
||||
pred, target, avg_factor=10, reduction_override=reduction_override)
|
||||
|
||||
# Test loss forward with avg_factor and reduction
|
||||
for reduction_override in [None, 'none', 'mean']:
|
||||
loss_class(naive_dice=naive_dice)(
|
||||
pred, target, avg_factor=10, reduction_override=reduction_override)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test loss forward with has_acted=False and use_sigmoid=False
|
||||
for use_sigmoid in [True, False]:
|
||||
loss_class(
|
||||
use_sigmoid=use_sigmoid, activate=True,
|
||||
naive_dice=naive_dice)(pred, target)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test loss forward with weight.ndim != loss.ndim
|
||||
with pytest.raises(AssertionError):
|
||||
weight = torch.rand((2, 8))
|
||||
loss_class(naive_dice=naive_dice)(pred, target, weight)
|
||||
|
||||
# Test loss forward with len(weight) != len(pred)
|
||||
with pytest.raises(AssertionError):
|
||||
weight = torch.rand(8)
|
||||
loss_class(naive_dice=naive_dice)(pred, target, weight)
|
||||
|
||||
# Test _expand_onehot_labels_dice
|
||||
pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float()
|
||||
target = torch.tensor([[[0, 0], [0, 1]]])
|
||||
target_onehot = torch.tensor([[[[1, 1], [1, 0]], [[0, 0], [0, 1]]]])
|
||||
weight = torch.rand(1)
|
||||
loss = loss_class(naive_dice=naive_dice)(pred, target, weight)
|
||||
loss_onehot = loss_class(naive_dice=naive_dice)(pred, target_onehot,
|
||||
weight)
|
||||
assert torch.equal(loss, loss_onehot)
|
||||
|
||||
# Test Whether Loss is 0 when pred == target, eps == 0 and naive_dice=False
|
||||
target = torch.randint(0, 2, (1, 10, 4, 4))
|
||||
pred = target.float()
|
||||
target = target.sigmoid()
|
||||
weight = torch.rand(1)
|
||||
loss = loss_class(
|
||||
naive_dice=False, use_sigmoid=True, eps=0)(pred, target, weight)
|
||||
assert loss.item() == 0
|
||||
|
||||
# Test ignore_index when ignore_index is the only class
|
||||
with pytest.raises(AssertionError):
|
||||
pred = torch.ones((1, 1, 4, 4))
|
||||
target = torch.randint(0, 1, (1, 4, 4))
|
||||
weight = torch.rand(1)
|
||||
loss = loss_class(
|
||||
naive_dice=naive_dice, use_sigmoid=False, ignore_index=0,
|
||||
eps=0)(pred, target, weight)
|
||||
|
||||
# Test ignore_index with naive_dice = False
|
||||
pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float()
|
||||
target = torch.tensor([[[[1, 1], [1, 0]], [[1, 0], [0, 1]]]]).sigmoid()
|
||||
weight = torch.rand(1)
|
||||
loss = loss_class(
|
||||
naive_dice=False, use_sigmoid=True, ignore_index=1,
|
||||
eps=0)(pred, target, weight)
|
||||
assert loss.item() == 0
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.losses import HuasdorffDisstanceLoss
|
||||
|
||||
|
||||
def test_huasdorff_distance_loss():
|
||||
loss_class = HuasdorffDisstanceLoss
|
||||
pred = torch.rand((10, 8, 6, 6))
|
||||
target = torch.rand((10, 6, 6))
|
||||
class_weight = torch.rand(8)
|
||||
|
||||
# Test loss forward
|
||||
loss = loss_class()(pred, target)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test loss forward with avg_factor
|
||||
loss = loss_class()(pred, target, avg_factor=10)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test loss forward with avg_factor and reduction is None, 'sum' and 'mean'
|
||||
for reduction in [None, 'sum', 'mean']:
|
||||
loss = loss_class()(pred, target, avg_factor=10, reduction=reduction)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test loss forward with class_weight
|
||||
with pytest.raises(AssertionError):
|
||||
loss_class(class_weight=class_weight)(pred, target)
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.losses.kldiv_loss import KLDivLoss
|
||||
|
||||
|
||||
def test_kldiv_loss_with_none_reduction():
|
||||
loss_class = KLDivLoss
|
||||
pred = torch.rand((8, 5, 5))
|
||||
target = torch.rand((8, 5, 5))
|
||||
reduction = 'none'
|
||||
|
||||
# Test loss forward
|
||||
loss = loss_class(reduction=reduction)(pred, target)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.shape == (8, 5, 5), f'{loss.shape}'
|
||||
|
||||
|
||||
def test_kldiv_loss_with_mean_reduction():
|
||||
loss_class = KLDivLoss
|
||||
pred = torch.rand((8, 5, 5))
|
||||
target = torch.rand((8, 5, 5))
|
||||
reduction = 'mean'
|
||||
|
||||
# Test loss forward
|
||||
loss = loss_class(reduction=reduction)(pred, target)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.shape == (8, ), f'{loss.shape}'
|
||||
|
||||
|
||||
def test_kldiv_loss_with_sum_reduction():
|
||||
loss_class = KLDivLoss
|
||||
pred = torch.rand((8, 5, 5))
|
||||
target = torch.rand((8, 5, 5))
|
||||
reduction = 'sum'
|
||||
|
||||
# Test loss forward
|
||||
loss = loss_class(reduction=reduction)(pred, target)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.shape == (8, ), f'{loss.shape}'
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmseg.models.losses import SiLogLoss
|
||||
|
||||
|
||||
class TestSiLogLoss(TestCase):
|
||||
|
||||
def test_SiLogLoss_forward(self):
|
||||
pred = torch.tensor([[1.0, 2.0], [3.5, 4.0]], dtype=torch.float32)
|
||||
target = torch.tensor([[0.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
|
||||
weight = torch.tensor([1.0, 0.5], dtype=torch.float32)
|
||||
|
||||
loss_module = SiLogLoss()
|
||||
loss = loss_module.forward(pred, target, weight)
|
||||
|
||||
expected_loss = 0.02
|
||||
self.assertAlmostEqual(loss.item(), expected_loss, places=2)
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def test_tversky_lose():
|
||||
from mmseg.models import build_loss
|
||||
|
||||
# test alpha + beta != 1
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(
|
||||
type='TverskyLoss',
|
||||
class_weight=[1.0, 2.0, 3.0],
|
||||
loss_weight=1.0,
|
||||
alpha=0.4,
|
||||
beta=0.7,
|
||||
loss_name='loss_tversky')
|
||||
tversky_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(8, 3, 4, 4)
|
||||
labels = (torch.rand(8, 4, 4) * 3).long()
|
||||
tversky_loss(logits, labels, ignore_index=1)
|
||||
|
||||
# test tversky loss
|
||||
loss_cfg = dict(
|
||||
type='TverskyLoss',
|
||||
class_weight=[1.0, 2.0, 3.0],
|
||||
loss_weight=1.0,
|
||||
ignore_index=1,
|
||||
loss_name='loss_tversky')
|
||||
tversky_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(8, 3, 4, 4)
|
||||
labels = (torch.rand(8, 4, 4) * 3).long()
|
||||
tversky_loss(logits, labels)
|
||||
|
||||
# test loss with class weights from file
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import mmengine
|
||||
import numpy as np
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
|
||||
mmengine.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl',
|
||||
'pkl') # from pkl file
|
||||
loss_cfg = dict(
|
||||
type='TverskyLoss',
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0,
|
||||
ignore_index=1,
|
||||
loss_name='loss_tversky')
|
||||
tversky_loss = build_loss(loss_cfg)
|
||||
tversky_loss(logits, labels)
|
||||
|
||||
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
|
||||
loss_cfg = dict(
|
||||
type='TverskyLoss',
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0,
|
||||
ignore_index=1,
|
||||
loss_name='loss_tversky')
|
||||
tversky_loss = build_loss(loss_cfg)
|
||||
tversky_loss(logits, labels)
|
||||
tmp_file.close()
|
||||
os.remove(f'{tmp_file.name}.pkl')
|
||||
os.remove(f'{tmp_file.name}.npy')
|
||||
|
||||
# test tversky loss has name `loss_tversky`
|
||||
loss_cfg = dict(
|
||||
type='TverskyLoss',
|
||||
smooth=2,
|
||||
loss_weight=1.0,
|
||||
ignore_index=1,
|
||||
alpha=0.3,
|
||||
beta=0.7,
|
||||
loss_name='loss_tversky')
|
||||
tversky_loss = build_loss(loss_cfg)
|
||||
assert tversky_loss.loss_name == 'loss_tversky'
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models import Feature2Pyramid
|
||||
|
||||
|
||||
def test_feature2pyramid():
|
||||
# test
|
||||
rescales = [4, 2, 1, 0.5]
|
||||
embed_dim = 64
|
||||
inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))]
|
||||
|
||||
fpn = Feature2Pyramid(
|
||||
embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))
|
||||
outputs = fpn(inputs)
|
||||
assert outputs[0].shape == torch.Size([1, 64, 128, 128])
|
||||
assert outputs[1].shape == torch.Size([1, 64, 64, 64])
|
||||
assert outputs[2].shape == torch.Size([1, 64, 32, 32])
|
||||
assert outputs[3].shape == torch.Size([1, 64, 16, 16])
|
||||
|
||||
# test rescales = [2, 1, 0.5, 0.25]
|
||||
rescales = [2, 1, 0.5, 0.25]
|
||||
inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))]
|
||||
|
||||
fpn = Feature2Pyramid(
|
||||
embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))
|
||||
outputs = fpn(inputs)
|
||||
assert outputs[0].shape == torch.Size([1, 64, 64, 64])
|
||||
assert outputs[1].shape == torch.Size([1, 64, 32, 32])
|
||||
assert outputs[2].shape == torch.Size([1, 64, 16, 16])
|
||||
assert outputs[3].shape == torch.Size([1, 64, 8, 8])
|
||||
|
||||
# test rescales = [4, 2, 0.25, 0]
|
||||
rescales = [4, 2, 0.25, 0]
|
||||
with pytest.raises(KeyError):
|
||||
fpn = Feature2Pyramid(
|
||||
embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))
|
||||
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models import FPN
|
||||
|
||||
|
||||
def test_fpn():
|
||||
in_channels = [64, 128, 256, 512]
|
||||
inputs = [
|
||||
torch.randn(1, c, 56 // 2**i, 56 // 2**i)
|
||||
for i, c in enumerate(in_channels)
|
||||
]
|
||||
|
||||
fpn = FPN(in_channels, 64, len(in_channels))
|
||||
outputs = fpn(inputs)
|
||||
assert outputs[0].shape == torch.Size([1, 64, 56, 56])
|
||||
assert outputs[1].shape == torch.Size([1, 64, 28, 28])
|
||||
assert outputs[2].shape == torch.Size([1, 64, 14, 14])
|
||||
assert outputs[3].shape == torch.Size([1, 64, 7, 7])
|
||||
|
||||
fpn = FPN(
|
||||
in_channels,
|
||||
64,
|
||||
len(in_channels),
|
||||
upsample_cfg=dict(mode='nearest', scale_factor=2.0))
|
||||
outputs = fpn(inputs)
|
||||
assert outputs[0].shape == torch.Size([1, 64, 56, 56])
|
||||
assert outputs[1].shape == torch.Size([1, 64, 28, 28])
|
||||
assert outputs[2].shape == torch.Size([1, 64, 14, 14])
|
||||
assert outputs[3].shape == torch.Size([1, 64, 7, 7])
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.necks import ICNeck
|
||||
from mmseg.models.necks.ic_neck import CascadeFeatureFusion
|
||||
from ..test_heads.utils import _conv_has_norm, to_cuda
|
||||
|
||||
|
||||
def test_ic_neck():
|
||||
# test with norm_cfg
|
||||
neck = ICNeck(
|
||||
in_channels=(4, 16, 16),
|
||||
out_channels=8,
|
||||
norm_cfg=dict(type='SyncBN'),
|
||||
align_corners=False)
|
||||
assert _conv_has_norm(neck, sync_bn=True)
|
||||
|
||||
inputs = [
|
||||
torch.randn(1, 4, 32, 64),
|
||||
torch.randn(1, 16, 16, 32),
|
||||
torch.randn(1, 16, 8, 16)
|
||||
]
|
||||
neck = ICNeck(
|
||||
in_channels=(4, 16, 16),
|
||||
out_channels=4,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
align_corners=False)
|
||||
if torch.cuda.is_available():
|
||||
neck, inputs = to_cuda(neck, inputs)
|
||||
|
||||
outputs = neck(inputs)
|
||||
assert outputs[0].shape == (1, 4, 16, 32)
|
||||
assert outputs[1].shape == (1, 4, 32, 64)
|
||||
assert outputs[1].shape == (1, 4, 32, 64)
|
||||
|
||||
|
||||
def test_ic_neck_cascade_feature_fusion():
|
||||
cff = CascadeFeatureFusion(64, 64, 32)
|
||||
assert cff.conv_low.in_channels == 64
|
||||
assert cff.conv_low.out_channels == 32
|
||||
assert cff.conv_high.in_channels == 64
|
||||
assert cff.conv_high.out_channels == 32
|
||||
|
||||
|
||||
def test_ic_neck_input_channels():
|
||||
with pytest.raises(AssertionError):
|
||||
# ICNet Neck input channel constraints.
|
||||
ICNeck(
|
||||
in_channels=(16, 64, 64, 64),
|
||||
out_channels=32,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
align_corners=False)
|
||||
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.necks import JPU
|
||||
|
||||
|
||||
def test_fastfcn_neck():
|
||||
# Test FastFCN Standard Forward
|
||||
model = JPU(
|
||||
in_channels=(64, 128, 256),
|
||||
mid_channels=64,
|
||||
start_level=0,
|
||||
end_level=-1,
|
||||
dilations=(1, 2, 4, 8),
|
||||
)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
batch_size = 1
|
||||
input = [
|
||||
torch.randn(batch_size, 64, 64, 128),
|
||||
torch.randn(batch_size, 128, 32, 64),
|
||||
torch.randn(batch_size, 256, 16, 32)
|
||||
]
|
||||
feat = model(input)
|
||||
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size([batch_size, 64, 64, 128])
|
||||
assert feat[1].shape == torch.Size([batch_size, 128, 32, 64])
|
||||
assert feat[2].shape == torch.Size([batch_size, 256, 64, 128])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# FastFCN input and in_channels constraints.
|
||||
JPU(in_channels=(256, 64, 128), start_level=0, end_level=5)
|
||||
|
||||
# Test not default start_level
|
||||
model = JPU(in_channels=(64, 128, 256), start_level=1, end_level=-1)
|
||||
input = [
|
||||
torch.randn(batch_size, 64, 64, 128),
|
||||
torch.randn(batch_size, 128, 32, 64),
|
||||
torch.randn(batch_size, 256, 16, 32)
|
||||
]
|
||||
feat = model(input)
|
||||
assert len(feat) == 2
|
||||
assert feat[0].shape == torch.Size([batch_size, 128, 32, 64])
|
||||
assert feat[1].shape == torch.Size([batch_size, 2048, 32, 64])
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models import MLANeck
|
||||
|
||||
|
||||
def test_mla():
|
||||
in_channels = [4, 4, 4, 4]
|
||||
mla = MLANeck(in_channels, 32)
|
||||
|
||||
inputs = [torch.randn(1, c, 12, 12) for i, c in enumerate(in_channels)]
|
||||
outputs = mla(inputs)
|
||||
assert outputs[0].shape == torch.Size([1, 32, 12, 12])
|
||||
assert outputs[1].shape == torch.Size([1, 32, 12, 12])
|
||||
assert outputs[2].shape == torch.Size([1, 32, 12, 12])
|
||||
assert outputs[3].shape == torch.Size([1, 32, 12, 12])
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models import MultiLevelNeck
|
||||
|
||||
|
||||
def test_multilevel_neck():
|
||||
|
||||
# Test init_weights
|
||||
MultiLevelNeck([266], 32).init_weights()
|
||||
|
||||
# Test multi feature maps
|
||||
in_channels = [32, 64, 128, 256]
|
||||
inputs = [torch.randn(1, c, 14, 14) for i, c in enumerate(in_channels)]
|
||||
|
||||
neck = MultiLevelNeck(in_channels, 32)
|
||||
outputs = neck(inputs)
|
||||
assert outputs[0].shape == torch.Size([1, 32, 7, 7])
|
||||
assert outputs[1].shape == torch.Size([1, 32, 14, 14])
|
||||
assert outputs[2].shape == torch.Size([1, 32, 28, 28])
|
||||
assert outputs[3].shape == torch.Size([1, 32, 56, 56])
|
||||
|
||||
# Test one feature map
|
||||
in_channels = [768]
|
||||
inputs = [torch.randn(1, 768, 14, 14)]
|
||||
|
||||
neck = MultiLevelNeck(in_channels, 32)
|
||||
outputs = neck(inputs)
|
||||
assert outputs[0].shape == torch.Size([1, 32, 7, 7])
|
||||
assert outputs[1].shape == torch.Size([1, 32, 14, 14])
|
||||
assert outputs[2].shape == torch.Size([1, 32, 28, 28])
|
||||
assert outputs[3].shape == torch.Size([1, 32, 56, 56])
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
@@ -0,0 +1,57 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine import ConfigDict
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
from .utils import _segmentor_forward_train_test
|
||||
|
||||
|
||||
def test_cascade_encoder_decoder():
|
||||
|
||||
# test 1 decode head, w.o. aux head
|
||||
cfg = ConfigDict(
|
||||
type='CascadeEncoderDecoder',
|
||||
num_stages=2,
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=[
|
||||
dict(type='ExampleDecodeHead'),
|
||||
dict(type='ExampleCascadeDecodeHead')
|
||||
])
|
||||
cfg.test_cfg = ConfigDict(mode='whole')
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test slide mode
|
||||
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test 1 decode head, 1 aux head
|
||||
cfg = ConfigDict(
|
||||
type='CascadeEncoderDecoder',
|
||||
num_stages=2,
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=[
|
||||
dict(type='ExampleDecodeHead'),
|
||||
dict(type='ExampleCascadeDecodeHead')
|
||||
],
|
||||
auxiliary_head=dict(type='ExampleDecodeHead'))
|
||||
cfg.test_cfg = ConfigDict(mode='whole')
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test 1 decode head, 2 aux head
|
||||
cfg = ConfigDict(
|
||||
type='CascadeEncoderDecoder',
|
||||
num_stages=2,
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=[
|
||||
dict(type='ExampleDecodeHead'),
|
||||
dict(type='ExampleCascadeDecodeHead')
|
||||
],
|
||||
auxiliary_head=[
|
||||
dict(type='ExampleDecodeHead'),
|
||||
dict(type='ExampleDecodeHead')
|
||||
])
|
||||
cfg.test_cfg = ConfigDict(mode='whole')
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from os.path import dirname, join
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine import Config, ConfigDict
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
import mmseg
|
||||
from mmseg.models.segmentors import DepthEstimator
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
class TestDepthEstimator(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
repo_dpath = dirname(dirname(mmseg.__file__))
|
||||
config_dpath = join(repo_dpath, 'configs/_base_/models/vpd_sd.py')
|
||||
vpd_cfg = Config.fromfile(config_dpath).stable_diffusion_cfg
|
||||
vpd_cfg.pop('checkpoint')
|
||||
|
||||
backbone_cfg = dict(
|
||||
type='VPD',
|
||||
diffusion_cfg=vpd_cfg,
|
||||
class_embed_path='https://download.openmmlab.com/mmsegmentation/'
|
||||
'v0.5/vpd/nyu_class_embeddings.pth',
|
||||
class_embed_select=True,
|
||||
pad_shape=64,
|
||||
unet_cfg=dict(use_attn=False),
|
||||
)
|
||||
|
||||
head_cfg = dict(
|
||||
type='VPDDepthHead',
|
||||
max_depth=10,
|
||||
)
|
||||
|
||||
self.model = DepthEstimator(
|
||||
backbone=backbone_cfg, decode_head=head_cfg)
|
||||
|
||||
inputs = torch.randn(1, 3, 64, 80)
|
||||
data_sample = SegDataSample()
|
||||
data_sample.gt_depth_map = PixelData(data=torch.rand(1, 64, 80))
|
||||
data_sample.set_metainfo(dict(img_shape=(64, 80), ori_shape=(64, 80)))
|
||||
self.data = dict(inputs=inputs, data_samples=[data_sample])
|
||||
|
||||
def test_slide_flip_inference(self):
|
||||
|
||||
self.model.test_cfg = ConfigDict(
|
||||
dict(mode='slide_flip', crop_size=(64, 64), stride=(16, 16)))
|
||||
|
||||
with torch.no_grad():
|
||||
out = self.model.predict(**deepcopy(self.data))
|
||||
|
||||
self.assertEqual(len(out), 1)
|
||||
self.assertIn('pred_depth_map', out[0].keys())
|
||||
self.assertListEqual(list(out[0].pred_depth_map.shape), [64, 80])
|
||||
|
||||
def test__forward(self):
|
||||
data = deepcopy(self.data)
|
||||
data['inputs'] = data['inputs'][:, :, :64, :64]
|
||||
with torch.no_grad():
|
||||
out = self.model._forward(**data)
|
||||
self.assertListEqual(list(out.shape), [1, 1, 64, 64])
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine import ConfigDict
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.structures import SegDataSample
|
||||
from .utils import _segmentor_forward_train_test
|
||||
|
||||
|
||||
def test_encoder_decoder():
|
||||
|
||||
# test 1 decode head, w.o. aux head
|
||||
|
||||
cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(type='ExampleDecodeHead'),
|
||||
train_cfg=None,
|
||||
test_cfg=dict(mode='whole'))
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test out_channels == 1
|
||||
cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(
|
||||
type='ExampleDecodeHead', num_classes=2, out_channels=1),
|
||||
train_cfg=None,
|
||||
test_cfg=dict(mode='whole'))
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test slide mode
|
||||
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test 1 decode head, 1 aux head
|
||||
cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(type='ExampleDecodeHead'),
|
||||
auxiliary_head=dict(type='ExampleDecodeHead'))
|
||||
cfg.test_cfg = ConfigDict(mode='whole')
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test 1 decode head, 2 aux head
|
||||
cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(type='ExampleDecodeHead'),
|
||||
auxiliary_head=[
|
||||
dict(type='ExampleDecodeHead'),
|
||||
dict(type='ExampleDecodeHead')
|
||||
])
|
||||
cfg.test_cfg = ConfigDict(mode='whole')
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
|
||||
def test_postprocess_result():
|
||||
cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(type='ExampleDecodeHead'),
|
||||
train_cfg=None,
|
||||
test_cfg=dict(mode='whole'))
|
||||
model = build_segmentor(cfg)
|
||||
|
||||
# test postprocess
|
||||
data_sample = SegDataSample()
|
||||
data_sample.gt_sem_seg = PixelData(
|
||||
**{'data': torch.randint(0, 10, (1, 8, 8))})
|
||||
data_sample.set_metainfo({
|
||||
'padding_size': (0, 2, 0, 2),
|
||||
'ori_shape': (8, 8)
|
||||
})
|
||||
seg_logits = torch.zeros((1, 2, 10, 10))
|
||||
seg_logits[:, :, :8, :8] = 1
|
||||
data_samples = [data_sample]
|
||||
|
||||
outputs = model.postprocess_result(seg_logits, data_samples)
|
||||
assert outputs[0].seg_logits.data.shape == torch.Size((2, 8, 8))
|
||||
assert torch.allclose(outputs[0].seg_logits.data, torch.ones((2, 8, 8)))
|
||||
|
||||
data_sample = SegDataSample()
|
||||
data_sample.gt_sem_seg = PixelData(
|
||||
**{'data': torch.randint(0, 10, (1, 8, 8))})
|
||||
data_sample.set_metainfo({
|
||||
'img_padding_size': (0, 2, 0, 2),
|
||||
'ori_shape': (8, 8)
|
||||
})
|
||||
|
||||
data_samples = [data_sample]
|
||||
outputs = model.postprocess_result(seg_logits, data_samples)
|
||||
assert outputs[0].seg_logits.data.shape == torch.Size((2, 8, 8))
|
||||
assert torch.allclose(outputs[0].seg_logits.data, torch.ones((2, 8, 8)))
|
||||
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine import ConfigDict
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
from tests.test_models.test_segmentors.utils import \
|
||||
_segmentor_forward_train_test
|
||||
|
||||
|
||||
def test_multimodal_encoder_decoder():
|
||||
|
||||
cfg = ConfigDict(
|
||||
type='MultimodalEncoderDecoder',
|
||||
asymetric_input=False,
|
||||
image_encoder=dict(type='ExampleBackbone', out_indices=[1, 2, 3, 4]),
|
||||
text_encoder=dict(
|
||||
type='ExampleTextEncoder',
|
||||
vocabulary=['A', 'B', 'C'],
|
||||
output_dims=3),
|
||||
decode_head=dict(
|
||||
type='ExampleDecodeHead', out_channels=1, num_classes=2),
|
||||
train_cfg=None,
|
||||
test_cfg=dict(mode='whole'))
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from mmengine import ConfigDict
|
||||
from mmengine.model import BaseTTAModel
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from .utils import * # noqa: F401,F403
|
||||
|
||||
init_default_scope('mmseg')
|
||||
|
||||
|
||||
def test_encoder_decoder_tta():
|
||||
|
||||
segmentor_cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(type='ExampleDecodeHead'),
|
||||
train_cfg=None,
|
||||
test_cfg=dict(mode='whole'))
|
||||
|
||||
cfg = ConfigDict(type='SegTTAModel', module=segmentor_cfg)
|
||||
|
||||
model: BaseTTAModel = MODELS.build(cfg)
|
||||
|
||||
imgs = []
|
||||
data_samples = []
|
||||
directions = ['horizontal', 'vertical']
|
||||
for i in range(12):
|
||||
flip_direction = directions[0] if i % 3 == 0 else directions[1]
|
||||
imgs.append(torch.randn(1, 3, 10 + i, 10 + i))
|
||||
data_samples.append([
|
||||
SegDataSample(
|
||||
metainfo=dict(
|
||||
ori_shape=(10, 10),
|
||||
img_shape=(10 + i, 10 + i),
|
||||
flip=(i % 2 == 0),
|
||||
flip_direction=flip_direction,
|
||||
img_path=tempfile.mktemp()),
|
||||
gt_sem_seg=PixelData(data=torch.randint(0, 19, (1, 10, 10))))
|
||||
])
|
||||
|
||||
model.test_step(dict(inputs=imgs, data_samples=data_samples))
|
||||
|
||||
# test out_channels == 1
|
||||
segmentor_cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(
|
||||
type='ExampleDecodeHead',
|
||||
num_classes=2,
|
||||
out_channels=1,
|
||||
threshold=0.4),
|
||||
train_cfg=None,
|
||||
test_cfg=dict(mode='whole'))
|
||||
model.module = MODELS.build(segmentor_cfg)
|
||||
for data_sample in data_samples:
|
||||
data_sample[0].gt_sem_seg.data = torch.randint(0, 2, (1, 10, 10))
|
||||
model.test_step(dict(inputs=imgs, data_samples=data_samples))
|
||||
182
Seg_All_In_One_MMSeg/tests/test_models/test_segmentors/utils.py
Normal file
182
Seg_All_In_One_MMSeg/tests/test_models/test_segmentors/utils.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.structures import PixelData
|
||||
from torch import nn
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmseg.models import SegDataPreProcessor
|
||||
from mmseg.models.decode_heads.cascade_decode_head import BaseCascadeDecodeHead
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
input_shape (tuple):
|
||||
input batch dimensions
|
||||
|
||||
num_classes (int):
|
||||
number of semantic classes
|
||||
"""
|
||||
(N, C, H, W) = input_shape
|
||||
|
||||
imgs = torch.randn(*input_shape)
|
||||
segs = torch.randint(
|
||||
low=0, high=num_classes - 1, size=(N, H, W), dtype=torch.long)
|
||||
|
||||
img_metas = [{
|
||||
'img_shape': (H, W),
|
||||
'ori_shape': (H, W),
|
||||
'pad_shape': (H, W, C),
|
||||
'filename': '<demo>.png',
|
||||
'scale_factor': 1.0,
|
||||
'flip': False,
|
||||
'flip_direction': 'horizontal'
|
||||
} for _ in range(N)]
|
||||
|
||||
data_samples = [
|
||||
SegDataSample(
|
||||
gt_sem_seg=PixelData(data=segs[i]), metainfo=img_metas[i])
|
||||
for i in range(N)
|
||||
]
|
||||
|
||||
mm_inputs = {'imgs': torch.FloatTensor(imgs), 'data_samples': data_samples}
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ExampleBackbone(nn.Module):
|
||||
|
||||
def __init__(self, out_indices=None):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
self.out_indices = out_indices
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
if self.out_indices is None:
|
||||
return [self.conv(x)]
|
||||
else:
|
||||
outs = []
|
||||
for i in self.out_indices:
|
||||
outs.append(self.conv(x))
|
||||
return outs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ExampleDecodeHead(BaseDecodeHead):
|
||||
|
||||
def __init__(self, num_classes=19, out_channels=None, **kwargs):
|
||||
super().__init__(
|
||||
3, 3, num_classes=num_classes, out_channels=out_channels, **kwargs)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.cls_seg(inputs[0])
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ExampleTextEncoder(nn.Module):
|
||||
|
||||
def __init__(self, vocabulary=None, output_dims=None):
|
||||
super().__init__()
|
||||
self.vocabulary = vocabulary
|
||||
self.output_dims = output_dims
|
||||
|
||||
def forward(self):
|
||||
return torch.randn((len(self.vocabulary), self.output_dims))
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ExampleCascadeDecodeHead(BaseCascadeDecodeHead):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(3, 3, num_classes=19)
|
||||
|
||||
def forward(self, inputs, prev_out):
|
||||
return self.cls_seg(inputs[0])
|
||||
|
||||
|
||||
def _segmentor_forward_train_test(segmentor):
|
||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||
num_classes = segmentor.decode_head[-1].num_classes
|
||||
else:
|
||||
num_classes = segmentor.decode_head.num_classes
|
||||
# batch_size=2 for BatchNorm
|
||||
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
|
||||
|
||||
# convert to cuda Tensor if applicable
|
||||
if torch.cuda.is_available():
|
||||
segmentor = segmentor.cuda()
|
||||
|
||||
# check data preprocessor
|
||||
if not hasattr(segmentor,
|
||||
'data_preprocessor') or segmentor.data_preprocessor is None:
|
||||
segmentor.data_preprocessor = SegDataPreProcessor()
|
||||
|
||||
mm_inputs = segmentor.data_preprocessor(mm_inputs, True)
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
data_samples = mm_inputs.pop('data_samples')
|
||||
|
||||
# create optimizer wrapper
|
||||
optimizer = SGD(segmentor.parameters(), lr=0.1)
|
||||
optim_wrapper = OptimWrapper(optimizer)
|
||||
|
||||
# Test forward train
|
||||
losses = segmentor.forward(imgs, data_samples, mode='loss')
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
# Test train_step
|
||||
data_batch = dict(inputs=imgs, data_samples=data_samples)
|
||||
outputs = segmentor.train_step(data_batch, optim_wrapper)
|
||||
assert isinstance(outputs, dict)
|
||||
assert 'loss' in outputs
|
||||
|
||||
# Test val_step
|
||||
with torch.no_grad():
|
||||
segmentor.eval()
|
||||
data_batch = dict(inputs=imgs, data_samples=data_samples)
|
||||
outputs = segmentor.val_step(data_batch)
|
||||
assert isinstance(outputs, list)
|
||||
|
||||
# Test forward simple test
|
||||
with torch.no_grad():
|
||||
segmentor.eval()
|
||||
data_batch = dict(inputs=imgs, data_samples=data_samples)
|
||||
results = segmentor.forward(imgs, data_samples, mode='tensor')
|
||||
assert isinstance(results, torch.Tensor)
|
||||
|
||||
|
||||
def _segmentor_predict(segmentor):
|
||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||
num_classes = segmentor.decode_head[-1].num_classes
|
||||
else:
|
||||
num_classes = segmentor.decode_head.num_classes
|
||||
# batch_size=2 for BatchNorm
|
||||
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
|
||||
|
||||
# convert to cuda Tensor if applicable
|
||||
if torch.cuda.is_available():
|
||||
segmentor = segmentor.cuda()
|
||||
|
||||
# check data preprocessor
|
||||
if not hasattr(segmentor,
|
||||
'data_preprocessor') or segmentor.data_preprocessor is None:
|
||||
segmentor.data_preprocessor = SegDataPreProcessor()
|
||||
|
||||
mm_inputs = segmentor.data_preprocessor(mm_inputs, True)
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
data_samples = mm_inputs.pop('data_samples')
|
||||
|
||||
# Test predict
|
||||
with torch.no_grad():
|
||||
segmentor.eval()
|
||||
data_batch = dict(inputs=imgs, data_samples=data_samples)
|
||||
outputs = segmentor.predict(**data_batch)
|
||||
assert isinstance(outputs, list)
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user