first commit
This commit is contained in:
8
Seg_All_In_One_MMSeg/CITATION.cff
Normal file
8
Seg_All_In_One_MMSeg/CITATION.cff
Normal file
@@ -0,0 +1,8 @@
|
||||
cff-version: 1.2.0
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- name: "MMSegmentation Contributors"
|
||||
title: "OpenMMLab Semantic Segmentation Toolbox and Benchmark"
|
||||
date-released: 2020-07-10
|
||||
url: "https://github.com/open-mmlab/mmsegmentation"
|
||||
license: Apache-2.0
|
||||
203
Seg_All_In_One_MMSeg/LICENSE
Normal file
203
Seg_All_In_One_MMSeg/LICENSE
Normal file
@@ -0,0 +1,203 @@
|
||||
Copyright 2020 The MMSegmentation Authors. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020 The MMSegmentation Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
5
Seg_All_In_One_MMSeg/MANIFEST.in
Normal file
5
Seg_All_In_One_MMSeg/MANIFEST.in
Normal file
@@ -0,0 +1,5 @@
|
||||
include requirements/*.txt
|
||||
include mmseg/.mim/model-index.yml
|
||||
include mmseg/utils/bpe_simple_vocab_16e6.txt.gz
|
||||
recursive-include mmseg/.mim/configs *.py *.yaml
|
||||
recursive-include mmseg/.mim/tools *.py *.sh
|
||||
@@ -0,0 +1,297 @@
|
||||
import os, requests, hashlib
|
||||
from tqdm import tqdm
|
||||
|
||||
### 链接获取网址:https://github.com/open-mmlab/mmcv/blob/master/mmcv/model_zoo/[deprecated.json | mmcls.json | open_mmlab.json | torchvision_0.12.json] ###
|
||||
|
||||
# open_mmlab JSON 数据
|
||||
open_mmlab_model_urls = {
|
||||
"vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
|
||||
"detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
|
||||
"detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
|
||||
"detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
|
||||
"detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
|
||||
"detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
|
||||
"resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
|
||||
"resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
|
||||
"resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
|
||||
"contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
|
||||
"detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
|
||||
"detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
|
||||
"jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
|
||||
"jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
|
||||
"jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
|
||||
"jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
|
||||
"jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
|
||||
"jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
|
||||
"msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
|
||||
"msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
|
||||
"msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
|
||||
"msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
|
||||
"msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
|
||||
"bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
|
||||
"kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
|
||||
"kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
|
||||
"res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
|
||||
"regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
|
||||
"regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
|
||||
"regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
|
||||
"regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
|
||||
"regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
|
||||
"regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
|
||||
"regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
|
||||
"regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
|
||||
"resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
|
||||
"resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
|
||||
"resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
|
||||
"mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
|
||||
"mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
|
||||
"mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
|
||||
"contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
|
||||
"contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
|
||||
"resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
|
||||
"resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
|
||||
"resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
|
||||
"darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
|
||||
"mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth",
|
||||
"pidnet-s": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-s_imagenet1k_20230306-715e6273.pth",
|
||||
"pidnet-m": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-m_imagenet1k_20230306-39893c52.pth",
|
||||
"pidnet-l": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-l_imagenet1k_20230306-67889109.pth",
|
||||
"ddrnet23-s": "https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/pretrain/ddrnet23s-in1kpre_3rdparty-1ccac5b1.pth",
|
||||
"ddrnet23": "https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/pretrain/ddrnet23-in1kpre_3rdparty-9ca29f62.pth",
|
||||
"stdc1": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc1_20220308-5368626c.pth",
|
||||
"stdc2": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc2_20220308-7dbd9127.pth"
|
||||
}
|
||||
|
||||
# deprecated_model_urls = {{
|
||||
# "resnet50_caffe": "detectron/resnet50_caffe",
|
||||
# "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
|
||||
# "resnet101_caffe": "detectron/resnet101_caffe",
|
||||
# "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
|
||||
# }}
|
||||
|
||||
mmcls_model_urls = {
|
||||
"vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
|
||||
"vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
|
||||
"vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
|
||||
"vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
|
||||
"vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
|
||||
"vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
|
||||
"vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
|
||||
"vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
|
||||
"resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth",
|
||||
"resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth",
|
||||
"resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth",
|
||||
"resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth",
|
||||
"resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth",
|
||||
"resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth",
|
||||
"resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth",
|
||||
"resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth",
|
||||
"resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
|
||||
"resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
|
||||
"resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
|
||||
"resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
|
||||
"se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
|
||||
"se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
|
||||
"resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
|
||||
"resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
|
||||
"resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
|
||||
"resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
|
||||
"shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
|
||||
"shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
|
||||
"mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth",
|
||||
"mobilenet_v3_small": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth",
|
||||
"mobilenet_v3_large": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth",
|
||||
"repvgg_A0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth",
|
||||
"repvgg_A1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth",
|
||||
"repvgg_A2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth",
|
||||
"repvgg_B0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth",
|
||||
"repvgg_B1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth",
|
||||
"repvgg_B1g2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth",
|
||||
"repvgg_B1g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth",
|
||||
"repvgg_B2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth",
|
||||
"repvgg_B2g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth",
|
||||
"repvgg_B3": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth",
|
||||
"repvgg_B3g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth",
|
||||
"repvgg_D2se": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth",
|
||||
"res2net101_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth",
|
||||
"res2net50_w14": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth",
|
||||
"res2net50_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth",
|
||||
"swin_tiny": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth",
|
||||
"swin_small": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth",
|
||||
"swin_base": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth",
|
||||
"swin_large": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth",
|
||||
"t2t_vit_t_14": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth",
|
||||
"t2t_vit_t_19": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth",
|
||||
"t2t_vit_t_24": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth",
|
||||
"tnt_small": "https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth",
|
||||
"vit_base_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth",
|
||||
"vit_base_p32": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth",
|
||||
"vit_large_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-large-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-b20ba619.pth"
|
||||
}
|
||||
|
||||
torchvision_012_model_urls = {
|
||||
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
|
||||
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
|
||||
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
|
||||
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
|
||||
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
|
||||
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
|
||||
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
|
||||
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
|
||||
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
|
||||
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
|
||||
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
|
||||
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
|
||||
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
|
||||
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
|
||||
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
|
||||
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
|
||||
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
|
||||
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
|
||||
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
|
||||
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
|
||||
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
|
||||
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
|
||||
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
|
||||
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
|
||||
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
|
||||
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
|
||||
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
|
||||
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
|
||||
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
|
||||
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
|
||||
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
|
||||
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
|
||||
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
|
||||
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
|
||||
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
|
||||
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
|
||||
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
||||
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
||||
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
|
||||
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
|
||||
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
|
||||
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
|
||||
"shufflenetv2_x1.5": None,
|
||||
"shufflenetv2_x2.0": None,
|
||||
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
|
||||
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
|
||||
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
|
||||
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
|
||||
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
|
||||
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
||||
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
||||
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
||||
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
||||
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
|
||||
}
|
||||
|
||||
def calculate_file_hash(file_path, hash_algorithm='md5'):
|
||||
"""计算文件的哈希值,默认使用 MD5"""
|
||||
hash_func = hashlib.new(hash_algorithm)
|
||||
with open(file_path, 'rb') as f:
|
||||
while chunk := f.read(8192):
|
||||
hash_func.update(chunk)
|
||||
return hash_func.hexdigest()
|
||||
|
||||
def download_file(url, output_path):
|
||||
"""下载并保存文件,显示下载进度条"""
|
||||
response = requests.get(url, stream=True)
|
||||
|
||||
if response.status_code == 200:
|
||||
# 获取文件的总大小,以便确定进度条的总长度
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
|
||||
# 初始化 tqdm 进度条
|
||||
with tqdm(total=total_size, unit='B', unit_scale=True, desc=output_path, ncols=100) as pbar:
|
||||
with open(output_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
pbar.update(len(chunk)) # 更新进度条
|
||||
|
||||
print(f"Downloaded {output_path}")
|
||||
else:
|
||||
print(f"Failed to download {url}")
|
||||
|
||||
def file_exists_and_same(url, output_path):
|
||||
"""检查文件是否已存在并且相同"""
|
||||
if not os.path.exists(output_path):
|
||||
return False
|
||||
|
||||
# 计算远程文件的大小
|
||||
response = requests.head(url)
|
||||
remote_file_size = int(response.headers.get('Content-Length', 0))
|
||||
|
||||
# 获取本地文件的大小
|
||||
local_file_size = os.path.getsize(output_path)
|
||||
|
||||
# 如果文件大小不同,返回 False
|
||||
if local_file_size != remote_file_size:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
# # 如果大小相同,比较文件哈希值
|
||||
# remote_hash = requests.get(url + ".md5").text.strip() if url.endswith(".pth") else None
|
||||
# local_hash = calculate_file_hash(output_path) if remote_hash else None
|
||||
# print(remote_hash, local_hash)
|
||||
# # 返回 True 如果哈希值相同
|
||||
# return local_hash == remote_hash if remote_hash else False
|
||||
|
||||
def download_all_models(model_urls, output_dir):
|
||||
"""遍历JSON并下载文件"""
|
||||
for model_name, url in model_urls.items():
|
||||
# 如果url为空则继续
|
||||
if url == None:
|
||||
print(" ", end='')
|
||||
print(f"\033[91m{model_name}后URL为空,跳过下载!\033[0m")
|
||||
continue
|
||||
# 创建保存路径
|
||||
file_name = f"{model_name}.pth" # 将文件名按 key 进行命名
|
||||
output_path = os.path.join(output_dir, file_name)
|
||||
|
||||
# 获取 output_path 中的文件夹路径,并确保该路径存在
|
||||
output_folder = os.path.dirname(output_path)
|
||||
os.makedirs(output_folder, exist_ok=True) # 如果文件夹不存在则创建
|
||||
|
||||
# 检查文件是否已存在并且相同
|
||||
if file_exists_and_same(url, output_path):
|
||||
print(" ", end='')
|
||||
print(f"\033[93mFile {output_path} already exists and the size is same, skipping download.\033[0m")
|
||||
else:
|
||||
print(" ", end='')
|
||||
print(f"正在下载 {model_name} : {output_path} 中...")
|
||||
# 下载文件
|
||||
download_file(url, output_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
### 1.下载openmmlab数据集 ###
|
||||
# 创建存储下载内容的文件夹
|
||||
output_open_mmlab_dir = './My_Local_Model/open_mmlab'
|
||||
os.makedirs(output_open_mmlab_dir, exist_ok=True)
|
||||
# 执行下载
|
||||
print(f"\033[32m下载open_mmlab数据中...\033[0m")
|
||||
download_all_models(open_mmlab_model_urls, output_open_mmlab_dir)
|
||||
|
||||
# ### 2.下载deprecated数据集 ###
|
||||
# # 创建存储下载内容的文件夹
|
||||
# output_deprecated_dir = './My_Local_Model/deprecated'
|
||||
# os.makedirs(output_deprecated_dir, exist_ok=True)
|
||||
# # 执行下载
|
||||
# download_all_models(deprecated_model_urls, output_deprecated_dir)
|
||||
|
||||
### 3.下载mmcls数据集 ###
|
||||
# 创建存储下载内容的文件夹
|
||||
output_mmcls_dir = './My_Local_Model/mmcls'
|
||||
os.makedirs(output_mmcls_dir, exist_ok=True)
|
||||
# 执行下载
|
||||
download_all_models(mmcls_model_urls, output_mmcls_dir)
|
||||
|
||||
### 4.下载torchvision_012数据集 ###
|
||||
# 创建存储下载内容的文件夹
|
||||
output_torchvision_012_dir = './My_Local_Model/torchvision_012'
|
||||
os.makedirs(output_torchvision_012_dir, exist_ok=True)
|
||||
# 执行下载
|
||||
download_all_models(torchvision_012_model_urls, output_torchvision_012_dir)
|
||||
@@ -0,0 +1,628 @@
|
||||
{
|
||||
"publicdataset_cholecseg8k": {
|
||||
"train_imgs_num": 6464,
|
||||
"classes": [
|
||||
"背景",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
"10",
|
||||
"11",
|
||||
"12"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
29,
|
||||
32,
|
||||
136
|
||||
],
|
||||
[
|
||||
160,
|
||||
15,
|
||||
95
|
||||
],
|
||||
[
|
||||
0,
|
||||
160,
|
||||
233
|
||||
],
|
||||
[
|
||||
52,
|
||||
184,
|
||||
178
|
||||
],
|
||||
[
|
||||
90,
|
||||
120,
|
||||
41
|
||||
]
|
||||
],
|
||||
"palette_num": 13,
|
||||
"mean": [
|
||||
85.65740418979115,
|
||||
53.99282220050495,
|
||||
46.074045888534535
|
||||
],
|
||||
"std": [
|
||||
72.24589167201978,
|
||||
56.76979155397199,
|
||||
49.056637115061775
|
||||
],
|
||||
"imgs_num": 6464
|
||||
},
|
||||
"my_dataset_model": {
|
||||
"train_imgs_num": 631,
|
||||
"classes": [
|
||||
"背景",
|
||||
"肝脏",
|
||||
"胆囊",
|
||||
"分离钳",
|
||||
"止血海绵",
|
||||
"肝总管",
|
||||
"胆总管",
|
||||
"吸引器",
|
||||
"剪刀",
|
||||
"止血纱布",
|
||||
"生物夹",
|
||||
"无损伤钳",
|
||||
"喷洒",
|
||||
"胆囊管",
|
||||
"胆囊动脉",
|
||||
"电凝",
|
||||
"标本袋",
|
||||
"引流管",
|
||||
"纱布",
|
||||
"金属钛夹",
|
||||
"术中超声",
|
||||
"吻合器",
|
||||
"乳胶管",
|
||||
"推结器",
|
||||
"肝带",
|
||||
"钳夹",
|
||||
"超声刀",
|
||||
"脂肪",
|
||||
"双极电凝",
|
||||
"棉球",
|
||||
"血管阻断夹",
|
||||
"肿瘤",
|
||||
"针",
|
||||
"线",
|
||||
"韧带",
|
||||
"胆囊静脉"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
29,
|
||||
32,
|
||||
136
|
||||
],
|
||||
[
|
||||
160,
|
||||
15,
|
||||
95
|
||||
],
|
||||
[
|
||||
0,
|
||||
160,
|
||||
233
|
||||
],
|
||||
[
|
||||
52,
|
||||
184,
|
||||
178
|
||||
],
|
||||
[
|
||||
90,
|
||||
120,
|
||||
41
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
177,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
167,
|
||||
24,
|
||||
233
|
||||
],
|
||||
[
|
||||
112,
|
||||
113,
|
||||
150
|
||||
],
|
||||
[
|
||||
0,
|
||||
255,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
255,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
255,
|
||||
255
|
||||
],
|
||||
[
|
||||
138,
|
||||
251,
|
||||
213
|
||||
],
|
||||
[
|
||||
136,
|
||||
162,
|
||||
196
|
||||
],
|
||||
[
|
||||
197,
|
||||
83,
|
||||
181
|
||||
],
|
||||
[
|
||||
202,
|
||||
202,
|
||||
200
|
||||
],
|
||||
[
|
||||
113,
|
||||
102,
|
||||
140
|
||||
],
|
||||
[
|
||||
66,
|
||||
115,
|
||||
82
|
||||
],
|
||||
[
|
||||
240,
|
||||
16,
|
||||
116
|
||||
],
|
||||
[
|
||||
155,
|
||||
132,
|
||||
0
|
||||
],
|
||||
[
|
||||
155,
|
||||
62,
|
||||
0
|
||||
],
|
||||
[
|
||||
146,
|
||||
175,
|
||||
236
|
||||
],
|
||||
[
|
||||
255,
|
||||
172,
|
||||
159
|
||||
],
|
||||
[
|
||||
245,
|
||||
161,
|
||||
0
|
||||
],
|
||||
[
|
||||
134,
|
||||
124,
|
||||
118
|
||||
],
|
||||
[
|
||||
0,
|
||||
157,
|
||||
142
|
||||
],
|
||||
[
|
||||
181,
|
||||
85,
|
||||
105
|
||||
],
|
||||
[
|
||||
42,
|
||||
8,
|
||||
66
|
||||
]
|
||||
],
|
||||
"palette_num": 36,
|
||||
"mean": [
|
||||
94.94709810464319,
|
||||
61.729422339499315,
|
||||
75.93763705236911
|
||||
],
|
||||
"std": [
|
||||
44.00550608113231,
|
||||
42.695956669847746,
|
||||
44.99354156225513
|
||||
],
|
||||
"imgs_num": 2000
|
||||
},
|
||||
"publicdataset_autolaparo": {
|
||||
"train_imgs_num": 1440,
|
||||
"classes": [
|
||||
"背景",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
29,
|
||||
32,
|
||||
136
|
||||
],
|
||||
[
|
||||
160,
|
||||
15,
|
||||
95
|
||||
]
|
||||
],
|
||||
"palette_num": 10,
|
||||
"mean": [
|
||||
123.62464353460942,
|
||||
85.34836259209033,
|
||||
82.31539425671558
|
||||
],
|
||||
"std": [
|
||||
47.172211618459315,
|
||||
47.08256715323592,
|
||||
48.135121265163605
|
||||
]
|
||||
},
|
||||
"publicdataset_endovis_2017": {
|
||||
"train_imgs_num": 1800,
|
||||
"classes": [
|
||||
"背景",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
]
|
||||
],
|
||||
"palette_num": 8,
|
||||
"mean": [
|
||||
122.21429912990676,
|
||||
77.0821859677977,
|
||||
87.03836664626716
|
||||
],
|
||||
"std": [
|
||||
50.53335800365262,
|
||||
42.895340354037465,
|
||||
47.739426483390446
|
||||
]
|
||||
},
|
||||
"publicdataset_dresden": {
|
||||
"train_imgs_num": 17363,
|
||||
"classes": [
|
||||
"背景",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
"10"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
29,
|
||||
32,
|
||||
136
|
||||
],
|
||||
[
|
||||
160,
|
||||
15,
|
||||
95
|
||||
],
|
||||
[
|
||||
0,
|
||||
160,
|
||||
233
|
||||
]
|
||||
],
|
||||
"palette_num": 11,
|
||||
"mean": [
|
||||
103.172638338208,
|
||||
61.44762740851152,
|
||||
51.407770213021976
|
||||
],
|
||||
"std": [
|
||||
75.77031253622098,
|
||||
54.63616729031377,
|
||||
49.45572239497569
|
||||
]
|
||||
},
|
||||
"publicdataset_endovis_2018": {
|
||||
"train_imgs_num": 1800,
|
||||
"classes": [
|
||||
"背景",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7"
|
||||
],
|
||||
"palette": [
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
91,
|
||||
0
|
||||
],
|
||||
[
|
||||
255,
|
||||
234,
|
||||
0
|
||||
],
|
||||
[
|
||||
85,
|
||||
111,
|
||||
181
|
||||
],
|
||||
[
|
||||
181,
|
||||
227,
|
||||
14
|
||||
],
|
||||
[
|
||||
72,
|
||||
0,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
155,
|
||||
33
|
||||
],
|
||||
[
|
||||
255,
|
||||
0,
|
||||
255
|
||||
]
|
||||
],
|
||||
"palette_num": 8,
|
||||
"mean": [
|
||||
122.21429912990676,
|
||||
77.0821859677977,
|
||||
87.03836664626716
|
||||
],
|
||||
"std": [
|
||||
50.53335800365262,
|
||||
42.895340354037465,
|
||||
47.739426483390446
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "my_dataset_model",
|
||||
"dataset_class_name": "MyDataset_model",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/Seg_All_In_One_MMSeg/My_Data",
|
||||
"img_scale_width": 1920,
|
||||
"img_scale_height": 1080,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "A_Ori",
|
||||
"train_seg_map_path": "A_Label_GT_label_fold",
|
||||
"val_img_path": "A_Ori",
|
||||
"val_seg_map_path": "A_Label_GT_label_fold",
|
||||
"test_img_path": "A_Ori",
|
||||
"test_seg_map_path": "A_Label_GT_label_fold"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": [
|
||||
"背景", "肝脏", "胆囊", "分离钳", "止血海绵", "肝总管", "胆总管", "吸引器", "剪刀", "止血纱布", "生物夹", "无损伤钳", "喷洒",
|
||||
"胆囊管", "胆囊动脉", "电凝", "标本袋", "引流管", "纱布", "金属钛夹", "术中超声", "吻合器", "乳胶管", "推结器",
|
||||
"肝带", "钳夹", "超声刀", "脂肪", "双极电凝", "棉球", "血管阻断夹", "肿瘤", "针", "线", "韧带", "胆囊静脉"
|
||||
],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255],
|
||||
[29, 32, 136], [160, 15, 95], [0, 160, 233], [52, 184, 178], [90, 120, 41], [255, 0, 0], [177, 0, 0],
|
||||
[167, 24, 233], [112, 113, 150], [0, 255, 0], [255, 255, 255], [0, 255, 255], [138, 251, 213], [136, 162, 196],
|
||||
[197, 83, 181], [202, 202, 200], [113, 102, 140], [66, 115, 82], [240, 16, 116], [155, 132, 0], [155, 62, 0],
|
||||
[146, 175, 236], [255, 172, 159], [245, 161, 0], [134, 124, 118], [0, 157, 142], [181, 85, 105], [42, 8, 66]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".png",
|
||||
"seg_map_suffix": "_gtFine_labelTrainIds.png",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "publicdataset_autolaparo",
|
||||
"dataset_class_name": "PublicDataSet_AutoLaparo",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/2_AutoLaparo-10Type-1920x1080",
|
||||
"img_scale_width": 1920,
|
||||
"img_scale_height": 1080,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "images/train",
|
||||
"train_seg_map_path": "labels_GT/train",
|
||||
"val_img_path": "images/val",
|
||||
"val_seg_map_path": "labels_GT/val",
|
||||
"test_img_path": "images/val",
|
||||
"test_seg_map_path": "labels_GT/val"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".png",
|
||||
"seg_map_suffix": ".png",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "publicdataset_cholecseg8k",
|
||||
"dataset_class_name": "PublicDataSet_CholecSeg8k",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/1_CholecSeg8k-13Type-1920x1080",
|
||||
"img_scale_width": 1920,
|
||||
"img_scale_height": 1080,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "images/train",
|
||||
"train_seg_map_path": "labels_GT/train",
|
||||
"val_img_path": "images/val",
|
||||
"val_seg_map_path": "labels_GT/val",
|
||||
"test_img_path": "images/val",
|
||||
"test_seg_map_path": "labels_GT/val"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95], [0, 160, 233], [52, 184, 178], [90, 120, 41]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".png",
|
||||
"seg_map_suffix": ".png",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "publicdataset_dresden",
|
||||
"dataset_class_name": "PublicDataSet_Dresden",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/4_Dresden-11Type-512x512",
|
||||
"img_scale_width": 512,
|
||||
"img_scale_height": 512,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "images/train",
|
||||
"train_seg_map_path": "labels_GT/train",
|
||||
"val_img_path": "images/val",
|
||||
"val_seg_map_path": "labels_GT/val",
|
||||
"test_img_path": "images/test",
|
||||
"test_seg_map_path": "labels_GT/test"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255], [29, 32, 136], [160, 15, 95], [0, 160, 233]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".png",
|
||||
"seg_map_suffix": ".png",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "publicdataset_endovis_2017",
|
||||
"dataset_class_name": "PublicDataSet_Endovis_2017",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/3_1_Endovis_2017-8Type-512x512",
|
||||
"img_scale_width": 512,
|
||||
"img_scale_height": 512,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "images/train",
|
||||
"train_seg_map_path": "labels_GT/train",
|
||||
"val_img_path": "images/val",
|
||||
"val_seg_map_path": "labels_GT/val",
|
||||
"test_img_path": "images/val",
|
||||
"test_seg_map_path": "labels_GT/val"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7"],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".bmp",
|
||||
"seg_map_suffix": ".bmp",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"____一、comment_dataset_info": "定义多个数据集文件名和对应类名",
|
||||
"dataset_info": {
|
||||
"dataset_file_name": "publicdataset_endovis_2018",
|
||||
"dataset_class_name": "PublicDataSet_Endovis_2018",
|
||||
"data_root": "/home/wkmgc/Desktop/Seg/DataSet_Public/3_2_Endovis_2018-8Type-512x512",
|
||||
"img_scale_width": 512,
|
||||
"img_scale_height": 512,
|
||||
"____#####comment_paths#####": "训练、验证、测试集所在文件夹",
|
||||
"paths": {
|
||||
"train_img_path": "images/train",
|
||||
"train_seg_map_path": "labels_GT/train",
|
||||
"val_img_path": "images/val",
|
||||
"val_seg_map_path": "labels_GT/val",
|
||||
"test_img_path": "images/val",
|
||||
"test_seg_map_path": "labels_GT/val"
|
||||
}
|
||||
},
|
||||
"____二、comment_label_info": "定义Label图片相关参数",
|
||||
"label_info": {
|
||||
"classes": ["背景", "1", "2", "3", "4", "5", "6", "7"],
|
||||
"palette": [
|
||||
[0, 0, 0], [255, 91, 0], [255, 234, 0], [85, 111, 181], [181, 227, 14], [72, 0, 255], [0, 155, 33], [255, 0, 255]
|
||||
],
|
||||
"____#####comment#####": "一般不太会变的参数",
|
||||
"img_suffix": ".bmp",
|
||||
"seg_map_suffix": ".bmp",
|
||||
"____#####comment_reduce_zero_label_1#####": "在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;",
|
||||
"____#####comment_reduce_zero_label_2#####": "在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】",
|
||||
"reduce_zero_label": false
|
||||
},
|
||||
"____三、comment_training_info": "定义训练相关参数",
|
||||
"training_info": {
|
||||
"crop_size_width": 256,
|
||||
"crop_size_height": 256,
|
||||
"train_batch_size": 16,
|
||||
"train_num_workers": 4,
|
||||
"val_and_test_batch_size": 1,
|
||||
"val_and_test_num_workers": 4
|
||||
}
|
||||
}
|
||||
|
||||
61
Seg_All_In_One_MMSeg/My_All_In_One/1_Initial_Data_All-ori.py
Normal file
61
Seg_All_In_One_MMSeg/My_All_In_One/1_Initial_Data_All-ori.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
import os
|
||||
from Initial_Data_Program.Initial_Data_Gen_configs_base_datasets_my_dataset import generate_configs_base_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_utils_class_names import generate_mmseg_utils_class_names_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_my_dataset import generate_mmseg_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_init_ import generate_mmseg_datasets_init_file
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.1.定义各数据集相关参数 ###########
|
||||
# 可以定义多个数据集文件名 和 对应类名
|
||||
dataset_file_names = ["my_dataset_model"] # =['my_dataset', 'my_dataset_2'] # =["my_dataset"]
|
||||
dataset_class_names = ["MyDataset_model"] # =['MyDataset', 'MyDataset2'] # =["MyDataset"]
|
||||
dataset_file_name='my_dataset_model' # 数据集 文件名.py TODO
|
||||
dataset_class_name='MyDataset_model' # 数据集 类名称 TODO
|
||||
data_root='/home/audience/Desktop/Seg_data/Data' # 数据根目录
|
||||
img_scale=(1920, 1080) # 图片大小
|
||||
# 训练、验证、测试集所在文件夹
|
||||
train_img_path='A_Ori'
|
||||
train_seg_map_path='A_Label_GT_label_fold'
|
||||
val_img_path='A_Ori'
|
||||
val_seg_map_path='A_Label_GT_label_fold'
|
||||
test_img_path='A_Ori'
|
||||
test_seg_map_path='A_Label_GT_label_fold'
|
||||
|
||||
########### 1.2.定义Label图片相关参数 ###########
|
||||
classes = ['肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉']
|
||||
palette = [[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118],[0,157,142],[181,85,105],[42,8,66]]
|
||||
# 这里的classes一定是经过“Initial_Gen_mmseg_datasets_my_dataset.py”处理的
|
||||
classes_all = [
|
||||
['背景','肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉'],
|
||||
]
|
||||
palette_all = [
|
||||
[[0,0,0],[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118],[0,157,142],[181,85,105],[42,8,66]],
|
||||
]
|
||||
# 一般不太会变的参数
|
||||
img_suffix = ".png"
|
||||
seg_map_suffix = "_gtFine_labelTrainIds.png"
|
||||
reduce_zero_label = False # 在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】
|
||||
|
||||
########### 1.3.定义训练相关参数 ###########
|
||||
# 一般不太会变的参数
|
||||
crop_size=(512, 512) # 分割大小
|
||||
train_batch_size=4 # 训练batch
|
||||
train_num_workers=4 # 训练并行运行数量
|
||||
val_and_test_batch_size=1 # 验证集和测试集batch
|
||||
val_and_test_num_workers=4 # 验证集和测试集并行运行数量
|
||||
|
||||
########### 2.文件存储位置 ###########
|
||||
output_configs_base_datasets_my_dataset=f'./configs/_base_/datasets/{dataset_file_name}.py'
|
||||
output_mmseg_datasets_dataset_file_name = os.path.join(f'./mmseg/datasets/{dataset_file_name}.py')
|
||||
output_mmseg_datasets_init = os.path.join('./mmseg/datasets/__init__.py')
|
||||
output_mmseg_utils_class_names = f'./mmseg/utils/class_names.py'
|
||||
|
||||
########### 3.运行程序生成配置文件 ###########
|
||||
success = generate_configs_base_datasets_my_dataset_file(output_file=output_configs_base_datasets_my_dataset, dataset_class_name=dataset_class_name , data_root=data_root, img_scale=img_scale, crop_size=crop_size, train_batch_size=train_batch_size, train_num_workers=train_num_workers, val_and_test_batch_size=val_and_test_batch_size, val_and_test_num_workers=val_and_test_num_workers, train_img_path=train_img_path, train_seg_map_path=train_seg_map_path, val_img_path=val_img_path, val_seg_map_path=val_seg_map_path, test_img_path=test_img_path, test_seg_map_path=test_seg_map_path)
|
||||
success, classes, palette = generate_mmseg_datasets_my_dataset_file(output_file=output_mmseg_datasets_dataset_file_name, dataset_class_name=dataset_class_name, classes=classes, palette=palette, img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, reduce_zero_label=reduce_zero_label)
|
||||
# 需要用到上一步的classes和palette
|
||||
success = generate_mmseg_datasets_init_file(output_file=output_mmseg_datasets_init, dataset_file_names=dataset_file_names, dataset_class_names=dataset_class_names)
|
||||
success = generate_mmseg_utils_class_names_file(output_file=output_mmseg_utils_class_names, dataset_file_names=dataset_file_names, classes_all=classes_all, palette_all=palette_all)
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
import os, json
|
||||
from Initial_Data_Program.Initial_Data_Calculate_std_and_mean import calculate_pic_std_and_mean
|
||||
from Initial_Data_Program.Initial_Data_Gen_configs_base_datasets_my_dataset import generate_configs_base_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_utils_class_names import generate_mmseg_utils_class_names_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_my_dataset import generate_mmseg_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_init_ import generate_mmseg_datasets_init_file
|
||||
|
||||
def load_json_files(directory, not_check_list=None):
|
||||
"""
|
||||
读取指定文件夹下的所有 JSON 文件,排除指定的文件。
|
||||
|
||||
:param directory: 要读取的目录
|
||||
:param not_check_list: 要排除的文件名列表,不包含路径。默认为空列表
|
||||
:return: 包含所有有效 JSON 数据的列表,每个元素为一个字典
|
||||
"""
|
||||
if not_check_list is None:
|
||||
not_check_list = [''] # 默认排除文件
|
||||
|
||||
# 获取所有 .json 文件,并排除在 not_check_list 中的文件
|
||||
json_files = [f for f in os.listdir(directory) if f.endswith('.json') and f not in not_check_list]
|
||||
|
||||
data_list = []
|
||||
|
||||
# 遍历每个 JSON 文件
|
||||
for json_file in json_files:
|
||||
json_path = os.path.join(directory, json_file)
|
||||
|
||||
try:
|
||||
# 打开并加载 JSON 文件
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
# 将文件名(去掉 .json)添加到数据中
|
||||
data['file_name_json'] = json_file.rstrip(".json")
|
||||
data_list.append(data)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"\033[91mError decoding JSON file: {json_path}\033[0m")
|
||||
except Exception as e:
|
||||
print(f"\033[91mError reading file {json_path}: {str(e)}\033[0m")
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def process_json_data(json_data):
|
||||
"""处理每个 JSON 数据,生成相应参数"""
|
||||
# 提取 dataset 信息
|
||||
dataset_file_name = json_data["dataset_info"]["dataset_file_name"]
|
||||
dataset_class_name = json_data["dataset_info"]["dataset_class_name"]
|
||||
data_root = json_data["dataset_info"]["data_root"]
|
||||
|
||||
# 转换 img_scale 为元组 (img_scale_width, img_scale_height)
|
||||
img_scale = (json_data["dataset_info"]["img_scale_width"], json_data["dataset_info"]["img_scale_height"])
|
||||
|
||||
# 提取其他必要信息
|
||||
train_img_path = json_data["dataset_info"]["paths"]["train_img_path"]
|
||||
train_seg_map_path = json_data["dataset_info"]["paths"]["train_seg_map_path"]
|
||||
val_img_path = json_data["dataset_info"]["paths"]["val_img_path"]
|
||||
val_seg_map_path = json_data["dataset_info"]["paths"]["val_seg_map_path"]
|
||||
test_img_path = json_data["dataset_info"]["paths"]["test_img_path"]
|
||||
test_seg_map_path = json_data["dataset_info"]["paths"]["test_seg_map_path"]
|
||||
|
||||
# 提取 label 相关信息
|
||||
classes = json_data["label_info"]["classes"]
|
||||
palette = json_data["label_info"]["palette"]
|
||||
img_suffix = json_data["label_info"]["img_suffix"]
|
||||
seg_map_suffix = json_data["label_info"]["seg_map_suffix"]
|
||||
reduce_zero_label = json_data["label_info"]["reduce_zero_label"]
|
||||
|
||||
# 提取训练相关参数
|
||||
# 转换 crop_size 为元组 (crop_size_width, crop_size_height)
|
||||
crop_size = (json_data["training_info"]["crop_size_width"], json_data["training_info"]["crop_size_height"])
|
||||
train_batch_size = json_data["training_info"]["train_batch_size"]
|
||||
train_num_workers = json_data["training_info"]["train_num_workers"]
|
||||
val_and_test_batch_size = json_data["training_info"]["val_and_test_batch_size"]
|
||||
val_and_test_num_workers = json_data["training_info"]["val_and_test_num_workers"]
|
||||
|
||||
return (dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
|
||||
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,)
|
||||
|
||||
|
||||
def save_all_record_to_json(output_file, dataset_file_names, classes_all, palette_all, palette_num_all, mean_all, std_all):
|
||||
"""构建一个 JSON 文件,用于存储每个数据集的信息"""
|
||||
# 构建一个字典,用于存储每个数据集的信息
|
||||
data_record = {}
|
||||
|
||||
# 假设所有列表长度相同,遍历每个 dataset_file_name
|
||||
for i in range(len(dataset_file_names)):
|
||||
data_record[dataset_file_names[i]] = {
|
||||
"classes": classes_all[i],
|
||||
"palette": palette_all[i],
|
||||
"palette_num": palette_num_all[i],
|
||||
"mean": list(mean_all[i]),
|
||||
"std": list(std_all[i])
|
||||
}
|
||||
|
||||
# 将字典写入 JSON 文件
|
||||
with open(output_file, 'w', encoding='utf-8') as json_file:
|
||||
json.dump(data_record, json_file, ensure_ascii=False, indent=6)
|
||||
|
||||
print(f"\033[93m相关数据汇总到 {output_file} successfully!\033[0m")
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-定义训练文件夹路径 ###########
|
||||
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
|
||||
all_data_record_json = "All_Data_Record.json" # 记录所有数据
|
||||
|
||||
########### 2.1.遍历所有配置文件,并生成对应数据集 ###########
|
||||
# 定义保存信息的列表
|
||||
dataset_file_names = []
|
||||
dataset_class_names = []
|
||||
classes_all = []
|
||||
palette_all = []
|
||||
palette_num_all = []
|
||||
mean_all = []
|
||||
std_all = []
|
||||
|
||||
# 2.1.1. 从 ./1_Data_Parameter 文件夹读取所有 JSON 文件
|
||||
json_data_list = load_json_files(train_parameter_dir, not_check_list = [all_data_record_json])
|
||||
|
||||
# 2.1.2. 遍历每个 JSON 数据文件并生成对应配置文件
|
||||
for json_data in json_data_list:
|
||||
# A. 输出当前处理文件
|
||||
print(f"\033[32m正在处理{json_data['file_name_json']}.json文件\033[0m")
|
||||
|
||||
# B. 处理 JSON 数据并提取参数
|
||||
(dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
|
||||
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,
|
||||
) = process_json_data(json_data)
|
||||
|
||||
# 保存文件名和类别名
|
||||
dataset_file_names.append(dataset_file_name)
|
||||
dataset_class_names.append(dataset_class_name)
|
||||
|
||||
# C. 文件存储位置
|
||||
output_configs_base_datasets_my_dataset = f'./configs/_base_/datasets/{dataset_file_name}.py'
|
||||
output_mmseg_datasets_dataset_file_name = os.path.join(f'./mmseg/datasets/{dataset_file_name}.py')
|
||||
output_mmseg_datasets_init = os.path.join('./mmseg/datasets/__init__.py')
|
||||
output_mmseg_utils_class_names = f'./mmseg/utils/class_names.py'
|
||||
|
||||
# D. 运行程序生成配置文件
|
||||
# 生成 ./configs/_base_/datasets/{dataset_file_name}.py
|
||||
print(" ",end='')
|
||||
success = generate_configs_base_datasets_my_dataset_file(
|
||||
output_file=output_configs_base_datasets_my_dataset,
|
||||
dataset_class_name=dataset_class_name,
|
||||
data_root=data_root,
|
||||
img_scale=img_scale,
|
||||
crop_size=crop_size,
|
||||
train_batch_size=train_batch_size,
|
||||
train_num_workers=train_num_workers,
|
||||
val_and_test_batch_size=val_and_test_batch_size,
|
||||
val_and_test_num_workers=val_and_test_num_workers,
|
||||
train_img_path=train_img_path,
|
||||
train_seg_map_path=train_seg_map_path,
|
||||
val_img_path=val_img_path,
|
||||
val_seg_map_path=val_seg_map_path,
|
||||
test_img_path=test_img_path,
|
||||
test_seg_map_path=test_seg_map_path
|
||||
)
|
||||
|
||||
# 生成 ./mmseg/datasets/{dataset_file_name}.py
|
||||
print(" ",end='')
|
||||
success, classes, palette = generate_mmseg_datasets_my_dataset_file(
|
||||
output_file=output_mmseg_datasets_dataset_file_name,
|
||||
dataset_class_name=dataset_class_name,
|
||||
classes=classes,
|
||||
palette=palette,
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label
|
||||
)
|
||||
|
||||
mean, std = calculate_pic_std_and_mean(
|
||||
dataset_dir = os.path.join(data_root, train_img_path)
|
||||
)
|
||||
|
||||
# 保存标注类名和颜色
|
||||
classes_all.append(classes)
|
||||
palette_all.append(palette)
|
||||
palette_num_all.append(len(classes))
|
||||
mean_all.append(mean)
|
||||
std_all.append(std)
|
||||
|
||||
########### 2.2.汇总所有信息运行生成 init 和 class_names 文件 ###########
|
||||
# 生成 ./mmseg/datasets/__init__.py
|
||||
print(" ",end='')
|
||||
success = generate_mmseg_datasets_init_file(
|
||||
output_file=output_mmseg_datasets_init,
|
||||
dataset_file_names=dataset_file_names,
|
||||
dataset_class_names=dataset_class_names
|
||||
)
|
||||
|
||||
# 生成 ./mmseg/utils/class_names.py
|
||||
print(" ",end='')
|
||||
success = generate_mmseg_utils_class_names_file(
|
||||
output_file=output_mmseg_utils_class_names,
|
||||
dataset_file_names=dataset_file_names,
|
||||
classes_all=classes_all,
|
||||
palette_all=palette_all
|
||||
)
|
||||
|
||||
########### 2.2.汇总dataset_file_names、classes_all、palette_all、palette_num_all所有信息到My_All_In_One/1_Data_Parameter/All_Data_Record.json文件 ###########
|
||||
output_all_data_record = os.path.join(train_parameter_dir, all_data_record_json)
|
||||
# 调用函数保存数据
|
||||
success = save_all_record_to_json(output_file=output_all_data_record, dataset_file_names=dataset_file_names, classes_all=classes_all, palette_all=palette_all, palette_num_all=palette_num_all, mean_all=mean_all, std_all=std_all)
|
||||
|
||||
@@ -0,0 +1,212 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
import os, json
|
||||
from Initial_Data_Program.Initial_Data_Calculate_std_and_mean import calculate_pic_std_and_mean
|
||||
from Initial_Data_Program.Initial_Data_Gen_configs_base_datasets_my_dataset import generate_configs_base_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_utils_class_names import generate_mmseg_utils_class_names_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_my_dataset import generate_mmseg_datasets_my_dataset_file
|
||||
from Initial_Data_Program.Initial_Data_Gen_mmseg_datasets_init_ import generate_mmseg_datasets_init_file
|
||||
|
||||
def load_json_files(directory, not_check_list=None):
|
||||
"""
|
||||
读取指定文件夹下的所有 JSON 文件,排除指定的文件。
|
||||
|
||||
:param directory: 要读取的目录
|
||||
:param not_check_list: 要排除的文件名列表,不包含路径。默认为空列表
|
||||
:return: 包含所有有效 JSON 数据的列表,每个元素为一个字典
|
||||
"""
|
||||
if not_check_list is None:
|
||||
not_check_list = [''] # 默认排除文件
|
||||
|
||||
# 获取所有 .json 文件,并排除在 not_check_list 中的文件
|
||||
json_files = [f for f in os.listdir(directory) if f.endswith('.json') and f not in not_check_list]
|
||||
|
||||
data_list = []
|
||||
|
||||
# 遍历每个 JSON 文件
|
||||
for json_file in json_files:
|
||||
json_path = os.path.join(directory, json_file)
|
||||
|
||||
try:
|
||||
# 打开并加载 JSON 文件
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
# 将文件名(去掉 .json)添加到数据中
|
||||
data['file_name_json'] = json_file.rstrip(".json")
|
||||
data_list.append(data)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"\033[91mError decoding JSON file: {json_path}\033[0m")
|
||||
except Exception as e:
|
||||
print(f"\033[91mError reading file {json_path}: {str(e)}\033[0m")
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def process_json_data(json_data):
|
||||
"""处理每个 JSON 数据,生成相应参数"""
|
||||
# 提取 dataset 信息
|
||||
dataset_file_name = json_data["dataset_info"]["dataset_file_name"]
|
||||
dataset_class_name = json_data["dataset_info"]["dataset_class_name"]
|
||||
data_root = json_data["dataset_info"]["data_root"]
|
||||
|
||||
# 转换 img_scale 为元组 (img_scale_width, img_scale_height)
|
||||
img_scale = (json_data["dataset_info"]["img_scale_width"], json_data["dataset_info"]["img_scale_height"])
|
||||
|
||||
# 提取其他必要信息
|
||||
train_img_path = json_data["dataset_info"]["paths"]["train_img_path"]
|
||||
train_seg_map_path = json_data["dataset_info"]["paths"]["train_seg_map_path"]
|
||||
val_img_path = json_data["dataset_info"]["paths"]["val_img_path"]
|
||||
val_seg_map_path = json_data["dataset_info"]["paths"]["val_seg_map_path"]
|
||||
test_img_path = json_data["dataset_info"]["paths"]["test_img_path"]
|
||||
test_seg_map_path = json_data["dataset_info"]["paths"]["test_seg_map_path"]
|
||||
|
||||
# 提取 label 相关信息
|
||||
classes = json_data["label_info"]["classes"]
|
||||
palette = json_data["label_info"]["palette"]
|
||||
img_suffix = json_data["label_info"]["img_suffix"]
|
||||
seg_map_suffix = json_data["label_info"]["seg_map_suffix"]
|
||||
reduce_zero_label = json_data["label_info"]["reduce_zero_label"]
|
||||
|
||||
# 提取训练相关参数
|
||||
# 转换 crop_size 为元组 (crop_size_width, crop_size_height)
|
||||
crop_size = (json_data["training_info"]["crop_size_width"], json_data["training_info"]["crop_size_height"])
|
||||
train_batch_size = json_data["training_info"]["train_batch_size"]
|
||||
train_num_workers = json_data["training_info"]["train_num_workers"]
|
||||
val_and_test_batch_size = json_data["training_info"]["val_and_test_batch_size"]
|
||||
val_and_test_num_workers = json_data["training_info"]["val_and_test_num_workers"]
|
||||
|
||||
return (dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
|
||||
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,)
|
||||
|
||||
|
||||
def save_all_record_to_json(output_file, dataset_file_names, classes_all, palette_all, palette_num_all, mean_all, std_all, train_imgs_num):
|
||||
"""构建一个 JSON 文件,用于存储每个数据集的信息"""
|
||||
# 构建一个字典,用于存储每个数据集的信息
|
||||
data_record = {}
|
||||
|
||||
# 假设所有列表长度相同,遍历每个 dataset_file_name
|
||||
for i in range(len(dataset_file_names)):
|
||||
data_record[dataset_file_names[i]] = {
|
||||
"classes": classes_all[i],
|
||||
"palette": palette_all[i],
|
||||
"palette_num": palette_num_all[i],
|
||||
"mean": list(mean_all[i]),
|
||||
"std": list(std_all[i]),
|
||||
"train_imgs_num": train_imgs_num[i]
|
||||
}
|
||||
|
||||
# 将字典写入 JSON 文件
|
||||
with open(output_file, 'w', encoding='utf-8') as json_file:
|
||||
json.dump(data_record, json_file, ensure_ascii=False, indent=6)
|
||||
|
||||
print(f"\033[93m相关数据汇总到 {output_file} successfully!\033[0m")
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-定义训练文件夹路径 ###########
|
||||
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
|
||||
all_data_record_json = "All_Data_Record.json" # 记录所有数据
|
||||
|
||||
########### 2.1.遍历所有配置文件,并生成对应数据集 ###########
|
||||
# 定义保存信息的列表
|
||||
dataset_file_names = []
|
||||
dataset_class_names = []
|
||||
classes_all = []
|
||||
palette_all = []
|
||||
palette_num_all = []
|
||||
mean_all = []
|
||||
std_all = []
|
||||
train_imgs_num = []
|
||||
|
||||
# 2.1.1. 从 ./1_Data_Parameter 文件夹读取所有 JSON 文件
|
||||
json_data_list = load_json_files(train_parameter_dir, not_check_list = [all_data_record_json])
|
||||
|
||||
# 2.1.2. 遍历每个 JSON 数据文件并生成对应配置文件
|
||||
for json_data in json_data_list:
|
||||
# A. 输出当前处理文件
|
||||
print(f"\033[32m正在处理{json_data['file_name_json']}.json文件\033[0m")
|
||||
|
||||
# B. 处理 JSON 数据并提取参数
|
||||
(dataset_file_name, dataset_class_name, data_root, img_scale, train_img_path, train_seg_map_path, val_img_path, val_seg_map_path, test_img_path, test_seg_map_path,
|
||||
classes, palette, img_suffix, seg_map_suffix, reduce_zero_label, crop_size, train_batch_size, train_num_workers, val_and_test_batch_size, val_and_test_num_workers,
|
||||
) = process_json_data(json_data)
|
||||
|
||||
# 保存文件名和类别名
|
||||
dataset_file_names.append(dataset_file_name)
|
||||
dataset_class_names.append(dataset_class_name)
|
||||
|
||||
# C. 文件存储位置
|
||||
output_configs_base_datasets_my_dataset = f'./configs/_base_/datasets/{dataset_file_name}.py'
|
||||
output_mmseg_datasets_dataset_file_name = os.path.join(f'./mmseg/datasets/{dataset_file_name}.py')
|
||||
output_mmseg_datasets_init = os.path.join('./mmseg/datasets/__init__.py')
|
||||
output_mmseg_utils_class_names = f'./mmseg/utils/class_names.py'
|
||||
|
||||
# D. 运行程序生成配置文件
|
||||
# 生成 ./configs/_base_/datasets/{dataset_file_name}.py
|
||||
print(" ",end='')
|
||||
success = generate_configs_base_datasets_my_dataset_file(
|
||||
output_file=output_configs_base_datasets_my_dataset,
|
||||
dataset_class_name=dataset_class_name,
|
||||
data_root=data_root,
|
||||
img_scale=img_scale,
|
||||
crop_size=crop_size,
|
||||
train_batch_size=train_batch_size,
|
||||
train_num_workers=train_num_workers,
|
||||
val_and_test_batch_size=val_and_test_batch_size,
|
||||
val_and_test_num_workers=val_and_test_num_workers,
|
||||
train_img_path=train_img_path,
|
||||
train_seg_map_path=train_seg_map_path,
|
||||
val_img_path=val_img_path,
|
||||
val_seg_map_path=val_seg_map_path,
|
||||
test_img_path=test_img_path,
|
||||
test_seg_map_path=test_seg_map_path
|
||||
)
|
||||
|
||||
# 生成 ./mmseg/datasets/{dataset_file_name}.py
|
||||
print(" ",end='')
|
||||
success, classes, palette = generate_mmseg_datasets_my_dataset_file(
|
||||
output_file=output_mmseg_datasets_dataset_file_name,
|
||||
dataset_class_name=dataset_class_name,
|
||||
classes=classes,
|
||||
palette=palette,
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label
|
||||
)
|
||||
|
||||
mean, std = calculate_pic_std_and_mean(
|
||||
dataset_dir = os.path.join(data_root, train_img_path)
|
||||
)
|
||||
|
||||
# 保存标注类名和颜色
|
||||
classes_all.append(classes)
|
||||
palette_all.append(palette)
|
||||
palette_num_all.append(len(classes))
|
||||
mean_all.append(mean)
|
||||
std_all.append(std)
|
||||
train_imgs_num.append(len(os.listdir(os.path.join(data_root, train_img_path))))
|
||||
|
||||
########### 2.2.汇总所有信息运行生成 init 和 class_names 文件 ###########
|
||||
# 生成 ./mmseg/datasets/__init__.py
|
||||
print(" ",end='')
|
||||
success = generate_mmseg_datasets_init_file(
|
||||
output_file=output_mmseg_datasets_init,
|
||||
dataset_file_names=dataset_file_names,
|
||||
dataset_class_names=dataset_class_names
|
||||
)
|
||||
|
||||
# 生成 ./mmseg/utils/class_names.py
|
||||
print(" ",end='')
|
||||
success = generate_mmseg_utils_class_names_file(
|
||||
output_file=output_mmseg_utils_class_names,
|
||||
dataset_file_names=dataset_file_names,
|
||||
classes_all=classes_all,
|
||||
palette_all=palette_all
|
||||
)
|
||||
|
||||
########### 2.2.汇总dataset_file_names、classes_all、palette_all、palette_num_all所有信息到My_All_In_One/1_Data_Parameter/All_Data_Record.json文件 ###########
|
||||
output_all_data_record = os.path.join(train_parameter_dir, all_data_record_json)
|
||||
# 调用函数保存数据
|
||||
success = save_all_record_to_json(output_file=output_all_data_record, dataset_file_names=dataset_file_names, classes_all=classes_all, palette_all=palette_all, palette_num_all=palette_num_all, mean_all=mean_all, std_all=std_all, train_imgs_num=train_imgs_num)
|
||||
|
||||
101
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_ann_r50.py
Normal file
101
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_ann_r50.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'ann_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:ann【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/ann/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,102 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'apcnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:apcnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/apcnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,176 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
# 交互式选择 decode_head 的函数
|
||||
def select_decode_head(decode_head_choose):
|
||||
print("可用的 decode head 选项:")
|
||||
for i, key in enumerate(decode_head_choose.keys()):
|
||||
print(f"{i + 1}. {key}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 提示用户输入选择
|
||||
choice = int(input("请选择需要的 decode head(输入编号):"))
|
||||
# 检查输入是否在有效范围内
|
||||
if 1 <= choice <= len(decode_head_choose):
|
||||
selected_key = list(decode_head_choose.keys())[choice - 1]
|
||||
print(f"你选择了: {selected_key}")
|
||||
return decode_head_choose[selected_key]
|
||||
else:
|
||||
print("输入的编号不正确,请重新输入。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入有效的编号。")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'upernet_beit'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(640, 640)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
# 获取backbone模型、是否需要预训练
|
||||
model_list = ['pretrain/beit_base_patch16_224_pt22k_ft22k.pth', 'pretrain/beit_large_patch16_224_pt22k_ft22k.pth']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list) # 需要选择是否用预训练模型
|
||||
|
||||
backbone = create_dict_by_kwargs(type='BEiT', img_size=crop_size,)
|
||||
|
||||
decode_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
|
||||
decode_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
|
||||
|
||||
auxiliary_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori
|
||||
auxiliary_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1
|
||||
|
||||
if selected_model_name == 'pretrain/beit_base_patch16_224_pt22k_ft22k.pth':
|
||||
backbone_new=dict(
|
||||
# patch_size=16,
|
||||
# in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
# mlp_ratio=4,
|
||||
out_indices=(3, 5, 7, 11),
|
||||
# qv_bias=True,
|
||||
# attn_drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
# norm_cfg=dict(type='LN', eps=1e-6),
|
||||
# act_cfg=dict(type='GELU'),
|
||||
# norm_eval=False,
|
||||
init_values=0.1)
|
||||
neck = dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5])
|
||||
decode_head=dict(in_channels=[768, 768, 768, 768], num_classes=num_classes, channels=768, norm_cfg=norm_cfg, loss_decode=decode_head_loss_decode)
|
||||
auxiliary_head=dict(norm_cfg=norm_cfg, in_channels=768, num_classes=num_classes, loss_decode = auxiliary_head_loss_decode)
|
||||
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='AmpOptimWrapper', # type='OptimWrapper', # TODO Ori
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=3e-5, betas=(0.9, 0.999), weight_decay=0.05),
|
||||
constructor='LayerDecayOptimizerConstructor',
|
||||
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9))
|
||||
|
||||
model_size = 'base'
|
||||
|
||||
train_dataloader, batch_size = generate_train_dataloader(2)
|
||||
|
||||
elif selected_model_name == 'pretrain/beit_large_patch16_224_pt22k_ft22k.pth':
|
||||
backbone_new=dict(
|
||||
embed_dims=1024,
|
||||
num_layers=24,
|
||||
num_heads=16,
|
||||
# mlp_ratio=4,
|
||||
# qv_bias=True,
|
||||
init_values=1e-6,
|
||||
drop_path_rate=0.2,
|
||||
out_indices=[7, 11, 15, 23])
|
||||
neck=dict(type='Feature2Pyramid', embed_dim=1024, rescales=[4, 2, 1, 0.5])
|
||||
decode_head=dict(in_channels=[1024, 1024, 1024, 1024], num_classes=num_classes, channels=1024, norm_cfg=norm_cfg, loss_decode=decode_head_loss_decode)
|
||||
auxiliary_head=dict(norm_cfg=norm_cfg, in_channels=1024, num_classes=num_classes, loss_decode = auxiliary_head_loss_decode)
|
||||
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='AmpOptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05),
|
||||
constructor='LayerDecayOptimizerConstructor',
|
||||
paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.95),
|
||||
accumulative_counts=2)
|
||||
|
||||
model_size = 'large'
|
||||
|
||||
train_dataloader, batch_size = generate_train_dataloader(1)
|
||||
|
||||
# 更新backbone
|
||||
backbone.update(backbone_new)
|
||||
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
pretrained = pretrained_pth,
|
||||
neck = neck,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = optim_wrapper
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:beit【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}-{model_size}_b{batch_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/beit/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler, train_dataloader = train_dataloader)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,133 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'bisenetv1_r18-d32'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (1024, 1024)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
# 获取backbone模型、是否需要预训练
|
||||
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list, need_select_pretrained = True) # 需要选择是否用预训练模型
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
if select_pretrained == True:
|
||||
pretrained_txt = 'Pre'
|
||||
else:
|
||||
pretrained_txt = 'NoPre'
|
||||
|
||||
# 模型信息
|
||||
if selected_model_name == 'openmmlab/resnet18_v1c' :
|
||||
backbone_context_channels = (128, 256, 512)
|
||||
backbone_spatial_channels = (64, 64, 64, 128)
|
||||
backbone_out_channels = 256
|
||||
decode_head_in_channels = 256
|
||||
decode_head_channels = 256
|
||||
auxiliary_head_in_channels = 128
|
||||
auxiliary_head_channels = 64
|
||||
|
||||
elif selected_model_name == 'openmmlab/resnet50_v1c' or selected_model_name == 'openmmlab/resnet101_v1c' :
|
||||
backbone_context_channels = (512, 1024, 2048)
|
||||
backbone_spatial_channels = (256, 256, 256, 512)
|
||||
backbone_out_channels = 1024
|
||||
decode_head_in_channels = 1024
|
||||
decode_head_channels = 1024
|
||||
auxiliary_head_in_channels = 512
|
||||
auxiliary_head_channels = 256
|
||||
|
||||
# 需要选择预训练模型
|
||||
if select_pretrained == True:
|
||||
backbone_backbone_cfg = create_dict_by_kwargs(type='ResNet', norm_cfg=norm_cfg, depth=depth, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
else:
|
||||
backbone_backbone_cfg = create_dict_by_kwargs(type='ResNet', norm_cfg=norm_cfg, depth=depth)
|
||||
|
||||
backbone = create_dict_by_kwargs(context_channels=backbone_context_channels, norm_cfg=norm_cfg, spatial_channels=backbone_spatial_channels, out_channels=backbone_out_channels, backbone_cfg=backbone_backbone_cfg)
|
||||
|
||||
# decode、auxiliary损失下载方式 TODO
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
decode_head = create_dict_by_kwargs(type='FCNHead',in_channels=decode_head_in_channels, channels=decode_head_channels, num_classes=num_classes)
|
||||
|
||||
# auxiliary损失下载方式 TODO # 可更改
|
||||
# auxiliary_head_loss_decode_dict = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO # 可更改
|
||||
auxiliary_head_loss_decode_dict = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
auxiliary_head = [create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict),
|
||||
create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict)]
|
||||
# 获取原始auxiliary_head
|
||||
auxiliary_head = get_var_from_py_file(alg_file_pth, var_name="model")["auxiliary_head"]
|
||||
# 新的auxiliary_head
|
||||
auxiliary_head_new = [create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict),
|
||||
create_dict_by_kwargs(type='FCNHead', norm_cfg=norm_cfg, in_channels=auxiliary_head_in_channels, channels=auxiliary_head_channels, num_classes=num_classes, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict)]
|
||||
# 更新auxiliary_head
|
||||
auxiliary_head = update_list_dict_var(auxiliary_head, auxiliary_head_new)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:bisenetv1【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/bisenetv1/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
111
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_bisenetv2.py
Normal file
111
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_bisenetv2.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'bisenetv2'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
# 3.3.1. 预处理data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
|
||||
# 3.3.2. 采样函数sampler
|
||||
selected_sampler_name, use_sampler, selected_sampler_info = select_sampler(sampler_list=['OHEMPixelSampler'])
|
||||
if use_sampler == True:
|
||||
decode_head_sampler=selected_sampler_info
|
||||
auxiliary_head_sampler=selected_sampler_info
|
||||
use_sampler_txt = 'ohempSampler' # 标记
|
||||
else:
|
||||
decode_head_sampler=None
|
||||
auxiliary_head_sampler=None
|
||||
use_sampler_txt = ''
|
||||
|
||||
# 3.3.3. 解码器decode_head
|
||||
# decode、auxiliary损失下载方式 TODO
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
decode_head = create_dict_by_kwargs(type='FCNHead', loss_decode=decode_head_loss_decode_dict, sampler=decode_head_sampler, num_classes=num_classes)
|
||||
|
||||
# 3.3.3. 辅助部分auxiliary_head
|
||||
# auxiliary损失下载方式 TODO # 可更改
|
||||
auxiliary_head_loss_decode_dict = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4, reduction="none") # DiceLoss损失函数 # TODO # 可更改
|
||||
# auxiliary_head_loss_decode_dict = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
# 获取原始auxiliary_head
|
||||
auxiliary_head = get_var_from_py_file(alg_file_pth, var_name="model")["auxiliary_head"]
|
||||
# 新的auxiliary_head
|
||||
auxiliary_head_new = [create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,),
|
||||
create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,),
|
||||
create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,),
|
||||
create_dict_by_kwargs(type='FCNHead', sampler=auxiliary_head_sampler, norm_cfg=norm_cfg, align_corners=False, loss_decode=auxiliary_head_loss_decode_dict, num_classes=num_classes,)]
|
||||
# 更新auxiliary_head
|
||||
auxiliary_head = update_list_dict_var(auxiliary_head, auxiliary_head_new)
|
||||
|
||||
# 3.3.4. 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 4.1. 算法名称解析:bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_fcn_{use_sampler_txt+'_' if use_sampler_txt != '' else ''}g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/bisenetv2/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 4.2. 将信息临时写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
102
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_ccnet_r50.py
Normal file
102
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_ccnet_r50.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'ccnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:ccnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/ccnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,95 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'cgnet'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(680, 680), (512, 1024)]) # 选择切割大小
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['decode_head']
|
||||
|
||||
# Way Ori: loss_decode
|
||||
decode_head_loss_decode_dict = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0,
|
||||
class_weight=[
|
||||
2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
|
||||
10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
|
||||
10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
|
||||
10.396974, 10.055647
|
||||
]
|
||||
)
|
||||
# Way 1: loss_decode
|
||||
decode_head_loss_decode_dict = dict(_delete_=True, type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
decode_head['_delete_'] = True
|
||||
decode_head['loss_decode'] = decode_head_loss_decode_dict
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# 综合model
|
||||
model = dict(data_preprocessor = model_data_preprocessor, decode_head=decode_head)
|
||||
|
||||
# test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:cgnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_fcn_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/cgnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
102
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_danet_r50.py
Normal file
102
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_danet_r50.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'danet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:danet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/danet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,95 @@
|
||||
import os, sys, argparse, json
|
||||
import importlib.util
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'ddrnet'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
# 3.3.1. 预处理data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
|
||||
# # 3.3.2. 骨架backbone、解码器decode_head
|
||||
model_list = ["openmmlab/ddrnet23-s", "openmmlab/ddrnet23"]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
if selected_model_name == "openmmlab/ddrnet23-s":
|
||||
backbone = dict(channels=32, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
decode_head = dict(in_channels=32 * 4, channels=64, num_classes=num_classes)
|
||||
model_size = 'small'
|
||||
elif selected_model_name == "openmmlab/ddrnet23":
|
||||
backbone = dict(channels=64, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
decode_head = dict(in_channels=64 * 4, channels=128, num_classes=num_classes)
|
||||
model_size = 'normal'
|
||||
else:
|
||||
quit("Error: 未知的模型名称")
|
||||
|
||||
# 3.3.4. 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
backbone = backbone,
|
||||
decode_head = decode_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 4.1. 算法名称解析:bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_{model_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/ddrnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 4.2. 将信息临时写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,135 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'deeplabv3_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769),(1280,1280)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', 'torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', ]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
type_model = selected_model_info['type']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
# 3.3.2. 采样函数sampler
|
||||
selected_sampler_name, use_sampler, selected_sampler_info = select_sampler(sampler_list=['OHEMPixelSampler'])
|
||||
if use_sampler == True:
|
||||
decode_head_sampler=selected_sampler_info
|
||||
auxiliary_head_sampler=selected_sampler_info
|
||||
use_sampler_txt = 'ohempSampler' # 标记
|
||||
backbone_dilations=(1, 1, 1, 2)
|
||||
backbone_strides=(1, 2, 2, 1)
|
||||
backbone_multi_grid=(1, 2, 4)
|
||||
decode_head_dilations=(1, 6, 12, 18)
|
||||
else:
|
||||
decode_head_sampler=None
|
||||
auxiliary_head_sampler=None
|
||||
use_sampler_txt = ''
|
||||
backbone_dilations=None # (1, 1, 2, 4),
|
||||
backbone_strides=None # (1, 2, 1, 1)
|
||||
backbone_multi_grid=None # 没有
|
||||
decode_head_dilations=None # (1, 12, 24, 36),
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
|
||||
backbone = create_dict_by_kwargs(depth=depth, type=type_model, dilations=backbone_dilations, strides=backbone_strides, multi_grid=backbone_multi_grid) # generate_model_backbone(depth=depth,)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
if depth == 18:
|
||||
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, in_channels=512, channels=128, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, in_channels=2048, channels=512, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
|
||||
if depth == 18:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=256, channels=64, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=1024, channels=256, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:deeplabv3【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_{use_sampler_txt+'_' if use_sampler_txt != '' else ''}g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/deeplabv3/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,135 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'deeplabv3plus_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769),(1280,1280)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', 'torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', ]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
type_model = selected_model_info['type']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
# 3.3.2. 采样函数sampler
|
||||
selected_sampler_name, use_sampler, selected_sampler_info = select_sampler(sampler_list=['OHEMPixelSampler'])
|
||||
if use_sampler == True:
|
||||
decode_head_sampler=selected_sampler_info
|
||||
auxiliary_head_sampler=selected_sampler_info
|
||||
use_sampler_txt = 'ohempSampler' # 标记
|
||||
backbone_dilations=(1, 1, 1, 2)
|
||||
backbone_strides=(1, 2, 2, 1)
|
||||
backbone_multi_grid=(1, 2, 4)
|
||||
decode_head_dilations=(1, 6, 12, 18)
|
||||
else:
|
||||
decode_head_sampler=None
|
||||
auxiliary_head_sampler=None
|
||||
use_sampler_txt = ''
|
||||
backbone_dilations=None # (1, 1, 2, 4),
|
||||
backbone_strides=None # (1, 2, 1, 1)
|
||||
backbone_multi_grid=None # 没有
|
||||
decode_head_dilations=None # (1, 12, 24, 36),
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
|
||||
backbone = create_dict_by_kwargs(depth=depth, type=type_model, dilations=backbone_dilations, strides=backbone_strides, multi_grid=backbone_multi_grid) # generate_model_backbone(depth=depth,)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
if depth == 18:
|
||||
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, c1_in_channels=64, c1_channels=12, in_channels=512, channels=128, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
decode_head = create_dict_by_kwargs(sampler=decode_head_sampler, dilations=decode_head_dilations, c1_in_channels=256, c1_channels=48, in_channels=2048, channels=512, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
|
||||
if depth == 18:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=256, channels=64, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=1024, channels=256, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:deeplabv3plus【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_{use_sampler_txt+'_' if use_sampler_txt != '' else ''}g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/deeplabv3plus/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,102 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'dnl_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:dnlnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/dnlnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
100
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_dpt_vit.py
Normal file
100
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_dpt_vit.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'dpt_vit-b16'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(1024, 1024)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
# 3.3.1. 预处理data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
|
||||
model_list = ['pretrain/vit-b16_p16_224-80ecf9dd.pth']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
model_data_preprocessor = data_preprocessor
|
||||
|
||||
# 3.3.2. 解码器decode_head
|
||||
# decode、auxiliary损失下载方式 TODO
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # 默认 TODO
|
||||
decode_head = create_dict_by_kwargs(loss_decode=decode_head_loss_decode_dict, num_classes=num_classes)
|
||||
|
||||
# 3.3.3. 辅助部分auxiliary_head [空]
|
||||
auxiliary_head = None
|
||||
|
||||
# 3.3.4. 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
pretrained = pretrained,
|
||||
decode_head = decode_head,
|
||||
# auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper(type_of_back_bone = "Vit")
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 3.5. 生成train_dataloader部分[1个batch_size就要20G] ###########
|
||||
train_dataloader, batch_size = generate_train_dataloader(batch_size_default=1)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 4.1. 算法名称解析:dpt_vit【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_g{GPU_num}_b{batch_size}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/dpt/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler, train_dataloader=train_dataloader)
|
||||
|
||||
# 4.2. 将信息临时写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,106 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'emanet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
# test_cfg_mode = None
|
||||
# test_cfg_crop_div_stride = crop_size
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:emanet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/emanet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,106 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'encnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
# test_cfg_mode = None
|
||||
# test_cfg_crop_div_stride = crop_size
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:encnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/encnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,83 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'erfnet_fcn'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512,1024), (512, 512)]) # 选择切割大小
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
# decode_head_loss_decode_dict = dict(type='CrossEntropyLoss',
|
||||
# use_sigmoid=False,
|
||||
# loss_weight=1.0)
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
decode_head=dict(loss_decode=decode_head_loss_decode_dict)
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# 综合model
|
||||
model = dict(data_preprocessor = model_data_preprocessor, decode_head=decode_head)
|
||||
|
||||
# test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:erfnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_fcn_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/erfnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,89 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'fast_scnn'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512,1024), (512, 512)]) # 选择切割大小
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
# 3.1. base、norm_cfg、data_preprocessor部分
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
# 3.2. model部分
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# decode损失
|
||||
decode_head_loss_decode_dict = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
decode_head=dict(loss_decode=decode_head_loss_decode_dict)
|
||||
|
||||
# auxiliary损失
|
||||
auxiliary_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['auxiliary_head']
|
||||
for i in range(len(auxiliary_head)):
|
||||
auxiliary_head[i]['loss_decode']['use_sigmoid']=False
|
||||
auxiliary_head[i]['loss_decode']['type']='DiceLoss'
|
||||
auxiliary_head[i]['num_classes']=num_classes
|
||||
auxiliary_head[i]['norm_cfg']=norm_cfg
|
||||
|
||||
# 综合model
|
||||
model = dict(data_preprocessor = model_data_preprocessor, decode_head = decode_head, auxiliary_head=auxiliary_head)
|
||||
|
||||
# 3.3. optim_wrapper、param_scheduler
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:fastscnn【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/fastscnn/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,189 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
# 交互式选择 decode_head 的函数
|
||||
def select_decode_head(decode_head_choose):
|
||||
print("可用的 decode head 选项:")
|
||||
for i, key in enumerate(decode_head_choose.keys()):
|
||||
print(f"{i + 1}. {key}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 提示用户输入选择
|
||||
choice = int(input("请选择需要的 decode head(输入编号):"))
|
||||
# 检查输入是否在有效范围内
|
||||
if 1 <= choice <= len(decode_head_choose):
|
||||
selected_key = list(decode_head_choose.keys())[choice - 1]
|
||||
print(f"你选择了: {selected_key}")
|
||||
return decode_head_choose[selected_key]
|
||||
else:
|
||||
print("输入的编号不正确,请重新输入。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入有效的编号。")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'fastfcn_r50-d32_jpu_psp'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (1024, 1024)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
# 获取backbone模型、是否需要预训练
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list) # 需要选择是否用预训练模型
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
backbone = create_dict_by_kwargs(norm_cfg=norm_cfg, depth=depth)
|
||||
|
||||
# 定义 decode_head 的选项字典
|
||||
decode_head_choose = {
|
||||
'psp-PSPHead': {
|
||||
'decode_head':dict(
|
||||
type='PSPHead',
|
||||
in_channels=2048,
|
||||
in_index=2,
|
||||
channels=512,
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg, # 假设 norm_cfg 预定义
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
),
|
||||
'decode_head_name':'aspp'
|
||||
},
|
||||
'enc-EncHead': {
|
||||
'decode_head':dict(
|
||||
_delete_=True,
|
||||
type='EncHead',
|
||||
in_channels=[512, 1024, 2048],
|
||||
in_index=(0, 1, 2),
|
||||
channels=512,
|
||||
num_codes=32,
|
||||
use_se_loss=True,
|
||||
add_lateral=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
||||
loss_se_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)
|
||||
),
|
||||
'decode_head_name':'aspp'
|
||||
},
|
||||
'aspp-ASPPHead': {
|
||||
'decode_head':dict(
|
||||
_delete_=True,
|
||||
type='ASPPHead',
|
||||
in_channels=2048,
|
||||
in_index=2,
|
||||
channels=512,
|
||||
dilations=(1, 12, 24, 36),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
),
|
||||
'decode_head_name':'aspp'
|
||||
}
|
||||
}
|
||||
|
||||
decode_head_dict = select_decode_head(decode_head_choose=decode_head_choose)
|
||||
decode_head = decode_head_dict['decode_head']
|
||||
decode_head_name = decode_head_dict['decode_head_name']
|
||||
|
||||
decode_head_loss_decode_type = 'DiceLoss' # Way 1: 更改 Loss
|
||||
# decode_head_loss_decode_type = 'CrossEntropyLoss' # Way ori: 不更改
|
||||
|
||||
# 修改decode_head中loss_decode type
|
||||
if 'loss_se_decode' in decode_head.keys():
|
||||
decode_head['loss_se_decode']['type'] = decode_head_loss_decode_type
|
||||
if 'loss_decode' in decode_head.keys():
|
||||
decode_head['loss_decode']['type'] = decode_head_loss_decode_type
|
||||
|
||||
# auxiliary 为 None
|
||||
auxiliary_head_loss_decode_dict = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4)
|
||||
auxiliary_head = dict(loss_decode=auxiliary_head_loss_decode_dict, norm_cfg=norm_cfg, num_classes=num_classes) # Way 1: 更改 Loss
|
||||
# auxiliary_head = dict(norm_cfg=norm_cfg, num_classes=num_classes) # Way ori: 不更改
|
||||
|
||||
# 综合model
|
||||
if auxiliary_head == None:
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head
|
||||
)
|
||||
else:
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:fastfcn【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_jpu_{decode_head_name}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/fastfcn/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
132
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_fcn_r18.py
Normal file
132
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_fcn_r18.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'fcn_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
# 如果crop_size更大则调整dilations和strides
|
||||
if crop_size[0]*crop_size[1] > 512*512:
|
||||
backbone_dilations = (1, 1, 1, 2) # 每一步更精细信息
|
||||
backbone_strides = (1, 2, 2, 1) # 步子走的大
|
||||
decode_head_dilation = 6 # 每一步更多信息
|
||||
auxiliary_head_dilation = 6 # 每一步更多信息
|
||||
else:
|
||||
backbone_dilations = (1, 1, 2, 4) # 每一步更多信息
|
||||
backbone_strides = (1, 2, 1, 1) # 步子走的小
|
||||
decode_head_dilation = 1 # 每一步更精细信息
|
||||
auxiliary_head_dilation = 1 # 每一步更精细信息
|
||||
|
||||
model_list = ['torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', 'open-mmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
model_type = selected_model_info['type']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 模型信息
|
||||
if str(depth) == '18' :
|
||||
backbone_out_channels = 256
|
||||
decode_head_in_channels = 512 #
|
||||
decode_head_channels = 128 #
|
||||
auxiliary_head_in_channels = 256 #
|
||||
auxiliary_head_channels = 64 #
|
||||
|
||||
elif str(depth) == '50' or str(depth) == '101':
|
||||
backbone_out_channels = 1024
|
||||
decode_head_in_channels = 2048 #
|
||||
decode_head_channels = 512 #
|
||||
auxiliary_head_in_channels = 1024 #
|
||||
auxiliary_head_channels = 256 #
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = create_dict_by_kwargs(type=model_type, depth=depth, strides = backbone_strides, dilations=backbone_dilations)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = create_dict_by_kwargs(dilation=decode_head_dilation, channels=decode_head_channels, in_channels=decode_head_in_channels, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = create_dict_by_kwargs(dilation=auxiliary_head_dilation, channels=auxiliary_head_channels, in_channels=auxiliary_head_in_channels, num_classes=num_classes, loss_decode=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:fcn【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/fcn/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
106
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_gcnet_r50.py
Normal file
106
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_gcnet_r50.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'gcnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
# test_cfg_mode = None
|
||||
# test_cfg_crop_div_stride = crop_size
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:gcnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/gcnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
144
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_hrnet_fcn.py
Normal file
144
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_hrnet_fcn.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'fcn_hr18'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
# # 如果crop_size更大则调整dilations和strides
|
||||
# if crop_size[0]*crop_size[1] > 512*512:
|
||||
# backbone_dilations = (1, 1, 1, 2) # 每一步更精细信息
|
||||
# backbone_strides = (1, 2, 2, 1) # 步子走的大
|
||||
# decode_head_dilation = 6 # 每一步更多信息
|
||||
# auxiliary_head_dilation = 6 # 每一步更多信息
|
||||
# else:
|
||||
# backbone_dilations = (1, 1, 2, 4) # 每一步更多信息
|
||||
# backbone_strides = (1, 2, 1, 1) # 步子走的小
|
||||
# decode_head_dilation = 1 # 每一步更精细信息
|
||||
# auxiliary_head_dilation = 1 # 每一步更精细信息
|
||||
|
||||
model_list = ['open-mmlab://msra/hrnetv2_w18', 'open-mmlab://msra/hrnetv2_w18_small', 'open-mmlab://msra/hrnetv2_w48']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 模型信息
|
||||
if selected_model_name == 'open-mmlab://msra/hrnetv2_w18':
|
||||
backbone_extra=dict(
|
||||
stage1=dict(num_blocks=(4, ), num_channels=(64, )),
|
||||
stage2=dict(num_blocks=(4, 4), num_channels=(18, 36)),
|
||||
stage3=dict(num_modules=4, num_blocks=(4, 4, 4), num_channels=(18, 36, 72)),
|
||||
stage4=dict(num_modules=3, num_blocks=(4, 4, 4, 4), num_channels=(18, 36, 72, 144)))
|
||||
|
||||
decode_head_channels = dict(in_channels=[18, 36, 72, 144], channels=sum([18, 36, 72, 144]))
|
||||
alg_text = 'w18'
|
||||
elif selected_model_name == 'open-mmlab://msra/hrnetv2_w18_small':
|
||||
backbone_extra=dict(
|
||||
stage1=dict(num_blocks=(2, )),
|
||||
stage2=dict(num_blocks=(2, 2)),
|
||||
stage3=dict(num_modules=3, num_blocks=(2, 2, 2)),
|
||||
stage4=dict(num_modules=2, num_blocks=(2, 2, 2, 2)))
|
||||
|
||||
decode_head_channels = dict(in_channels=[18, 36, 72, 144], channels=sum([18, 36, 72, 144])) # 同上
|
||||
alg_text = 'w18-small'
|
||||
elif selected_model_name == 'open-mmlab://msra/hrnetv2_w48':
|
||||
# 改变channel
|
||||
backbone_extra=dict(
|
||||
stage2=dict(num_channels=(48, 96)),
|
||||
stage3=dict(num_channels=(48, 96, 192)),
|
||||
stage4=dict(num_channels=(48, 96, 192, 384)))
|
||||
|
||||
decode_head_channels = dict(in_channels=[48, 96, 192, 384], channels=sum([48, 96, 192, 384]))
|
||||
alg_text = 'w48'
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # Way: Ori TODO
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # Way: 1 TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = create_dict_by_kwargs(extra = backbone_extra)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = create_dict_by_kwargs(num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
decode_head.update(decode_head_channels) # 加入channels相关信息
|
||||
|
||||
auxiliary_head = None
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
# auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:hrnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_{alg_text}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/hrnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
124
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_icnet_r18.py
Normal file
124
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_icnet_r18.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'icnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(832, 832), (512, 512)]) # 选择切割大小
|
||||
|
||||
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list, need_select_pretrained=True)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 需要预训练模型
|
||||
if select_pretrained == True:
|
||||
backbone_backbone_cfg_init_cfg = dict(type='Pretrained', checkpoint=pretrained_pth)
|
||||
else:
|
||||
backbone_cfg_init_cfg = None
|
||||
backbone_backbone_cfg = create_dict_by_kwargs(depth=depth, init_cfg=backbone_backbone_cfg_init_cfg)
|
||||
|
||||
# 模型信息
|
||||
if str(depth) == '18' :
|
||||
backbone_layer_channels = (128, 512)
|
||||
elif str(depth) == '50' or str(depth) == '101':
|
||||
backbone_layer_channels = (512, 2048)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1: decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori: decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1: DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori: DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
backbone = create_dict_by_kwargs(backbone_cfg = backbone_backbone_cfg, layer_channels=backbone_layer_channels, norm_cfg=norm_cfg, align_corners=align_corners)
|
||||
|
||||
neck = create_dict_by_kwargs(norm_cfg=norm_cfg, align_corners=align_corners)
|
||||
|
||||
decode_head = create_dict_by_kwargs(num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
|
||||
auxiliary_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['auxiliary_head']
|
||||
for i in range(len(auxiliary_head)):
|
||||
auxiliary_head[i].update(num_classes=num_classes, loss_decode=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:icnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/icnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,106 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'isanet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
# test_cfg_mode = None
|
||||
# test_cfg_crop_div_stride = crop_size
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:isanet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/isanet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
276
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_knet.py
Normal file
276
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_knet.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'knet_r50-d8_my'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name) # _base_无算法
|
||||
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (640,640)]) # 选择切割大小
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
# 选择对应的模型
|
||||
config_dict = {
|
||||
'DeepLabV3': {
|
||||
'decode_head_kernel_generate_head_dict': dict(
|
||||
_delete_ = True,
|
||||
type='ASPPHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
dilations=(1, 12, 24, 36),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False),
|
||||
'backbone_dilations': (1, 1, 2, 4),
|
||||
'backbone_strides': (1, 2, 1, 1),
|
||||
'auxiliary_head_in_channels':1024
|
||||
},
|
||||
'PSPNet': {
|
||||
'decode_head_kernel_generate_head_dict': dict(
|
||||
_delete_ = True,
|
||||
type='PSPHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False),
|
||||
'backbone_dilations': (1, 1, 2, 4),
|
||||
'backbone_strides': (1, 2, 1, 1),
|
||||
'auxiliary_head_in_channels':1024
|
||||
},
|
||||
'FCN': {
|
||||
'decode_head_kernel_generate_head_dict': dict(
|
||||
_delete_ = True,
|
||||
type='FCNHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
num_convs=2,
|
||||
concat_input=True,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False),
|
||||
'backbone_dilations': (1, 1, 2, 4),
|
||||
'backbone_strides': (1, 2, 1, 1),
|
||||
'auxiliary_head_in_channels':1024
|
||||
},
|
||||
'UPerNet': {
|
||||
'decode_head_kernel_generate_head_dict': dict(
|
||||
type='UPerHead',
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
in_index=[0, 1, 2, 3],
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False),
|
||||
'backbone_dilations': (1, 1, 1, 1),
|
||||
'backbone_strides': (1, 2, 2, 2),
|
||||
'auxiliary_head_in_channels':1024
|
||||
}
|
||||
}
|
||||
|
||||
models = ['DeepLabV3', 'PSPNet', 'FCN', 'UPerNet']
|
||||
print("请你选择对应的模型:")
|
||||
while True:
|
||||
for index, model in enumerate(models, 1):
|
||||
print(f"{index}. {model}")
|
||||
try:
|
||||
user_input = int(input("Enter your choice (1-4): "))
|
||||
if 1 <= user_input <= 4:
|
||||
model_choice = models[user_input - 1]
|
||||
break
|
||||
else:
|
||||
print("Invalid input, please enter a number between 1 and 4.")
|
||||
except ValueError:
|
||||
print("Invalid input, please enter a valid number.")
|
||||
|
||||
decode_head_kernel_generate_head_dict = config_dict[model_choice]['decode_head_kernel_generate_head_dict']
|
||||
backbone_dilations = config_dict[model_choice]['backbone_dilations']
|
||||
backbone_strides = config_dict[model_choice]['backbone_strides']
|
||||
auxiliary_head_in_channels = config_dict[model_choice]['auxiliary_head_in_channels']
|
||||
|
||||
model_list = ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', "pretrain/swin_tiny-f41b89d3.pth", "pretrain/swin_large-d5bdebaf.pth"]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
if selected_model_info['type'] == 'ResNetV1c':
|
||||
depth = selected_model_info['depth']
|
||||
backbone = create_dict_by_kwargs(depth = depth, norm_cfg=norm_cfg, dilations=backbone_dilations, strides=backbone_strides)
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
if selected_model_name == 'pretrain/swin_tiny-f41b89d3.pth':
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
embed_dims=96, #
|
||||
depths=[2, 2, 6, 2], #
|
||||
num_heads=[3, 6, 12, 24], #
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3, #
|
||||
use_abs_pos_embed=False,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3))
|
||||
if selected_model_name == 'pretrain/swin_large-d5bdebaf.pth':
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
embed_dims=192, #
|
||||
depths=[2, 2, 18, 2], #
|
||||
num_heads=[6, 12, 24, 48], #
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.4, #
|
||||
use_abs_pos_embed=False,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3))
|
||||
backbone.update(create_dict_by_kwargs(norm_cfg=dict(type='LN'))) # TODO Transformer一类的内容norm_cfg都设置为LN TODO # 没有 dilations 和 strides
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# 如果使用slide模式,则开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
else:
|
||||
align_corners = False
|
||||
test_slide = 'no_testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
|
||||
# 3.2. decode_head
|
||||
decode_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['decode_head']
|
||||
|
||||
decode_head_kernel_update_head_list = decode_head['kernel_update_head']
|
||||
for i in range(len(decode_head_kernel_update_head_list)):
|
||||
decode_head_kernel_update_head_list[i]['num_classes'] = num_classes
|
||||
decode_head['kernel_update_head'] = decode_head_kernel_update_head_list
|
||||
|
||||
decode_head_kernel_generate_head_dict['num_classes'] = num_classes
|
||||
if model_choice == 'UPerNet':
|
||||
if selected_model_name == 'pretrain/swin_tiny-f41b89d3.pth':
|
||||
decode_head_kernel_generate_head_dict['in_channels'] = [96, 192, 384, 768]
|
||||
auxiliary_head_in_channels = 384
|
||||
elif selected_model_name == 'pretrain/swin_large-d5bdebaf.pth':
|
||||
decode_head_kernel_generate_head_dict['in_channels'] = [192, 384, 768, 1536]
|
||||
auxiliary_head_in_channels = 768
|
||||
elif model_choice in ['DeepLabV3', 'PSPNet', 'FCN']:
|
||||
if selected_model_name == 'pretrain/swin_tiny-f41b89d3.pth':
|
||||
decode_head_kernel_generate_head_dict['in_channels'] = 768
|
||||
auxiliary_head_in_channels = 384
|
||||
elif selected_model_name == 'pretrain/swin_large-d5bdebaf.pth':
|
||||
decode_head_kernel_generate_head_dict['in_channels'] = 1536
|
||||
auxiliary_head_in_channels = 768
|
||||
|
||||
|
||||
decode_head_kernel_generate_head_dict_loss_decode = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
|
||||
decode_head_kernel_generate_head_dict_loss_decode = dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
|
||||
decode_head_kernel_generate_head_dict['loss_decode'] = decode_head_kernel_generate_head_dict_loss_decode
|
||||
|
||||
decode_head['kernel_generate_head'] = decode_head_kernel_generate_head_dict
|
||||
|
||||
decode_head['kernel_generate_head']['align_corners'] = align_corners
|
||||
|
||||
# 3.2. auxiliary_head
|
||||
auxiliary_head = get_var_from_py_file(os.path.join('./configs/_base_/models', alg_file_name+'.py'), 'model')['auxiliary_head']
|
||||
auxiliary_head_loss_decode = dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori
|
||||
auxiliary_head_loss_decode = dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1
|
||||
auxiliary_head['loss_decode'] = auxiliary_head_loss_decode
|
||||
auxiliary_head['align_corners'] = align_corners
|
||||
auxiliary_head['in_channels'] = auxiliary_head_in_channels
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
pretrained = pretrained_pth,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
# 4. train_dataloader
|
||||
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=2, num_workers_default=2)
|
||||
if selected_model_info['type'] == 'ResNetV1c':
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
optim_wrapper = generate_optim_wrapper('swin')
|
||||
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:knet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
if selected_model_info['type'] == 'ResNetV1c':
|
||||
alg_file_name = f"{alg_name}_r{depth}_{model_choice}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
alg_file_name = f"{alg_name}_swin_{selected_model_info['size']}_{model_choice}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/knet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, train_dataloader=train_dataloader, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,136 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
# 交互式选择 decode_head 的函数
|
||||
def select_decode_head(decode_head_choose):
|
||||
print("可用的 decode head 选项:")
|
||||
for i, key in enumerate(decode_head_choose.keys()):
|
||||
print(f"{i + 1}. {key}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 提示用户输入选择
|
||||
choice = int(input("请选择需要的 decode head(输入编号):"))
|
||||
# 检查输入是否在有效范围内
|
||||
if 1 <= choice <= len(decode_head_choose):
|
||||
selected_key = list(decode_head_choose.keys())[choice - 1]
|
||||
print(f"你选择了: {selected_key}")
|
||||
return decode_head_choose[selected_key]
|
||||
else:
|
||||
print("输入的编号不正确,请重新输入。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入有效的编号。")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'upernet_mae'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
# 获取backbone模型、是否需要预训练
|
||||
model_list = ["mae_pretrain_vit_base_mmcls.pth"]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list) # 需要选择是否用预训练模型
|
||||
|
||||
backbone = create_dict_by_kwargs(type='MAE', img_size=crop_size, init_values=1.0)
|
||||
|
||||
decode_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
|
||||
decode_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
|
||||
|
||||
auxiliary_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # Way Ori
|
||||
auxiliary_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1
|
||||
|
||||
neck=dict(embed_dim=768, rescales=[4, 2, 1, 0.5])
|
||||
decode_head=dict(in_channels=[768, 768, 768, 768], channels=768, num_classes=num_classes, norm_cfg=norm_cfg, loss_decode=decode_head_loss_decode)
|
||||
auxiliary_head=dict(in_channels=768, norm_cfg=norm_cfg, num_classes=num_classes, loss_decode = auxiliary_head_loss_decode)
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size, select_slide=True)
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode='slide', crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='AmpOptimWrapper', # type='OptimWrapper', # TODO Ori
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=3e-5, betas=(0.9, 0.999), weight_decay=0.05),
|
||||
constructor='LayerDecayOptimizerConstructor',
|
||||
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9))
|
||||
|
||||
model_size = 'base'
|
||||
|
||||
train_dataloader, batch_size = generate_train_dataloader(4)
|
||||
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
pretrained = pretrained_pth,
|
||||
decode_head = decode_head,
|
||||
neck = neck,
|
||||
auxiliary_head = auxiliary_head,
|
||||
test_cfg = test_cfg
|
||||
)
|
||||
|
||||
########### 3.4. 生成fp16部分 ###########
|
||||
fp16 = dict(loss_scale='dynamic') # mixed precision
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = optim_wrapper
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:beit【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}-{model_size}_b{batch_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-testslide.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/mae/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, fp16=fp16,optim_wrapper=optim_wrapper, param_scheduler=param_scheduler, train_dataloader = train_dataloader)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,284 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'mask2former_my'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name) # _base_无算法
|
||||
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (640, 640), (512, 1024)]) # 选择切割大小
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
# 3.1. 选择backbone
|
||||
model_list = ['torchvision://resnet50', 'torchvision://resnet101', 'pretrain/swin_large-6580f57d.pth', 'pretrain/swin_base-e5c09f74.pth', 'pretrain/swin_small-7ba6d6dd.pth', 'pretrain/swin_tiny-1cdeb081.pth']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
depth = selected_model_info['depth']
|
||||
backbone = create_dict_by_kwargs(depth = depth, type=selected_model_info['type'], norm_cfg=norm_cfg, num_stages=4, frozen_stages=-1, style='pytorch', init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
in_channels = [256, 512, 1024, 2048]
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
if selected_model_name == 'pretrain/swin_tiny-1cdeb081.pth':
|
||||
in_channels = [96, 192, 384, 768]
|
||||
depths=[2, 2, 6, 2]
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=224,
|
||||
embed_dims=96,
|
||||
depths=depths,
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
# frozen_stages=-1,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_small-7ba6d6dd.pth':
|
||||
depths=[2, 2, 18, 2]
|
||||
in_channels = [96, 192, 384, 768]
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=224,
|
||||
embed_dims=96,
|
||||
depths=depths,
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
# frozen_stages=-1,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_base-e5c09f74.pth':
|
||||
depths=[2, 2, 18, 2]
|
||||
in_channels = [128, 256, 512, 1024]
|
||||
backbone = dict(
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=384,
|
||||
embed_dims=128,
|
||||
depths=depths,
|
||||
num_heads=[4, 8, 16, 32],
|
||||
window_size=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
# frozen_stages=-1,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_large-6580f57d.pth':
|
||||
in_channels = [192, 384, 768, 1536]
|
||||
depths=[2, 2, 18, 2]
|
||||
backbone = dict(
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=384,
|
||||
embed_dims=192,
|
||||
depths=depths,
|
||||
num_heads=[6, 12, 24, 48],
|
||||
window_size=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
# frozen_stages=-1,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
# set all layers in backbone to lr_mult=0.1
|
||||
# set all norm layers, position_embeding,
|
||||
# query_embeding, level_embeding to decay_multi=0.0
|
||||
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
|
||||
custom_keys = {
|
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
|
||||
'backbone.patch_embed.norm': backbone_norm_multi,
|
||||
'backbone.norm': backbone_norm_multi,
|
||||
'absolute_pos_embed': backbone_embed_multi,
|
||||
'relative_position_bias_table': backbone_embed_multi,
|
||||
'query_embed': embed_multi,
|
||||
'query_feat': embed_multi,
|
||||
'level_embed': embed_multi
|
||||
}
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
|
||||
for stage_id, num_blocks in enumerate(depths)
|
||||
for block_id in range(num_blocks)
|
||||
})
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
|
||||
for stage_id in range(len(depths) - 1)
|
||||
})
|
||||
|
||||
# 3.2. decode_head
|
||||
decode_head = dict(num_classes=num_classes, in_channels=in_channels,
|
||||
loss_cls=dict(class_weight=[1.0] * num_classes + [0.1]))
|
||||
|
||||
# 3.3. optimizer部分
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
# optimizer
|
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
|
||||
optimizer = dict(
|
||||
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
clip_grad=dict(max_norm=0.01, norm_type=2),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
|
||||
'query_embed': embed_multi,
|
||||
'query_feat': embed_multi,
|
||||
'level_embed': embed_multi,
|
||||
},
|
||||
norm_decay_mult=0.0))
|
||||
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=4, num_workers_default=4)
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
# set all layers in backbone to lr_mult=0.1
|
||||
# set all norm layers, position_embeding,
|
||||
# query_embeding, level_embeding to decay_multi=0.0
|
||||
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
|
||||
custom_keys = {
|
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
|
||||
'backbone.patch_embed.norm': backbone_norm_multi,
|
||||
'backbone.norm': backbone_norm_multi,
|
||||
'absolute_pos_embed': backbone_embed_multi,
|
||||
'relative_position_bias_table': backbone_embed_multi,
|
||||
'query_embed': embed_multi,
|
||||
'query_feat': embed_multi,
|
||||
'level_embed': embed_multi
|
||||
}
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
|
||||
for stage_id, num_blocks in enumerate(depths)
|
||||
for block_id in range(num_blocks)
|
||||
})
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
|
||||
for stage_id in range(len(depths) - 1)
|
||||
})
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
clip_grad=dict(max_norm=0.01, norm_type=2),
|
||||
paramwise_cfg={'custom_keys':custom_keys, 'norm_decay_mult':0.0})
|
||||
|
||||
# 更新train_dataloader
|
||||
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=2, num_workers_default=2)
|
||||
|
||||
# 3.4. train_dataloader部分
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomChoiceResize',
|
||||
scales=[int(max(crop_size[0], crop_size[1]) * x * 0.1) for x in range(5, 21)],
|
||||
resize_type='ResizeShortestEdge',
|
||||
max_size=max(crop_size[0], crop_size[1])*4),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline)))
|
||||
|
||||
# 3.5. param_scheduler部分
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
# 综合model
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head
|
||||
)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:mask2former【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
alg_file_name = f"{alg_name}_swin_{selected_model_info['size']}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/mask2former/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, train_dataloader=train_dataloader, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,246 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_5_train_dataloader import generate_train_dataloader
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'maskformer_my'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name) # _base_无算法
|
||||
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (640, 640), (512, 1024)]) # 选择切割大小
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
# 3.1. 选择backbone
|
||||
model_list = ['torchvision://resnet50', 'torchvision://resnet101', 'pretrain/swin_large-6580f57d.pth', 'pretrain/swin_base-e5c09f74.pth', 'pretrain/swin_small-7ba6d6dd.pth', 'pretrain/swin_tiny-1cdeb081.pth']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
depth = selected_model_info['depth']
|
||||
backbone = create_dict_by_kwargs(depth = depth, type=selected_model_info['type'], norm_cfg=norm_cfg, num_stages=4, frozen_stages=-1, style='pytorch', init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
in_channels = [256, 512, 1024, 2048]
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
if selected_model_name == 'pretrain/swin_tiny-1cdeb081.pth':
|
||||
in_channels = [96, 192, 384, 768]
|
||||
depths=[2, 2, 6, 2]
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=224,
|
||||
embed_dims=96,
|
||||
depths=depths,
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
act_cfg=dict(type='GELU'), # ADD
|
||||
use_abs_pos_embed=False, # ADD
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_small-7ba6d6dd.pth':
|
||||
depths=[2, 2, 18, 2]
|
||||
in_channels = [96, 192, 384, 768]
|
||||
backbone = dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=224,
|
||||
embed_dims=96,
|
||||
depths=depths,
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
act_cfg=dict(type='GELU'), # ADD
|
||||
use_abs_pos_embed=False, # ADD
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_base-e5c09f74.pth':
|
||||
depths=[2, 2, 18, 2]
|
||||
in_channels = [128, 256, 512, 1024]
|
||||
backbone = dict(
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=384,
|
||||
embed_dims=128,
|
||||
depths=depths,
|
||||
num_heads=[4, 8, 16, 32],
|
||||
window_size=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
act_cfg=dict(type='GELU'), # ADD
|
||||
use_abs_pos_embed=False, # ADD
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
if selected_model_name == 'pretrain/swin_large-6580f57d.pth':
|
||||
in_channels = [192, 384, 768, 1536]
|
||||
depths=[2, 2, 18, 2]
|
||||
backbone = dict(
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=384,
|
||||
embed_dims=192,
|
||||
depths=depths,
|
||||
num_heads=[6, 12, 24, 48],
|
||||
window_size=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True,
|
||||
act_cfg=dict(type='GELU'), # ADD
|
||||
use_abs_pos_embed=False, # ADD
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
with_cp=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
|
||||
# 3.2. decode_head
|
||||
decode_head = dict(num_classes=num_classes, in_channels=in_channels,
|
||||
loss_cls=dict(class_weight=[1.0] * num_classes + [0.1]))
|
||||
|
||||
# 3.3. optimizer部分
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
# optimizer
|
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
|
||||
optimizer = dict(
|
||||
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
clip_grad=dict(max_norm=0.01, norm_type=2),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
|
||||
'query_embed': embed_multi,
|
||||
'query_feat': embed_multi,
|
||||
'level_embed': embed_multi,
|
||||
},
|
||||
norm_decay_mult=0.0))
|
||||
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=4, num_workers_default=4)
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
# set all layers in backbone to lr_mult=0.1
|
||||
# set all norm layers, position_embeding,
|
||||
# query_embeding, level_embeding to decay_multi=0.0
|
||||
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
|
||||
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
|
||||
custom_keys = {
|
||||
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
|
||||
'backbone.patch_embed.norm': backbone_norm_multi,
|
||||
'backbone.norm': backbone_norm_multi,
|
||||
'absolute_pos_embed': backbone_embed_multi,
|
||||
'relative_position_bias_table': backbone_embed_multi,
|
||||
'query_embed': embed_multi,
|
||||
'query_feat': embed_multi,
|
||||
'level_embed': embed_multi
|
||||
}
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
|
||||
for stage_id, num_blocks in enumerate(depths)
|
||||
for block_id in range(num_blocks)
|
||||
})
|
||||
custom_keys.update({
|
||||
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
|
||||
for stage_id in range(len(depths) - 1)
|
||||
})
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
clip_grad=dict(max_norm=0.01, norm_type=2),
|
||||
paramwise_cfg={'custom_keys':custom_keys, 'norm_decay_mult':0.0})
|
||||
|
||||
# 更新train_dataloader
|
||||
train_dataloader, batch_size, num_workers = generate_train_dataloader(batch_size_default=2, num_workers_default=2)
|
||||
|
||||
# 3.5. param_scheduler部分
|
||||
param_scheduler = generate_param_scheduler(train_time_or_epoch_k)
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
# 综合model
|
||||
model = dict(
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head
|
||||
)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:maskformer【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
if selected_model_info['type'] == 'ResNet':
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
elif selected_model_info['type'] == 'swin':
|
||||
alg_file_name = f"{alg_name}_swin_{selected_model_info['size']}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/maskformer/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, train_dataloader=train_dataloader, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
152
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_pidnet.py
Normal file
152
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_pidnet.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import os, sys, argparse, json
|
||||
import importlib.util
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
def get_train_pipeline_from_config(dataset_file_name: str):
|
||||
"""
|
||||
动态加载指定的配置文件并读取其中的 'train_pipeline' 列表。
|
||||
|
||||
Args:
|
||||
dataset_file_name (str): 数据集配置文件的名字 (不包含.py后缀)。
|
||||
|
||||
Returns:
|
||||
list | None: 如果成功找到,则返回 train_pipeline 列表;
|
||||
如果文件不存在或文件中没有 train_pipeline 变量,则返回 None。
|
||||
"""
|
||||
# 1. 构建配置文件的相对路径
|
||||
# 使用 os.path.join 来确保路径在不同操作系统上都是正确的
|
||||
config_path = os.path.join('configs/_base_/datasets/', f'{dataset_file_name}.py')
|
||||
|
||||
# 2. 检查文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
print(f"\033[30错误:配置文件不存在于 '{config_path}'\033[0m")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 3. 动态加载 .py 文件作为一个模块
|
||||
# 创建一个模块规范 (spec)
|
||||
# 模块名可以是任意的,这里用文件名以防冲突
|
||||
spec = importlib.util.spec_from_file_location(dataset_file_name, config_path)
|
||||
|
||||
# 根据规范创建一个模块对象
|
||||
config_module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# 执行模块代码,使其所有变量(如 train_pipeline)都加载到模块对象中
|
||||
spec.loader.exec_module(config_module)
|
||||
|
||||
# 4. 从加载的模块中获取 train_pipeline 变量
|
||||
# 使用 getattr 来安全地获取,如果不存在,可以设置一个默认值
|
||||
train_pipeline = getattr(config_module, 'train_pipeline', None)
|
||||
|
||||
if train_pipeline is None:
|
||||
print(f"错误:在 '{config_path}' 文件中未找到 'train_pipeline' 变量。")
|
||||
return None
|
||||
|
||||
return train_pipeline
|
||||
|
||||
except Exception as e:
|
||||
print(f"读取配置文件时发生未知错误: {e}")
|
||||
return None
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'pidnet'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.1. 生成train_pipeline部分 ###########
|
||||
config_path = os.path.join('../_base_/datasets/', f'{dataset_file_name}.py')
|
||||
|
||||
train_pipeline = get_train_pipeline_from_config(dataset_file_name)
|
||||
train_pipeline.insert(-1, dict(type='GenerateEdge', edge_width=4)) # For pidnet
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
# 3.3.1. 预处理data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
|
||||
# # 3.3.2. 骨架backbone、解码器decode_head
|
||||
model_list = ['openmmlab/pidnet-s', 'openmmlab/pidnet-m', 'openmmlab/pidnet-l']
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
if selected_model_name == 'openmmlab/pidnet-s':
|
||||
backbone = dict(channels=32, ppm_channels=96, num_stem_blocks=2, num_branch_blocks=3, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
decode_head = dict(num_classes=num_classes, in_channels=128, channels=128)
|
||||
model_size = 'small'
|
||||
elif selected_model_name == 'openmmlab/pidnet-m':
|
||||
backbone = dict(channels=64, ppm_channels=96, num_stem_blocks=2, num_branch_blocks=3, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
decode_head = dict(num_classes=num_classes, in_channels=256, channels=128)
|
||||
model_size = 'middle'
|
||||
elif selected_model_name == 'openmmlab/pidnet-l':
|
||||
backbone = dict(channels=64, ppm_channels=112, num_stem_blocks=3, num_branch_blocks=4, init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
decode_head = dict(num_classes=num_classes, in_channels=256, channels=256)
|
||||
model_size = 'large'
|
||||
else:
|
||||
quit("Error: 未知的模型名称")
|
||||
|
||||
# 3.3.4. 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
backbone = backbone,
|
||||
decode_head = decode_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 4.1. 算法名称解析:bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_{model_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/pidnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, train_pipeline = train_pipeline, train_dataloader = train_dataloader, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 4.2. 将信息临时写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
111
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_pspnet.py
Normal file
111
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_pspnet.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'pspnet_r50-d8'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769),(1280,1280)]) # 选择切割大小
|
||||
model_list = ['openmmlab/resnet18_v1c', 'openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c', 'torchvision://resnet18', 'torchvision://resnet50', 'torchvision://resnet101', ]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
depth = selected_model_info['depth']
|
||||
type_model = selected_model_info['type']
|
||||
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size)
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
# decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数 # TODO
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数 # TODO
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
|
||||
backbone = create_dict_by_kwargs(depth=depth, type=type_model) # generate_model_backbone(depth=depth,)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
if depth == 18:
|
||||
decode_head = create_dict_by_kwargs(in_channels=512, channels=128, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
decode_head = create_dict_by_kwargs(in_channels=2048, channels=512, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
|
||||
|
||||
if depth == 18:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=256, channels=64, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
|
||||
elif depth == 101 or depth == 50:
|
||||
auxiliary_head = create_dict_by_kwargs(in_channels=1024, channels=256, num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, norm_cfg=norm_cfg,)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:pspnet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_r{depth}_{'Pre' if select_pretrained else 'NoPre'}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/pspnet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, test_cfg=test_cfg, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
139
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_stdc.py
Normal file
139
Seg_All_In_One_MMSeg/My_All_In_One/2_Alg_Program/my_stdc.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import os, sys, argparse, json
|
||||
import importlib.util
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs, get_var_from_file, update_list_dict_var, get_var_from_py_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
def get_pretrained_model_choice():
|
||||
"""
|
||||
提示用户选择是否使用预训练模型,并返回他们的选择。
|
||||
|
||||
返回:
|
||||
str: 如果用户选择 '是' 或默认,则返回 '是'。
|
||||
如果用户选择 '否',则返回 '否'。
|
||||
"""
|
||||
while True:
|
||||
# 如果用户直接按回车,input()返回空字符串,or "1" 使其默认值为 "1"
|
||||
choice = input("可用的预训练模型选项:\n1. 是\n2. 否\n请选择是否使用预训练模型(默认1): ") or "1"
|
||||
|
||||
if choice == "1":
|
||||
print("您已选择:1. 是")
|
||||
return True
|
||||
elif choice == "2":
|
||||
print("您已选择:2. 否")
|
||||
return False
|
||||
else:
|
||||
# 如果输入了其他无效内容(如"3"),则提示重新输入
|
||||
print("\n无效输入,请输入 1 或 2。\n")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'stdc'
|
||||
alg_file_pth = f"configs/_base_/models/{alg_file_name}.py"
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
########### 3.1. 生成_base_部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
########### 3.1. 生成norm_cfg部分 ###########
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
########### 3.2. 生成data_preprocessor部分 ###########
|
||||
crop_size = select_crop_size(predefined_options=[(512, 512)]) # 选择切割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
########### 3.3. 生成model部分 ###########
|
||||
# 3.3.1. 预处理data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
|
||||
# # 3.3.2. 骨架backbone、解码器decode_head
|
||||
decode_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori
|
||||
decode_head_loss_decode=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1
|
||||
|
||||
# auxiliary_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori: DiceLoss损失函数 # TODO
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1: DiceLoss损失函数 # TODO
|
||||
|
||||
auxiliary_head = get_var_from_py_file(alg_file_pth, 'model')['auxiliary_head']
|
||||
for i in range(len(auxiliary_head)):
|
||||
if auxiliary_head[i]['type'] == 'FCNHead':
|
||||
auxiliary_head[i].update(num_classes=num_classes, loss_decode=auxiliary_head_loss_decode_dict)
|
||||
|
||||
model_list = ["openmmlab/stdc1", "openmmlab/stdc2"]
|
||||
selected_model_name, select_pretrained, pretrained_pth, selected_model_info = select_pretrained_model(model_list=model_list)
|
||||
use_pretrained = get_pretrained_model_choice()
|
||||
if selected_model_name == "openmmlab/stdc1":
|
||||
if use_pretrained:
|
||||
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet1'), init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
Pre = "Pre"
|
||||
else:
|
||||
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet1'))
|
||||
Pre = "NoPre"
|
||||
decode_head = dict(num_classes=num_classes, loss_decode=decode_head_loss_decode)
|
||||
model_size = 'V1_'+Pre
|
||||
elif selected_model_name == "openmmlab/stdc2":
|
||||
if use_pretrained:
|
||||
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet2'), init_cfg=dict(type='Pretrained', checkpoint=pretrained_pth))
|
||||
Pre = "Pre"
|
||||
else:
|
||||
backbone = dict(backbone_cfg=dict(stdc_type='STDCNet2'))
|
||||
Pre = "NoPre"
|
||||
decode_head = dict(num_classes=num_classes, loss_decode=decode_head_loss_decode)
|
||||
model_size = 'V2_'+Pre
|
||||
else:
|
||||
quit("Error: 未知的模型名称")
|
||||
|
||||
# 3.3.4. 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
backbone = backbone,
|
||||
decode_head = decode_head,
|
||||
)
|
||||
|
||||
########### 3.4. 生成optim、param_scheduler部分 ###########
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 4.1. 算法名称解析:bisenetv2【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_{model_size}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/stdc/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 4.2. 将信息临时写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,105 @@
|
||||
import os, sys, argparse, json
|
||||
# 添加上级目录到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Select_Tool import select_crop_size, select_pretrained_model, select_test_cfg_slide, select_sampler
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, create_dict_by_kwargs
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_4_optimizer import generate_optim_wrapper, generate_param_scheduler
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-算法默认参数 ###########
|
||||
alg_file_name = 'deeplabv3_unet_s5-d16'
|
||||
|
||||
########### 2.参数解析 ###########
|
||||
parser = argparse.ArgumentParser(description="Run algorithm with specified parameters.")
|
||||
|
||||
# 添加命令行参数
|
||||
parser.add_argument('--alg_name', type=str, required=True, help="Algorithm name (e.g., 'ann')")
|
||||
parser.add_argument('--dataset_file_name', type=str, required=True, help="Dataset file name (e.g., 'my_dataset_model')")
|
||||
parser.add_argument('--mean', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--std', type=str, required=True, help="Mean of dataset (e.g., [1.2, 1.2, 1.2])")
|
||||
parser.add_argument('--dataset_num_classes', type=int, required=True, help="Dataset class num (e.g., 36)")
|
||||
parser.add_argument('--GPU_num', type=int, required=True, help="Num of GPU to train (e.g., 1)")
|
||||
parser.add_argument('--schedule_file_name', type=str, required=True, help="Schedule file name (e.g., 'schedule_4k_check_400')")
|
||||
parser.add_argument('--train_type', type=str, required=True, help="Schedule train type (e.g., Epoch or Iteration)")
|
||||
parser.add_argument('--train_time_or_epoch_k', type=int, required=True, help="Schedule train time k (e.g., 4)")
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 变量赋值
|
||||
alg_name = args.alg_name
|
||||
dataset_file_name = args.dataset_file_name
|
||||
mean = json.loads(args.mean) # 使用json.loads转换成list
|
||||
std = json.loads(args.std) # 使用json.loads转换成list
|
||||
num_classes = args.dataset_num_classes
|
||||
schedule_file_name = args.schedule_file_name
|
||||
train_type = args.train_type
|
||||
train_time_or_epoch_k = args.train_time_or_epoch_k
|
||||
GPU_num = args.GPU_num
|
||||
|
||||
|
||||
########### 3. 设定参数 ###########
|
||||
# 3.1. 自选参数
|
||||
crop_size = select_crop_size(predefined_options=[(128, 128), (256, 256)]) # 选择切割大小
|
||||
test_cfg_mode, test_cfg_crop_div_stride = select_test_cfg_slide(crop_size, select_slide=True) # 默认选择滑动模式
|
||||
|
||||
# 3.2. 自动生成参数
|
||||
# decode、auxiliary损失下载方式
|
||||
decode_head_loss_decode_dict=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) # Way Ori: decode损失函数 # TODO
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # Way 1: decode损失函数 # TODO
|
||||
decode_head_loss_decode_dict = [
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
|
||||
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # Way 1: DiceLoss损失函数 # TODO
|
||||
|
||||
# 开启align_corners
|
||||
if test_cfg_mode == "slide":
|
||||
align_corners = True
|
||||
test_slide = 'testslide' # 算法名进行标记
|
||||
|
||||
|
||||
########### 3.生成各个部分 ###########
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size, mean = mean, std = std)
|
||||
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# TODO TODO decode_head_loss_decode_dict
|
||||
decode_head = create_dict_by_kwargs(num_classes=num_classes, loss_decode=decode_head_loss_decode_dict, align_corners=align_corners, norm_cfg=norm_cfg,)
|
||||
|
||||
# auxiliary_head_loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
|
||||
auxiliary_head = create_dict_by_kwargs(num_classes=num_classes, norm_cfg=norm_cfg, loss_decode=auxiliary_head_loss_decode_dict)
|
||||
|
||||
test_cfg = generate_model_test_cfg(test_cfg_mode=test_cfg_mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
|
||||
# 综合model
|
||||
model = dict(
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
test_cfg = test_cfg,
|
||||
auxiliary_head = auxiliary_head,
|
||||
)
|
||||
|
||||
|
||||
optim_wrapper = generate_optim_wrapper()
|
||||
param_scheduler = generate_param_scheduler(train_type, train_time_or_epoch_k)
|
||||
|
||||
########### 4.程序写入 ###########
|
||||
# 算法名称解析:unet【Alg】 _ r50【pretrained模型深度】 - d8【_base_/models】 _ 4【GPU数量】 x b2【Batch Size大小】 - 40k【schedule】 _ cityscapes【数据集】 - 512x1024【crop_size】.py
|
||||
alg_file_name = f"{alg_name}_g{GPU_num}-{schedule_file_name.lstrip('schedule_')}_{dataset_file_name}-{crop_size[0]}x{crop_size[1]}-{test_slide}.py"
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/unet/', alg_file_name)
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, norm_cfg=norm_cfg, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model, optim_wrapper=optim_wrapper, param_scheduler=param_scheduler)
|
||||
|
||||
# 将信息写入文件
|
||||
with open("_temp_.txt", "w") as file:
|
||||
json.dump({'alg_infos':{'alg_file_name':alg_file_name,'alg_file_pth':output_configs_alg_my_alg}}, file)
|
||||
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
import os
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.1.定义算法基本参数 ###########
|
||||
# A. generate_base_config
|
||||
alg_name = 'ann' # ./configs中算法简称
|
||||
alg_file_name = 'ann_r50-d8' # 算法根文件
|
||||
dataset_file_name = 'my_dataset_model' # 数据文件
|
||||
schedule_file_name = 'schedule_4k_check_400' # schedule文件
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
|
||||
# B. generate_norm_cfg
|
||||
GPU_num = 1 # GPU数量
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
|
||||
# C. generate_data_preprocessor
|
||||
crop_size = (512,512) # 分割大小
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size)
|
||||
|
||||
# D. generate_model
|
||||
# D.1. pretrained
|
||||
pretrained_pth = './My_Local_Model/open_mmlab/resnet50_v1c.pth' # 预训练模型位置
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
|
||||
# D.2. backbone
|
||||
depth = 50 # 模型深度
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
|
||||
# D.3. data_preprocessor
|
||||
model_data_preprocessor = data_preprocessor # 修改data_preprocessor
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# D.4. decode_head、auxiliary_head
|
||||
num_classes=36 # 分类数
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # decode损失函数
|
||||
align_corners=False # 是否需要角对齐
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # DiceLoss损失函数
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
|
||||
# D.5. train_cfg
|
||||
# train_cfg = generate_model_train_cfg()
|
||||
|
||||
# D.6. test_cfg
|
||||
# test_cfg = generate_model_test_cfg()
|
||||
|
||||
# E. 综合model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
# train_cfg = train_cfg,
|
||||
# test_cfg = test_cfg,
|
||||
)
|
||||
|
||||
########### 2.文件存储 ###########
|
||||
# output_configs_alg_my_alg = os.path.join(f'my_{alg_name}.py')
|
||||
output_configs_alg_my_alg = os.path.join(f'./configs/{alg_name}/', f'my_{alg_name}.py')
|
||||
|
||||
write_config_to_file(output_configs_alg_my_alg, _base_=_base_, crop_size=crop_size, data_preprocessor=data_preprocessor, model=model)
|
||||
|
||||
@@ -0,0 +1,301 @@
|
||||
import os, json, subprocess
|
||||
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXk import generate_times_configs_base_schedules_schedule_file
|
||||
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXe import generate_epochs_configs_base_schedules_schedule_file
|
||||
from datetime import datetime
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_gpu_info
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
|
||||
def load_json_file(file_name):
|
||||
"""
|
||||
读取并返回指定 JSON 文件的内容。
|
||||
|
||||
:param file_name: 要读取的 JSON 文件路径
|
||||
:return: JSON 文件的内容作为字典或列表
|
||||
"""
|
||||
try:
|
||||
# 打开并读取 JSON 文件
|
||||
with open(file_name, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
print(f"\033[91mError: File {file_name} not found.\033[0m")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
print(f"\033[91mError: Failed to decode JSON from {file_name}.\033[0m")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"\033[91mAn error occurred while reading {file_name}: {str(e)}\033[0m")
|
||||
return None
|
||||
|
||||
def process_data_record_json_data(json_data):
|
||||
"""
|
||||
将 JSON 数据直接存入一个大的字典 all_data_record。
|
||||
:param json_data: 从 JSON 文件加载的数据
|
||||
:return: all_data_record 字典
|
||||
"""
|
||||
all_data_record = {}
|
||||
|
||||
# 遍历 json_data 中的每个键(即每个数据集)
|
||||
for dataset_name, dataset_info in json_data.items():
|
||||
all_data_record[dataset_name] = {
|
||||
"classes": dataset_info['classes'],
|
||||
"palette": dataset_info['palette'],
|
||||
"palette_num": dataset_info['palette_num'],
|
||||
"mean": dataset_info['mean'],
|
||||
"std": dataset_info['std'],
|
||||
}
|
||||
|
||||
return all_data_record
|
||||
|
||||
# 选择要处理的数据集
|
||||
def select_dataset(all_data_record):
|
||||
"""
|
||||
让用户从 all_data_record 中选择数据集,并返回 dataset_file_name 和对应的 palette_num。
|
||||
|
||||
:param all_data_record: 包含所有数据集信息的字典
|
||||
:return: 选定的 dataset_file_name 和 palette_num
|
||||
"""
|
||||
# 获取所有数据集的名称
|
||||
dataset_names = list(all_data_record.keys())
|
||||
|
||||
# 显示可用的数据集,并让用户选择
|
||||
print("选择可用数据集:")
|
||||
for i, name in enumerate(dataset_names):
|
||||
print(f"{i + 1}. {name} - {all_data_record[name]['palette_num']}类")
|
||||
|
||||
# 用户输入选择的数据集编号
|
||||
while True:
|
||||
try:
|
||||
selection = int(input("请选择数据集编号(输入数字):")) - 1
|
||||
if 0 <= selection < len(dataset_names):
|
||||
dataset_file_name = dataset_names[selection]
|
||||
break
|
||||
else:
|
||||
print(f"输入的编号无效,请输入 1 到 {len(dataset_names)} 之间的数字。")
|
||||
except ValueError:
|
||||
print("无效输入,请输入数字。")
|
||||
|
||||
# 获取对应的 palette_num
|
||||
palette_num = all_data_record[dataset_file_name]['palette_num']
|
||||
mean = all_data_record[dataset_file_name]['mean']
|
||||
std = all_data_record[dataset_file_name]['std']
|
||||
|
||||
print(f" 已选择数据集: {dataset_file_name} ,其对应的分类数为: {palette_num}")
|
||||
|
||||
return dataset_file_name, palette_num, mean, std
|
||||
|
||||
# 选择对应的算法
|
||||
def select_alg(alg_directory):
|
||||
"""
|
||||
选择一个算法文件并运行它。
|
||||
|
||||
:return: 选定的算法文件名和相对路径
|
||||
"""
|
||||
# 获取 ./Alg 目录下的所有 Python 文件
|
||||
algorithms = [f for f in os.listdir(alg_directory) if f.endswith('.py')]
|
||||
|
||||
# 显示可用的算法文件
|
||||
print("选择可用算法:")
|
||||
for i, alg in enumerate(algorithms):
|
||||
print(f"{i + 1}. {alg}")
|
||||
|
||||
# 用户选择算法
|
||||
while True:
|
||||
try:
|
||||
selection = int(input("请选择算法编号(输入数字):")) - 1
|
||||
if 0 <= selection < len(algorithms):
|
||||
selected_alg = algorithms[selection]
|
||||
break
|
||||
else:
|
||||
print(f"输入的编号无效,请输入 1 到 {len(algorithms)} 之间的数字。")
|
||||
except ValueError:
|
||||
print("无效输入,请输入数字。")
|
||||
|
||||
# 生成选定算法的相对路径
|
||||
relative_alg_path = os.path.join(alg_directory, selected_alg)
|
||||
print(f" 已选择算法: {selected_alg} {relative_alg_path}")
|
||||
|
||||
return selected_alg, relative_alg_path
|
||||
|
||||
# 选择计算用的GPU
|
||||
def select_GPU():
|
||||
"""
|
||||
让用户选择要使用的 GPU 数量,并返回相应的 GPU 列表(以 0,1 格式输出)。
|
||||
如果用户没有输入任何选择,默认为选择一块 GPU,编号为 0。
|
||||
:return: 用户选择的 GPU 列表和数量
|
||||
"""
|
||||
# 获取 GPU 信息
|
||||
gpu_info = get_gpu_info()
|
||||
|
||||
if not gpu_info:
|
||||
return None, 0
|
||||
|
||||
# 显示 GPU 信息
|
||||
print("可用的 GPU 列表:")
|
||||
for idx, mem_free in gpu_info:
|
||||
print(f"GPU {idx}: 剩余显存 {mem_free} MB")
|
||||
|
||||
# 提供默认选择 GPU 个数和编号
|
||||
default_num_gpus = 1
|
||||
default_gpu_idx = 0
|
||||
|
||||
# 用户选择 GPU 个数(默认选择 1)
|
||||
try:
|
||||
num_gpus = input(f"请选择使用的 GPU 个数 (1-{len(gpu_info)}, 默认为 1): ")
|
||||
num_gpus = int(num_gpus) if num_gpus else default_num_gpus
|
||||
if not (1 <= num_gpus <= len(gpu_info)):
|
||||
print(f"输入无效,使用默认值 1 个 GPU。")
|
||||
num_gpus = default_num_gpus
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 1 个 GPU。")
|
||||
num_gpus = default_num_gpus
|
||||
|
||||
# 用户选择 GPU 编号(如果选择多个 GPU,逐个选择编号)
|
||||
selected_gpus = []
|
||||
for i in range(num_gpus):
|
||||
try:
|
||||
gpu_idx = input(f"请输入要使用的第 {i + 1} 个 GPU 的编号 (默认为 {default_gpu_idx}): ")
|
||||
gpu_idx = int(gpu_idx) if gpu_idx else default_gpu_idx
|
||||
if gpu_idx in [idx for idx, _ in gpu_info] and gpu_idx not in selected_gpus:
|
||||
selected_gpus.append(gpu_idx)
|
||||
else:
|
||||
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
|
||||
selected_gpus.append(default_gpu_idx)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
|
||||
selected_gpus.append(default_gpu_idx)
|
||||
|
||||
# 返回以 0,1 格式的 GPU 列表和 GPU 数量
|
||||
gpu_list_str = ','.join(map(str, selected_gpus))
|
||||
print(f"\n已选择 GPU: {gpu_list_str},共 {num_gpus} 块 GPU")
|
||||
|
||||
return gpu_list_str, num_gpus
|
||||
|
||||
# 选择训练批次相关信息
|
||||
def select_schedule():
|
||||
"""
|
||||
选择训练时间、验证次数、日志间隔,并生成训练计划配置文件。
|
||||
提供默认值:train_time_k=40, check_num=10, loggerhook_interval=50
|
||||
"""
|
||||
# 交互式输入 train_time_k,默认值 40
|
||||
try:
|
||||
train_time_k = input("请输入训练次数(k为单位,默认40k):").strip()
|
||||
train_time_k = int(train_time_k) if train_time_k else 40
|
||||
except ValueError:
|
||||
print("无效输入,使用默认训练时间 40k")
|
||||
train_time_k = 40
|
||||
|
||||
# 交互式输入 check_num,默认值 10
|
||||
try:
|
||||
check_num = input("请输入检查点数量(默认10):").strip()
|
||||
check_num = int(check_num) if check_num else 10
|
||||
except ValueError:
|
||||
print("无效输入,使用默认检查点数量 20")
|
||||
check_num = 20
|
||||
|
||||
# 交互式输入 loggerhook_interval,默认值 50
|
||||
try:
|
||||
loggerhook_interval = input("请输入日志间隔(默认50次迭代):").strip()
|
||||
loggerhook_interval = int(loggerhook_interval) if loggerhook_interval else 50
|
||||
except ValueError:
|
||||
print("无效输入,使用默认日志间隔 50 次迭代")
|
||||
loggerhook_interval = 50
|
||||
|
||||
# 计算验证比例和检查点间隔
|
||||
val_proportion = 1 / check_num # 验证比例
|
||||
checkpoint_interval = int(train_time_k * 1000 * val_proportion) # 计算保存检查点的间隔
|
||||
|
||||
# 生成文件名和路径
|
||||
output_configs_base_schedules_schedules_Timek = os.path.join('./configs/_base_/schedules/', f'schedule_{train_time_k}k_check_{checkpoint_interval}.py')
|
||||
schedule_file_name = f'schedule_{train_time_k}k_check_{checkpoint_interval}.py'
|
||||
|
||||
# 调用生成配置文件函数
|
||||
generate_times_configs_base_schedules_schedule_file(
|
||||
output_file=output_configs_base_schedules_schedules_Timek,
|
||||
train_time_k=train_time_k,
|
||||
val_proportion=val_proportion,
|
||||
loggerhook_interval=loggerhook_interval
|
||||
)
|
||||
|
||||
return schedule_file_name, train_time_k
|
||||
|
||||
# 运行对应的代码去生成算法配置文件
|
||||
def run_alg_to_gen_alg(relative_alg_path, mean, std, alg_name, dataset_file_name, num_classes, GPU_num, schedule_file_name, train_time_k):
|
||||
try:
|
||||
# 构建命令和参数
|
||||
cmd = [
|
||||
'python', relative_alg_path,
|
||||
'--alg_name', alg_name,
|
||||
'--dataset_file_name', dataset_file_name,
|
||||
'--mean', str(mean),
|
||||
'--std', str(std),
|
||||
'--dataset_num_classes', str(num_classes),
|
||||
'--GPU_num', str(GPU_num),
|
||||
'--schedule_file_name', schedule_file_name,
|
||||
'--train_time_k', str(train_time_k)
|
||||
]
|
||||
|
||||
# 打印当前执行的命令
|
||||
print(f"\n正在运行: \033[33m{' '.join(cmd)}\033[0m")
|
||||
|
||||
# 运行 Python 脚本并传递参数
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
print(f"\033[92m算法生成器 {relative_alg_path} 运行成功。\033[0m")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"\033[91m算法生成器 {relative_alg_path} 运行失败,错误信息: {e}\033[0m")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-定义数据集、算法信息路径 ###########
|
||||
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
|
||||
all_data_record_json = "All_Data_Record.json"
|
||||
all_data_record_file = os.path.join(train_parameter_dir, all_data_record_json) # 数据集信息记录路径
|
||||
|
||||
alg_directory = './My_All_In_One/2_Alg_Program' # 算法配置生成路径
|
||||
|
||||
work_dir_base = './work_dirs' # 工作路径
|
||||
|
||||
########### 2.获取现有数据集信息 ###########
|
||||
print(f"\033[36m{'='*10} 一、选择训练数据集 {'='*10}\033[0m")
|
||||
data_record_json_data = load_json_file(file_name = all_data_record_file) # 加载数据集信息
|
||||
all_data_record = process_data_record_json_data(json_data = data_record_json_data) # 分析数据集信息
|
||||
dataset_file_name, num_classes, mean, std = select_dataset(all_data_record = all_data_record) # 选择特定数据集
|
||||
|
||||
########### 3.获取现有GPU信息 ###########
|
||||
print(f"\033[36m{'='*10} 二、选择训练GPU {'='*10}\033[0m")
|
||||
gpu_list_str, GPU_num = select_GPU() # 选择GPU
|
||||
|
||||
########### 4.获取训练批次信息 ###########
|
||||
print(f"\033[36m{'='*10} 三、选择训练批次信息 {'='*10}\033[0m")
|
||||
schedule_file_name, train_time_k = select_schedule() # 选择训练批次
|
||||
schedule_file_name = schedule_file_name.rstrip('.py')
|
||||
|
||||
########### 5.获取现有算法信息 ###########
|
||||
print(f"\033[36m{'='*10} 四、选择训练算法 {'='*10}\033[0m")
|
||||
selected_alg, relative_alg_path = select_alg(alg_directory=alg_directory) # 选择算法
|
||||
alg_name = selected_alg.rstrip(".py")
|
||||
|
||||
########### 6.运行选定的算法 ###########
|
||||
print(f"\033[36m{'='*5} 生成训练算法配置 {'='*5}\033[0m")
|
||||
run_alg_to_gen_alg(relative_alg_path=relative_alg_path, alg_name=alg_name, dataset_file_name=dataset_file_name, num_classes=num_classes, GPU_num=GPU_num, mean=mean, std=std,schedule_file_name=schedule_file_name, train_time_k=train_time_k)
|
||||
|
||||
# 算法相关信息
|
||||
with open("_temp_.txt", "r") as file:
|
||||
data = json.load(file)
|
||||
alg_file_name = data['alg_infos']['alg_file_name']
|
||||
alg_file_pth = data['alg_infos']['alg_file_pth']
|
||||
os.remove("_temp_.txt")
|
||||
|
||||
########### 7.输出工作、算法目录 ###########
|
||||
# 获取并打印当前的年月日时分秒
|
||||
data_now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
work_dir_file = os.path.join(work_dir_base, f"{dataset_file_name}-Class_{num_classes}-Alg_{alg_name}-AlgName_{alg_file_name}-Card_{GPU_num}-Data_{data_now}")
|
||||
|
||||
print(f"\033[36m训练指令:python tools/train.py {alg_file_pth} --work-dir {work_dir_file}")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,357 @@
|
||||
import os, json, subprocess
|
||||
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXk import generate_times_configs_base_schedules_schedule_file
|
||||
from Initial_Schedule_Program.Initial_Train_Gen_configs_base_schedules_schedule_XXe import generate_epochs_configs_base_schedules_schedule_file
|
||||
from datetime import datetime
|
||||
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_Tool import format_all_data, write_config_to_file, get_gpu_info
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_0_base_ import generate_base_config
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_1_norm_cfg import generate_norm_cfg
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_2_data_preprocessor import generate_data_preprocessor
|
||||
from Initial_Alg_Program.Initial_Alg_Gen_3_model import generate_model_pretrained, generate_model_backbone, generate_model_data_preprocessor, generate_model_decode_head, generate_model_auxiliary_head, generate_model_train_cfg, generate_model_test_cfg
|
||||
|
||||
def load_json_file(file_name):
|
||||
"""
|
||||
读取并返回指定 JSON 文件的内容。
|
||||
|
||||
:param file_name: 要读取的 JSON 文件路径
|
||||
:return: JSON 文件的内容作为字典或列表
|
||||
"""
|
||||
try:
|
||||
# 打开并读取 JSON 文件
|
||||
with open(file_name, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
print(f"\033[91mError: File {file_name} not found.\033[0m")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
print(f"\033[91mError: Failed to decode JSON from {file_name}.\033[0m")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"\033[91mAn error occurred while reading {file_name}: {str(e)}\033[0m")
|
||||
return None
|
||||
|
||||
def process_data_record_json_data(json_data):
|
||||
"""
|
||||
将 JSON 数据直接存入一个大的字典 all_data_record。
|
||||
:param json_data: 从 JSON 文件加载的数据
|
||||
:return: all_data_record 字典
|
||||
"""
|
||||
all_data_record = {}
|
||||
|
||||
# 遍历 json_data 中的每个键(即每个数据集)
|
||||
for dataset_name, dataset_info in json_data.items():
|
||||
all_data_record[dataset_name] = {
|
||||
"classes": dataset_info['classes'],
|
||||
"palette": dataset_info['palette'],
|
||||
"palette_num": dataset_info['palette_num'],
|
||||
"mean": dataset_info['mean'],
|
||||
"std": dataset_info['std'],
|
||||
"train_imgs_num": dataset_info['train_imgs_num']
|
||||
}
|
||||
|
||||
return all_data_record
|
||||
|
||||
# 选择要处理的数据集
|
||||
def select_dataset(all_data_record):
|
||||
"""
|
||||
让用户从 all_data_record 中选择数据集,并返回 dataset_file_name 和对应的 palette_num。
|
||||
|
||||
:param all_data_record: 包含所有数据集信息的字典
|
||||
:return: 选定的 dataset_file_name 和 palette_num
|
||||
"""
|
||||
# 获取所有数据集的名称
|
||||
dataset_names = list(all_data_record.keys())
|
||||
|
||||
# 显示可用的数据集,并让用户选择
|
||||
print("选择可用数据集:")
|
||||
for i, name in enumerate(dataset_names):
|
||||
print(f"{i + 1}. {name} - {all_data_record[name]['palette_num']}类")
|
||||
|
||||
# 用户输入选择的数据集编号
|
||||
while True:
|
||||
try:
|
||||
selection = int(input("请选择数据集编号(输入数字):")) - 1
|
||||
if 0 <= selection < len(dataset_names):
|
||||
dataset_file_name = dataset_names[selection]
|
||||
break
|
||||
else:
|
||||
print(f"输入的编号无效,请输入 1 到 {len(dataset_names)} 之间的数字。")
|
||||
except ValueError:
|
||||
print("无效输入,请输入数字。")
|
||||
|
||||
# 获取对应的 palette_num
|
||||
palette_num = all_data_record[dataset_file_name]['palette_num']
|
||||
mean = all_data_record[dataset_file_name]['mean']
|
||||
std = all_data_record[dataset_file_name]['std']
|
||||
train_imgs_num = all_data_record[dataset_file_name]['train_imgs_num']
|
||||
|
||||
print(f" 已选择数据集: {dataset_file_name} ,其对应的分类数为: {palette_num}")
|
||||
|
||||
return dataset_file_name, palette_num, mean, std, train_imgs_num
|
||||
|
||||
# 选择对应的算法
|
||||
def select_alg(alg_directory):
|
||||
"""
|
||||
选择一个算法文件并运行它。
|
||||
|
||||
:return: 选定的算法文件名和相对路径
|
||||
"""
|
||||
# 获取 ./Alg 目录下的所有 Python 文件
|
||||
algorithms = [f for f in os.listdir(alg_directory) if f.endswith('.py')]
|
||||
|
||||
# 显示可用的算法文件
|
||||
print("选择可用算法:")
|
||||
for i, alg in enumerate(algorithms):
|
||||
print(f"{i + 1}. {alg}")
|
||||
|
||||
# 用户选择算法
|
||||
while True:
|
||||
try:
|
||||
selection = int(input("请选择算法编号(输入数字):")) - 1
|
||||
if 0 <= selection < len(algorithms):
|
||||
selected_alg = algorithms[selection]
|
||||
break
|
||||
else:
|
||||
print(f"输入的编号无效,请输入 1 到 {len(algorithms)} 之间的数字。")
|
||||
except ValueError:
|
||||
print("无效输入,请输入数字。")
|
||||
|
||||
# 生成选定算法的相对路径
|
||||
relative_alg_path = os.path.join(alg_directory, selected_alg)
|
||||
print(f" 已选择算法: {selected_alg} {relative_alg_path}")
|
||||
|
||||
return selected_alg, relative_alg_path
|
||||
|
||||
# 选择计算用的GPU
|
||||
def select_GPU():
|
||||
"""
|
||||
让用户选择要使用的 GPU 数量,并返回相应的 GPU 列表(以 0,1 格式输出)。
|
||||
如果用户没有输入任何选择,默认为选择一块 GPU,编号为 0。
|
||||
:return: 用户选择的 GPU 列表和数量
|
||||
"""
|
||||
# 获取 GPU 信息
|
||||
gpu_info = get_gpu_info()
|
||||
|
||||
if not gpu_info:
|
||||
return None, 0
|
||||
|
||||
# 显示 GPU 信息
|
||||
print("可用的 GPU 列表:")
|
||||
for idx, mem_free in gpu_info:
|
||||
print(f"GPU {idx}: 剩余显存 {mem_free} MB")
|
||||
|
||||
# 提供默认选择 GPU 个数和编号
|
||||
default_num_gpus = 1
|
||||
default_gpu_idx = 0
|
||||
|
||||
# 用户选择 GPU 个数(默认选择 1)
|
||||
try:
|
||||
num_gpus = input(f"请选择使用的 GPU 个数 (1-{len(gpu_info)}, 默认为 1): ")
|
||||
num_gpus = int(num_gpus) if num_gpus else default_num_gpus
|
||||
if not (1 <= num_gpus <= len(gpu_info)):
|
||||
print(f"输入无效,使用默认值 1 个 GPU。")
|
||||
num_gpus = default_num_gpus
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 1 个 GPU。")
|
||||
num_gpus = default_num_gpus
|
||||
|
||||
# 用户选择 GPU 编号(如果选择多个 GPU,逐个选择编号)
|
||||
selected_gpus = []
|
||||
for i in range(num_gpus):
|
||||
try:
|
||||
gpu_idx = input(f"请输入要使用的第 {i + 1} 个 GPU 的编号 (默认为 {default_gpu_idx}): ")
|
||||
gpu_idx = int(gpu_idx) if gpu_idx else default_gpu_idx
|
||||
if gpu_idx in [idx for idx, _ in gpu_info] and gpu_idx not in selected_gpus:
|
||||
selected_gpus.append(gpu_idx)
|
||||
else:
|
||||
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
|
||||
selected_gpus.append(default_gpu_idx)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认 GPU {default_gpu_idx}")
|
||||
selected_gpus.append(default_gpu_idx)
|
||||
|
||||
# 返回以 0,1 格式的 GPU 列表和 GPU 数量
|
||||
gpu_list_str = ','.join(map(str, selected_gpus))
|
||||
print(f"\n已选择 GPU: {gpu_list_str},共 {num_gpus} 块 GPU")
|
||||
|
||||
return gpu_list_str, num_gpus
|
||||
|
||||
# 选择训练批次相关信息
|
||||
def select_schedule(train_imgs_num):
|
||||
"""
|
||||
让用户选择训练模式(Iteration 或 Epoch),并根据选择收集参数,
|
||||
最终生成对应的训练计划 schedule 配置文件。
|
||||
"""
|
||||
# 1. 让用户选择模式
|
||||
while True:
|
||||
mode = input("请选择训练模式 (1: Iteration, 2: Epoch) [默认: 2]: ").strip()
|
||||
if mode in ['1', '2', '']:
|
||||
mode = '2' if mode == '' else mode
|
||||
break
|
||||
else:
|
||||
print("无效输入,请输入 1 或 2。")
|
||||
|
||||
# --- 模式 1: 基于 Iteration 的配置 ---
|
||||
if mode == '1':
|
||||
print("\n--- 您已选择 Iteration 模式 ---")
|
||||
try:
|
||||
train_time_or_epoch_k = input("请输入训练次数 (k为单位, 默认40k): ").strip()
|
||||
train_time_or_epoch_k = int(train_time_or_epoch_k) if train_time_or_epoch_k else 40
|
||||
except ValueError:
|
||||
print("无效输入,使用默认训练时间 40k")
|
||||
train_time_or_epoch_k = 40
|
||||
|
||||
try:
|
||||
check_num = input("请输入总的验证/保存次数 (默认10): ").strip()
|
||||
check_num = int(check_num) if check_num else 10
|
||||
except ValueError:
|
||||
print("无效输入,使用默认次数 10")
|
||||
check_num = 10
|
||||
|
||||
try:
|
||||
loggerhook_interval = input("请输入日志间隔 (默认50次迭代): ").strip()
|
||||
loggerhook_interval = int(loggerhook_interval) if loggerhook_interval else 50
|
||||
except ValueError:
|
||||
print("无效输入,使用默认日志间隔 50")
|
||||
loggerhook_interval = 50
|
||||
|
||||
val_proportion = 1 / check_num
|
||||
# 根据验证比例计算间隔,确保至少为1
|
||||
interval = max(1, int(train_time_or_epoch_k * 1000 * val_proportion))
|
||||
|
||||
schedule_file_name = f'schedule_{train_time_or_epoch_k}k_check_{interval}.py'
|
||||
output_path = os.path.join('./configs/_base_/schedules/', schedule_file_name)
|
||||
|
||||
generate_times_configs_base_schedules_schedule_file(
|
||||
output_file=output_path,
|
||||
train_time_or_epoch_k=train_time_or_epoch_k,
|
||||
val_proportion=val_proportion,
|
||||
loggerhook_interval=loggerhook_interval
|
||||
)
|
||||
# 返回文件名和训练时长
|
||||
return schedule_file_name, "Iteration", train_time_or_epoch_k
|
||||
|
||||
# --- 模式 2: 基于 Epoch 的配置 ---
|
||||
elif mode == '2':
|
||||
print("\n--- 您已选择 Epoch 模式 ---")
|
||||
try:
|
||||
max_epochs = input("请输入训练总轮数 (Epochs, 默认300): ").strip()
|
||||
max_epochs = int(max_epochs) if max_epochs else 300
|
||||
except ValueError:
|
||||
print("无效输入,使用默认轮数 300")
|
||||
max_epochs = 300
|
||||
|
||||
try:
|
||||
val_interval = input("请输入验证间隔的轮次数 (Epochs, 默认1): ").strip()
|
||||
val_interval = int(val_interval) if val_interval else 1
|
||||
except ValueError:
|
||||
print("无效输入,使用默认验证间隔 1")
|
||||
val_interval = 1
|
||||
|
||||
try:
|
||||
checkpoint_interval = input("请输入模型保存间隔的轮次数 (Epochs, 默认10): ").strip()
|
||||
checkpoint_interval = int(checkpoint_interval) if checkpoint_interval else 10
|
||||
except ValueError:
|
||||
print("无效输入,使用默认保存间隔 10")
|
||||
checkpoint_interval = 10
|
||||
|
||||
try:
|
||||
loggerhook_interval_default = train_imgs_num // 16
|
||||
loggerhook_interval = input(f"请输入日志间隔 (默认{loggerhook_interval_default}次迭代): ").strip()
|
||||
loggerhook_interval = int(loggerhook_interval) if loggerhook_interval else loggerhook_interval_default
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认日志间隔 {loggerhook_interval_default}")
|
||||
loggerhook_interval = loggerhook_interval_default
|
||||
|
||||
schedule_file_name = f'schedule_{max_epochs}e_val{val_interval}_check{checkpoint_interval}.py'
|
||||
output_path = os.path.join('./configs/_base_/schedules/', schedule_file_name)
|
||||
|
||||
# 注意:这里调用的是基于Epoch的生成函数
|
||||
generate_epochs_configs_base_schedules_schedule_file(
|
||||
output_file=output_path,
|
||||
max_epochs=max_epochs,
|
||||
val_interval=val_interval,
|
||||
checkpoint_interval=checkpoint_interval,
|
||||
loggerhook_interval=loggerhook_interval
|
||||
)
|
||||
# 返回文件名和训练时长
|
||||
return schedule_file_name, "Epoch", max_epochs
|
||||
|
||||
# 运行对应的代码去生成算法配置文件
|
||||
def run_alg_to_gen_alg(relative_alg_path, mean, std, alg_name, dataset_file_name, num_classes, GPU_num, schedule_file_name, train_type, train_time_or_epoch_k):
|
||||
try:
|
||||
# 构建命令和参数
|
||||
cmd = [
|
||||
'python', relative_alg_path,
|
||||
'--alg_name', alg_name,
|
||||
'--dataset_file_name', dataset_file_name,
|
||||
'--mean', str(mean),
|
||||
'--std', str(std),
|
||||
'--dataset_num_classes', str(num_classes),
|
||||
'--GPU_num', str(GPU_num),
|
||||
'--schedule_file_name', schedule_file_name,
|
||||
'--train_type', train_type,
|
||||
'--train_time_or_epoch_k', str(train_time_or_epoch_k)
|
||||
]
|
||||
|
||||
# 打印当前执行的命令
|
||||
print(f"\n正在运行: \033[33m{' '.join(cmd)}\033[0m")
|
||||
|
||||
# 运行 Python 脚本并传递参数
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
print(f"\033[92m算法生成器 {relative_alg_path} 运行成功。\033[0m")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"\033[91m算法生成器 {relative_alg_path} 运行失败,错误信息: {e}\033[0m")
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 1.超参数-定义数据集、算法信息路径 ###########
|
||||
train_parameter_dir = './My_All_In_One/1_Data_Parameter'
|
||||
all_data_record_json = "All_Data_Record.json"
|
||||
all_data_record_file = os.path.join(train_parameter_dir, all_data_record_json) # 数据集信息记录路径
|
||||
|
||||
alg_directory = './My_All_In_One/2_Alg_Program' # 算法配置生成路径
|
||||
|
||||
work_dir_base = '../DataSet_Public_outputs' # 工作路径
|
||||
|
||||
########### 2.获取现有数据集信息 ###########
|
||||
print(f"\033[36m{'='*10} 一、选择训练数据集 {'='*10}\033[0m")
|
||||
data_record_json_data = load_json_file(file_name = all_data_record_file) # 加载数据集信息
|
||||
all_data_record = process_data_record_json_data(json_data = data_record_json_data) # 分析数据集信息
|
||||
dataset_file_name, num_classes, mean, std, train_imgs_num = select_dataset(all_data_record = all_data_record) # 选择特定数据集
|
||||
|
||||
########### 3.获取现有GPU信息 ###########
|
||||
print(f"\033[36m{'='*10} 二、选择训练GPU {'='*10}\033[0m")
|
||||
gpu_list_str, GPU_num = select_GPU() # 选择GPU
|
||||
|
||||
########### 4.获取训练批次信息 ###########
|
||||
print(f"\033[36m{'='*10} 三、选择训练批次信息 {'='*10}\033[0m")
|
||||
schedule_file_name, train_type, train_time_or_epoch_k = select_schedule(train_imgs_num) # 选择训练批次
|
||||
schedule_file_name = schedule_file_name.rstrip('.py')
|
||||
|
||||
########### 5.获取现有算法信息 ###########
|
||||
print(f"\033[36m{'='*10} 四、选择训练算法 {'='*10}\033[0m")
|
||||
selected_alg, relative_alg_path = select_alg(alg_directory=alg_directory) # 选择算法
|
||||
alg_name = selected_alg.rstrip(".py")
|
||||
|
||||
########### 6.运行选定的算法 ###########
|
||||
print(f"\033[36m{'='*5} 生成训练算法配置 {'='*5}\033[0m")
|
||||
run_alg_to_gen_alg(relative_alg_path=relative_alg_path, alg_name=alg_name, dataset_file_name=dataset_file_name, num_classes=num_classes, GPU_num=GPU_num, mean=mean, std=std,schedule_file_name=schedule_file_name, train_type = train_type, train_time_or_epoch_k=train_time_or_epoch_k)
|
||||
|
||||
# 算法相关信息
|
||||
with open("_temp_.txt", "r") as file:
|
||||
data = json.load(file)
|
||||
alg_file_name = data['alg_infos']['alg_file_name']
|
||||
alg_file_pth = data['alg_infos']['alg_file_pth']
|
||||
os.remove("_temp_.txt")
|
||||
|
||||
########### 7.输出工作、算法目录 ###########
|
||||
# 获取并打印当前的年月日时分秒
|
||||
data_now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
work_dir_file = os.path.join(work_dir_base, f"{dataset_file_name}-Class_{num_classes}-Alg_{alg_name}-AlgName_{alg_file_name}-Card_{GPU_num}-Data_{data_now}")
|
||||
|
||||
print(f"\033[36m训练指令:python tools/train.py {alg_file_pth} --work-dir {work_dir_file}")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
def find_specific_epochs(root_folder):
|
||||
"""
|
||||
遍历指定文件夹,查找所有 epoch_XXX.pth 文件,
|
||||
其中 XXX 为整数且不能被 10 整除,并返回它们的绝对路径。
|
||||
|
||||
:param root_folder: 要搜索的根文件夹路径
|
||||
:return: 一个包含符合条件文件绝对路径的列表
|
||||
"""
|
||||
matching_files = []
|
||||
|
||||
for dirpath, _, filenames in os.walk(root_folder):
|
||||
for filename in filenames:
|
||||
if filename.startswith('epoch_') and filename.endswith('.pth'):
|
||||
try:
|
||||
number_str = filename[len('epoch_'):-len('.pth')]
|
||||
if number_str.isdigit():
|
||||
epoch_number = int(number_str)
|
||||
if epoch_number % 10 != 0:
|
||||
full_path = os.path.join(dirpath, filename)
|
||||
matching_files.append(os.path.abspath(full_path))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return matching_files
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 1. 获取当前脚本文件所在的绝对目录
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 2. 以脚本目录为基准,构建目标路径
|
||||
# V1
|
||||
target_directory = os.path.abspath(os.path.join(script_dir, '../../Hardisk'))
|
||||
# V2
|
||||
# target_directory = os.path.abspath(os.path.join(script_dir, '../../DataSet_Public_outputs'))
|
||||
|
||||
# 1. 查找符合条件的文件
|
||||
found_paths = find_specific_epochs(target_directory)
|
||||
|
||||
if not found_paths:
|
||||
print(f"在 '{os.path.abspath(target_directory)}' 及其子目录中没有找到符合条件的文件。")
|
||||
sys.exit(0)
|
||||
|
||||
# 2. 列出所有找到的文件,并请求用户确认
|
||||
print("找到了以下符合条件的文件:")
|
||||
for path in found_paths:
|
||||
print(path)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("警告:接下来的操作将永久删除以上列出的所有文件!")
|
||||
print("="*50 + "\n")
|
||||
|
||||
# 3. 获取用户输入
|
||||
try:
|
||||
# 使用 strip() 去除首尾空格,使用 lower() 转换为小写
|
||||
confirm = input("您确定要删除这 {} 个文件吗? (请输入 'yes' 进行确认,输入其他任何内容则取消): ".format(len(found_paths)))
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n操作被用户中断。")
|
||||
sys.exit(1)
|
||||
|
||||
# 4. 根据用户的输入执行操作
|
||||
if confirm.lower().strip() == 'yes':
|
||||
print("\n正在开始删除文件...")
|
||||
deleted_count = 0
|
||||
error_count = 0
|
||||
for path in found_paths:
|
||||
try:
|
||||
os.remove(path)
|
||||
print(f"已删除: {path}")
|
||||
deleted_count += 1
|
||||
except OSError as e:
|
||||
print(f"删除失败: {path} (原因: {e})")
|
||||
error_count += 1
|
||||
print(f"\n操作完成。成功删除 {deleted_count} 个文件,{error_count} 个文件删除失败。")
|
||||
else:
|
||||
print("\n操作已取消,没有文件被删除。")
|
||||
@@ -0,0 +1,99 @@
|
||||
#!/bin/bash
|
||||
|
||||
# --- 脚本说明 ---
|
||||
# 功能: 使用 rsync 将指定格式的文件夹从源目录同步到目标目录。
|
||||
# - 脚本可以从任何位置安全执行。
|
||||
# - 在同步前会检查源文件/目录是否存在,如果不存在则跳过该任务。
|
||||
#
|
||||
# rsync 参数说明:
|
||||
# -a: 归档模式,保留文件属性。
|
||||
# -v: 详细模式。
|
||||
# -h: 人性化显示大小。
|
||||
# --progress: 显示传输进度。
|
||||
# --stats: 显示任务统计。
|
||||
|
||||
# --- 路径配置 ---
|
||||
# 获取脚本文件所在的绝对目录路径,确保路径的准确性。
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
|
||||
# 基于脚本所在目录设置源和目标的绝对路径
|
||||
SOURCE_DIR="$SCRIPT_DIR/../../DataSet_Public_outputs"
|
||||
DEST_DIR="$SCRIPT_DIR/../../Hardisk"
|
||||
|
||||
# --- 主程序 ---
|
||||
echo "脚本执行目录已锁定为: $SCRIPT_DIR"
|
||||
echo "解析后的源目录 (Source): $SOURCE_DIR"
|
||||
echo "解析后的目标根目录 (Destination): $DEST_DIR"
|
||||
echo "========================================"
|
||||
echo "开始执行 rsync 同步任务(带安全检查)..."
|
||||
echo ""
|
||||
|
||||
# --- 任务处理函数 (简化代码) ---
|
||||
# 定义一个函数来处理每个同步任务,避免代码重复
|
||||
# 参数1: 任务名称 (例如: 1_cholecseg8k)
|
||||
# 参数2: 源文件/目录的通配符模式
|
||||
# 参数3: 目标目录
|
||||
handle_sync_task() {
|
||||
local task_name="$1"
|
||||
local src_pattern="$2"
|
||||
local dest_dir="$3"
|
||||
|
||||
echo "--> 正在检查任务: $task_name"
|
||||
|
||||
# 将通配符匹配到的文件存入数组
|
||||
local sources=($src_pattern)
|
||||
|
||||
# 检查数组的第一个元素是否存在。如果通配符没有匹配到任何文件,
|
||||
# bash会把通配符本身作为字符串返回,而这个字符串命名的文件通常不存在。
|
||||
if [ -e "${sources[0]}" ]; then
|
||||
echo " 源文件已找到,准备同步..."
|
||||
echo " 目标目录: $dest_dir"
|
||||
|
||||
# 创建目标目录(如果不存在)
|
||||
mkdir -p "$dest_dir"
|
||||
|
||||
# 执行 rsync
|
||||
rsync -avh --progress --stats "${sources[@]}" "$dest_dir/"
|
||||
echo "--> 任务 '$task_name' 同步完成。"
|
||||
else
|
||||
echo " 警告: 源路径 '$src_pattern' 未匹配到任何文件,已跳过此任务。"
|
||||
fi
|
||||
echo "----------------------------------------"
|
||||
}
|
||||
|
||||
|
||||
# --- 任务列表 ---
|
||||
|
||||
# 任务1: 同步 1_cholecseg8k
|
||||
handle_sync_task \
|
||||
"1_cholecseg8k" \
|
||||
"$SOURCE_DIR/1_cholecseg8k-Class_13-Alg*" \
|
||||
"$DEST_DIR/1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg"
|
||||
|
||||
# 任务2: 同步 2_autolaparo
|
||||
handle_sync_task \
|
||||
"2_autolaparo" \
|
||||
"$SOURCE_DIR/2_autolaparo-Class_10-Alg*" \
|
||||
"$DEST_DIR/2_AutoLaparo-10Type-1920x1080_outputs-MMSeg"
|
||||
|
||||
# 任务3: 同步 3_1_endovis_2017
|
||||
handle_sync_task \
|
||||
"3_1_endovis_2017" \
|
||||
"$SOURCE_DIR/3_1_endovis_2017-Class_8-Alg*" \
|
||||
"$DEST_DIR/3_1_Endovis_2017-8Type-512x512_outputs-MMSeg"
|
||||
|
||||
# 任务4: 同步 3_2_endovis_2018
|
||||
handle_sync_task \
|
||||
"3_2_endovis_2018" \
|
||||
"$SOURCE_DIR/3_2_endovis_2018-Class_8-Alg*" \
|
||||
"$DEST_DIR/3_2_Endovis_2018-8Type-512x512_outputs-MMSeg"
|
||||
|
||||
# 任务5: 同步 4_dresden
|
||||
handle_sync_task \
|
||||
"4_dresden" \
|
||||
"$SOURCE_DIR/4_dresden-Class_11-Alg*" \
|
||||
"$DEST_DIR/4_Dresden-11Type-512x512_outputs-MMSeg"
|
||||
|
||||
|
||||
echo "========================================"
|
||||
echo "所有 rsync 同步任务已执行完毕!"
|
||||
@@ -0,0 +1,427 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import re
|
||||
import subprocess
|
||||
import csv
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.model.utils import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import Runner, load_checkpoint
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 辅助函数 ---
|
||||
def find_model_files(model_dir: str):
|
||||
"""
|
||||
在给定的模型目录中查找配置文件、最佳检查点和日志文件。
|
||||
|
||||
Args:
|
||||
model_dir (str): 模型的根目录。
|
||||
|
||||
Returns:
|
||||
Optional]: 包含 'config', 'checkpoint', 'log' 路径的字典,
|
||||
如果缺少任何必要文件,则返回 None。
|
||||
"""
|
||||
config_files = glob.glob(os.path.join(model_dir, '*.py'))
|
||||
if not config_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
|
||||
return None
|
||||
config_path = config_files[0]
|
||||
|
||||
checkpoint_path = os.path.join(model_dir, 'best.pth')
|
||||
if not os.path.exists(checkpoint_path):
|
||||
epoch_files = glob.glob(os.path.join(model_dir, 'epoch_*.pth'))
|
||||
if not epoch_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到 'best.pth' 或 'epoch_*.pth' 检查点文件。")
|
||||
return None
|
||||
|
||||
# 通过正则表达式从文件名中提取周期数并找到最大的
|
||||
latest_epoch = -1
|
||||
latest_file = None
|
||||
for f in epoch_files:
|
||||
match = re.search(r'epoch_(\d+)\.pth', os.path.basename(f))
|
||||
if match:
|
||||
epoch_num = int(match.group(1))
|
||||
if epoch_num > latest_epoch:
|
||||
latest_epoch = epoch_num
|
||||
latest_file = f
|
||||
|
||||
if latest_file:
|
||||
checkpoint_path = latest_file
|
||||
else:
|
||||
logging.warning(f"在目录 {model_dir} 中无法确定最新的检查点文件。")
|
||||
return None
|
||||
|
||||
return {'config': config_path, 'checkpoint': checkpoint_path}
|
||||
|
||||
def find_model_config(model_dir: str):
|
||||
"""
|
||||
在给定的模型目录中查找配置文件 (.py)。
|
||||
|
||||
Args:
|
||||
model_dir (str): 模型的根目录。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 配置文件的路径,如果未找到则返回 None。
|
||||
"""
|
||||
config_files = glob.glob(os.path.join(model_dir, '*.py'))
|
||||
if not config_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
|
||||
return None
|
||||
return config_files[0]
|
||||
|
||||
def get_shape_from_path(path: str):
|
||||
"""
|
||||
从文件夹路径中通过正则表达式提取分辨率 (宽x高)。
|
||||
|
||||
Args:
|
||||
path (str): 数据集文件夹的路径。
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[int, int]]: 一个包含 (高度, 宽度) 的元组,如果未找到则返回 None。
|
||||
注意:工具需要 H W 格式。
|
||||
"""
|
||||
match = re.search(r'(\d+)x(\d+)', os.path.basename(path))
|
||||
if match:
|
||||
width, height = int(match.group(1)), int(match.group(2))
|
||||
return (height, width) # 返回 H, W
|
||||
return None
|
||||
|
||||
def get_flops_and_params(config_path: str, shape: Tuple[int, int]) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
运行 mmsegmentation 的 get_flops.py 工具并解析其输出。
|
||||
此版本适配了新版的直接输出格式 (例如 "Flops: 0.118T")。
|
||||
|
||||
Args:
|
||||
config_path (str): 模型的 .py 配置文件路径。
|
||||
shape (Tuple[int, int]): 输入图像的 (H, W) 元组。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, str]]: 包含 'params' 和 'flops' 的字典,如果失败则返回 None。
|
||||
"""
|
||||
# 检查工具脚本是否存在
|
||||
tool_script = 'tools/analysis_tools/get_flops.py'
|
||||
if not os.path.exists(tool_script):
|
||||
logging.error(f"错误: '{tool_script}' 未找到。请确保在 MMSegmentation 项目的根目录下运行此脚本。")
|
||||
return None
|
||||
|
||||
# 构建命令行
|
||||
command = [
|
||||
'python', tool_script, config_path,
|
||||
'--shape', str(shape[0]), str(shape[1])
|
||||
]
|
||||
|
||||
logging.info(f"执行命令: {' '.join(command)}")
|
||||
|
||||
try:
|
||||
# 执行命令并捕获输出
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True, encoding='utf-8')
|
||||
output = result.stdout
|
||||
|
||||
# 使用新的正则表达式来匹配更新后的输出格式
|
||||
flops_match = re.search(r"Flops:\s*([0-9.]+\s*[TGMK]?)", output)
|
||||
params_match = re.search(r"Params:\s*([0-9.]+\s*[TGMK]?)", output)
|
||||
|
||||
if flops_match and params_match:
|
||||
raw_flops_str = flops_match.group(1).strip()
|
||||
params = params_match.group(1).strip()
|
||||
# --- 开始单位换算 ---
|
||||
value_str = raw_flops_str.rstrip('TGMKtgmk').strip()
|
||||
unit = raw_flops_str[-1].upper() if raw_flops_str[-1].isalpha() else 'G'
|
||||
try:
|
||||
value = float(value_str)
|
||||
if unit == 'T':
|
||||
value_in_g = value * 1000
|
||||
elif unit == 'M':
|
||||
value_in_g = value / 1000
|
||||
elif unit == 'K':
|
||||
value_in_g = value / 1_000_000
|
||||
else: # 默认单位是 G
|
||||
value_in_g = value
|
||||
# 使用 :g 格式化可以去除末尾多余的0
|
||||
flops = f"{value_in_g:g} G"
|
||||
except ValueError:
|
||||
flops = raw_flops_str # 如果转换失败,则使用原始值
|
||||
# --- 单位换算结束 ---
|
||||
|
||||
logging.info(f"✅ 解析成功: FLOPs={flops} (原始值: {raw_flops_str}), Params={params}")
|
||||
return {'flops': flops, 'params': params}
|
||||
else:
|
||||
logging.warning(f"❌ 无法从命令输出中解析 FLOPs 或 Params。请检查以下输出内容:\n---\n{output}\n---")
|
||||
return None
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"执行 get_flops.py 时出错。返回码: {e.returncode}")
|
||||
logging.error(f"错误输出 (stderr):\n---\n{e.stderr}\n---")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("错误: 'python' 命令未找到。请确保 Python 环境已正确配置。")
|
||||
return None
|
||||
|
||||
def get_benchmark_stats(config_path: str, checkpoint_path: str, repeat_times: int) -> Optional[Dict[str, float]]:
|
||||
"""
|
||||
Runs inference benchmark based on the logic from benchmark.py.
|
||||
|
||||
Args:
|
||||
config_path (str): Path to the model's .py config file.
|
||||
checkpoint_path (str): Path to the model's .pth checkpoint file.
|
||||
repeat_times (int): Number of times to run the benchmark.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, float]]: Dict containing 'average_fps' and 'fps_variance', or None on failure.
|
||||
"""
|
||||
try:
|
||||
cfg = Config.fromfile(config_path)
|
||||
init_default_scope(cfg.get('default_scope', 'mmseg'))
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
cfg.model.pretrained = None
|
||||
cfg.test_dataloader.batch_size = 1 # Crucial for FPS measurement
|
||||
|
||||
overall_fps_list = []
|
||||
for time_index in range(repeat_times):
|
||||
logging.info(f"--- Starting Benchmark Run {time_index + 1}/{repeat_times} ---")
|
||||
|
||||
data_loader = Runner.build_dataloader(cfg.test_dataloader)
|
||||
|
||||
cfg.model.train_cfg = None
|
||||
model = MODELS.build(cfg.model)
|
||||
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda(0)
|
||||
else:
|
||||
logging.warning("CUDA is not available. Benchmarking on CPU, results may be slow.")
|
||||
|
||||
model = revert_sync_batchnorm(model)
|
||||
model.eval()
|
||||
|
||||
num_warmup = 5
|
||||
pure_inf_time = 0
|
||||
total_iters = 100 # Reduced from 200 for faster script execution
|
||||
|
||||
for i, data in enumerate(data_loader):
|
||||
data = model.data_preprocessor(data, True)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
model(data['inputs'], data['data_samples'], mode='predict')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start_time
|
||||
|
||||
if i >= num_warmup:
|
||||
pure_inf_time += elapsed
|
||||
|
||||
if (i + 1) == total_iters:
|
||||
fps = (total_iters - num_warmup) / pure_inf_time
|
||||
logging.info(f"Run {time_index + 1} Overall FPS: {fps:.2f} img/s")
|
||||
overall_fps_list.append(fps)
|
||||
break
|
||||
|
||||
if not overall_fps_list:
|
||||
logging.error("Benchmark failed to produce any results.")
|
||||
return None
|
||||
|
||||
avg_fps = round(np.mean(overall_fps_list), 2)
|
||||
fps_var = round(np.var(overall_fps_list), 4)
|
||||
|
||||
logging.info(f"✅ Benchmark Complete: Average FPS={avg_fps}, Variance={fps_var}")
|
||||
return {'average_fps': avg_fps, 'fps_variance': fps_var}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"An exception occurred during benchmarking: {e}")
|
||||
return None
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
# --- 开始交互式选择修改 (V2 - 两级菜单) ---
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# 1. 定义有效的数据集文件夹白名单
|
||||
VALID_DATASET_FOLDERS = [
|
||||
'1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
|
||||
'3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
|
||||
'3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
|
||||
'4_Dresden-11Type-512x512_outputs-MMSeg'
|
||||
]
|
||||
|
||||
# 2. 查找存在的、有效的数据集目录
|
||||
existing_dataset_dirs = [
|
||||
os.path.join(input_root, d) for d in VALID_DATASET_FOLDERS
|
||||
if os.path.isdir(os.path.join(input_root, d))
|
||||
]
|
||||
|
||||
if not existing_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
|
||||
# 3. 第一级菜单:选择数据集
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = [] # 初始化最终要处理的目录列表
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
# 4. 查找选定数据集下的所有算法子目录
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。程序退出。")
|
||||
else:
|
||||
# 5. 第二级菜单:选择算法
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
|
||||
# 6. 根据第二级选择,最终确定 model_dirs
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]] # 将单个路径放入列表中
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
|
||||
# --- 交互式选择修改结束 ---
|
||||
results: List[Dict[str, str]] = []
|
||||
# 修改后的循环,将遍历经过用户筛选后的 model_dirs 列表
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"--- 开始处理模型: {model_name} ---")
|
||||
files = find_model_files(model_dir)
|
||||
if not files:
|
||||
logging.warning(f"跳过目录 {model_dir},因为缺少必要文件。")
|
||||
continue
|
||||
|
||||
# 构建输出目录
|
||||
# 从模型名中提取数据集标识作为Key
|
||||
dataset_key = model_name.split('-')[0]
|
||||
dataset_map = {
|
||||
'1_cholecseg8k': '1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_autolaparo': '2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
|
||||
'3_1_endovis_2017': '3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
|
||||
'3_2_endovis_2018': '3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
|
||||
'4_dresden': '4_Dresden-11Type-512x512_outputs-MMSeg'
|
||||
}
|
||||
# 使用提取的Key(字符串)进行查询,并为默认值也使用该Key
|
||||
output_dataset_folder = dataset_map.get(dataset_key, f"{dataset_key}_outputs-MMSeg")
|
||||
|
||||
# 尝试从数据集文件夹名称中获取分辨率
|
||||
input_shape = get_shape_from_path(selected_dataset_dir)
|
||||
if not input_shape:
|
||||
# 如果无法自动提取,要求用户输入
|
||||
logging.warning(f"无法从文件夹 '{os.path.basename(selected_dataset_dir)}' 名称中自动检测分辨率。")
|
||||
try:
|
||||
h_str = input("请输入默认测试高度 (H),例如 512: ").strip()
|
||||
w_str = input("请输入默认测试宽度 (W),例如 512: ").strip()
|
||||
input_shape = (int(h_str), int(w_str))
|
||||
except ValueError:
|
||||
logging.error("输入无效,必须是整数。程序退出。")
|
||||
return
|
||||
logging.info(f"将使用输入形状 (H, W): {input_shape} 进行计算。")
|
||||
|
||||
# 加载配置
|
||||
config_file = find_model_config(model_dir)
|
||||
|
||||
# 获取 FLOPs 和 Params
|
||||
flops_and_params_stats = get_flops_and_params(config_file, input_shape)
|
||||
benchmark_stats = get_benchmark_stats(files['config'], files['checkpoint'], args.repeat_times)
|
||||
if flops_and_params_stats:
|
||||
short_model_name = model_name.split('Alg_', 1)[1]
|
||||
results.append({
|
||||
'Model': short_model_name,
|
||||
'Params': flops_and_params_stats['params'] if flops_and_params_stats else 'N/A' ,
|
||||
'FLOPs': flops_and_params_stats['flops'] if flops_and_params_stats else 'N/A' ,
|
||||
'Input_Shape (HxW)': f"{input_shape[0]}x{input_shape[1]}",
|
||||
'Average_FPS': benchmark_stats['average_fps'] if benchmark_stats else 'N/A',
|
||||
'FPS_Variance': benchmark_stats['fps_variance'] if benchmark_stats else 'N/A'
|
||||
})
|
||||
else:
|
||||
logging.warning(f"未能获取模型 {model_name} 的统计信息。")
|
||||
|
||||
# --- 将结果写入 CSV 文件 ---
|
||||
if not results:
|
||||
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
|
||||
return
|
||||
|
||||
# 新建文件夹并保存 CSV
|
||||
final_output_dir = os.path.join(output_root, output_dataset_folder)
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
|
||||
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_flops_params_fps_summary.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
fieldnames = ['Model', 'Params', 'FLOPs', 'Input_Shape (HxW)', 'Average_FPS', 'FPS_Variance']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
writer.writeheader()
|
||||
writer.writerows(results)
|
||||
|
||||
logging.info(f"=== 全部处理完成!结果已成功保存到: {output_csv_path} ===")
|
||||
except IOError as e:
|
||||
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="MMSegmentation 自动化评估脚本")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含已训练模型文件夹的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--repeat-times',
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of times to repeat the benchmark for averaging."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,332 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import re
|
||||
import subprocess
|
||||
import csv
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 辅助函数 ---
|
||||
def find_model_files(model_dir: str):
|
||||
"""
|
||||
在给定的模型目录中查找配置文件、最佳检查点和日志文件。
|
||||
|
||||
Args:
|
||||
model_dir (str): 模型的根目录。
|
||||
|
||||
Returns:
|
||||
Optional]: 包含 'config', 'checkpoint', 'log' 路径的字典,
|
||||
如果缺少任何必要文件,则返回 None。
|
||||
"""
|
||||
config_files = glob.glob(os.path.join(model_dir, '*.py'))
|
||||
if not config_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
|
||||
return None
|
||||
config_path = config_files[0]
|
||||
|
||||
checkpoint_path = os.path.join(model_dir, 'best.pth')
|
||||
if not os.path.exists(checkpoint_path):
|
||||
epoch_files = glob.glob(os.path.join(model_dir, 'epoch_*.pth'))
|
||||
if not epoch_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到 'best.pth' 或 'epoch_*.pth' 检查点文件。")
|
||||
return None
|
||||
|
||||
# 通过正则表达式从文件名中提取周期数并找到最大的
|
||||
latest_epoch = -1
|
||||
latest_file = None
|
||||
for f in epoch_files:
|
||||
match = re.search(r'epoch_(\d+)\.pth', os.path.basename(f))
|
||||
if match:
|
||||
epoch_num = int(match.group(1))
|
||||
if epoch_num > latest_epoch:
|
||||
latest_epoch = epoch_num
|
||||
latest_file = f
|
||||
|
||||
if latest_file:
|
||||
checkpoint_path = latest_file
|
||||
else:
|
||||
logging.warning(f"在目录 {model_dir} 中无法确定最新的检查点文件。")
|
||||
return None
|
||||
|
||||
return {'config': config_path, 'checkpoint': checkpoint_path}
|
||||
|
||||
def find_model_config(model_dir: str):
|
||||
"""
|
||||
在给定的模型目录中查找配置文件 (.py)。
|
||||
|
||||
Args:
|
||||
model_dir (str): 模型的根目录。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 配置文件的路径,如果未找到则返回 None。
|
||||
"""
|
||||
config_files = glob.glob(os.path.join(model_dir, '*.py'))
|
||||
if not config_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
|
||||
return None
|
||||
return config_files[0]
|
||||
|
||||
def get_shape_from_path(path: str):
|
||||
"""
|
||||
从文件夹路径中通过正则表达式提取分辨率 (宽x高)。
|
||||
|
||||
Args:
|
||||
path (str): 数据集文件夹的路径。
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[int, int]]: 一个包含 (高度, 宽度) 的元组,如果未找到则返回 None。
|
||||
注意:工具需要 H W 格式。
|
||||
"""
|
||||
match = re.search(r'(\d+)x(\d+)', os.path.basename(path))
|
||||
if match:
|
||||
width, height = int(match.group(1)), int(match.group(2))
|
||||
return (height, width) # 返回 H, W
|
||||
return None
|
||||
|
||||
def get_flops_and_params(config_path: str, shape: Tuple[int, int]) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
运行 mmsegmentation 的 get_flops.py 工具并解析其输出。
|
||||
此版本适配了新版的直接输出格式 (例如 "Flops: 0.118T")。
|
||||
|
||||
Args:
|
||||
config_path (str): 模型的 .py 配置文件路径。
|
||||
shape (Tuple[int, int]): 输入图像的 (H, W) 元组。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, str]]: 包含 'params' 和 'flops' 的字典,如果失败则返回 None。
|
||||
"""
|
||||
# 检查工具脚本是否存在
|
||||
tool_script = 'tools/analysis_tools/get_flops.py'
|
||||
if not os.path.exists(tool_script):
|
||||
logging.error(f"错误: '{tool_script}' 未找到。请确保在 MMSegmentation 项目的根目录下运行此脚本。")
|
||||
return None
|
||||
|
||||
# 构建命令行
|
||||
command = [
|
||||
'python', tool_script, config_path,
|
||||
'--shape', str(shape[0]), str(shape[1])
|
||||
]
|
||||
|
||||
logging.info(f"执行命令: {' '.join(command)}")
|
||||
|
||||
try:
|
||||
# 执行命令并捕获输出
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True, encoding='utf-8')
|
||||
output = result.stdout
|
||||
|
||||
# 使用新的正则表达式来匹配更新后的输出格式
|
||||
flops_match = re.search(r"Flops:\s*([0-9.]+\s*[TGMK]?)", output)
|
||||
params_match = re.search(r"Params:\s*([0-9.]+\s*[TGMK]?)", output)
|
||||
|
||||
if flops_match and params_match:
|
||||
raw_flops_str = flops_match.group(1).strip()
|
||||
params = params_match.group(1).strip()
|
||||
# --- 开始单位换算 ---
|
||||
value_str = raw_flops_str.rstrip('TGMKtgmk').strip()
|
||||
unit = raw_flops_str[-1].upper() if raw_flops_str[-1].isalpha() else 'G'
|
||||
try:
|
||||
value = float(value_str)
|
||||
if unit == 'T':
|
||||
value_in_g = value * 1000
|
||||
elif unit == 'M':
|
||||
value_in_g = value / 1000
|
||||
elif unit == 'K':
|
||||
value_in_g = value / 1_000_000
|
||||
else: # 默认单位是 G
|
||||
value_in_g = value
|
||||
# 使用 :g 格式化可以去除末尾多余的0
|
||||
flops = f"{value_in_g:g} G"
|
||||
except ValueError:
|
||||
flops = raw_flops_str # 如果转换失败,则使用原始值
|
||||
# --- 单位换算结束 ---
|
||||
|
||||
logging.info(f"✅ 解析成功: FLOPs={flops} (原始值: {raw_flops_str}), Params={params}")
|
||||
return {'flops': flops, 'params': params}
|
||||
else:
|
||||
logging.warning(f"❌ 无法从命令输出中解析 FLOPs 或 Params。请检查以下输出内容:\n---\n{output}\n---")
|
||||
return None
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"执行 get_flops.py 时出错。返回码: {e.returncode}")
|
||||
logging.error(f"错误输出 (stderr):\n---\n{e.stderr}\n---")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("错误: 'python' 命令未找到。请确保 Python 环境已正确配置。")
|
||||
return None
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
# --- 开始交互式选择修改 (V2 - 两级菜单) ---
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# 1. 定义有效的数据集文件夹白名单
|
||||
VALID_DATASET_FOLDERS = [
|
||||
'1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
|
||||
'3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
|
||||
'3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
|
||||
'4_Dresden-11Type-512x512_outputs-MMSeg'
|
||||
]
|
||||
|
||||
# 2. 查找存在的、有效的数据集目录
|
||||
existing_dataset_dirs = [
|
||||
os.path.join(input_root, d) for d in VALID_DATASET_FOLDERS
|
||||
if os.path.isdir(os.path.join(input_root, d))
|
||||
]
|
||||
|
||||
if not existing_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
|
||||
# 3. 第一级菜单:选择数据集
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = [] # 初始化最终要处理的目录列表
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
# 4. 查找选定数据集下的所有算法子目录
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。程序退出。")
|
||||
else:
|
||||
# 5. 第二级菜单:选择算法
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
|
||||
# 6. 根据第二级选择,最终确定 model_dirs
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]] # 将单个路径放入列表中
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
|
||||
# --- 交互式选择修改结束 ---
|
||||
results: List[Dict[str, str]] = []
|
||||
# 修改后的循环,将遍历经过用户筛选后的 model_dirs 列表
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"--- 开始处理模型: {model_name} ---")
|
||||
files = find_model_files(model_dir)
|
||||
if not files:
|
||||
logging.warning(f"跳过目录 {model_dir},因为缺少必要文件。")
|
||||
continue
|
||||
|
||||
# 构建输出目录
|
||||
# 从模型名中提取数据集标识作为Key
|
||||
dataset_key = model_name.split('-')[0]
|
||||
dataset_map = {
|
||||
'1_cholecseg8k': '1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_autolaparo': '2_AutoLaparo-10Type-1280x1024_outputs-MMSeg',
|
||||
'3_1_endovis_2017': '3_1_EndoVis_2017-7Type-1280x1024_outputs-MMSeg',
|
||||
'3_2_endovis_2018': '3_2_EndoVis_2018-11Type-1280x1024_outputs-MMSeg',
|
||||
'4_dresden': '4_Dresden-6Type-1920x1080_outputs-MMSeg'
|
||||
}
|
||||
# 使用提取的Key(字符串)进行查询,并为默认值也使用该Key
|
||||
output_dataset_folder = dataset_map.get(dataset_key, f"{dataset_key}_outputs-MMSeg")
|
||||
|
||||
# 尝试从数据集文件夹名称中获取分辨率
|
||||
input_shape = get_shape_from_path(selected_dataset_dir)
|
||||
if not input_shape:
|
||||
# 如果无法自动提取,要求用户输入
|
||||
logging.warning(f"无法从文件夹 '{os.path.basename(selected_dataset_dir)}' 名称中自动检测分辨率。")
|
||||
try:
|
||||
h_str = input("请输入默认测试高度 (H),例如 512: ").strip()
|
||||
w_str = input("请输入默认测试宽度 (W),例如 512: ").strip()
|
||||
input_shape = (int(h_str), int(w_str))
|
||||
except ValueError:
|
||||
logging.error("输入无效,必须是整数。程序退出。")
|
||||
return
|
||||
logging.info(f"将使用输入形状 (H, W): {input_shape} 进行计算。")
|
||||
|
||||
# 加载配置
|
||||
config_file = find_model_config(model_dir)
|
||||
|
||||
# 获取 FLOPs 和 Params
|
||||
stats = get_flops_and_params(config_file, input_shape)
|
||||
if stats:
|
||||
short_model_name = model_name.split('Alg_', 1)[1]
|
||||
results.append({
|
||||
'Model': short_model_name,
|
||||
'Params': stats['params'],
|
||||
'FLOPs': stats['flops'],
|
||||
'Input_Shape (HxW)': f"{input_shape[0]}x{input_shape[1]}"
|
||||
})
|
||||
else:
|
||||
logging.warning(f"未能获取模型 {model_name} 的统计信息。")
|
||||
|
||||
# --- 将结果写入 CSV 文件 ---
|
||||
if not results:
|
||||
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
|
||||
return
|
||||
|
||||
# 新建文件夹并保存 CSV
|
||||
final_output_dir = os.path.join(output_root, output_dataset_folder)
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
|
||||
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_flops_params_summary.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
fieldnames = ['Model', 'Params', 'FLOPs', 'Input_Shape (HxW)']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
writer.writeheader()
|
||||
writer.writerows(results)
|
||||
|
||||
logging.info(f"=== 全部处理完成!结果已成功保存到: {output_csv_path} ===")
|
||||
except IOError as e:
|
||||
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="MMSegmentation 自动化评估脚本")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含已训练模型文件夹的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果的根目录。"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,322 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import re
|
||||
import csv
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
# TODO 这个是获取最后一次结果的 TODO
|
||||
|
||||
# --- 配置日志记录 ---
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 辅助函数 ---
|
||||
def find_all_log_files_sorted(algorithm_dir: str) -> List[str]:
|
||||
"""
|
||||
查找给定算法目录中所有的.log文件,并按从新到旧的顺序排列。
|
||||
|
||||
Args:
|
||||
algorithm_dir (str): 算法的根目录。
|
||||
|
||||
Returns:
|
||||
List[str]: 按时间倒序排列的日志文件路径列表。
|
||||
"""
|
||||
try:
|
||||
subdirs = [d for d in os.listdir(algorithm_dir) if os.path.isdir(os.path.join(algorithm_dir, d))]
|
||||
except FileNotFoundError:
|
||||
logging.error(f"算法目录不存在: {algorithm_dir}")
|
||||
return []
|
||||
|
||||
if not subdirs:
|
||||
logging.warning(f"在目录 {algorithm_dir} 中未找到任何时间戳子目录。")
|
||||
return []
|
||||
|
||||
# 按名称倒序排序,最新的目录会排在最前面
|
||||
sorted_subdirs = sorted(subdirs, reverse=True)
|
||||
|
||||
log_files = []
|
||||
for subdir_name in sorted_subdirs:
|
||||
subdir_path = os.path.join(algorithm_dir, subdir_name)
|
||||
logs_in_subdir = glob.glob(os.path.join(subdir_path, '*.log'))
|
||||
if logs_in_subdir:
|
||||
# 假设每个子目录只有一个log文件
|
||||
log_files.append(logs_in_subdir[0])
|
||||
|
||||
return log_files
|
||||
|
||||
def parse_log_metrics(log_path: str) -> Optional[Dict]:
|
||||
"""
|
||||
解析日志文件,提取最后一次完整验证(validation)的结果及其对应的Epoch。
|
||||
|
||||
Args:
|
||||
log_path (str): 日志文件的路径。
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 包含 'epoch', 'summary', 'class_wise' 指标的字典,如果解析失败则返回 None。
|
||||
"""
|
||||
try:
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except IOError as e:
|
||||
logging.error(f"无法读取日志文件: {log_path}。错误: {e}")
|
||||
return None
|
||||
# 定义正则表达式
|
||||
summary_pattern = re.compile(
|
||||
r"Iter\(val\) \[\d+/\d+\]\s+aAcc:\s*([\d.]+)\s+mIoU:\s*([\d.]+)\s+mAcc:\s*([\d.]+)"
|
||||
)
|
||||
class_table_pattern = re.compile(
|
||||
# --- 使用新的模式匹配可变长度的顶部边框 ---
|
||||
r"\+(?:-+\+)+\s*\n"
|
||||
r"\|.*?Class.*?\|.*?IoU.*?\|.*?Acc.*?\|\n"
|
||||
# --- 匹配中间边框 ---
|
||||
r"\+(?:-+\+)+\s*\n"
|
||||
# --- 捕获表格主体 ---
|
||||
r"((?:\|.*?\|.*?\|.*?\|\n)+)"
|
||||
# --- 匹配底部边框 ---
|
||||
r"\+(?:-+\+)+\s*\n",
|
||||
re.MULTILINE
|
||||
)
|
||||
epoch_pattern = re.compile(r"Saving checkpoint at (\d+) epochs|resumed epoch: (\d+)")
|
||||
|
||||
# 查找所有匹配项
|
||||
summary_matches = list(re.finditer(summary_pattern, content))
|
||||
table_matches = list(re.finditer(class_table_pattern, content))
|
||||
epoch_matches = list(re.finditer(epoch_pattern, content))
|
||||
|
||||
if not summary_matches or not table_matches:
|
||||
logging.warning(f"❌ 在日志 {os.path.basename(log_path)} 中未能找到完整的验证结果。")
|
||||
return None
|
||||
|
||||
last_summary_match = summary_matches[-1]
|
||||
last_table_match = None
|
||||
for table in reversed(table_matches):
|
||||
if table.end() < last_summary_match.start():
|
||||
last_table_match = table
|
||||
break
|
||||
|
||||
if not last_table_match:
|
||||
logging.warning(f"❌ 在日志 {os.path.basename(log_path)} 中找到总结行但未能匹配到对应的类别表格。")
|
||||
return None
|
||||
|
||||
# 寻找关联的最新Epoch
|
||||
last_epoch = "N/A"
|
||||
latest_epoch_num = -1
|
||||
for epoch_match in epoch_matches:
|
||||
if epoch_match.end() < last_table_match.start():
|
||||
epoch_str = epoch_match.group(1) or epoch_match.group(2)
|
||||
if epoch_str:
|
||||
epoch_num = int(epoch_str)
|
||||
if epoch_num > latest_epoch_num:
|
||||
latest_epoch_num = epoch_num
|
||||
last_epoch = f"epoch_{epoch_num}"
|
||||
|
||||
# 解析数据
|
||||
summary_groups = last_summary_match.groups()
|
||||
results = {
|
||||
'epoch': last_epoch,
|
||||
'summary': {
|
||||
'aAcc': summary_groups[0],
|
||||
'mIoU': summary_groups[1],
|
||||
'mAcc': summary_groups[2]
|
||||
},
|
||||
'class_wise': []
|
||||
}
|
||||
|
||||
table_content = last_table_match.group(0)
|
||||
row_pattern = re.compile(r"\|\s*([\w\s]+?)\s*\|\s*([\d.]+)\s*\|\s*([\d.]+)\s*\|")
|
||||
for line in table_content.strip().split('\n'):
|
||||
row_match = row_pattern.match(line)
|
||||
if row_match:
|
||||
class_name, iou, acc = row_match.groups()
|
||||
results['class_wise'].append({
|
||||
'Class': class_name.strip(),
|
||||
'IoU': iou,
|
||||
'Acc': acc
|
||||
})
|
||||
|
||||
if results['class_wise']:
|
||||
logging.info(f"✅ 成功从 {os.path.basename(log_path)} 中解析出 Epoch '{last_epoch}' 的指标。")
|
||||
return results
|
||||
else:
|
||||
logging.warning(f"❌ 在 {os.path.basename(log_path)} 中未能解析出任何类别行。")
|
||||
return None
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# --- 交互式菜单 ---
|
||||
all_dataset_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(input_root, '*_outputs-MMSeg')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not all_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(all_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = []
|
||||
selected_dataset_dir = None
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。")
|
||||
return
|
||||
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]]
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
return
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
return
|
||||
|
||||
# --- 开始处理选定的算法 (逻辑已修改) ---
|
||||
csv_rows = []
|
||||
output_dataset_folder = ""
|
||||
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"\n--- 开始处理算法: {model_name} ---")
|
||||
|
||||
# (路径构建代码保持不变)
|
||||
if not output_dataset_folder:
|
||||
dataset_key = model_name.split('-')[0]
|
||||
dataset_folder_map = {
|
||||
'1_cholecseg8k': '1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_autolaparo': '2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
|
||||
'3_1_endovis_2017': '3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
|
||||
'3_2_endovis_2018': '3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
|
||||
'4_dresden': '4_Dresden-11Type-512x512_outputs-MMSeg'
|
||||
}
|
||||
output_dataset_folder = dataset_folder_map.get(dataset_key, f"{dataset_key}_outputs-MMSeg")
|
||||
|
||||
# --- 新的循环查找逻辑 ---
|
||||
# 1. 获取所有按时间倒序排列的日志文件
|
||||
all_logs_sorted = find_all_log_files_sorted(model_dir)
|
||||
|
||||
if not all_logs_sorted:
|
||||
logging.warning(f"跳过算法 {model_name},因为未找到任何日志文件。")
|
||||
continue
|
||||
|
||||
# 2. 循环尝试解析,直到成功或全部失败
|
||||
metrics = None
|
||||
for log_file_path in all_logs_sorted:
|
||||
logging.info(f"正在尝试解析: {os.path.relpath(log_file_path)}")
|
||||
metrics = parse_log_metrics(log_file_path)
|
||||
if metrics:
|
||||
logging.info(f"在 {os.path.basename(log_file_path)} 中成功找到并解析了指标。")
|
||||
break # 找到后立即跳出循环
|
||||
|
||||
# 3. 如果所有日志都尝试失败,则跳过此算法
|
||||
if not metrics:
|
||||
logging.warning(f"❌❌❌跳过算法 {model_name},因为在其所有日志文件中都未能找到有效的指标。❌❌❌")
|
||||
continue
|
||||
|
||||
# --- 创建一个 "宽" 格式的行 ---
|
||||
summary = metrics['summary']
|
||||
short_model_name = model_name.split('Alg_', 1)[1]
|
||||
row_data = {
|
||||
'Algorithm': short_model_name,
|
||||
'Epoch': metrics['epoch'],
|
||||
'mIoU': summary['mIoU'],
|
||||
'mAcc': summary['mAcc'],
|
||||
'aAcc': summary['aAcc']
|
||||
}
|
||||
|
||||
# 将每个类别的IoU和Acc作为新列添加到行数据中
|
||||
for class_data in metrics['class_wise']:
|
||||
class_name = class_data['Class'].replace(' ', '_') # 清理类名以用作表头
|
||||
row_data[f'{class_name}_IoU'] = class_data['IoU']
|
||||
row_data[f'{class_name}_Acc'] = class_data['Acc']
|
||||
|
||||
csv_rows.append(row_data)
|
||||
|
||||
# --- 将结果写入 CSV 文件 (逻辑已修改) ---
|
||||
if not csv_rows:
|
||||
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
|
||||
return
|
||||
|
||||
# --- 动态生成并排序表头 ---
|
||||
# 基础列保持固定顺序
|
||||
base_fieldnames = ['Algorithm', 'Epoch', 'mIoU', 'mAcc', 'aAcc']
|
||||
# 从第一个结果中获取所有与类别相关的列名,并按字母排序
|
||||
first_row_keys = csv_rows[0].keys()
|
||||
class_fieldnames = sorted([key for key in first_row_keys if key not in base_fieldnames])
|
||||
# 最终的完整表头
|
||||
final_fieldnames = base_fieldnames + class_fieldnames
|
||||
|
||||
# --- 构建输出路径并写入文件 ---
|
||||
final_output_dir = os.path.join(output_root, output_dataset_folder)
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
|
||||
# 在文件名中加入 "_wide" 以区分格式
|
||||
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_metrics_summary_wide.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=final_fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(csv_rows)
|
||||
|
||||
logging.info(f"=== 全部处理完成!结果已成功保存到: {output_csv_path} ===")
|
||||
except IOError as e:
|
||||
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="MMSegmentation 最终指标提取脚本 (V2)")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含数据集输出文件夹 (例如 '..._outputs-MMSeg') 的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果的根目录。"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,366 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import re
|
||||
import csv
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
# TODO 这个是获取最后一次结果的 TODO
|
||||
|
||||
# --- 配置日志记录 ---
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 辅助函数 ---
|
||||
def find_all_log_files_sorted(algorithm_dir: str) -> List[str]:
|
||||
"""
|
||||
查找给定算法目录中所有的.log文件,并按从新到旧的顺序排列。
|
||||
|
||||
Args:
|
||||
algorithm_dir (str): 算法的根目录。
|
||||
|
||||
Returns:
|
||||
List[str]: 按时间倒序排列的日志文件路径列表。
|
||||
"""
|
||||
try:
|
||||
subdirs = [d for d in os.listdir(algorithm_dir) if os.path.isdir(os.path.join(algorithm_dir, d))]
|
||||
except FileNotFoundError:
|
||||
logging.error(f"算法目录不存在: {algorithm_dir}")
|
||||
return []
|
||||
|
||||
if not subdirs:
|
||||
logging.warning(f"在目录 {algorithm_dir} 中未找到任何时间戳子目录。")
|
||||
return []
|
||||
|
||||
# 按名称倒序排序,最新的目录会排在最前面
|
||||
sorted_subdirs = sorted(subdirs, reverse=True)
|
||||
|
||||
log_files = []
|
||||
for subdir_name in sorted_subdirs:
|
||||
subdir_path = os.path.join(algorithm_dir, subdir_name)
|
||||
logs_in_subdir = glob.glob(os.path.join(subdir_path, '*.log'))
|
||||
if logs_in_subdir:
|
||||
# 假设每个子目录只有一个log文件
|
||||
log_files.append(logs_in_subdir[0])
|
||||
|
||||
return log_files
|
||||
|
||||
def get_max_epochs(config_path: str) -> Optional[int]:
|
||||
"""
|
||||
从config.py文件中解析train_cfg字典以获取max_epochs的值。
|
||||
|
||||
Args:
|
||||
config_path (str): config.py文件的路径。
|
||||
|
||||
Returns:
|
||||
Optional[int]: max_epochs的值,如果找不到则返回None。
|
||||
"""
|
||||
if not os.path.exists(config_path):
|
||||
logging.error(f"配置文件不存在: {config_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 使用正则表达式查找 max_epochs
|
||||
match = re.search(r"train_cfg\s*=\s*dict\(.*?max_epochs\s*=\s*(\d+),.*?\)", content, re.DOTALL)
|
||||
|
||||
if match:
|
||||
max_epochs = int(match.group(1))
|
||||
logging.info(f"从 {os.path.basename(config_path)} 中成功读取 max_epochs: {max_epochs}")
|
||||
return max_epochs
|
||||
else:
|
||||
logging.warning(f"在 {config_path} 中未找到 'max_epochs'。")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"解析 {config_path} 时出错: {e}")
|
||||
return None
|
||||
|
||||
def parse_log_metrics(log_path: str, max_epochs: int) -> List[Dict]:
|
||||
"""
|
||||
解析日志文件,提取所有完整验证(validation)的结果,并计算其对应的Epoch。
|
||||
|
||||
Args:
|
||||
log_path (str): 日志文件的路径。
|
||||
max_epochs (int): 从config.py中读取的最大epoch数。
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含每次验证的 'epoch', 'summary', 'class_wise' 指标的字典列表。
|
||||
"""
|
||||
try:
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except IOError as e:
|
||||
logging.error(f"无法读取日志文件: {log_path}。错误: {e}")
|
||||
return []
|
||||
|
||||
# 定义所有需要的正则表达式
|
||||
summary_pattern = re.compile(
|
||||
r"Iter\(val\) \[\d+/\d+\]\s+aAcc:\s*([\d.]+)\s+mIoU:\s*([\d.]+)\s+mAcc:\s*([\d.]+)"
|
||||
)
|
||||
class_table_pattern = re.compile(
|
||||
r"\+(?:-+\+)+\s*\n"
|
||||
r"\|.*?Class.*?\|.*?IoU.*?\|.*?Acc.*?\|\n"
|
||||
r"\+(?:-+\+)+\s*\n"
|
||||
r"((?:\|.*?\|.*?\|.*?\|\n)+)"
|
||||
r"\+(?:-+\+)+\s*\n",
|
||||
re.MULTILINE
|
||||
)
|
||||
train_iter_pattern = re.compile(r"Iter\(train\)\s*\[\s*(\d+)\s*/\s*(\d+)\]")
|
||||
|
||||
# 查找所有匹配项
|
||||
summary_matches = list(re.finditer(summary_pattern, content))
|
||||
table_matches = list(re.finditer(class_table_pattern, content))
|
||||
train_iter_matches = list(re.finditer(train_iter_pattern, content))
|
||||
|
||||
all_metrics = []
|
||||
|
||||
# 遍历每一次的总结行 (summary)
|
||||
for i, summary_match in enumerate(summary_matches):
|
||||
# 寻找与总结行对应的类别表格
|
||||
# 表格应该出现在总结行之前
|
||||
last_table_match = None
|
||||
for table in reversed(table_matches):
|
||||
if table.end() < summary_match.start():
|
||||
# 确保这个表格没有被上一个总结行用过
|
||||
is_already_used = False
|
||||
if i > 0:
|
||||
if table.end() < summary_matches[i-1].start():
|
||||
is_already_used = True
|
||||
if not is_already_used:
|
||||
last_table_match = table
|
||||
break
|
||||
|
||||
if not last_table_match:
|
||||
continue
|
||||
|
||||
# 寻找表格前最近的 Iter(train) 行来计算epoch
|
||||
last_train_iter_match = None
|
||||
for train_iter in reversed(train_iter_matches):
|
||||
if train_iter.end() < last_table_match.start():
|
||||
last_train_iter_match = train_iter
|
||||
break
|
||||
|
||||
epoch = "N/A"
|
||||
if last_train_iter_match and max_epochs:
|
||||
current_iter, total_iters = last_train_iter_match.groups()
|
||||
try:
|
||||
# 根据公式计算epoch
|
||||
epoch = int(int(current_iter) / int(total_iters) * max_epochs)
|
||||
except (ValueError, ZeroDivisionError) as e:
|
||||
logging.warning(f"Epoch 计算失败: {e}")
|
||||
|
||||
# 解析总结指标
|
||||
summary_groups = summary_match.groups()
|
||||
results = {
|
||||
'epoch': epoch,
|
||||
'summary': {
|
||||
'aAcc': summary_groups[0],
|
||||
'mIoU': summary_groups[1],
|
||||
'mAcc': summary_groups[2]
|
||||
},
|
||||
'class_wise': []
|
||||
}
|
||||
|
||||
# 解析每个类别的数据
|
||||
table_content = last_table_match.group(1)
|
||||
row_pattern = re.compile(r"\|\s*([\w\s.-]+?)\s*\|\s*([\d.]+)\s*\|\s*([\d.]+)\s*\|")
|
||||
for line in table_content.strip().split('\n'):
|
||||
row_match = row_pattern.match(line)
|
||||
if row_match:
|
||||
class_name, iou, acc = row_match.groups()
|
||||
results['class_wise'].append({
|
||||
'Class': class_name.strip(),
|
||||
'IoU': iou,
|
||||
'Acc': acc
|
||||
})
|
||||
|
||||
if results['class_wise']:
|
||||
all_metrics.append(results)
|
||||
|
||||
if all_metrics:
|
||||
logging.info(f"✅ 成功从 {os.path.basename(log_path)} 中解析出 {len(all_metrics)} 组指标。")
|
||||
else:
|
||||
logging.warning(f"❌ 在日志 {os.path.basename(log_path)} 中未能找到完整的验证结果。")
|
||||
|
||||
return all_metrics
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# --- 交互式菜单 (这部分保持不变) ---
|
||||
all_dataset_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(input_root, '*_outputs-MMSeg')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not all_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(all_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = []
|
||||
selected_dataset_dir = None
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。")
|
||||
return
|
||||
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]]
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
return
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
return
|
||||
|
||||
# --- 开始处理选定的算法 ---
|
||||
csv_rows = []
|
||||
output_dataset_folder = ""
|
||||
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"\n--- 开始处理算法: {model_name} ---")
|
||||
|
||||
if not output_dataset_folder:
|
||||
dataset_key = model_name.split('-')[0]
|
||||
# ... (这部分路径映射逻辑保持不变)
|
||||
output_dataset_folder = f"{dataset_key}_outputs-MMSeg"
|
||||
|
||||
all_logs_sorted = find_all_log_files_sorted(model_dir)
|
||||
|
||||
if not all_logs_sorted:
|
||||
logging.warning(f"跳过算法 {model_name},因为未找到任何日志文件。")
|
||||
continue
|
||||
|
||||
# --- 新逻辑:获取max_epochs ---
|
||||
# 假设同一算法下所有训练的config是相同的,因此我们从最新的log对应的config读取
|
||||
latest_log_path = all_logs_sorted[0]
|
||||
config_path = os.path.join(os.path.dirname(latest_log_path), 'vis_data', 'config.py')
|
||||
max_epochs = get_max_epochs(config_path)
|
||||
if max_epochs is None:
|
||||
logging.error(f"无法为算法 {model_name} 找到 max_epochs,将跳过。")
|
||||
continue
|
||||
|
||||
# --- 新逻辑:聚合所有日志的所有指标 ---
|
||||
all_metrics_for_model = []
|
||||
for log_file_path in all_logs_sorted:
|
||||
logging.info(f"正在解析: {os.path.relpath(log_file_path)}")
|
||||
# 传递max_epochs
|
||||
metrics_from_log = parse_log_metrics(log_file_path, max_epochs)
|
||||
if metrics_from_log:
|
||||
all_metrics_for_model.extend(metrics_from_log)
|
||||
|
||||
if not all_metrics_for_model:
|
||||
logging.warning(f"❌❌❌跳过算法 {model_name},因为在其所有日志文件中都未能找到有效的指标。❌❌❌")
|
||||
continue
|
||||
|
||||
# --- 新逻辑:选择mIoU最高的记录 ---
|
||||
try:
|
||||
best_metric = max(all_metrics_for_model, key=lambda x: float(x['summary']['mIoU']))
|
||||
logging.info(f"找到了最佳指标: Epoch '{best_metric['epoch']}', mIoU: {best_metric['summary']['mIoU']}")
|
||||
except (ValueError, TypeError) as e:
|
||||
logging.error(f"为算法 {model_name} 寻找最佳mIoU时出错: {e}")
|
||||
continue
|
||||
|
||||
# --- 创建一个 "宽" 格式的行 ---
|
||||
summary = best_metric['summary']
|
||||
short_model_name = model_name.split('Alg_', 1)[1] if 'Alg_' in model_name else model_name
|
||||
row_data = {
|
||||
'Algorithm': short_model_name,
|
||||
'Epoch': best_metric['epoch'],
|
||||
'mIoU': summary['mIoU'],
|
||||
'mAcc': summary['mAcc'],
|
||||
'aAcc': summary['aAcc']
|
||||
}
|
||||
|
||||
for class_data in best_metric['class_wise']:
|
||||
class_name = class_data['Class'].replace(' ', '_')
|
||||
row_data[f'{class_name}_IoU'] = class_data['IoU']
|
||||
row_data[f'{class_name}_Acc'] = class_data['Acc']
|
||||
|
||||
csv_rows.append(row_data)
|
||||
|
||||
# --- 将结果写入 CSV 文件 (这部分保持不变) ---
|
||||
if not csv_rows:
|
||||
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
|
||||
return
|
||||
|
||||
base_fieldnames = ['Algorithm', 'Epoch', 'mIoU', 'mAcc', 'aAcc']
|
||||
first_row_keys = csv_rows[0].keys()
|
||||
class_fieldnames = sorted([key for key in first_row_keys if key not in base_fieldnames])
|
||||
final_fieldnames = base_fieldnames + class_fieldnames
|
||||
|
||||
final_output_dir = os.path.join(output_root, os.path.basename(selected_dataset_dir))
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
|
||||
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_metrics_summary_wide.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=final_fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(csv_rows)
|
||||
|
||||
logging.info(f"\n=== 全部处理完成!最佳结果已成功保存到: {output_csv_path} ===")
|
||||
except IOError as e:
|
||||
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="MMSegmentation 最终指标提取脚本 (V2)")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含数据集输出文件夹 (例如 '..._outputs-MMSeg') 的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果的根目录。"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,249 @@
|
||||
import os
|
||||
import glob
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import re
|
||||
|
||||
def get_model_family(model_name):
|
||||
"""
|
||||
根据模型名称提取模型族。
|
||||
例如: 'my_bisenetv1_r50' -> 'my_bisenetv1'
|
||||
'my_fast_scnn' -> 'my_fast_scnn'
|
||||
"""
|
||||
# 使用正则表达式匹配,将 _rXX 或 _dXX 等后缀去掉
|
||||
match = re.match(r'^(.*?)_r\d+$', model_name)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return model_name
|
||||
|
||||
def select_dataset(results_dir):
|
||||
"""
|
||||
扫描目录,让用户交互式选择一个数据集。
|
||||
"""
|
||||
print("正在扫描可用的数据集...")
|
||||
try:
|
||||
all_dataset_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(results_dir, '*_outputs-MMSeg')) if os.path.isdir(d)
|
||||
])
|
||||
except Exception as e:
|
||||
print(f"扫描目录 '{results_dir}' 时出错: {e}")
|
||||
return None, None
|
||||
|
||||
if not all_dataset_dirs:
|
||||
print(f"在 '{results_dir}' 中未找到任何数据集目录 (以 '_outputs-MMSeg' 结尾)。")
|
||||
print("请确保脚本与 'BestMode_Predict_Results_DataSet_Public' 文件夹在同一级目录下。")
|
||||
return None, None
|
||||
|
||||
print("\n请选择要可视化的数据集:")
|
||||
for i, dir_path in enumerate(all_dataset_dirs):
|
||||
dataset_name = os.path.basename(dir_path).replace('_outputs-MMSeg', '')
|
||||
print(f" [{i+1}] {dataset_name}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"\n请输入选项编号 (1-{len(all_dataset_dirs)}): ")
|
||||
choice_idx = int(choice) - 1
|
||||
if 0 <= choice_idx < len(all_dataset_dirs):
|
||||
selected_dir = all_dataset_dirs[choice_idx]
|
||||
dataset_name = os.path.basename(selected_dir).replace('_outputs-MMSeg', '')
|
||||
return selected_dir, dataset_name
|
||||
else:
|
||||
print("无效的选项,请输入列表中的编号。")
|
||||
except (ValueError, IndexError):
|
||||
print("无效的输入,请输入一个数字编号。")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n操作已取消。")
|
||||
return None, None
|
||||
|
||||
def plot_performance_speed(selected_dir, dataset_name):
|
||||
"""
|
||||
根据选定的数据集目录,加载数据并生成图表。
|
||||
"""
|
||||
print(f"\n正在为数据集 '{dataset_name}' 生成图表...")
|
||||
|
||||
# 构建文件路径
|
||||
metrics_file = os.path.join(selected_dir, f"{dataset_name}_metrics_summary_wide.csv")
|
||||
fps_file = os.path.join(selected_dir, f"{dataset_name}_flops_params_fps_summary.csv")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(metrics_file) or not os.path.exists(fps_file):
|
||||
print(f"错误: 在目录 '{selected_dir}' 中缺少所需的数据文件。")
|
||||
print(f" - 检查是否存在: {os.path.basename(metrics_file)}")
|
||||
print(f" - 检查是否存在: {os.path.basename(fps_file)}")
|
||||
return
|
||||
|
||||
# 加载数据
|
||||
try:
|
||||
metrics_df = pd.read_csv(metrics_file)
|
||||
# 只保留最新的epoch结果,避免重复
|
||||
metrics_df = metrics_df.sort_values('Epoch', ascending=False).drop_duplicates('Algorithm')
|
||||
|
||||
fps_df = pd.read_csv(fps_file)
|
||||
except FileNotFoundError as e:
|
||||
print(f"错误: 无法找到文件 {e.filename}")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"读取CSV文件时出错: {e}")
|
||||
return
|
||||
|
||||
# 合并两个DataFrame
|
||||
# metrics_df中的'Algorithm'列对应fps_df中的'Model'列
|
||||
merged_df = pd.merge(metrics_df, fps_df, left_on='Algorithm', right_on='Model')
|
||||
|
||||
if merged_df.empty:
|
||||
print("错误: 数据合并失败。请检查 'Algorithm' 和 'Model' 列中的模型名称是否匹配。")
|
||||
return
|
||||
|
||||
# 调用新函数来创建并保存摘要表格
|
||||
T1_create_and_save_summary_table(merged_df, selected_dir, dataset_name)
|
||||
|
||||
# 调用新函数来提取和保存所有IoU数据
|
||||
T2_extract_and_save_iou_data(metrics_df, selected_dir, dataset_name)
|
||||
|
||||
# 提取模型族
|
||||
merged_df['Family'] = merged_df['Model'].apply(get_model_family)
|
||||
|
||||
# --- 绘图 ---
|
||||
plt.style.use('seaborn-v0_8-whitegrid')
|
||||
fig, ax = plt.subplots(figsize=(16, 10))
|
||||
|
||||
# 定义颜色和标记
|
||||
families = sorted(merged_df['Family'].unique())
|
||||
palette = sns.color_palette("husl", len(families))
|
||||
markers = ['o', 's', 'X', 'D', '^', 'P', '*', 'v', '<', '>']
|
||||
|
||||
# 循环绘制每个模型族
|
||||
for i, family in enumerate(families):
|
||||
family_df = merged_df[merged_df['Family'] == family].sort_values('Average_FPS')
|
||||
color = palette[i]
|
||||
marker = markers[i % len(markers)]
|
||||
|
||||
# 绘制散点
|
||||
ax.scatter(family_df['Average_FPS'], family_df['mIoU'],
|
||||
color=color, marker=marker, s=150, label=family, zorder=3)
|
||||
|
||||
# 如果族内有多个模型,则用线连接
|
||||
if len(family_df) > 1:
|
||||
ax.plot(family_df['Average_FPS'], family_df['mIoU'],
|
||||
color=color, linestyle='--', linewidth=1.5, zorder=2)
|
||||
|
||||
# 在每个点旁边添加模型全名注释
|
||||
for j, row in family_df.iterrows():
|
||||
ax.text(row['Average_FPS'] * 1.01, row['mIoU'], row['Model'],
|
||||
fontsize=9, verticalalignment='center')
|
||||
|
||||
# 设置图表属性
|
||||
ax.set_title(f'Model Performance vs. Inference Speed ({dataset_name})', fontsize=18, pad=20)
|
||||
ax.set_xlabel('Inference Speed (FPS)', fontsize=14)
|
||||
ax.set_ylabel('Mean IoU (%)', fontsize=14)
|
||||
ax.legend(title='Model Family', bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.)
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 0.88, 1]) # 调整布局为图例留出空间
|
||||
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
|
||||
|
||||
# 保存并显示图表
|
||||
output_filename_png = f"F1_{dataset_name}_mIoU_vs_FPS.png"
|
||||
save_file_path_png = os.path.join(selected_dir, output_filename_png)
|
||||
plt.savefig(save_file_path_png, dpi=600)
|
||||
output_filename_svg = f"F1_{dataset_name}_mIoU_vs_FPS.svg"
|
||||
save_file_path_svg = os.path.join(selected_dir, output_filename_svg)
|
||||
plt.savefig(save_file_path_svg)
|
||||
print(f"\n图表已成功生成并保存为: {save_file_path_svg} 和 {save_file_path_png}")
|
||||
plt.show()
|
||||
|
||||
def T1_create_and_save_summary_table(merged_df, output_dir, dataset_name):
|
||||
"""
|
||||
根据合并后的数据创建、格式化并保存性能摘要表格。
|
||||
|
||||
Args:
|
||||
merged_df (pd.DataFrame): 包含所有模型指标和性能数据的DataFrame。
|
||||
output_dir (str): 保存CSV文件的目标目录。
|
||||
dataset_name (str): 数据集的名称,用于生成文件名。
|
||||
"""
|
||||
print("正在创建摘要表格...")
|
||||
|
||||
# 检查所需列是否存在
|
||||
required_columns = ['Model', 'mIoU', 'mAcc', 'aAcc', 'Average_FPS', 'FLOPs', 'Params']
|
||||
if not all(col in merged_df.columns for col in required_columns):
|
||||
print("错误: DataFrame中缺少必要的列。请检查CSV文件内容。")
|
||||
return
|
||||
|
||||
# 提取并复制数据,避免修改原始DataFrame
|
||||
summary_df = merged_df[required_columns].copy()
|
||||
|
||||
# 清理和转换数据
|
||||
# 将 '118 G' -> 118.0
|
||||
summary_df['FLOPs'] = summary_df['FLOPs'].astype(str).str.replace(' G', '', regex=False).astype(float)
|
||||
# 将 '13.274M' -> 13.274
|
||||
summary_df['Params'] = summary_df['Params'].astype(str).str.replace('M', '', regex=False).astype(float)
|
||||
|
||||
# 按照用户的要求重命名列
|
||||
summary_df.rename(columns={
|
||||
'Average_FPS': 'FPS',
|
||||
'FLOPs': 'G(GFLOPS)',
|
||||
'Params': 'Para(Params)'
|
||||
}, inplace=True)
|
||||
|
||||
# 按 mIoU 降序排序
|
||||
summary_df = summary_df.sort_values(by='mIoU', ascending=False)
|
||||
|
||||
# 保存表格到CSV文件
|
||||
summary_filename = f"T1_{dataset_name}_performance_summary.csv"
|
||||
summary_save_path = os.path.join(output_dir, summary_filename)
|
||||
|
||||
try:
|
||||
summary_df.to_csv(summary_save_path, index=False, float_format='%.3f')
|
||||
print(f"摘要表格已成功保存到: {summary_save_path}")
|
||||
except Exception as e:
|
||||
print(f"保存摘要表格时出错: {e}")
|
||||
|
||||
def T2_extract_and_save_iou_data(metrics_df, output_dir, dataset_name):
|
||||
"""
|
||||
从 metrics DataFrame 中提取所有 mIoU 和 Class_IoU,并保存到新的CSV文件。
|
||||
|
||||
Args:
|
||||
metrics_df (pd.DataFrame): 包含所有指标的原始DataFrame。
|
||||
output_dir (str): 保存CSV文件的目标目录。
|
||||
dataset_name (str): 数据集的名称,用于生成文件名。
|
||||
"""
|
||||
print("正在提取所有 mIoU 和 Class_IoU 数据...")
|
||||
|
||||
# 检查'Algorithm'列是否存在
|
||||
if 'Algorithm' not in metrics_df.columns:
|
||||
print("错误: 'Algorithm' 列未找到,无法继续。")
|
||||
return
|
||||
|
||||
# 找出所有与IoU相关的列
|
||||
# 包括 'mIoU' 以及所有以 '_IoU' 结尾的列
|
||||
iou_columns = ['Algorithm', 'mIoU'] + [col for col in metrics_df.columns if col.endswith('_IoU') and col != 'mIoU']
|
||||
|
||||
# 移除重复的列名(以防万一)
|
||||
iou_columns = list(dict.fromkeys(iou_columns))
|
||||
|
||||
# 提取数据
|
||||
iou_df = metrics_df[iou_columns].copy()
|
||||
|
||||
# 按 mIoU 降序排序,便于查看
|
||||
iou_df = iou_df.sort_values(by='mIoU', ascending=False)
|
||||
|
||||
# 定义并保存文件
|
||||
iou_filename = f"T2_{dataset_name}_all_iou_summary.csv"
|
||||
iou_save_path = os.path.join(output_dir, iou_filename)
|
||||
|
||||
try:
|
||||
iou_df.to_csv(iou_save_path, index=False, float_format='%.2f')
|
||||
print(f"所有IoU数据已成功保存到: {iou_save_path}")
|
||||
except Exception as e:
|
||||
print(f"保存IoU数据时出错: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 设置包含所有数据集结果的根目录
|
||||
results_root_dir = '../BestMode_Predict_Results_DataSet_Public'
|
||||
|
||||
# 启动交互式选择
|
||||
selected_directory, selected_dataset_name = select_dataset(results_root_dir)
|
||||
|
||||
# 如果用户成功选择,则生成图表
|
||||
if selected_directory and selected_dataset_name:
|
||||
plot_performance_speed(selected_directory, selected_dataset_name)
|
||||
@@ -0,0 +1,413 @@
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import argparse
|
||||
import re
|
||||
import csv
|
||||
from typing import Dict, Optional, List
|
||||
from collections import defaultdict
|
||||
|
||||
# --- 新增导入:用于绘图 ---
|
||||
try:
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
MATPLOTLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
MATPLOTLIB_AVAILABLE = False
|
||||
logging.warning(
|
||||
"未找到 'pandas' 或 'matplotlib' 库。"
|
||||
"脚本将只生成CSV文件,无法自动绘图。"
|
||||
"请运行 'pip install pandas matplotlib' 来安装它们。"
|
||||
)
|
||||
# --- 导入结束 ---
|
||||
|
||||
|
||||
# --- 配置日志记录 ---
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 辅助函数 (与 4_3 版本相同) ---
|
||||
|
||||
def find_all_log_files_sorted(algorithm_dir: str) -> List[str]:
|
||||
"""
|
||||
查找给定算法目录中所有的.log文件,并按从新到旧的顺序排列。
|
||||
"""
|
||||
try:
|
||||
subdirs = [d for d in os.listdir(algorithm_dir) if os.path.isdir(os.path.join(algorithm_dir, d))]
|
||||
except FileNotFoundError:
|
||||
logging.error(f"算法目录不存在: {algorithm_dir}")
|
||||
return []
|
||||
|
||||
if not subdirs:
|
||||
logging.warning(f"在目录 {algorithm_dir} 中未找到任何时间戳子目录。")
|
||||
return []
|
||||
|
||||
sorted_subdirs = sorted(subdirs, reverse=True)
|
||||
log_files = []
|
||||
for subdir_name in sorted_subdirs:
|
||||
subdir_path = os.path.join(algorithm_dir, subdir_name)
|
||||
logs_in_subdir = glob.glob(os.path.join(subdir_path, '*.log'))
|
||||
if logs_in_subdir:
|
||||
log_files.append(logs_in_subdir[0])
|
||||
|
||||
return log_files
|
||||
|
||||
def get_max_epochs(config_path: str) -> Optional[int]:
|
||||
"""
|
||||
从config.py文件中解析train_cfg字典以获取max_epochs的值。
|
||||
"""
|
||||
if not os.path.exists(config_path):
|
||||
logging.error(f"配置文件不存在: {config_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
match = re.search(r"train_cfg\s*=\s*dict\(.*?max_epochs\s*=\s*(\d+),.*?\)", content, re.DOTALL)
|
||||
|
||||
if match:
|
||||
max_epochs = int(match.group(1))
|
||||
logging.info(f"从 {os.path.basename(config_path)} 中成功读取 max_epochs: {max_epochs}")
|
||||
return max_epochs
|
||||
else:
|
||||
logging.warning(f"在 {config_path} 中未找到 'max_epochs'。")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"解析 {config_path} 时出错: {e}")
|
||||
return None
|
||||
|
||||
def parse_log_file_data(log_path: str, max_epochs: int) -> Dict:
|
||||
"""
|
||||
解析日志文件,提取所有训练损失和所有验证mIoU。
|
||||
"""
|
||||
try:
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except IOError as e:
|
||||
logging.error(f"无法读取日志文件: {log_path}。错误: {e}")
|
||||
return {'training_losses': {}, 'validation_mious': []}
|
||||
|
||||
training_losses_by_epoch = defaultdict(list)
|
||||
validation_mious = []
|
||||
|
||||
# 1. 提取训练损失
|
||||
total_iters_match = re.search(r"Iter\(train\)\s*\[\s*(\d+)\s*/\s*(\d+)\]", content)
|
||||
total_iters = 0
|
||||
if total_iters_match:
|
||||
try:
|
||||
total_iters = int(total_iters_match.group(2))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if total_iters > 0:
|
||||
train_loss_pattern = re.compile(
|
||||
r"Iter\(train\)\s*\[\s*(\d+)\s*/\s*(\d+)\]"
|
||||
r"(?:.*?)loss:\s*([\d\.]+)"
|
||||
)
|
||||
for match in re.finditer(train_loss_pattern, content):
|
||||
try:
|
||||
current_iter = int(match.group(1))
|
||||
loss = float(match.group(3))
|
||||
epoch = int((current_iter / total_iters) * max_epochs)
|
||||
training_losses_by_epoch[epoch].append(loss)
|
||||
except (ValueError, ZeroDivisionError) as e:
|
||||
logging.warning(f"解析训练损失时出错: {e}")
|
||||
else:
|
||||
logging.warning(f"在 {os.path.basename(log_path)} 中未找到有效的 'Iter(train)' 行来确定总迭代次数。")
|
||||
|
||||
# 2. 提取验证 mIoU
|
||||
val_summary_pattern = re.compile(
|
||||
r"Iter\(val\) \[\d+/\d+\]\s+aAcc:\s*[\d.]+\s+mIoU:\s*([\d.]+)\s+mAcc:\s*[\d.]+"
|
||||
)
|
||||
for match in re.finditer(val_summary_pattern, content):
|
||||
try:
|
||||
miou = float(match.group(1))
|
||||
validation_mious.append(miou)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
log_name = os.path.basename(log_path)
|
||||
if training_losses_by_epoch:
|
||||
logging.info(f"✅ 从 {log_name} 提取了 {len(training_losses_by_epoch)} 个Epoch的训练损失。")
|
||||
if validation_mious:
|
||||
logging.info(f"✅ 从 {log_name} 提取了 {len(validation_mious)} 次验证的mIoU。")
|
||||
if not training_losses_by_epoch and not validation_mious:
|
||||
logging.warning(f"❌ 在 {log_name} 中未找到训练损失或验证mIoU。")
|
||||
|
||||
return {
|
||||
'training_losses': training_losses_by_epoch,
|
||||
'validation_mious': validation_mious
|
||||
}
|
||||
|
||||
# --- 新增:从 4_4 脚本中合并过来的绘图函数 ---
|
||||
|
||||
def plot_loss_curves(csv_path: str):
|
||||
"""
|
||||
读取_training_loss_summary.csv文件并绘制训练损失曲线。
|
||||
(此版本已根据用户需求修改)
|
||||
|
||||
Args:
|
||||
csv_path (str): 输入的CSV文件路径。
|
||||
"""
|
||||
if not MATPLOTLIB_AVAILABLE:
|
||||
logging.warning("由于缺少 'pandas' 或 'matplotlib',跳过绘图。")
|
||||
return
|
||||
|
||||
if not os.path.exists(csv_path):
|
||||
logging.error(f"[绘图] 文件未找到: {csv_path}")
|
||||
return
|
||||
|
||||
logging.info(f"[绘图] 正在读取数据: {os.path.basename(csv_path)}")
|
||||
try:
|
||||
df = pd.read_csv(csv_path)
|
||||
except Exception as e:
|
||||
logging.error(f"[绘图] 读取CSV时出错: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
df = df.set_index('Algorithm')
|
||||
except KeyError:
|
||||
logging.error("[绘图] CSV文件中未找到 'Algorithm' 列。")
|
||||
return
|
||||
|
||||
loss_cols = [col for col in df.columns if col.startswith('Epoch_') and col.endswith('_Loss')]
|
||||
|
||||
if not loss_cols:
|
||||
logging.warning("[绘图] 在CSV中未找到任何 'Epoch_X_Loss' 列。")
|
||||
return
|
||||
|
||||
try:
|
||||
epochs = [int(col.split('_')[1]) for col in loss_cols]
|
||||
except (ValueError, IndexError) as e:
|
||||
logging.error(f"[绘图] 解析Epoch列名时出错: {e}。列名格式应为 'Epoch_N_Loss'。")
|
||||
return
|
||||
|
||||
df_losses = df[loss_cols].apply(pd.to_numeric, errors='coerce')
|
||||
|
||||
# --- 开始绘图 ---
|
||||
logging.info("[绘图] 正在生成图表...")
|
||||
|
||||
fig, ax = plt.subplots(figsize=(14, 8))
|
||||
|
||||
for alg_name, row in df_losses.iterrows():
|
||||
# --- 修改点 1 ---
|
||||
# 移除 marker 和 markersize,使用实线 (linestyle='-')
|
||||
ax.plot(epochs, row.values, label=alg_name, linestyle='-')
|
||||
|
||||
# --- 设置图表样式 ---
|
||||
ax.set_xlabel('Epoch', fontsize=12)
|
||||
ax.set_ylabel('Average Training Loss', fontsize=12)
|
||||
ax.set_title(f'Training Loss per Epoch\n(Source: {os.path.basename(csv_path)})', fontsize=14)
|
||||
ax.grid(True, linestyle=':', alpha=0.7)
|
||||
|
||||
# --- 修改点 2 ---
|
||||
# 将Y轴(纵轴)的范围设置为 0 到 5
|
||||
ax.set_ylim(bottom=0, top=5) # TODO TODO
|
||||
|
||||
# 将图例放在图表右侧外部
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left', title="Algorithms")
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 0.85, 1])
|
||||
|
||||
# --- 保存图表 ---
|
||||
output_png_path = os.path.splitext(csv_path)[0] + '.png'
|
||||
try:
|
||||
plt.savefig(output_png_path, dpi=150, bbox_inches='tight')
|
||||
logging.info(f"✅ 图表已成功保存到: {output_png_path}")
|
||||
except IOError as e:
|
||||
logging.error(f"[绘图] 保存图像时出错: {e}")
|
||||
finally:
|
||||
plt.close(fig) # 释放内存
|
||||
|
||||
# --- 主函数 (小幅修改以调用绘图) ---
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# --- 交互式菜单 (保持不变) ---
|
||||
all_dataset_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(input_root, '*_outputs-MMSeg')) if os.path.isdir(d)
|
||||
])
|
||||
if not all_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(all_dataset_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = []
|
||||
selected_dataset_dir = None
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。")
|
||||
return
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]]
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
return
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
return
|
||||
# --- 交互式菜单结束 ---
|
||||
|
||||
|
||||
# --- 处理逻辑 (保持不变) ---
|
||||
csv_rows = []
|
||||
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"\n--- 开始处理算法: {model_name} ---")
|
||||
|
||||
all_logs_sorted = find_all_log_files_sorted(model_dir)
|
||||
|
||||
if not all_logs_sorted:
|
||||
logging.warning(f"跳过算法 {model_name},因为未找到任何日志文件。")
|
||||
continue
|
||||
|
||||
latest_log_path = all_logs_sorted[0]
|
||||
config_path = os.path.join(os.path.dirname(latest_log_path), 'vis_data', 'config.py')
|
||||
max_epochs = get_max_epochs(config_path)
|
||||
if max_epochs is None:
|
||||
logging.error(f"无法为算法 {model_name} 找到 max_epochs,将跳过。")
|
||||
continue
|
||||
|
||||
all_losses_for_model = defaultdict(list)
|
||||
all_mious_for_model = []
|
||||
|
||||
for log_file_path in all_logs_sorted:
|
||||
logging.info(f"正在解析: {os.path.relpath(log_file_path)}")
|
||||
parsed_data = parse_log_file_data(log_file_path, max_epochs)
|
||||
for epoch, losses in parsed_data['training_losses'].items():
|
||||
all_losses_for_model[epoch].extend(losses)
|
||||
if parsed_data['validation_mious']:
|
||||
all_mious_for_model.extend(parsed_data['validation_mious'])
|
||||
|
||||
if not all_losses_for_model and not all_mious_for_model:
|
||||
logging.warning(f"❌❌❌跳过算法 {model_name},因为在其所有日志文件中都未能找到有效的训练或验证数据。❌❌❌")
|
||||
continue
|
||||
|
||||
best_miou = 0.0
|
||||
if all_mious_for_model:
|
||||
try:
|
||||
best_miou = max(all_mious_for_model)
|
||||
logging.info(f"找到了最佳 mIoU: {best_miou:.4f}")
|
||||
except (ValueError, TypeError) as e:
|
||||
logging.error(f"为算法 {model_name} 寻找最佳mIoU时出错: {e}")
|
||||
else:
|
||||
logging.warning(f"算法 {model_name} 没有找到任何mIoU数据。")
|
||||
|
||||
short_model_name = model_name.split('Alg_', 1)[1] if 'Alg_' in model_name else model_name
|
||||
row_data = {
|
||||
'Algorithm': short_model_name
|
||||
}
|
||||
|
||||
sorted_epochs = sorted(all_losses_for_model.keys())
|
||||
if not sorted_epochs:
|
||||
logging.warning(f"算法 {model_name} 没有找到任何训练损失数据。")
|
||||
|
||||
for epoch in sorted_epochs:
|
||||
losses = all_losses_for_model[epoch]
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
row_data[f'Epoch_{epoch}_Loss'] = f"{avg_loss:.4f}"
|
||||
|
||||
row_data['Best_mIoU'] = f"{best_miou:.4f}"
|
||||
csv_rows.append(row_data)
|
||||
# --- 处理逻辑结束 ---
|
||||
|
||||
|
||||
# --- 写入 CSV 文件 (修改后) ---
|
||||
if not csv_rows:
|
||||
logging.info("没有成功获取任何模型的统计数据,不生成 CSV 文件。")
|
||||
return
|
||||
|
||||
# 动态生成所有列名
|
||||
all_fieldnames_set = set()
|
||||
for row in csv_rows:
|
||||
all_fieldnames_set.update(row.keys())
|
||||
|
||||
base_fields = ['Algorithm']
|
||||
miou_field = ['Best_mIoU']
|
||||
loss_fields = [f for f in all_fieldnames_set if f.startswith('Epoch_')]
|
||||
|
||||
try:
|
||||
loss_fields.sort(key=lambda x: int(x.split('_')[1]))
|
||||
except (ValueError, IndexError):
|
||||
logging.error("排序Epoch列名时出错,将按字母顺序排序。")
|
||||
loss_fields.sort()
|
||||
|
||||
final_fieldnames = base_fields + loss_fields + miou_field
|
||||
|
||||
final_output_dir = os.path.join(output_root, os.path.basename(selected_dataset_dir))
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
dataset_name = os.path.basename(selected_dataset_dir).split('_outputs-MMSeg')[0]
|
||||
output_csv_path = os.path.join(final_output_dir, f'{dataset_name}_training_loss_summary.csv')
|
||||
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8-sig') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=final_fieldnames, extrasaction='ignore')
|
||||
writer.writeheader()
|
||||
writer.writerows(csv_rows)
|
||||
|
||||
logging.info(f"\n=== CSV文件已成功保存到: {output_csv_path} ===")
|
||||
|
||||
# --- 新增:调用绘图函数 ---
|
||||
# 仅在CSV成功写入后才尝试绘图
|
||||
plot_loss_curves(output_csv_path)
|
||||
# --- 新增结束 ---
|
||||
|
||||
except IOError as e:
|
||||
logging.error(f"无法写入 CSV 文件: {output_csv_path}。错误: {e}")
|
||||
except Exception as e_plot:
|
||||
# 捕获绘图时可能发生的其他错误
|
||||
logging.error(f"在主流程中调用绘图时出错: {e_plot}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description="MMSegmentation 训练损失提取与绘图脚本 (V3-Integrated)"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含数据集输出文件夹 (例如 '..._outputs-MMSeg') 的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果 (CSV和PNG) 的根目录。"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,28 @@
|
||||
from .Initial_Alg_Gen_Tool import get_var_from_file
|
||||
|
||||
# 生成 _base_ 变量,传入算法配置、数据集配置、schedule配置
|
||||
def generate_base_config(alg_file_name, dataset_file_name, schedule_file_name):
|
||||
if alg_file_name != None:
|
||||
base_config =[
|
||||
f'../_base_/models/{alg_file_name}.py',
|
||||
f'../_base_/datasets/{dataset_file_name}.py', #换成自己定义的数据集
|
||||
f'../_base_/default_runtime.py',
|
||||
f'../_base_/schedules/{schedule_file_name}.py'
|
||||
]
|
||||
else:
|
||||
base_config =[
|
||||
f'../_base_/datasets/{dataset_file_name}.py', #换成自己定义的数据集
|
||||
f'../_base_/default_runtime.py',
|
||||
f'../_base_/schedules/{schedule_file_name}.py'
|
||||
]
|
||||
|
||||
return base_config
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例用法
|
||||
alg_file_name = 'ann_r50-d8' # 算法根文件
|
||||
dataset_file_name = 'my_dataset_model' # 数据文件
|
||||
schedule_file_name = 'schedule_4k_check_400' # schedule文件
|
||||
|
||||
_base_ = generate_base_config(alg_file_name=alg_file_name, dataset_file_name=dataset_file_name, schedule_file_name=schedule_file_name)
|
||||
print(_base_)
|
||||
@@ -0,0 +1,19 @@
|
||||
from .Initial_Alg_Gen_Tool import get_var_from_file
|
||||
|
||||
# 单卡 norm_cfg = dict(type='BN')
|
||||
# 多卡 norm_cfg = dict(type='SyncBN')
|
||||
def generate_norm_cfg(GPU_num = 2):
|
||||
GPU_num = int(GPU_num)
|
||||
if GPU_num == 1:
|
||||
return dict(type='BN')
|
||||
elif GPU_num > 1:
|
||||
return dict(type='SyncBN')
|
||||
else:
|
||||
raise ValueError("GPU_num需要为大于等于1的整数数值")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例用法
|
||||
GPU_num = 1
|
||||
|
||||
norm_cfg = generate_norm_cfg(GPU_num=GPU_num)
|
||||
print(norm_cfg)
|
||||
@@ -0,0 +1,23 @@
|
||||
from .Initial_Alg_Gen_Tool import get_var_from_file
|
||||
|
||||
# crop_size 数据预处理是分割大小
|
||||
# crop_size = (512,512)
|
||||
def generate_data_preprocessor(crop_size=None, mean=None, std=None, bgr_to_rgb=False):
|
||||
|
||||
data_preprocessor = {}
|
||||
if crop_size != None:
|
||||
data_preprocessor['size']=crop_size
|
||||
if mean != None:
|
||||
data_preprocessor['mean']=mean
|
||||
if std != None:
|
||||
data_preprocessor['std']=std
|
||||
if bgr_to_rgb != None:
|
||||
data_preprocessor['bgr_to_rgb']=bgr_to_rgb
|
||||
return data_preprocessor
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例用法
|
||||
crop_size = (512,512)
|
||||
|
||||
data_preprocessor = generate_data_preprocessor(crop_size = crop_size)
|
||||
print(data_preprocessor)
|
||||
@@ -0,0 +1,119 @@
|
||||
from .Initial_Alg_Gen_Tool import get_var_from_file, format_dict
|
||||
|
||||
# 1. 修改pretrained
|
||||
# pretrained_pth = './My_Local_Model/open_mmlab/resnet50_v1c.pth')
|
||||
def generate_model_pretrained(pretrained_pth=None):
|
||||
pretrained = pretrained_pth
|
||||
if pretrained_pth == None:
|
||||
return None
|
||||
return pretrained
|
||||
|
||||
# 2. 修改backbone
|
||||
# depth=50
|
||||
def generate_model_backbone(depth=None):
|
||||
backbone = {}
|
||||
if depth != None:
|
||||
backbone['depth']=depth
|
||||
return backbone
|
||||
|
||||
# 3. 修改model_data_preprocessor
|
||||
# model_data_preprocessor='data_preprocessor'
|
||||
def generate_model_data_preprocessor(model_data_preprocessor=None):
|
||||
if model_data_preprocessor != None:
|
||||
model_data_preprocessor = model_data_preprocessor
|
||||
return model_data_preprocessor
|
||||
|
||||
# 4. 修改decode_head
|
||||
# num_classes='36' # 分割数
|
||||
# decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # 直接传入dict
|
||||
# align_corners=False是否需要角对齐
|
||||
def generate_model_decode_head(num_classes=None, decode_head_loss_decode_dict=None, align_corners=None):
|
||||
decode_head = {}
|
||||
if num_classes != None:
|
||||
decode_head['num_classes']=num_classes
|
||||
if decode_head_loss_decode_dict != None:
|
||||
decode_head['loss_decode']=decode_head_loss_decode_dict
|
||||
if align_corners != None:
|
||||
decode_head['align_corners']=align_corners
|
||||
|
||||
return decode_head
|
||||
|
||||
# 5. 修改auxiliary_head
|
||||
# num_classes='36' # 分割数
|
||||
# auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4) # 直接传入dict
|
||||
# align_corners=False是否需要角对齐
|
||||
def generate_model_auxiliary_head(num_classes=None, auxiliary_head_loss_decode_dict=None, align_corners=None):
|
||||
auxiliary_head = {}
|
||||
if num_classes != None:
|
||||
auxiliary_head['num_classes']=num_classes
|
||||
if auxiliary_head_loss_decode_dict != None:
|
||||
auxiliary_head['loss_decode']=auxiliary_head_loss_decode_dict
|
||||
if align_corners != None:
|
||||
auxiliary_head['align_corners']=align_corners
|
||||
|
||||
return auxiliary_head
|
||||
|
||||
# 6. 修改train_cfg
|
||||
def generate_model_train_cfg():
|
||||
train_cfg = {}
|
||||
|
||||
return train_cfg
|
||||
|
||||
# 6. 修改test_cfg
|
||||
# mode='slide'
|
||||
# crop_size = (767,767)
|
||||
# test_cfg_crop_div_stride = 1.5
|
||||
def generate_model_test_cfg(test_cfg_mode=None, crop_size=None, test_cfg_crop_div_stride=None):
|
||||
test_cfg = {}
|
||||
if test_cfg_mode != None:
|
||||
test_cfg['mode'] = test_cfg_mode
|
||||
if crop_size != None:
|
||||
test_cfg['crop_size'] = crop_size
|
||||
if test_cfg_crop_div_stride != None:
|
||||
test_cfg['stride'] = (int(crop_size[0]/test_cfg_crop_div_stride), int(crop_size[1]/test_cfg_crop_div_stride))
|
||||
|
||||
return test_cfg
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 1. 修改pretrained
|
||||
pretrained_pth = './My_Local_Model/open_mmlab/resnet50_v1c.pth'
|
||||
pretrained = generate_model_pretrained(pretrained_pth = pretrained_pth)
|
||||
|
||||
# 2. 修改backbone
|
||||
depth = 50
|
||||
backbone = generate_model_backbone(depth=depth)
|
||||
|
||||
# 3. 修改data_preprocessor
|
||||
model_data_preprocessor = 'data_preprocessor'
|
||||
model_data_preprocessor = generate_model_data_preprocessor(model_data_preprocessor=model_data_preprocessor)
|
||||
|
||||
# 4.5. 修改decode_head、auxiliary_head
|
||||
num_classes='36' # 分割数
|
||||
decode_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=1.0) # 直接传入dict
|
||||
align_corners=False # 是否需要角对齐
|
||||
auxiliary_head_loss_decode_dict=dict(type='DiceLoss', use_sigmoid=False, loss_weight=0.4)
|
||||
|
||||
decode_head = generate_model_decode_head(num_classes=num_classes, decode_head_loss_decode_dict=decode_head_loss_decode_dict, align_corners=align_corners)
|
||||
auxiliary_head = generate_model_auxiliary_head(num_classes=num_classes, auxiliary_head_loss_decode_dict=auxiliary_head_loss_decode_dict, align_corners=align_corners)
|
||||
|
||||
# 6. 修改train_cfg
|
||||
train_cfg = generate_model_train_cfg()
|
||||
|
||||
# 6. 修改test_cfg
|
||||
mode='slide'
|
||||
crop_size = (767,767)
|
||||
test_cfg_crop_div_stride = 1.5
|
||||
test_cfg = generate_model_test_cfg(mode=mode, crop_size=crop_size, test_cfg_crop_div_stride=test_cfg_crop_div_stride)
|
||||
|
||||
# 汇总为model
|
||||
model = dict(
|
||||
pretrained = pretrained,
|
||||
backbone = backbone,
|
||||
data_preprocessor = model_data_preprocessor,
|
||||
decode_head = decode_head,
|
||||
auxiliary_head = auxiliary_head,
|
||||
train_cfg = train_cfg,
|
||||
test_cfg = test_cfg,
|
||||
)
|
||||
model_ = format_dict(model)
|
||||
print(model_)
|
||||
@@ -0,0 +1,173 @@
|
||||
# optimizer(优化器设计)TODO
|
||||
# type_of_back_bone = "Vit"
|
||||
def generate_optim_wrapper(type_of_back_bone=None):
|
||||
|
||||
# optim_wrapper 配置-1
|
||||
optim_wrapper_1 = {
|
||||
'type': 'OptimWrapper', # 表示这是一个优化器包装器(OptimWrapper)
|
||||
'_delete_': True, # 通常用于在继承配置时删除旧的优化器配置,替换为新的优化器配置
|
||||
'optimizer': {
|
||||
'type': 'AdamW', # 优化器类型为 AdamW
|
||||
'lr': 0.0001, # 学习率,通常设定为一个较小的值
|
||||
'weight_decay': 0.0005 # 权重衰减系数
|
||||
},
|
||||
'clip_grad': {
|
||||
'max_norm': 1, # 梯度裁剪的最大范数
|
||||
'norm_type': 2 # L2 范数
|
||||
}
|
||||
}
|
||||
# optim_wrapper 配置-2
|
||||
optim_wrapper_2 = {
|
||||
'type': 'OptimWrapper', # 表示这是一个优化器包装器(OptimWrapper)
|
||||
'_delete_': True, # 通常用于在继承配置时删除旧的优化器配置,替换为新的优化器配置
|
||||
'optimizer': {
|
||||
'type': 'SGD', # 优化器类型为 SGD
|
||||
'lr': 0.05, # 学习率
|
||||
'weight_decay': 0.0005, # 权重衰减系数
|
||||
'momentum': 0.9
|
||||
},
|
||||
'clip_grad': {
|
||||
'max_norm': 1, # 梯度裁剪的最大范数
|
||||
'norm_type': 2 # L2 范数
|
||||
}
|
||||
}
|
||||
|
||||
optim_wrapper_list = [optim_wrapper_1, optim_wrapper_2]
|
||||
|
||||
# 打印所有可用的优化器选项
|
||||
while True:
|
||||
print("请选择 optim_wrapper (按 Enter 使用默认):")
|
||||
for i, scheduler in enumerate(optim_wrapper_list, 1):
|
||||
optimizer_type = scheduler['optimizer']['type']
|
||||
learning_rate = scheduler['optimizer']['lr']
|
||||
print(f"{i}. Optimizer: {optimizer_type}, LR: {learning_rate}")
|
||||
|
||||
# 如果是vit网络,则需要额外的paramwise_cfg
|
||||
if type_of_back_bone != None and (type_of_back_bone.lower() == 'vit' or type_of_back_bone.lower() == 'visiontransformer'):
|
||||
custom_keys={
|
||||
'pos_embed': dict(decay_mult=0.), # 位置嵌入(positional embeddings)。decay_mult=0. 意味着对这些嵌入不应用权重衰减
|
||||
'cls_token': dict(decay_mult=0.), # 是在某些模型(如 Transformer 或 BERT)中,用于分类任务的特定 token
|
||||
'norm': dict(decay_mult=0.) # 对归一化层的参数也禁用了权重衰减
|
||||
}
|
||||
optim_wrapper_list[i-1]['paramwise_cfg'] = dict(custom_keys)
|
||||
optim_wrapper_list[i-1]['_delete_'] = True
|
||||
|
||||
if type_of_back_bone != None and (type_of_back_bone.lower() == 'swin'):
|
||||
custom_keys={
|
||||
'absolute_pos_embed': dict(decay_mult=0.),
|
||||
'relative_position_bias_table': dict(decay_mult=0.),
|
||||
'norm': dict(decay_mult=0.)
|
||||
}
|
||||
optim_wrapper_list[i-1]['paramwise_cfg'] = dict(custom_keys)
|
||||
optim_wrapper_list[i-1]['_delete_'] = True
|
||||
|
||||
choice = input(f"输入 1 到 {len(optim_wrapper_list)} 进行选择(默认 1): ")
|
||||
|
||||
# 如果用户按下 Enter 或选择 1,默认返回 optim_wrapper_1
|
||||
if choice == '' or choice == '1':
|
||||
return optim_wrapper_list[0]
|
||||
elif choice.isdigit() and 1 <= int(choice) <= len(optim_wrapper_list):
|
||||
return optim_wrapper_list[int(choice) - 1]
|
||||
else:
|
||||
print(f"无效输入,请输入 1 到 {len(optim_wrapper_list)} 或按 Enter 使用默认值")
|
||||
|
||||
def generate_param_scheduler(train_type, train_time_or_epoch):
|
||||
"""
|
||||
根据训练模式(epoch或iteration)动态生成并选择学习率调度器。
|
||||
|
||||
Args:
|
||||
train_type (str): 训练模式, 'epoch' 或 'iteration'。
|
||||
train_time_or_epoch (int): 训练的总轮数(epochs)或总迭代次数(k)。
|
||||
"""
|
||||
# 1. 根据训练模式确定核心参数
|
||||
if train_type.lower() == 'epoch':
|
||||
by_epoch = True
|
||||
end_value = train_time_or_epoch
|
||||
warmup_end = 10 # 为epoch模式设置一个合理的10个epoch的预热期
|
||||
# 按比例调整MultiStepLR的milestones
|
||||
milestones = [int(end_value * 0.75), int(end_value * 0.9)]
|
||||
else: # 默认为 iteration 模式
|
||||
by_epoch = False
|
||||
end_value = train_time_or_epoch * 1000
|
||||
warmup_end = 1500 # iteration模式使用原有的1500次迭代预热
|
||||
# MultiStepLR的milestones, 同样适配总时长
|
||||
milestones = [int(end_value * 0.75), int(end_value * 0.9)]
|
||||
|
||||
# 2. 使用动态参数定义调度器模板
|
||||
# param_scheduler 配置-1
|
||||
param_scheduler_1 = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-6,
|
||||
by_epoch=by_epoch,
|
||||
begin=0,
|
||||
end=warmup_end),
|
||||
dict(
|
||||
type='PolyLR',
|
||||
power=0.9,
|
||||
begin=warmup_end,
|
||||
end=end_value,
|
||||
eta_min=1e-5,
|
||||
by_epoch=by_epoch)
|
||||
]
|
||||
# param_scheduler 配置-2
|
||||
param_scheduler_2 = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-6,
|
||||
by_epoch=by_epoch,
|
||||
begin=0,
|
||||
end=warmup_end),
|
||||
dict(
|
||||
type='PolyLR',
|
||||
power=1.0,
|
||||
begin=warmup_end,
|
||||
end=end_value,
|
||||
eta_min=0.0,
|
||||
by_epoch=by_epoch)
|
||||
]
|
||||
# param_scheduler 配置-3
|
||||
param_scheduler_3 = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=0.001,
|
||||
by_epoch=by_epoch,
|
||||
begin=0,
|
||||
end=warmup_end),
|
||||
dict(
|
||||
type='MultiStepLR',
|
||||
begin=warmup_end,
|
||||
end=end_value,
|
||||
milestones=milestones,
|
||||
by_epoch=by_epoch)
|
||||
]
|
||||
param_scheduler_list = [param_scheduler_1, param_scheduler_2, param_scheduler_3]
|
||||
|
||||
# 3. 打印并让用户选择 (这部分逻辑不变)
|
||||
while True:
|
||||
print("请选择 param_scheduler (学习率调度器):")
|
||||
for i, schedulers in enumerate(param_scheduler_list, 1):
|
||||
print(f"Scheduler {i}:")
|
||||
for scheduler in schedulers:
|
||||
# 提取字段并处理缺省情况
|
||||
scheduler_type = scheduler.get('type', '/')
|
||||
begin = scheduler.get('begin', '/')
|
||||
end = scheduler.get('end', '/')
|
||||
power = scheduler.get('power', '/')
|
||||
eta_min = scheduler.get('eta_min', '/')
|
||||
milestones_val = scheduler.get('milestones', '/')
|
||||
|
||||
# 根据类型决定显示哪些信息
|
||||
if scheduler_type == 'MultiStepLR':
|
||||
print(f" - {scheduler_type}: begin={begin}, end={end}, milestones={milestones_val}")
|
||||
else:
|
||||
print(f" - {scheduler_type}: begin={begin}, end={end}, power={power}, eta_min={eta_min}")
|
||||
|
||||
choice = input(f"输入 1 到 {len(param_scheduler_list)} 进行选择(默认 1): ").strip()
|
||||
|
||||
if choice == '' or choice == '1':
|
||||
return param_scheduler_list[0]
|
||||
elif choice.isdigit() and 1 <= int(choice) <= len(param_scheduler_list):
|
||||
return param_scheduler_list[int(choice) - 1]
|
||||
else:
|
||||
print(f"无效输入,请输入 1 到 {len(param_scheduler_list)} 或按 Enter 使用默认值\n")
|
||||
@@ -0,0 +1,59 @@
|
||||
# train_dataloader TODO
|
||||
def generate_train_dataloader(batch_size_default=None, num_workers_default=None):
|
||||
batch_size = None
|
||||
if batch_size_default != None:
|
||||
while True:
|
||||
user_input = input(f"请输入 batch size (默认为 {batch_size_default}): ")
|
||||
|
||||
# 如果用户没有输入内容,使用默认值
|
||||
if not user_input.strip():
|
||||
batch_size = batch_size_default
|
||||
break
|
||||
|
||||
# 尝试将输入转换为整数
|
||||
try:
|
||||
batch_size = int(user_input)
|
||||
if batch_size > 0:
|
||||
break # 输入正确,退出循环
|
||||
else:
|
||||
print("Batch size 必须是正整数,请重新输入。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入一个有效的整数。")
|
||||
print(f"将train_dataloader的batch_size设置为{batch_size}")
|
||||
|
||||
num_workers = None
|
||||
if num_workers_default != None:
|
||||
while True:
|
||||
user_input = input(f"请输入 num workers (默认为 {num_workers_default}): ")
|
||||
|
||||
# 如果用户没有输入内容,使用默认值
|
||||
if not user_input.strip():
|
||||
num_workers = num_workers_default
|
||||
break
|
||||
|
||||
# 尝试将输入转换为整数
|
||||
try:
|
||||
num_workers = int(user_input)
|
||||
if num_workers > 0:
|
||||
break # 输入正确,退出循环
|
||||
else:
|
||||
print("Num workers 必须是正整数,请重新输入。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入一个有效的整数。")
|
||||
print(f"将train_dataloader的num_workers设置为{num_workers}")
|
||||
|
||||
# 返回包含 batch_size 的字典
|
||||
train_dataloader = {}
|
||||
if num_workers == None:
|
||||
train_dataloader['batch_size'] = batch_size
|
||||
return train_dataloader, batch_size
|
||||
|
||||
if batch_size == None:
|
||||
train_dataloader['num_workers'] = num_workers
|
||||
return train_dataloader, num_workers
|
||||
|
||||
train_dataloader['batch_size'] = batch_size
|
||||
train_dataloader['num_workers'] = num_workers
|
||||
return train_dataloader, batch_size, num_workers
|
||||
|
||||
|
||||
@@ -0,0 +1,264 @@
|
||||
import ast, subprocess, os
|
||||
|
||||
# 从文件中获取特定变量信息,返回变量直和类型
|
||||
def get_var_from_file(filename, var_name="norm_cfg"):
|
||||
with open(filename, 'r', encoding='utf-8') as file:
|
||||
code = file.read()
|
||||
|
||||
# 解析文件的 AST 树
|
||||
tree = ast.parse(code)
|
||||
|
||||
# 遍历 AST 树,查找指定变量
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == var_name:
|
||||
# 将 AST 节点转换为 Python 对象
|
||||
var_value = ast.literal_eval(node.value)
|
||||
# 返回变量的值和类型
|
||||
return var_value, type(var_value)
|
||||
|
||||
# 如果没有找到指定的变量,返回 None 和 None
|
||||
return None, None
|
||||
|
||||
def get_var_from_py_file(file_path, var_name="auxiliary_head"):
|
||||
# 定义一个字典用于保存文件中的变量
|
||||
context = {}
|
||||
|
||||
# 读取并执行文件
|
||||
with open(file_path, 'r') as file:
|
||||
exec(file.read(), context)
|
||||
|
||||
# 获取 auxiliary_head 变量
|
||||
if var_name in context:
|
||||
return context[var_name]
|
||||
else:
|
||||
raise AttributeError(f"文件中没有找到 {var_name} 变量")
|
||||
|
||||
# 更新list的dict变量
|
||||
def update_list_dict_var(var, var_new):
|
||||
# 判断 var 和 var_new 的长度是否相等
|
||||
if len(var) != len(var_new):
|
||||
raise ValueError(f"var {len(var)} 和 var_new {len(var_new)} 的大小不相等,无法更新")
|
||||
|
||||
# 遍历 var_new,按索引更新 var
|
||||
for i, new_entry in enumerate(var_new):
|
||||
# 将 new_entry 中的键值覆盖 var[i] 中的相同键
|
||||
var[i].update(new_entry)
|
||||
return var
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例调用
|
||||
var_value, var_type = get_var_from_file('./configs/_base_/models', 'norm_cfg')
|
||||
if var_value is not None:
|
||||
print(f"变量名: norm_cfg\n值: {var_value}\n类型: {var_type}")
|
||||
else:
|
||||
print("未找到指定变量")
|
||||
|
||||
# 将文件以dict格式输出
|
||||
def format_dict(d, indent_level=1):
|
||||
formatted_items = []
|
||||
indent = ' ' * indent_level # 根据缩进级别生成空格
|
||||
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
# 如果值是字典,递归调用 format_dict_as_func 并增加缩进级别
|
||||
formatted_value = format_dict(value, indent_level + 1)
|
||||
formatted_items.append(f"{indent}{key}={formatted_value}")
|
||||
elif isinstance(value, str):
|
||||
# 如果是字符串,格式化时加引号
|
||||
formatted_items.append(f"{indent}{key}='{value}'")
|
||||
else:
|
||||
# 其他类型(如数值等)不加引号
|
||||
formatted_items.append(f"{indent}{key}={value}")
|
||||
|
||||
# 将所有键值对合成为多行的格式,并加上结尾逗号
|
||||
formatted_str = ",\n".join(formatted_items)
|
||||
|
||||
# 返回更美观的 dict() 的格式,保留缩进和换行
|
||||
return f"dict(\n{formatted_str},\n{' ' * (indent_level - 1)})"
|
||||
|
||||
# 所有格式正确输出
|
||||
def format_all_data_old(data, indent_level=0):
|
||||
indent = ' ' * indent_level # 根据缩进级别生成空格
|
||||
if isinstance(data, dict):
|
||||
formatted_items = []
|
||||
for key, value in data.items():
|
||||
formatted_value = format_all_data(value, indent_level + 1)
|
||||
formatted_items.append(f"{indent} {key}={formatted_value}")
|
||||
formatted_str = ",\n".join(formatted_items)
|
||||
return f"dict(\n{formatted_str},\n{indent})"
|
||||
|
||||
elif isinstance(data, list):
|
||||
formatted_items = [f"{indent} {format_all_data(item, indent_level + 1)}" for item in data]
|
||||
formatted_str = ",\n".join(formatted_items)
|
||||
return f"[\n{formatted_str},\n{indent}]"
|
||||
|
||||
elif isinstance(data, tuple):
|
||||
return f"{data}"
|
||||
|
||||
elif isinstance(data, str):
|
||||
# 如果字符串包含单引号,则使用双引号,否则使用单引号
|
||||
if "'" in data:
|
||||
return f'"{data}"'
|
||||
else:
|
||||
return f"'{data}'"
|
||||
|
||||
else:
|
||||
return str(data)
|
||||
|
||||
# 所有格式正确输出,加入键值中带有"."的处理
|
||||
def format_all_data(data, indent_level=0):
|
||||
indent = ' ' * indent_level # 根据缩进级别生成空格
|
||||
dot_key_items = [] # 用于存储带点号的键
|
||||
regular_items = [] # 用于存储普通键
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
formatted_value = format_all_data(value, indent_level + 1)
|
||||
# 如果键包含点号,收集到 dot_key_items 中
|
||||
if '.' in key:
|
||||
dot_key_items.append(f"('{key}', {formatted_value})")
|
||||
else:
|
||||
regular_items.append(f"{indent} {key}={formatted_value}")
|
||||
|
||||
# 生成带点号键的部分,并放在最前面
|
||||
if dot_key_items:
|
||||
dot_key_str = f"{indent} [{', '.join(dot_key_items)}]"
|
||||
else:
|
||||
dot_key_str = ""
|
||||
|
||||
# 生成普通键的部分
|
||||
regular_str = ",\n".join(regular_items)
|
||||
|
||||
# 返回最终的组合字符串,带点号键放在普通键前面
|
||||
if dot_key_str:
|
||||
return f"dict(\n{dot_key_str},\n{regular_str},\n{indent})"
|
||||
else:
|
||||
return f"dict(\n{regular_str},\n{indent})"
|
||||
|
||||
elif isinstance(data, list):
|
||||
formatted_items = [f"{indent} {format_all_data(item, indent_level + 1)}" for item in data]
|
||||
formatted_str = ",\n".join(formatted_items)
|
||||
return f"[\n{formatted_str},\n{indent}]"
|
||||
|
||||
elif isinstance(data, tuple):
|
||||
return f"{data}"
|
||||
|
||||
elif isinstance(data, str):
|
||||
# 如果字符串包含单引号,则使用双引号,否则使用单引号
|
||||
if "'" in data:
|
||||
return f'"{data}"'
|
||||
else:
|
||||
return f"'{data}'"
|
||||
|
||||
else:
|
||||
return str(data)
|
||||
|
||||
# # 批量将参数内容写入文件 # V1 传统版
|
||||
# def write_config_to_file(output_file, **kwargs):
|
||||
# """
|
||||
# 将传入的任意数量的参数写入指定文件,并格式化输出。
|
||||
|
||||
# :param output_file: 要写入的文件路径
|
||||
# :param kwargs: 任意数量的配置项,格式为 key=value
|
||||
# """
|
||||
# with open(output_file, 'w', encoding='utf-8') as file:
|
||||
# # 遍历 kwargs,将每个 key, value 写入文件
|
||||
# for key, value in kwargs.items():
|
||||
# file.write(f"{key} = {format_all_data(value)}\n\n")
|
||||
|
||||
# # 打印成功信息
|
||||
# print(f"\033[93m{output_file} file generated successfully\033[0m")
|
||||
|
||||
# 批量将参数内容写入文件 # V2 加入训练过程可视化 TODO
|
||||
def format_all_data(value):
|
||||
"""
|
||||
一个辅助函数,用于将 Python 对象格式化为适合写入配置文件的字符串。
|
||||
使用 repr() 可以确保字符串、列表、字典等都保持其 Python 语法格式。
|
||||
"""
|
||||
return repr(value)
|
||||
def write_config_to_file(output_file, **kwargs):
|
||||
"""
|
||||
将传入的任意数量的参数以及一个自动生成的 visualizer 配置写入指定文件。
|
||||
|
||||
:param output_file: 要写入的文件路径 (e.g., 'configs/exp1.py')
|
||||
:param kwargs: 任意数量的配置项,格式为 key=value
|
||||
"""
|
||||
# --- 1. 从 output_file 路径中提取实验名称 ---
|
||||
# 首先获取基本文件名 (e.g., 'exp1.py')
|
||||
base_name = os.path.basename(output_file)
|
||||
# 然后去掉文件扩展名,得到纯净的实验名 (e.g., 'exp1')
|
||||
experiment_name, _ = os.path.splitext(base_name)
|
||||
|
||||
# --- 2. 构建 visualizer 配置字典 ---
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
dict(type='TensorboardVisBackend'),
|
||||
dict(
|
||||
type='WandbVisBackend',
|
||||
init_kwargs=dict(
|
||||
project='Seg_MMSeg_Test', # 你的 wandb 项目名称
|
||||
name=experiment_name # 使用上面提取的文件名作为实验名
|
||||
)
|
||||
)
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=vis_backends
|
||||
)
|
||||
|
||||
# --- 3. 将自动生成的 visualizer 添加到要写入的内容中 ---
|
||||
# 如果 kwargs 中已经有 'visualizer',它将会被新的配置覆盖
|
||||
kwargs['vis_backends'] = vis_backends
|
||||
kwargs['visualizer'] = visualizer
|
||||
|
||||
# --- 4. 将所有配置项写入文件 ---
|
||||
with open(output_file, 'w', encoding='utf-8') as file:
|
||||
# 遍历所有配置项(包括新加入的 visualizer),写入文件
|
||||
for key, value in kwargs.items():
|
||||
# 使用 format_all_data 保证格式正确
|
||||
file.write(f"{key} = {format_all_data(value)}\n\n")
|
||||
|
||||
# 打印成功信息
|
||||
print(f"\033[93mConfiguration saved to '{output_file}' successfully.\033[0m")
|
||||
print(f"\033[96mWandB experiment name will be: '{experiment_name}'\033[0m")
|
||||
|
||||
# 获取系统中所有 GPU 的可用显存信息。
|
||||
def get_gpu_info():
|
||||
"""
|
||||
获取系统中所有 GPU 的可用显存信息。
|
||||
:return: GPU 显存信息的列表,列表中的每一项是一个 (GPU编号, 剩余显存) 元组
|
||||
"""
|
||||
try:
|
||||
# 使用 nvidia-smi 命令获取 GPU 显存信息
|
||||
result = subprocess.run(
|
||||
['nvidia-smi', '--query-gpu=index,memory.free', '--format=csv,noheader,nounits'],
|
||||
stdout=subprocess.PIPE,
|
||||
encoding='utf-8'
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
gpu_info = []
|
||||
lines = result.stdout.strip().split('\n')
|
||||
for line in lines:
|
||||
index, memory_free = line.split(',')
|
||||
gpu_info.append((int(index.strip()), int(memory_free.strip())))
|
||||
|
||||
return gpu_info
|
||||
except FileNotFoundError:
|
||||
print("\033[91mError: nvidia-smi 命令未找到,请确保 NVIDIA 驱动正确安装。\033[0m")
|
||||
return []
|
||||
|
||||
# 批量生成dict
|
||||
def create_dict_by_kwargs(**kwargs):
|
||||
# 批量生成字典,kwargs 会自动收集所有传入的命名参数
|
||||
return {key: value for key, value in kwargs.items() if value is not None}
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例字典
|
||||
my_dict = {'pretrained': './My_Local_Model/open_mmlab/resnet50_v1c.pth'}
|
||||
|
||||
# 打印成 dict() 的格式
|
||||
print(format_dict(my_dict))
|
||||
@@ -0,0 +1,236 @@
|
||||
# 选择crop_size大小
|
||||
def select_crop_size(predefined_options=[(512, 512), (512, 1024), (769, 769)]):
|
||||
"""
|
||||
让用户选择裁剪大小。默认选择 (512, 512),用户可以选择其他预定义的裁剪大小,
|
||||
或者输入自定义的大小,且自定义的数字不能为负数。
|
||||
|
||||
:return: 选择的裁剪大小 (width, height)
|
||||
"""
|
||||
# 预定义的裁剪大小选项
|
||||
predefined_options = {str(i+1): option for i, option in enumerate(predefined_options)}
|
||||
|
||||
# 显示可选的裁剪大小
|
||||
print("可用的裁剪大小选项:")
|
||||
for key, value in predefined_options.items():
|
||||
print(f"{key}. {value}", end=" ")
|
||||
print(f"{len(predefined_options)+1}. 自定义大小")
|
||||
|
||||
# 用户选择
|
||||
choice = input("请选择裁剪大小选项 (默认 1): ").strip()
|
||||
|
||||
# 如果用户没有输入,使用默认值
|
||||
if choice == "" or choice == "1":
|
||||
return predefined_options["1"]
|
||||
|
||||
# 如果用户选择了预定义的选项
|
||||
if choice in predefined_options:
|
||||
return predefined_options[choice]
|
||||
|
||||
# 如果用户选择自定义大小
|
||||
if choice == f"{len(predefined_options)+1}":
|
||||
while True:
|
||||
try:
|
||||
width = int(input("请输入裁剪宽度 (正整数): "))
|
||||
height = int(input("请输入裁剪高度 (正整数): "))
|
||||
if width > 0 and height > 0:
|
||||
return (width, height)
|
||||
else:
|
||||
print("宽度和高度必须是正整数。")
|
||||
except ValueError:
|
||||
print("输入无效,请输入有效的正整数。")
|
||||
|
||||
# 如果用户输入无效,返回默认值
|
||||
print("无效选择,使用默认值 (512, 512)")
|
||||
return predefined_options["1"]
|
||||
|
||||
# 大字典,包含所有模型的信息
|
||||
pretrained_models_dict = {
|
||||
"openmmlab/resnet18_v1c": {'pth': './My_Local_Model/open_mmlab/resnet18_v1c.pth', 'depth': 18, 'type':'ResNetV1c'},
|
||||
"openmmlab/resnet50_v1c": {'pth': './My_Local_Model/open_mmlab/resnet50_v1c.pth', 'depth': 50, 'type':'ResNetV1c'},
|
||||
"openmmlab/resnet101_v1c": {'pth': './My_Local_Model/open_mmlab/resnet101_v1c.pth', 'depth': 101, 'type':'ResNetV1c'},
|
||||
"torchvision://resnet18": {'pth': './My_Local_Model/torchvision_012/resnet18.pth', 'depth': 18, 'type':'ResNet'},
|
||||
"torchvision://resnet50": {'pth': './My_Local_Model/torchvision_012/resnet50.pth', 'depth': 50, 'type':'ResNet'},
|
||||
"torchvision://resnet101": {'pth': './My_Local_Model/torchvision_012/resnet101.pth', 'depth': 101, 'type':'ResNet'},
|
||||
"openmmlab/pidnet-s":{'pth': './My_Local_Model/open_mmlab/pidnet-s.pth', 'type':'pidnet', 'size':'small'},
|
||||
"openmmlab/pidnet-m":{'pth': './My_Local_Model/open_mmlab/pidnet-m.pth', 'type':'pidnet', 'size':'medium'},
|
||||
"openmmlab/pidnet-l":{'pth': './My_Local_Model/open_mmlab/pidnet-l.pth', 'type':'pidnet', 'size':'large'},
|
||||
"openmmlab/ddrnet23-s":{'pth': './My_Local_Model/open_mmlab/ddrnet23-s.pth', 'type':'ddrnet', 'size':'small'},
|
||||
"openmmlab/ddrnet23":{'pth': './My_Local_Model/open_mmlab/ddrnet23.pth', 'type':'ddrnet', 'size':'normal'},
|
||||
"openmmlab/stdc1":{'pth': './My_Local_Model/open_mmlab/stdc1.pth', 'type':'stdc', 'size':'V1'},
|
||||
"openmmlab/stdc2":{'pth': './My_Local_Model/open_mmlab/stdc2.pth', 'type':'stdc', 'size':'V2'},
|
||||
"pretrain/vit-b16_p16_224-80ecf9dd.pth":{'pth': './My_Local_Model/pretrain/vit-b16_p16_224-80ecf9dd.pth'},
|
||||
"pretrain/beit_base_patch16_224_pt22k_ft22k.pth":{'pth': './My_Local_Model/pretrain/beit_base_patch16_224_pt22k_ft22k.pth'},
|
||||
"pretrain/beit_large_patch16_224_pt22k_ft22k.pth":{'pth': './My_Local_Model/pretrain/beit_large_patch16_224_pt22k_ft22k.pth'},
|
||||
"pretrain/swin_large-d5bdebaf.pth":{'pth': './My_Local_Model/pretrain/swin_large_patch4_window7_224_22k_20220308-d5bdebaf.pth', 'type':'swin', 'size':'large'},
|
||||
"pretrain/swin_tiny-f41b89d3.pth":{'pth': './My_Local_Model/pretrain/swin_tiny_patch4_window7_224_20220308-f41b89d3.pth', 'type':'swin', 'size':'tiny'},
|
||||
"mae_pretrain_vit_base_mmcls.pth":{'pth': './My_Local_Model/pretrain/mae_pretrain_vit_base_mmcls.pth'},
|
||||
"open-mmlab://msra/hrnetv2_w18":{'pth': './My_Local_Model/open_mmlab/msra/hrnetv2_w18.pth'},
|
||||
"open-mmlab://msra/hrnetv2_w18_small":{'pth': './My_Local_Model/open_mmlab/msra/hrnetv2_w18_small.pth'},
|
||||
'open-mmlab://msra/hrnetv2_w48':{'pth': './My_Local_Model/open_mmlab/msra/hrnetv2_w48.pth'},
|
||||
"pretrain/swin_large-6580f57d.pth":{'pth': './My_Local_Model/pretrain/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth', 'type':'swin', 'size':'large'}, # mask2former
|
||||
"pretrain/swin_base-e5c09f74.pth":{'pth': './My_Local_Model/pretrain/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth', 'type':'swin', 'size':'base'}, # mask2former
|
||||
"pretrain/swin_small-7ba6d6dd.pth":{'pth': './My_Local_Model/pretrain/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth', 'type':'swin', 'size':'small'}, # mask2former
|
||||
"pretrain/swin_tiny-1cdeb081.pth":{'pth': './My_Local_Model/pretrain/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth', 'type':'swin', 'size':'tiny'}, # mask2former
|
||||
# 这里可以包含更多模型信息...
|
||||
}
|
||||
|
||||
# 选择预训练模型
|
||||
def select_pretrained_model(model_list, need_select_pretrained = False, pretrained_models_dict = pretrained_models_dict):
|
||||
"""
|
||||
让用户从给定的模型列表中选择预训练模型、是否选择预训练(否-默认开启预训练),并返回对应的 pth 路径和 其他 信息。
|
||||
|
||||
:param model_list: 可用的模型名称列表,例如 ['openmmlab/resnet50_v1c', 'openmmlab/resnet101_v1c']
|
||||
:return: (pretrained_pth, depth)
|
||||
"""
|
||||
|
||||
# 过滤传入的模型列表,确保它们在字典中有信息
|
||||
valid_models = {key: pretrained_models_dict[key] for key in model_list if key in pretrained_models_dict}
|
||||
|
||||
if not valid_models:
|
||||
print("错误:提供的模型列表中没有有效的模型信息。")
|
||||
return None, None
|
||||
|
||||
# 显示可用的预训练模型
|
||||
print("可用的预训练模型类型:")
|
||||
for i, (model_name, model_info) in enumerate(valid_models.items(), 1):
|
||||
print(f"{i}. {model_name}")
|
||||
|
||||
# 用户选择
|
||||
choice = input(f"请选择预训练模型编号 (1-{len(valid_models)}, 默认 1): ").strip()
|
||||
|
||||
# 如果用户没有输入,或输入无效,使用默认值
|
||||
if not choice.isdigit() or not (1 <= int(choice) <= len(valid_models)):
|
||||
choice = "1"
|
||||
|
||||
# 获取用户选择的模型信息
|
||||
selected_model_name = list(valid_models.keys())[int(choice) - 1]
|
||||
selected_model_info = valid_models[selected_model_name]
|
||||
|
||||
# TODO 选择特定信息
|
||||
# # 返回模型的 pth 路径和 depth
|
||||
pretrained_pth = selected_model_info['pth']
|
||||
# depth = selected_model_info['depth']
|
||||
|
||||
# print(f" 已选择模型: {selected_model_name} depth: {depth} pth: {pretrained_pth}")
|
||||
print(f" 已选择模型: {selected_model_name} pth: {pretrained_pth}")
|
||||
|
||||
if need_select_pretrained == False:
|
||||
print("默认开启预训练")
|
||||
select_pretrained = True
|
||||
return selected_model_name, select_pretrained, pretrained_pth, selected_model_info
|
||||
else:
|
||||
# 提示用户输入 True 或 False,或者直接按 Enter 默认使用预训练模型
|
||||
choice = input("是否使用预训练模型?输入 Y(使用)或 N(不使用),直接按 Enter 默认使用:True:")
|
||||
while True:
|
||||
# 如果用户没有输入,默认使用预训练模型
|
||||
if choice == '':
|
||||
select_pretrained = True
|
||||
break
|
||||
# 转换输入为布尔值
|
||||
elif choice.lower() == 'y':
|
||||
select_pretrained = True
|
||||
break
|
||||
elif choice.lower() == 'n':
|
||||
select_pretrained = False
|
||||
break
|
||||
else:
|
||||
print("无效输入,请输入 'True' 或 'False',或直接按 Enter 选择默认值")
|
||||
|
||||
return selected_model_name, select_pretrained, pretrained_pth, selected_model_info
|
||||
|
||||
# 大字典,包含所有模型的信息
|
||||
samplers_dict = {
|
||||
"OHEMPixelSampler": dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||
# 这里可以包含更多采样函数模型信息...
|
||||
}
|
||||
|
||||
# 选择采样函数
|
||||
def select_sampler(sampler_list, samplers_dict = samplers_dict):
|
||||
|
||||
# 提示用户输入 True 或 False,或者直接按 Enter 默认使用预训练模型
|
||||
choice = input("是否使用采样函数?输入 Y(使用)或 N(不使用),直接按 Enter 默认不使用:False:")
|
||||
while True:
|
||||
# 如果用户没有输入,默认使用预训练模型
|
||||
if choice == '':
|
||||
use_sampler = False
|
||||
return None, use_sampler, None
|
||||
# 转换输入为布尔值
|
||||
elif choice.lower() == 'y':
|
||||
use_sampler = True
|
||||
break
|
||||
elif choice.lower() == 'n':
|
||||
use_sampler = False
|
||||
return None, use_sampler, None
|
||||
else:
|
||||
print("无效输入,请输入 'True' 或 'False',或直接按 Enter 选择默认值")
|
||||
|
||||
# 过滤传入的模型列表,确保它们在字典中有信息
|
||||
valid_samplers = {key: samplers_dict[key] for key in sampler_list if key in samplers_dict}
|
||||
|
||||
if not valid_samplers:
|
||||
print("错误:提供的采样函数列表中没有有效的采样函数信息。")
|
||||
return None, None
|
||||
|
||||
# 如果只有一个可用的采样函数,直接选择
|
||||
if len(valid_samplers) != 1:
|
||||
# 显示可用的采样函数
|
||||
print("可用的采样函数:")
|
||||
for i, (sampler_name, sampler_info) in enumerate(valid_samplers.items(), 1):
|
||||
print(f"{i}. {sampler_name}")
|
||||
|
||||
# 用户选择
|
||||
choice = input(f"请选择采样函数编号 (1-{len(valid_samplers)}, 默认 1): ").strip()
|
||||
|
||||
# 如果用户没有输入,或输入无效,使用默认值
|
||||
if not choice.isdigit() or not (1 <= int(choice) <= len(valid_samplers)):
|
||||
choice = "1"
|
||||
else:
|
||||
choice = "1"
|
||||
|
||||
# 获取用户选择的模型信息
|
||||
selected_sampler_name = list(valid_samplers.keys())[int(choice) - 1]
|
||||
selected_sampler_info = valid_samplers[selected_sampler_name]
|
||||
|
||||
print(f" 已选择采样函数: {selected_sampler_name}")
|
||||
|
||||
return selected_sampler_name, use_sampler, selected_sampler_info
|
||||
|
||||
# 选择test_cfg中是否滑动,是否默认选择select_slide
|
||||
def select_test_cfg_slide(crop_size, select_slide=False):
|
||||
"""
|
||||
让用户选择是否使用滑动窗口,并根据选择设置相应的模式和参数。
|
||||
|
||||
:param crop_size: 输入的裁剪大小 (宽, 高)
|
||||
:return: test_cfg_mode, test_cfg_crop_div_stride, crop_size
|
||||
"""
|
||||
if select_slide == False:
|
||||
# 提示用户是否选择滑动窗口模式
|
||||
use_slide = input("是否选择滑动窗口模式?(y/n, 默认 n): ").strip().lower()
|
||||
else:
|
||||
use_slide = 'y' # 使用滑动窗口模式
|
||||
|
||||
# 默认不使用滑动窗口模式
|
||||
if use_slide == 'y':
|
||||
test_cfg_mode = 'slide'
|
||||
# 提示用户输入 test_cfg_crop_div_stride,默认值为 1.5
|
||||
try:
|
||||
test_cfg_crop_div_stride = input("请输入滑动窗口的 stride 除以 crop_size 比例 (默认 1.5): ").strip()
|
||||
test_cfg_crop_div_stride = float(test_cfg_crop_div_stride) if test_cfg_crop_div_stride else 1.5
|
||||
except ValueError:
|
||||
print("输入无效,使用默认比例 1.5")
|
||||
test_cfg_crop_div_stride = 1.5
|
||||
|
||||
# 计算 stride
|
||||
stride = tuple(int(c / test_cfg_crop_div_stride) for c in crop_size)
|
||||
print(f" 已选择滑动窗口模式: {test_cfg_mode}")
|
||||
print(f"crop_size: {crop_size}")
|
||||
print(f"stride: {stride}")
|
||||
else:
|
||||
test_cfg_mode = None # 默认模式
|
||||
test_cfg_crop_div_stride = None
|
||||
stride = None
|
||||
print(f" 关闭滑动窗口模式")
|
||||
|
||||
return test_cfg_mode, test_cfg_crop_div_stride
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from PIL import Image
|
||||
import os
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
def calculate_pic_std_and_mean(dataset_dir = r'./My_Data/A_Ori'):
|
||||
# 获取所有jpg图像文件
|
||||
image_files = [os.path.join(dataset_dir, filename) for filename in os.listdir(dataset_dir) if filename.lower().endswith(('.jpg', '.png', '.tiff', '.jpeg', '.bmp'))]
|
||||
|
||||
# 初始化用于存储累积的像素值
|
||||
sum_pixels_normalized = np.zeros(3)
|
||||
sum_squared_pixels_normalized = np.zeros(3)
|
||||
num_pixels = 0
|
||||
|
||||
# 使用tqdm创建一个进度条
|
||||
for image_file in tqdm(image_files, desc="Calculating mean and std"):
|
||||
image = Image.open(image_file).convert('RGB') # 确保图像为RGB
|
||||
image = np.array(image) # 原图像像素范围[0, 255]
|
||||
|
||||
# 归一化到[0, 1]范围
|
||||
image_normalized = image / 255.0
|
||||
|
||||
# 累积归一化像素值和归一化像素平方值
|
||||
sum_pixels_normalized += np.sum(image_normalized, axis=(0, 1)) # 按通道累积
|
||||
sum_squared_pixels_normalized += np.sum(image_normalized ** 2, axis=(0, 1)) # 按通道累积像素平方值
|
||||
num_pixels += image.shape[0] * image.shape[1] # 累积总像素数
|
||||
|
||||
# 计算整个数据集的归一化后的均值
|
||||
mean_normalized = sum_pixels_normalized / num_pixels
|
||||
|
||||
# 计算整个数据集的归一化后的标准差
|
||||
variance_normalized = sum_squared_pixels_normalized / num_pixels - mean_normalized ** 2
|
||||
variance_normalized = np.maximum(variance_normalized, 0) # 防止负数
|
||||
std_normalized = np.sqrt(variance_normalized)
|
||||
|
||||
# 反归一化回[0, 255]范围
|
||||
mean = mean_normalized * 255
|
||||
std = std_normalized * 255
|
||||
|
||||
print(f"\033[93m计算得图片均值-Mean: {mean} 计算得图片方差-Std: {std}\033[0m")
|
||||
|
||||
return mean, std
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例调用
|
||||
calculate_pic_std_and_mean()
|
||||
@@ -0,0 +1,121 @@
|
||||
import os
|
||||
|
||||
def generate_configs_base_datasets_my_dataset_file(
|
||||
output_file='./configs/_base_/datasets/my_dataset_model.py',
|
||||
dataset_class_name='MyDataset_model',
|
||||
data_root='/home/audience/Desktop/Seg_data/Data',
|
||||
img_scale=(1920, 1080),
|
||||
crop_size=(512, 512),
|
||||
train_batch_size=4,
|
||||
train_num_workers=4,
|
||||
val_and_test_batch_size=1,
|
||||
val_and_test_num_workers=4,
|
||||
train_img_path='A_Ori',
|
||||
train_seg_map_path='A_Label_GT_label_fold',
|
||||
val_img_path='A_Ori',
|
||||
val_seg_map_path='A_Label_GT_label_fold',
|
||||
test_img_path='A_Ori',
|
||||
test_seg_map_path='A_Label_GT_label_fold',
|
||||
):
|
||||
# 定义模板
|
||||
dataset_config_template = f"""# dataset settings
|
||||
dataset_class_name = '{dataset_class_name}' # TODO 上一步中你定义的数据集的名字
|
||||
data_root = '{data_root}' # TODO 数据集存储路径
|
||||
# img_norm_cfg = dict(
|
||||
# mean=[33.30, 35.03, 47.23], std=[48.00, 50.4, 60.51], to_rgb=True) # TODO 数据集的均值和标准差,空引用默认的,也可以网上搜代码计算
|
||||
img_scale = {img_scale} # img_scale图像尺寸 TODO (1920,1080)
|
||||
crop_size = {crop_size} # 数据增强时裁剪的大小 TODO 之后可以修改
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'), # ", reduce_zero_label=False" TODO 是否忽略0直选项
|
||||
dict(type='RandomResize', scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
# dict(type='GenerateEdge', edge_width=4), # For pidnet
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=img_scale, keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
train_dataloader = dict( # Train dataloader config
|
||||
batch_size={train_batch_size}, # Batch size of a single GPU TODO
|
||||
num_workers={train_num_workers}, # Worker to pre-fetch data for each single GPU
|
||||
persistent_workers=True, # Shut down the worker processes after an epoch end, which can accelerate training speed.
|
||||
sampler=dict(type='DefaultSampler', shuffle=True), # Randomly shuffle during training.
|
||||
dataset=dict( # Train dataset config
|
||||
type=dataset_class_name, # Type of dataset, refer to mmseg/datasets/ for details.
|
||||
data_root=data_root, # The root of dataset.
|
||||
data_prefix=dict(
|
||||
img_path='{train_img_path}',
|
||||
seg_map_path='{train_seg_map_path}'),
|
||||
pipeline=train_pipeline)) # Processing pipeline. This is passed by the train_pipeline created before.
|
||||
val_dataloader = dict(
|
||||
batch_size={val_and_test_batch_size}, # Batch size of a single GPU
|
||||
num_workers={val_and_test_num_workers}, # Worker to pre-fetch data for each single GPU
|
||||
persistent_workers=True, # Shut down the worker processes after an epoch end, which can accelerate testing speed.
|
||||
sampler=dict(type='DefaultSampler', shuffle=False), # Not shuffle during validation and testing.
|
||||
dataset=dict( # Test dataset config
|
||||
type=dataset_class_name, # Type of dataset, refer to mmseg/datasets/ for details.
|
||||
data_root=data_root, # The root of dataset.
|
||||
data_prefix=dict(
|
||||
img_path='{val_img_path}',
|
||||
seg_map_path='{val_seg_map_path}'),
|
||||
pipeline=test_pipeline)) # Processing pipeline. This is passed by the test_pipeline created before.
|
||||
test_dataloader = dict(
|
||||
batch_size={val_and_test_batch_size}, # Batch size of a single GPU
|
||||
num_workers={val_and_test_num_workers}, # Worker to pre-fetch data for each single GPU
|
||||
persistent_workers=True, # Shut down the worker processes after an epoch end, which can accelerate testing speed.
|
||||
sampler=dict(type='DefaultSampler', shuffle=False), # Not shuffle during validation and testing.
|
||||
dataset=dict( # Test dataset config
|
||||
type=dataset_class_name, # Type of dataset, refer to mmseg/datasets/ for details.
|
||||
data_root=data_root, # The root of dataset.
|
||||
data_prefix=dict(
|
||||
img_path='{test_img_path}',
|
||||
seg_map_path='{test_seg_map_path}'),
|
||||
pipeline=test_pipeline)) # Processing pipeline. This is passed by the test_pipeline created before.
|
||||
# The metric to measure the accuracy. Here, we use IoUMetric.
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
"""
|
||||
|
||||
# 创建目录(如果不存在的话)
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
|
||||
# 写入文件
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(dataset_config_template)
|
||||
print(f"\033[93m{output_file} file generated successfully\033[0m")
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 定义各参数 ###########
|
||||
dataset_file_name='my_dataset_model' # 数据集 文件名.py
|
||||
dataset_class_name='MyDataset_model' # 数据集 类名称
|
||||
data_root='/home/audience/Desktop/Seg_data/Data' # 数据根目录
|
||||
img_scale=(1920, 1080) # 图片大小
|
||||
# 训练、验证、测试集所在文件夹
|
||||
train_img_path='A_Ori'
|
||||
train_seg_map_path='A_Label_GT_label_fold'
|
||||
val_img_path='A_Ori'
|
||||
val_seg_map_path='A_Label_GT_label_fold'
|
||||
test_img_path='A_Ori'
|
||||
test_seg_map_path='A_Label_GT_label_fold'
|
||||
|
||||
# 一般不太会变的参数
|
||||
crop_size=(512, 512) # 分割大小
|
||||
train_batch_size=4 # 训练batch
|
||||
train_num_workers=4 # 训练并行运行数量
|
||||
val_and_test_batch_size=1 # 验证集和测试集batch
|
||||
val_and_test_num_workers=4 # 验证集和测试集并行运行数量
|
||||
|
||||
########### 文件存储位置 ###########
|
||||
output_configs_base_datasets_my_dataset=f'./configs/_base_/datasets/{dataset_file_name}.py'
|
||||
|
||||
# 使用默认变量生成配置文件
|
||||
success = generate_configs_base_datasets_my_dataset_file(output_file=output_configs_base_datasets_my_dataset, dataset_class_name=dataset_class_name , data_root=data_root, img_scale=img_scale, crop_size=crop_size, train_batch_size=train_batch_size, train_num_workers=train_num_workers, val_and_test_batch_size=val_and_test_batch_size, val_and_test_num_workers=val_and_test_num_workers, train_img_path=train_img_path, train_seg_map_path=train_seg_map_path, val_img_path=val_img_path, val_seg_map_path=val_seg_map_path, test_img_path=test_img_path, test_seg_map_path=test_seg_map_path)
|
||||
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
|
||||
def generate_mmseg_datasets_init_file(output_file, dataset_file_names=["my_data_set_model"], dataset_class_names=["MyDataset_model"]):
|
||||
# 判断数据集文件名 和 对应类名大小是否相同 # 不同则输出错误内容,并退出
|
||||
if len(dataset_file_names) != len(dataset_class_names):
|
||||
print(f"\033[91mInitial_Gen_mmseg_datasets_init_.py程序中MyDataset_File_name大小和MyDataset_Class_name大小不匹配,函数退出!\033[0m")
|
||||
return False
|
||||
|
||||
# 模板开头部分
|
||||
header = """# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# yapf: disable
|
||||
from .ade import ADE20KDataset
|
||||
from .basesegdataset import BaseCDDataset, BaseSegDataset
|
||||
from .bdd100k import BDD100KDataset
|
||||
from .chase_db1 import ChaseDB1Dataset
|
||||
from .cityscapes import CityscapesDataset
|
||||
from .coco_stuff import COCOStuffDataset
|
||||
from .dark_zurich import DarkZurichDataset
|
||||
from .dataset_wrappers import MultiImageMixDataset
|
||||
from .decathlon import DecathlonDataset
|
||||
from .drive import DRIVEDataset
|
||||
from .dsdl import DSDLSegDataset
|
||||
from .hrf import HRFDataset
|
||||
from .hsi_drive import HSIDrive20Dataset
|
||||
from .isaid import iSAIDDataset
|
||||
from .isprs import ISPRSDataset
|
||||
from .levir import LEVIRCDDataset
|
||||
from .lip import LIPDataset
|
||||
from .loveda import LoveDADataset
|
||||
from .mapillary import MapillaryDataset_v1, MapillaryDataset_v2
|
||||
from .night_driving import NightDrivingDataset
|
||||
from .nyu import NYUDataset
|
||||
from .pascal_context import PascalContextDataset, PascalContextDataset59
|
||||
from .potsdam import PotsdamDataset
|
||||
from .refuge import REFUGEDataset
|
||||
from .stare import STAREDataset
|
||||
from .synapse import SynapseDataset
|
||||
"""
|
||||
# 增加多个 dataset_file_names imports
|
||||
imports = ""
|
||||
for dataset_file, dataset_name in zip(dataset_file_names, dataset_class_names):
|
||||
imports += f"from .{dataset_file} import {dataset_name} # TODO\n"
|
||||
|
||||
# 中间固定部分
|
||||
middle = """# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
|
||||
LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray, LoadMultipleRSImageFromFile,
|
||||
LoadSingleRSImageFromFile, PackSegInputs,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
from .voc import PascalVOCDataset
|
||||
|
||||
# yapf: enable
|
||||
__all__ = [
|
||||
"""
|
||||
# 增加多个 vars.MyDataset_C 到 __all__
|
||||
all_datasets = ""
|
||||
for dataset_name in dataset_class_names:
|
||||
all_datasets += f" '{dataset_name}', # TODO\n"
|
||||
|
||||
# __all__ 中其他固定部分
|
||||
all_fixed = """ 'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
|
||||
'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
|
||||
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
|
||||
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
|
||||
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
|
||||
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
|
||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
|
||||
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
|
||||
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
|
||||
'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
|
||||
'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
|
||||
'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset',
|
||||
'NYUDataset', 'HSIDrive20Dataset'
|
||||
]
|
||||
"""
|
||||
|
||||
# 拼接完整内容
|
||||
content = header + imports + middle + all_datasets + all_fixed
|
||||
|
||||
# 写入文件
|
||||
with open(output_file, 'w', encoding='utf-8') as file:
|
||||
file.write(content)
|
||||
|
||||
print(f"\033[93m{output_file} file generated successfully\033[0m")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
########### 定义各参数 ###########
|
||||
# 可以定义多个数据集文件名 和 对应类名
|
||||
dataset_file_names = ["my_dataset_model"] # =['my_dataset', 'my_dataset_2'] # =["my_dataset"]
|
||||
dataset_class_names = ["MyDataset_model"] # =['MyDataset', 'MyDataset2'] # =["MyDataset"]
|
||||
|
||||
########### 文件存储位置 ###########
|
||||
output_mmseg_datasets_init = os.path.join('./mmseg/datasets/__init__.py')
|
||||
|
||||
# 生成 ./mmseg/datasets/__init__.py 文件
|
||||
success = generate_mmseg_datasets_init_file(output_file=output_mmseg_datasets_init, dataset_file_names=dataset_file_names, dataset_class_names=dataset_class_names)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
|
||||
def generate_mmseg_datasets_my_dataset_file(output_file='./mmseg/datasets/my_dataset_model.py', dataset_class_name = "MyDataset_model", classes=['背景'], palette=[[0,0,0]], img_suffix=".png", seg_map_suffix="_gtFine_labelTrainIds.png", reduce_zero_label=False):
|
||||
# 先判断 classes 和 palette 的大小是否一致
|
||||
if len(classes) != len(palette):
|
||||
print(f"\033[91mInitial_Gen_mmseg_datasets_my_dataset.py程序中classes大小和palette大小不匹配,函数退出!\033[0m")
|
||||
return False
|
||||
|
||||
# 判断是否有 '背景' 和 [0, 0, 0]
|
||||
if '背景' not in classes and 'background' not in classes and 'bg' not in classes and [0, 0, 0] not in palette:
|
||||
# 循环提示用户直到输入有效的值
|
||||
while True:
|
||||
print(f"现有Clas为:{classes}")
|
||||
print(f"现有palette为:{palette}")
|
||||
user_input = input("是否加入背景[0,0,0]到调色板第一位,Y加入,N不加入: ").strip().lower()
|
||||
if user_input == 'y':
|
||||
classes.insert(0,'背景')
|
||||
palette.insert(0,[0, 0, 0])
|
||||
print("已加入背景和调色板。")
|
||||
break
|
||||
elif user_input == 'n':
|
||||
print("未加入背景和调色板。")
|
||||
break
|
||||
else:
|
||||
print("无效输入,请输入Y或N。")
|
||||
|
||||
# 使用 f-string 进行模板文本替换
|
||||
template = f'''# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# import mmengine.fileio as fileio
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class {dataset_class_name}(BaseSegDataset): # 表示你定义的数据的名字,顺便取一个名字即可
|
||||
"""{dataset_class_name} dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes={classes}, # 背景最好放到第一个
|
||||
palette={palette}) # TODO 标注类型和颜色
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='{img_suffix}', # TODO mask图像类型
|
||||
seg_map_suffix='{seg_map_suffix}', # TODO mask图像后缀
|
||||
reduce_zero_label={reduce_zero_label}, # TODO 在第 0 类为无意义黑边时,使用reduce_zero_label = True将其和待分类内容分开;在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
# assert fileio.exists(
|
||||
# self.data_prefix['img_path'], backend_args=self.backend_args)
|
||||
'''
|
||||
|
||||
# 将生成的程序内容写入文件
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(template)
|
||||
|
||||
print(f"\033[93m{output_file} file generated successfully, 其中标签共{len(classes)}类\033[0m")
|
||||
return True, classes, palette
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
########### 定义各参数 ###########
|
||||
# 定义要替换的内容
|
||||
dataset_file_name = "my_dataset_model"
|
||||
dataset_class_name = "MyDataset_model"
|
||||
classes = ['肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉']
|
||||
palette = [[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118],[0,157,142],[181,85,105],[42,8,66]]
|
||||
|
||||
# 一般不太会变的参数
|
||||
img_suffix = ".png"
|
||||
seg_map_suffix = "_gtFine_labelTrainIds.png"
|
||||
reduce_zero_label = False # 在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用reduce_zero_label的 【reduce_zero_label = False】
|
||||
|
||||
########### 文件存储位置 ###########
|
||||
output_mmseg_datasets_dataset_file_name = os.path.join(f'./mmseg/datasets/{dataset_file_name}.py')
|
||||
|
||||
# 生成程序文件
|
||||
success, classes, palette = generate_mmseg_datasets_my_dataset_file(output_file=output_mmseg_datasets_dataset_file_name, dataset_class_name=dataset_class_name, classes=classes, palette=palette, img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, reduce_zero_label=reduce_zero_label)
|
||||
@@ -0,0 +1,617 @@
|
||||
import os
|
||||
|
||||
def generate_mmseg_utils_class_names_file(output_file=f'./mmseg/utils/class_names.py', dataset_file_names=['my_dataset_model', 'my_dataset_model_2'], classes_all=[['背景_1'], ['背景_2']], palette_all=[[[0,0,0]], [[0,0,0]]]):
|
||||
# 检查 dataset_file_names、classes_all 和 palette_all 的长度是否一致
|
||||
if len(dataset_file_names) != len(classes_all) or len(dataset_file_names) != len(palette_all):
|
||||
print(f"\033[91mInitial_Gen_mmseg_utils_class_names.py程序中 dataset_file_names 数量 {len(dataset_file_names)} 和 classes_all {len(classes_all)} 或 palette_all {len(palette_all)} 大小不匹配,函数退出!\033[0m")
|
||||
return False
|
||||
|
||||
# 逐一检查 classes 和 palette 的大小是否一致
|
||||
for i, (classes, palette) in enumerate(zip(classes_all, palette_all)):
|
||||
len_classes = len(classes)
|
||||
len_palette = len(palette)
|
||||
if len_classes != len_palette:
|
||||
print(f"\033[91mInitial_Gen_mmseg_utils_class_names.py程序中 {dataset_file_names[i]} 的 classes 大小 {len_classes} 和 palette 大小 {len_palette} 不匹配,函数退出!\033[0m")
|
||||
print(f"\033[91m具体 classes: \033[0m{classes}")
|
||||
print(f"\033[91m具体 palette: \033[0m{palette}")
|
||||
return False
|
||||
|
||||
# 初始化内容变量,用于存储多组 classes 和 palette
|
||||
content = '''# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.utils import is_str
|
||||
|
||||
|
||||
def cityscapes_classes():
|
||||
"""Cityscapes class names for external use."""
|
||||
return [
|
||||
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
||||
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
||||
'bicycle'
|
||||
]
|
||||
|
||||
|
||||
def ade_classes():
|
||||
"""ADE20K class names for external use."""
|
||||
return [
|
||||
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
|
||||
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
|
||||
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
|
||||
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
|
||||
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
|
||||
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
||||
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
||||
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
|
||||
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
|
||||
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
|
||||
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
|
||||
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
|
||||
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
||||
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
|
||||
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
|
||||
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
|
||||
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
|
||||
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
|
||||
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
|
||||
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
|
||||
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
|
||||
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
|
||||
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
||||
'clock', 'flag'
|
||||
]
|
||||
|
||||
|
||||
def voc_classes():
|
||||
"""Pascal VOC class names for external use."""
|
||||
return [
|
||||
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
|
||||
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
|
||||
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
|
||||
'tvmonitor'
|
||||
]
|
||||
|
||||
|
||||
def pcontext_classes():
|
||||
"""Pascal Context class names for external use."""
|
||||
return [
|
||||
'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird',
|
||||
'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat',
|
||||
'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain',
|
||||
'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',
|
||||
'horse', 'keyboard', 'light', 'motorbike', 'mountain', 'mouse',
|
||||
'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
|
||||
'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track',
|
||||
'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window',
|
||||
'wood'
|
||||
]
|
||||
|
||||
|
||||
def cocostuff_classes():
|
||||
"""CocoStuff class names for external use."""
|
||||
return [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||||
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
||||
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
||||
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
||||
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
||||
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
||||
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
||||
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
||||
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
||||
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
||||
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
|
||||
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
|
||||
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
|
||||
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
|
||||
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
|
||||
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
|
||||
'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
|
||||
'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
|
||||
'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper',
|
||||
'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
|
||||
'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
|
||||
'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
|
||||
'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
|
||||
'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
|
||||
'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
|
||||
'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
|
||||
'window-blind', 'window-other', 'wood'
|
||||
]
|
||||
|
||||
|
||||
def loveda_classes():
|
||||
"""LoveDA class names for external use."""
|
||||
return [
|
||||
'background', 'building', 'road', 'water', 'barren', 'forest',
|
||||
'agricultural'
|
||||
]
|
||||
|
||||
|
||||
def potsdam_classes():
|
||||
"""Potsdam class names for external use."""
|
||||
return [
|
||||
'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
|
||||
'clutter'
|
||||
]
|
||||
|
||||
|
||||
def vaihingen_classes():
|
||||
"""Vaihingen class names for external use."""
|
||||
return [
|
||||
'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
|
||||
'clutter'
|
||||
]
|
||||
|
||||
|
||||
def isaid_classes():
|
||||
"""iSAID class names for external use."""
|
||||
return [
|
||||
'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court',
|
||||
'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle',
|
||||
'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout',
|
||||
'Soccer_ball_field', 'plane', 'Harbor'
|
||||
]
|
||||
|
||||
|
||||
def stare_classes():
|
||||
"""stare class names for external use."""
|
||||
return ['background', 'vessel']
|
||||
|
||||
|
||||
def mapillary_v1_classes():
|
||||
"""mapillary_v1 class names for external use."""
|
||||
return [
|
||||
'Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier',
|
||||
'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking',
|
||||
'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane', 'Sidewalk',
|
||||
'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
|
||||
'Other Rider', 'Lane Marking - Crosswalk', 'Lane Marking - General',
|
||||
'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
|
||||
'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin',
|
||||
'CCTV Camera', 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
|
||||
'Phone Booth', 'Pothole', 'Street Light', 'Pole', 'Traffic Sign Frame',
|
||||
'Utility Pole', 'Traffic Light', 'Traffic Sign (Back)',
|
||||
'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car',
|
||||
'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer',
|
||||
'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'
|
||||
]
|
||||
|
||||
|
||||
def mapillary_v1_palette():
|
||||
"""mapillary_v1_ palette for external use."""
|
||||
return [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
|
||||
[180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255],
|
||||
[140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232],
|
||||
[150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60],
|
||||
[255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128],
|
||||
[255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180],
|
||||
[190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30],
|
||||
[255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220],
|
||||
[220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40],
|
||||
[33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150],
|
||||
[210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80],
|
||||
[250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20],
|
||||
[119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142],
|
||||
[0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]]
|
||||
|
||||
|
||||
def mapillary_v2_classes():
|
||||
"""mapillary_v2 class names for external use."""
|
||||
return [
|
||||
'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', 'Curb',
|
||||
'Fence', 'Guard Rail', 'Barrier', 'Road Median', 'Road Side',
|
||||
'Lane Separator', 'Temporary Barrier', 'Wall', 'Bike Lane',
|
||||
'Crosswalk - Plain', 'Curb Cut', 'Driveway', 'Parking',
|
||||
'Parking Aisle', 'Pedestrian Area', 'Rail Track', 'Road',
|
||||
'Road Shoulder', 'Service Lane', 'Sidewalk', 'Traffic Island',
|
||||
'Bridge', 'Building', 'Garage', 'Tunnel', 'Person', 'Person Group',
|
||||
'Bicyclist', 'Motorcyclist', 'Other Rider',
|
||||
'Lane Marking - Dashed Line', 'Lane Marking - Straight Line',
|
||||
'Lane Marking - Zigzag Line', 'Lane Marking - Ambiguous',
|
||||
'Lane Marking - Arrow (Left)', 'Lane Marking - Arrow (Other)',
|
||||
'Lane Marking - Arrow (Right)',
|
||||
'Lane Marking - Arrow (Split Left or Straight)',
|
||||
'Lane Marking - Arrow (Split Right or Straight)',
|
||||
'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - Give Way (Row)', 'Lane Marking - Give Way (Single)',
|
||||
'Lane Marking - Hatched (Chevron)',
|
||||
'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other',
|
||||
'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)',
|
||||
'Lane Marking - Symbol (Other)', 'Lane Marking - Text',
|
||||
'Lane Marking (only) - Dashed Line', 'Lane Marking (only) - Crosswalk',
|
||||
'Lane Marking (only) - Other', 'Lane Marking (only) - Test',
|
||||
'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
|
||||
'Banner', 'Bench', 'Bike Rack', 'Catch Basin', 'CCTV Camera',
|
||||
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Parking Meter',
|
||||
'Phone Booth', 'Pothole', 'Signage - Advertisement',
|
||||
'Signage - Ambiguous', 'Signage - Back', 'Signage - Information',
|
||||
'Signage - Other', 'Signage - Store', 'Street Light', 'Pole',
|
||||
'Pole Group', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Cone',
|
||||
'Traffic Light - General (Single)', 'Traffic Light - Pedestrians',
|
||||
'Traffic Light - General (Upright)',
|
||||
'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists',
|
||||
'Traffic Light - Other', 'Traffic Sign - Ambiguous',
|
||||
'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)',
|
||||
'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)',
|
||||
'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)',
|
||||
'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat',
|
||||
'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle',
|
||||
'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve',
|
||||
'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', 'Unlabeled'
|
||||
]
|
||||
|
||||
|
||||
def mapillary_v2_palette():
|
||||
"""mapillary_v2_ palette for external use."""
|
||||
return [[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32],
|
||||
[196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150],
|
||||
[250, 170, 33], [250, 170, 34], [128, 128, 128], [250, 170, 35],
|
||||
[102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170],
|
||||
[250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110], [110, 110, 110],
|
||||
[244, 35, 232], [128, 196, 128], [150, 100, 100], [70, 70, 70],
|
||||
[150, 150, 150], [150, 120, 90], [220, 20, 60], [220, 20, 60],
|
||||
[255, 0, 0], [255, 0, 100], [255, 0, 200], [255, 255, 255],
|
||||
[255, 255, 255], [250, 170, 29], [250, 170, 28], [250, 170, 26],
|
||||
[250, 170, 25], [250, 170, 24], [250, 170, 22], [250, 170, 21],
|
||||
[250, 170, 20], [255, 255, 255], [250, 170, 19], [250, 170, 18],
|
||||
[250, 170, 12], [250, 170, 11], [255, 255, 255], [255, 255, 255],
|
||||
[250, 170, 16], [250, 170, 15], [250, 170, 15], [255, 255, 255],
|
||||
[255, 255, 255], [255, 255, 255], [255, 255, 255], [64, 170, 64],
|
||||
[230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
|
||||
[100, 140, 180], [220, 128, 128], [222, 40, 40], [100, 170, 30],
|
||||
[40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255],
|
||||
[142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30],
|
||||
[250, 173, 30], [250, 174, 30], [250, 175, 30], [250, 176, 30],
|
||||
[210, 170, 100], [153, 153, 153], [153, 153, 153], [128, 128, 128],
|
||||
[0, 0, 80], [210, 60, 60], [250, 170, 30], [250, 170, 30],
|
||||
[250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30],
|
||||
[192, 192, 192], [192, 192, 192], [192, 192, 192], [220, 220, 0],
|
||||
[220, 220, 0], [0, 0, 196], [192, 192, 192], [220, 220, 0],
|
||||
[140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100],
|
||||
[0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64],
|
||||
[0, 0, 110], [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170],
|
||||
[32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81],
|
||||
[111, 111, 0], [0, 0, 0]]
|
||||
|
||||
|
||||
def cityscapes_palette():
|
||||
"""Cityscapes palette for external use."""
|
||||
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
||||
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
|
||||
[0, 0, 230], [119, 11, 32]]
|
||||
|
||||
|
||||
def ade_palette():
|
||||
"""ADE20K palette for external use."""
|
||||
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
||||
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
||||
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
||||
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
||||
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
||||
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
||||
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
||||
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
||||
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
||||
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
||||
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
||||
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
||||
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
||||
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
||||
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
||||
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
||||
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
||||
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
||||
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
||||
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
||||
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
||||
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
||||
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
||||
[102, 255, 0], [92, 0, 255]]
|
||||
|
||||
|
||||
def voc_palette():
|
||||
"""Pascal VOC palette for external use."""
|
||||
return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
|
||||
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
|
||||
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
|
||||
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
|
||||
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
|
||||
|
||||
|
||||
def pcontext_palette():
|
||||
"""Pascal Context palette for external use."""
|
||||
return [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
|
||||
[120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
|
||||
[4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
|
||||
[120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
|
||||
[204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
|
||||
[61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
|
||||
[255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
|
||||
[112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
|
||||
[10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
|
||||
[102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
|
||||
[0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
|
||||
[235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
|
||||
[250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
|
||||
[255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
|
||||
[0, 235, 255], [0, 173, 255], [31, 0, 255]]
|
||||
|
||||
|
||||
def cocostuff_palette():
|
||||
"""CocoStuff palette for external use."""
|
||||
return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
|
||||
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
|
||||
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
|
||||
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
|
||||
[0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
|
||||
[128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
|
||||
[64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0],
|
||||
[0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0],
|
||||
[192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32],
|
||||
[0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
|
||||
[128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64],
|
||||
[192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32],
|
||||
[64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
|
||||
[0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64],
|
||||
[128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32],
|
||||
[64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128],
|
||||
[128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0],
|
||||
[128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96],
|
||||
[64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0],
|
||||
[0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0],
|
||||
[192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96],
|
||||
[64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128],
|
||||
[128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64],
|
||||
[192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96],
|
||||
[0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0],
|
||||
[64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64],
|
||||
[128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96],
|
||||
[0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128],
|
||||
[192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0],
|
||||
[128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32],
|
||||
[0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64],
|
||||
[64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0],
|
||||
[192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32],
|
||||
[0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192],
|
||||
[192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64],
|
||||
[192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32],
|
||||
[64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64],
|
||||
[64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64],
|
||||
[128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32],
|
||||
[64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192],
|
||||
[192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0],
|
||||
[128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96],
|
||||
[64, 160, 64], [64, 64, 0]]
|
||||
|
||||
|
||||
def loveda_palette():
|
||||
"""LoveDA palette for external use."""
|
||||
return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
|
||||
[159, 129, 183], [0, 255, 0], [255, 195, 128]]
|
||||
|
||||
|
||||
def potsdam_palette():
|
||||
"""Potsdam palette for external use."""
|
||||
return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||
[255, 255, 0], [255, 0, 0]]
|
||||
|
||||
|
||||
def vaihingen_palette():
|
||||
"""Vaihingen palette for external use."""
|
||||
return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||
[255, 255, 0], [255, 0, 0]]
|
||||
|
||||
|
||||
def isaid_palette():
|
||||
"""iSAID palette for external use."""
|
||||
return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
|
||||
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127,
|
||||
127], [0, 0, 127],
|
||||
[0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191],
|
||||
[0, 127, 255], [0, 100, 155]]
|
||||
|
||||
|
||||
def stare_palette():
|
||||
"""STARE palette for external use."""
|
||||
return [[120, 120, 120], [6, 230, 230]]
|
||||
|
||||
|
||||
def synapse_palette():
|
||||
"""Synapse palette for external use."""
|
||||
return [[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], [0, 255, 255],
|
||||
[255, 0, 255], [255, 255, 0], [60, 255, 255], [240, 240, 240]]
|
||||
|
||||
|
||||
def synapse_classes():
|
||||
"""Synapse class names for external use."""
|
||||
return [
|
||||
'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney',
|
||||
'liver', 'pancreas', 'spleen', 'stomach'
|
||||
]
|
||||
|
||||
|
||||
def lip_classes():
|
||||
"""LIP class names for external use."""
|
||||
return [
|
||||
'background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes',
|
||||
'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt',
|
||||
'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe',
|
||||
'rightShoe'
|
||||
]
|
||||
|
||||
|
||||
def lip_palette():
|
||||
"""LIP palette for external use."""
|
||||
return [
|
||||
'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'UpperClothes',
|
||||
'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt',
|
||||
'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe',
|
||||
'Right-shoe'
|
||||
]
|
||||
|
||||
|
||||
def bdd100k_classes():
|
||||
"""BDD100K class names for external use(the class name is compatible with
|
||||
Cityscapes )."""
|
||||
return [
|
||||
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
||||
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
||||
'bicycle'
|
||||
]
|
||||
|
||||
|
||||
def bdd100k_palette():
|
||||
"""bdd100k palette for external use(same with cityscapes)"""
|
||||
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
||||
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
|
||||
[0, 0, 230], [119, 11, 32]]
|
||||
|
||||
|
||||
def hsidrive_classes():
|
||||
"""HSI Drive 2.0 class names for external use."""
|
||||
return [
|
||||
'unlabelled', 'road', 'road marks', 'vegetation', 'painted metal',
|
||||
'sky', 'concrete', 'pedestrian', 'water', 'unpainted metal', 'glass'
|
||||
]
|
||||
|
||||
|
||||
def hsidrive_palette():
|
||||
"""HSI Drive 2.0 palette for external use."""
|
||||
return [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0], [255, 0, 0],
|
||||
[0, 0, 255], [102, 51, 0], [255, 255, 0], [0, 207, 250],
|
||||
[255, 166, 0], [0, 204, 204]]
|
||||
|
||||
|
||||
'''
|
||||
|
||||
# 生成多组 classes 和 palette 定义
|
||||
for dataset_file_name, classes, palette in zip(dataset_file_names, classes_all, palette_all):
|
||||
# 每一组使用 f-string 进行模板替换
|
||||
content += f'''def {dataset_file_name}_classes(): # TODO
|
||||
"""{dataset_file_name} class names for external use."""
|
||||
return {classes}
|
||||
|
||||
|
||||
def {dataset_file_name}_palette(): # TODO
|
||||
"""{dataset_file_name} palette for external use."""
|
||||
return {palette}
|
||||
|
||||
|
||||
'''
|
||||
|
||||
# dataset_aliases 和 get_classes, get_palette 函数的定义
|
||||
content += '''dataset_aliases = {
|
||||
'''
|
||||
|
||||
for dataset_file_name in dataset_file_names:
|
||||
content += f" '{dataset_file_name}': ['{dataset_file_name}'], # TODO\n"
|
||||
|
||||
# 追加固定的 hsidrive 映射
|
||||
content += ''' 'cityscapes': ['cityscapes'],
|
||||
'ade': ['ade', 'ade20k'],
|
||||
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'],
|
||||
'pcontext': ['pcontext', 'pascal_context', 'voc2010'],
|
||||
'loveda': ['loveda'],
|
||||
'potsdam': ['potsdam'],
|
||||
'vaihingen': ['vaihingen'],
|
||||
'cocostuff': [
|
||||
'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff',
|
||||
'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k',
|
||||
'coco_stuff164k'
|
||||
],
|
||||
'isaid': ['isaid', 'iSAID'],
|
||||
'stare': ['stare', 'STARE'],
|
||||
'lip': ['LIP', 'lip'],
|
||||
'mapillary_v1': ['mapillary_v1'],
|
||||
'mapillary_v2': ['mapillary_v2'],
|
||||
'bdd100k': ['bdd100k'],
|
||||
'hsidrive': [
|
||||
'hsidrive', 'HSIDrive', 'HSI-Drive', 'hsidrive20', 'HSIDrive20',
|
||||
'HSI-Drive20'
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def get_classes(dataset):
|
||||
"""Get class names of a dataset."""
|
||||
alias2name = {}
|
||||
for name, aliases in dataset_aliases.items():
|
||||
for alias in aliases:
|
||||
alias2name[alias] = name
|
||||
|
||||
if isinstance(dataset, str):
|
||||
if dataset in alias2name:
|
||||
labels = eval(alias2name[dataset] + '_classes()')
|
||||
else:
|
||||
raise ValueError(f'Unrecognized dataset: {dataset}')
|
||||
else:
|
||||
raise TypeError(f'dataset must a str, but got {type(dataset)}')
|
||||
return labels
|
||||
|
||||
|
||||
def get_palette(dataset):
|
||||
"""Get class palette (RGB) of a dataset."""
|
||||
alias2name = {}
|
||||
for name, aliases in dataset_aliases.items():
|
||||
for alias in aliases:
|
||||
alias2name[alias] = name
|
||||
|
||||
if isinstance(dataset, str):
|
||||
if dataset in alias2name:
|
||||
labels = eval(alias2name[dataset] + '_palette()')
|
||||
else:
|
||||
raise ValueError(f'Unrecognized dataset: {dataset}')
|
||||
else:
|
||||
raise TypeError(f'dataset must a str, but got {type(dataset)}')
|
||||
return labels
|
||||
'''
|
||||
|
||||
# 将生成的程序内容写入文件
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
print(f"\033[93m{output_file} file generated successfully\033[0m")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
########### 定义各参数 ###########
|
||||
# 可以定义多个数据集文件名 和 对应的classes和palette
|
||||
dataset_file_names = ["my_dataset_model"]
|
||||
# 这里的classes一定是经过“Initial_Gen_mmseg_datasets_my_dataset.py”处理的
|
||||
classes_all = [
|
||||
['背景','肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉'],
|
||||
]
|
||||
palette_all = [
|
||||
[[0,0,0],[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118],[0,157,142],[181,85,105],[42,8,66]],
|
||||
]
|
||||
|
||||
########### 文件存储位置 ###########
|
||||
output_mmseg_utils_class_names = f'./mmseg/utils/class_names.py'
|
||||
|
||||
# 生成程序文件
|
||||
success = generate_mmseg_utils_class_names_file(output_file=output_mmseg_utils_class_names, dataset_file_names=dataset_file_names, classes_all=classes_all, palette_all=palette_all)
|
||||
@@ -0,0 +1,82 @@
|
||||
import os
|
||||
|
||||
def generate_epochs_configs_base_schedules_schedule_file(output_file, max_epochs=300, val_interval=1, checkpoint_interval=10, loggerhook_interval=300):
|
||||
"""
|
||||
生成基于 Epoch 的 MMSegmentation schedule 配置文件。
|
||||
|
||||
Args:
|
||||
output_file (str): 输出配置文件的路径。
|
||||
max_epochs (int): 最大训练轮次。
|
||||
val_interval (int): 验证间隔的轮次数。
|
||||
checkpoint_interval (int): 保存模型权重间隔的轮次数。
|
||||
loggerhook_interval (int): 日志打印的迭代间隔。
|
||||
"""
|
||||
|
||||
# 定义 Epoch-based 的模板
|
||||
schedule_config_template = f"""# optimizer
|
||||
# For SGD.
|
||||
# optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
# For AdamW.
|
||||
optimizer = dict(type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.01)
|
||||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='PolyLR',
|
||||
eta_min=1e-4,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end={max_epochs},
|
||||
by_epoch=True)
|
||||
]
|
||||
|
||||
# training schedule for {max_epochs} epochs
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs={max_epochs}, val_interval={val_interval})
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval={loggerhook_interval}, log_metric_by_epoch=True),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=True, interval={checkpoint_interval}, save_best='mIoU', rule='greater'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
"""
|
||||
|
||||
# --- 将生成的内容输出到界面 ---
|
||||
print("\033[92m--- Generated Configuration Content ---\033[0m")
|
||||
print(schedule_config_template)
|
||||
print("\033[92m-------------------------------------\033[0m")
|
||||
|
||||
# 创建目录(如果不存在的话)
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
|
||||
# 写入文件
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(schedule_config_template)
|
||||
|
||||
print(f"\033[93m{output_file} file generated successfully\033[0m")
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 定义各参数 ###########
|
||||
max_epochs = 300 # 训练总轮数
|
||||
val_interval = 1 # 每 1 个 epoch 验证一次
|
||||
checkpoint_interval = 10 # 每 10 个 epoch 保存一次模型
|
||||
loggerhook_interval = max_epochs # 日志间隔1次迭代
|
||||
|
||||
########### 文件存储位置 ###########
|
||||
# 文件名将反映 epoch 数量, e.g., schedule_300e.py
|
||||
output_filename = f'schedule_{max_epochs}e.py'
|
||||
output_configs_base_schedules_file = os.path.join('./configs/_base_/schedules/', output_filename)
|
||||
|
||||
# 使用变量生成配置文件
|
||||
success = generate_epoch_based_schedule_file(
|
||||
output_file=output_configs_base_schedules_file,
|
||||
max_epochs=max_epochs,
|
||||
val_interval=val_interval,
|
||||
checkpoint_interval=checkpoint_interval,
|
||||
loggerhook_interval=loggerhook_interval
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
|
||||
def generate_times_configs_base_schedules_schedule_file(output_file, train_time_k=4, val_proportion=1/10, loggerhook_interval=50):
|
||||
|
||||
# 定义模板
|
||||
schedule_config_template = f"""# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='PolyLR',
|
||||
eta_min=1e-4,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end={int(train_time_k*1000)},
|
||||
by_epoch=False)
|
||||
]
|
||||
# training schedule for {train_time_k}k
|
||||
train_cfg = dict(
|
||||
type='IterBasedTrainLoop', max_iters={int(train_time_k*1000)}, val_interval={int(train_time_k*1000*val_proportion)})
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval={loggerhook_interval}, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval={int(train_time_k*1000*val_proportion)}, , save_best='mIoU', rule='greater'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
"""
|
||||
|
||||
# 创建目录(如果不存在的话)
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
|
||||
# 写入文件
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(schedule_config_template)
|
||||
|
||||
print(f"\033[93m{output_file} file generated successfully\033[0m")
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
########### 定义各参数 ###########
|
||||
train_time_k = 4 # 训练轮数(以k为单位)
|
||||
|
||||
# 一般不太会变的参数
|
||||
val_proportion = 1/10 # 验证比例
|
||||
loggerhook_interval = 50 # 日志间隔50次迭代
|
||||
|
||||
checkpoint_interval = int(train_time_k*1000*val_proportion) # 训练多少轮保存pth # TODO 可自定义
|
||||
|
||||
########### 文件存储位置 ###########
|
||||
output_configs_base_schedules_schedules_Timek = os.path.join('./configs/_base_/schedules/', f'schedule_{train_time_k}k.py')
|
||||
|
||||
# 使用变量生成配置文件
|
||||
success = generate_configs_base_schedules_schedule_file(output_file=output_configs_base_schedules_schedules_Timek, train_time_k=train_time_k, val_proportion=val_proportion, loggerhook_interval=loggerhook_interval)
|
||||
540
Seg_All_In_One_MMSeg/My_All_In_One/x4_Predict_V1-.py
Normal file
540
Seg_All_In_One_MMSeg/My_All_In_One/x4_Predict_V1-.py
Normal file
@@ -0,0 +1,540 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MMSegmentation 自动化推理、评估与分析脚本
|
||||
|
||||
本脚本旨在为基于 MMSegmentation 训练的语义分割模型提供一个全面、自动化的后处理流水线。
|
||||
它能够系统地发现指定目录下的训练产物,并为每个模型执行一系列详尽的分析任务,最终生成结构化的报告。
|
||||
|
||||
核心功能:
|
||||
1. 模型发现与检查点选择:
|
||||
- 自动扫描指定输入目录(默认为 './Outputs_mmseg')下的所有子文件夹。
|
||||
- 智能选择用于推理的权重文件:优先选择 'best.pth',若不存在,则选择周期数(epoch)最大的检查点文件。
|
||||
- 定位与模型对应的配置文件(.py)和训练日志文件(.json)。
|
||||
|
||||
2. 验证集推理与结果生成:
|
||||
- 对每个模型的验证集进行推理,并使用 tqdm 显示进度条。
|
||||
- 生成并保存原始的、未经着色的预测掩码图(predicted_raw_masks)。
|
||||
- 生成并保存包含“原始图像-预测图(着色后)-真值图标注(着色后)”的三图对比分析图像(prediction_analysis)。
|
||||
|
||||
3. 综合性能评估:
|
||||
- 计算并记录多项关键评估指标,包括 IoU (mIoU), F1-Score (mFscore/mDice), Accuracy (aAcc/mAcc), Recall (mRecall), 和 Precision (mPrecision)。
|
||||
- 深入计算每个类别的混淆矩阵基本元素:真正例(TP)、真负例(TN)、假正例(FP)、假负例(FN)。这一功能通过对 IoUMetric 内部数据的再处理实现,提供了标准评估流程之外的精细化分析能力。
|
||||
|
||||
4. 模型复杂度分析:
|
||||
- 计算模型的参数量(Parameters)和计算量(FLOPs),以评估其资源消耗。
|
||||
- 该计算基于 MMEngine 提供的底层分析工具,确保了与 MMSegmentation 官方工具的一致性。
|
||||
|
||||
5. 训练过程可视化:
|
||||
- 解析训练日志文件(.json),提取损失(loss)、准确率(accuracy)、交并比(mIoU)等关键指标在训练过程中的变化。
|
||||
- 使用 Matplotlib 和 Seaborn 绘制指标变化曲线图,并保存为图像文件,便于直观分析模型的收敛情况。
|
||||
|
||||
6. 结构化结果输出:
|
||||
- 所有生成的分析结果(指标、图像、日志图表)均被保存在一个结构化的输出目录中(默认为 './BestMode_Predict_Results_DataSet_Public')。
|
||||
- 每个模型的输出都存放在以其原始文件夹命名的子目录中,保持清晰的对应关系。
|
||||
- 最终的量化指标(包括总体指标、模型复杂度、逐类TP/TN/FP/FN)被整合并保存到一个名为 'test_set_metrics.csv' 的文件中,采用长格式(long format)存储,便于后续的数据处理和比较。
|
||||
|
||||
使用方法:
|
||||
直接在终端运行此脚本。可以通过命令行参数指定输入和输出的根目录。
|
||||
python predict.py --input_dir./Outputs_mmseg --output_dir./BestMode_Predict_Results_DataSet_Public
|
||||
|
||||
依赖:
|
||||
- torch, torchvision
|
||||
- mmengine
|
||||
- mmcv
|
||||
- mmsegmentation>=1.0.0
|
||||
- numpy
|
||||
- pandas
|
||||
- matplotlib
|
||||
- seaborn
|
||||
- tqdm
|
||||
- opencv-python
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import glob
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import mmcv
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from tqdm import tqdm
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.analysis import get_model_complexity_info
|
||||
from mmseg.apis import init_model, inference_model
|
||||
from mmseg.evaluation.metrics import IoUMetric
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
# --- 全局配置 ---
|
||||
# 设置日志记录
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# 自动选择设备
|
||||
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
logging.info(f"使用设备: {DEVICE}")
|
||||
|
||||
# --- 辅助函数 ---
|
||||
def find_model_files(model_dir: str) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
在给定的模型目录中查找配置文件、最佳检查点和日志文件。
|
||||
|
||||
Args:
|
||||
model_dir (str): 模型的根目录。
|
||||
|
||||
Returns:
|
||||
Optional]: 包含 'config', 'checkpoint', 'log' 路径的字典,
|
||||
如果缺少任何必要文件,则返回 None。
|
||||
"""
|
||||
config_files = glob.glob(os.path.join(model_dir, '*.py'))
|
||||
if not config_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
|
||||
return None
|
||||
config_path = config_files[0]
|
||||
|
||||
checkpoint_path = os.path.join(model_dir, 'best.pth')
|
||||
if not os.path.exists(checkpoint_path):
|
||||
epoch_files = glob.glob(os.path.join(model_dir, 'epoch_*.pth'))
|
||||
if not epoch_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到 'best.pth' 或 'epoch_*.pth' 检查点文件。")
|
||||
return None
|
||||
|
||||
# 通过正则表达式从文件名中提取周期数并找到最大的
|
||||
latest_epoch = -1
|
||||
latest_file = None
|
||||
for f in epoch_files:
|
||||
match = re.search(r'epoch_(\d+)\.pth', os.path.basename(f))
|
||||
if match:
|
||||
epoch_num = int(match.group(1))
|
||||
if epoch_num > latest_epoch:
|
||||
latest_epoch = epoch_num
|
||||
latest_file = f
|
||||
|
||||
if latest_file:
|
||||
checkpoint_path = latest_file
|
||||
else:
|
||||
logging.warning(f"在目录 {model_dir} 中无法确定最新的检查点文件。")
|
||||
return None
|
||||
|
||||
|
||||
# log_files = glob.glob(os.path.join(model_dir, 'vis_data', '*.json'))
|
||||
# if not log_files:
|
||||
# logging.warning(f"在目录 {model_dir}/vis_data/ 中未找到日志文件 (.json)。")
|
||||
# return None
|
||||
# # 通常只有一个json日志文件,取第一个
|
||||
# log_path = log_files
|
||||
|
||||
# return {'config': config_path, 'checkpoint': checkpoint_path, 'log': log_path}
|
||||
return {'config': config_path, 'checkpoint': checkpoint_path}
|
||||
|
||||
def calculate_model_complexity(cfg: Config) -> Dict[str, float]:
|
||||
"""
|
||||
计算模型的 FLOPs 和参数量。
|
||||
|
||||
Args:
|
||||
cfg (Config): 加载的模型配置对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: 包含 'flops_G' 和 'params_M' 的字典。
|
||||
"""
|
||||
try:
|
||||
model = MODELS.build(cfg.model)
|
||||
model.eval()
|
||||
|
||||
# 从配置中获取输入尺寸,如果不存在则使用默认值
|
||||
if hasattr(cfg, 'crop_size'):
|
||||
input_shape = (3, *cfg.crop_size)
|
||||
else:
|
||||
# 尝试从test_pipeline中查找
|
||||
input_shape = None
|
||||
if hasattr(cfg, 'test_pipeline'):
|
||||
for transform in cfg.test_pipeline:
|
||||
if transform['type'] == 'Resize':
|
||||
input_shape = (3, transform['scale'][1], transform['scale'])
|
||||
break
|
||||
if input_shape is None:
|
||||
input_shape = (3, 512, 512)
|
||||
logging.info(f"配置中未找到 'crop_size' 或 'Resize',使用默认输入尺寸 {input_shape} 进行 FLOPs 计算。")
|
||||
|
||||
# 使用 MMEngine 的内置工具进行分析
|
||||
result = get_model_complexity_info(model, input_shape)
|
||||
|
||||
flops_str = result.get('flops_str', '0.0G')
|
||||
params_str = result.get('params_str', '0.0M')
|
||||
|
||||
# 解析字符串以获取数值
|
||||
flops_g = float(re.search(r'(\d+\.?\d*)', flops_str).group(1))
|
||||
params_m = float(re.search(r'(\d+\.?\d*)', params_str).group(1))
|
||||
|
||||
# 单位转换
|
||||
if 'G' not in flops_str.upper():
|
||||
flops_g /= 1e3 # 假设是 MMac
|
||||
if 'M' not in params_str.upper():
|
||||
params_m /= 1e6 # 假设是 K
|
||||
|
||||
return {'flops_G': flops_g, 'params_M': params_m}
|
||||
except Exception as e:
|
||||
logging.error(f"计算模型复杂度时出错: {e}")
|
||||
return {'flops_G': 0.0, 'params_M': 0.0}
|
||||
|
||||
def analyze_training_log(log_path: str, output_dir: str):
|
||||
"""
|
||||
解析训练日志文件并绘制指标曲线。
|
||||
|
||||
Args:
|
||||
log_path (str): 训练日志文件 (.json) 的路径。
|
||||
output_dir (str): 保存绘图的目录。
|
||||
"""
|
||||
try:
|
||||
log_data = []
|
||||
with open(log_path, 'r') as f:
|
||||
for line in f:
|
||||
log_data.append(json.loads(line))
|
||||
|
||||
df = pd.DataFrame(log_data)
|
||||
|
||||
# 提取关键指标
|
||||
metrics_to_plot = {
|
||||
'mIoU': 'mIoU',
|
||||
'mAcc': 'mAcc',
|
||||
'aAcc': 'aAcc',
|
||||
'loss': 'loss'
|
||||
}
|
||||
|
||||
# 筛选出验证和训练周期的行
|
||||
val_df = df[df['mode'] == 'val'].dropna(subset=['step'])
|
||||
train_df = df[df['mode'] == 'train'].dropna(subset=['step'])
|
||||
|
||||
sns.set_theme(style="whitegrid")
|
||||
|
||||
for key, name in metrics_to_plot.items():
|
||||
plt.figure(figsize=(12, 6))
|
||||
|
||||
if key in val_df.columns:
|
||||
sns.lineplot(data=val_df, x='step', y=key, label=f'Validation {name}')
|
||||
|
||||
if key in train_df.columns and key == 'loss':
|
||||
# 训练 loss 通常波动较大,可以进行平滑处理
|
||||
train_df[f'{key}_smooth'] = train_df[key].rolling(window=50, min_periods=1).mean()
|
||||
sns.lineplot(data=train_df, x='step', y=f'{key}_smooth', label=f'Training {name} (Smoothed)')
|
||||
|
||||
plt.title(f'{name} Curve during Training')
|
||||
plt.xlabel('Training Steps')
|
||||
plt.ylabel(name)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
save_path = os.path.join(output_dir, f'{name}_curve.png')
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
logging.info(f"已保存日志图表: {save_path}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"分析训练日志 {log_path} 时出错: {e}")
|
||||
|
||||
def run_inference_and_evaluation(cfg: Config, checkpoint_path: str, output_dir: str) -> Dict:
|
||||
"""
|
||||
执行模型推理、评估和可视化。
|
||||
|
||||
该函数通过一次性遍历验证数据集来高效地完成所有任务,避免了重复的数据加载和模型前向传播。
|
||||
|
||||
Args:
|
||||
cfg (Config): 加载的模型配置对象。
|
||||
checkpoint_path (str): 模型检查点文件的路径。
|
||||
output_dir (str): 保存所有输出的根目录。
|
||||
|
||||
Returns:
|
||||
Dict: 包含评估指标和逐类 TP/TN/FP/FN 计数的字典。
|
||||
"""
|
||||
# --- 1. 初始化 ---
|
||||
model = init_model(cfg, checkpoint_path, device=DEVICE)
|
||||
|
||||
# 以编程方式构建数据加载器,确保与训练时的数据预处理流程一致
|
||||
val_dataloader = Runner.build_dataloader(cfg.val_dataloader)
|
||||
|
||||
# 实例化评估器和可视化器
|
||||
metric = IoUMetric(iou_metrics=['mIoU'])
|
||||
visualizer = SegLocalVisualizer(save_dir=os.path.join(output_dir, 'prediction_analysis'))
|
||||
|
||||
# 关键步骤:为可视化器设置数据集元信息(类别名和调色板)
|
||||
visualizer.dataset_meta = val_dataloader.dataset.metainfo
|
||||
|
||||
metric.dataset_meta = val_dataloader.dataset.metainfo
|
||||
|
||||
# --- 2. 创建输出子目录 ---
|
||||
raw_mask_dir = os.path.join(output_dir, 'predicted_raw_masks')
|
||||
viz_dir = os.path.join(output_dir, 'prediction_analysis')
|
||||
os.makedirs(raw_mask_dir, exist_ok=True)
|
||||
os.makedirs(viz_dir, exist_ok=True)
|
||||
|
||||
# --- 3. 推理循环 ---
|
||||
model.eval()
|
||||
metric.results = [] # 清空历史结果
|
||||
with torch.no_grad():
|
||||
for data in tqdm(val_dataloader, desc=f"推理与评估: {os.path.basename(output_dir)}"):
|
||||
# 将数据移动到指定设备
|
||||
# inputs = data['inputs'].to(DEVICE) # 有误
|
||||
inputs = data['inputs'][0].to(DEVICE).float().unsqueeze(0)
|
||||
data_samples =data['data_samples']
|
||||
|
||||
# 执行推理
|
||||
result = model(inputs, data_samples=data_samples, mode='predict')
|
||||
|
||||
# 将预测结果合并到 GT 样本中,以便于评估和可视化
|
||||
for i in range(len(data_samples)):
|
||||
data_samples[i].pred_sem_seg = result[i].pred_sem_seg
|
||||
|
||||
# 使用评估器的 process 方法处理一个批次的结果
|
||||
predictions_as_dicts = [r.to_dict() for r in result]
|
||||
metric.process(data, predictions_as_dicts)
|
||||
|
||||
# 保存原始预测掩码和可视化结果
|
||||
for sample in data_samples:
|
||||
pred_mask = sample.pred_sem_seg.data.squeeze().cpu().numpy().astype(np.uint8)
|
||||
img_filename = os.path.basename(sample.img_path)
|
||||
cv2.imwrite(os.path.join(raw_mask_dir, img_filename), pred_mask)
|
||||
|
||||
# 生成并保存三图对比的可视化结果
|
||||
image = mmcv.imread(sample.img_path)
|
||||
visualizer.add_datasample(
|
||||
name=os.path.splitext(img_filename),
|
||||
image=image,
|
||||
data_sample=sample,
|
||||
show=False
|
||||
)
|
||||
|
||||
# --- 4. 计算最终指标 ---
|
||||
metrics_summary = metric.compute_metrics(metric.results)
|
||||
|
||||
# --- 5. 派生 TP, TN, FP, FN ---
|
||||
if not metric.results:
|
||||
logging.error("评估结果为空,无法计算 TP/TN/FP/FN。")
|
||||
return {'summary': metrics_summary, 'per_class': {}}
|
||||
|
||||
# 聚合所有批次的结果
|
||||
total_area_intersect = torch.stack([res for res in metric.results]).sum(0)
|
||||
total_area_union = torch.stack([res[1] for res in metric.results]).sum(0)
|
||||
total_area_pred_label = torch.stack([res[2] for res in metric.results]).sum(0)
|
||||
total_area_label = torch.stack([res[3] for res in metric.results]).sum(0)
|
||||
|
||||
# TP: 预测为正类,实际也为正类 (交集) [4]
|
||||
tp = total_area_intersect.numpy()
|
||||
# FP: 预测为正类,实际为负类 (预测区域 - 交集) [4]
|
||||
fp = (total_area_pred_label - total_area_intersect).numpy()
|
||||
# FN: 预测为负类,实际为正类 (真值区域 - 交集) [4]
|
||||
fn = (total_area_label - total_area_intersect).numpy()
|
||||
|
||||
# TN: 预测为负类,实际也为负类
|
||||
# TN_i = total_valid_pixels - (TP_i + FP_i + FN_i)
|
||||
# total_valid_pixels = sum of all elements in confusion matrix = sum of all ground truth pixels
|
||||
total_valid_pixels = total_area_label.sum().item() * len(val_dataloader.dataset) / sum(len(b['inputs']) for b in val_dataloader) # 估算
|
||||
union = total_area_pred_label + total_area_label - total_area_intersect
|
||||
tn = total_valid_pixels - union.numpy()
|
||||
|
||||
per_class_metrics = {'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn}
|
||||
|
||||
return {'summary': metrics_summary, 'per_class': per_class_metrics}
|
||||
|
||||
def consolidate_and_save_results(metrics_data: Dict, complexity_data: Dict, metainfo: Dict, output_dir: str):
|
||||
"""
|
||||
整合所有量化指标并保存为 CSV 文件。
|
||||
|
||||
Args:
|
||||
metrics_data (Dict): 来自 run_inference_and_evaluation 的指标数据。
|
||||
complexity_data (Dict): 来自 calculate_model_complexity 的复杂度数据。
|
||||
metainfo (Dict): 数据集的元信息,包含类别名。
|
||||
output_dir (str): 保存 CSV 文件的目录。
|
||||
"""
|
||||
num_classes = len(metainfo['classes'])
|
||||
class_names = metainfo['classes']
|
||||
|
||||
rows = []
|
||||
|
||||
# 添加总体指标
|
||||
summary = metrics_data.get('summary', {})
|
||||
rows.append({'metric': 'iou_score', 'class_id': -1, 'class_name': 'Overall', 'value': summary.get('mIoU', 0)})
|
||||
rows.append({'metric': 'f1_score', 'class_id': -1, 'class_name': 'Overall', 'value': summary.get('mFscore', summary.get('mDice', 0))})
|
||||
rows.append({'metric': 'accuracy', 'class_id': -1, 'class_name': 'Overall', 'value': summary.get('aAcc', 0)})
|
||||
rows.append({'metric': 'recall', 'class_id': -1, 'class_name': 'Overall', 'value': summary.get('mRecall', 0)})
|
||||
rows.append({'metric': 'precision', 'class_id': -1, 'class_name': 'Overall', 'value': summary.get('mPrecision', 0)})
|
||||
|
||||
# 添加模型复杂度
|
||||
rows.append({'metric': 'params_M', 'class_id': -1, 'class_name': 'N/A', 'value': complexity_data.get('params_M', 0)})
|
||||
rows.append({'metric': 'flops_G', 'class_id': -1, 'class_name': 'N/A', 'value': complexity_data.get('flops_G', 0)})
|
||||
|
||||
# 添加逐类 TP/TN/FP/FN
|
||||
per_class = metrics_data.get('per_class', {})
|
||||
for metric_name, values in per_class.items():
|
||||
if len(values) == num_classes:
|
||||
for i in range(num_classes):
|
||||
rows.append({
|
||||
'metric': metric_name,
|
||||
'class_id': i,
|
||||
'class_name': class_names[i],
|
||||
'value': values[i]
|
||||
})
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
save_path = os.path.join(output_dir, 'prediction_analysis', 'test_set_metrics.csv')
|
||||
df.to_csv(save_path, index=False)
|
||||
logging.info(f"所有量化指标已保存至: {save_path}")
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
# --- 开始交互式选择修改 (V2 - 两级菜单) ---
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# 1. 定义有效的数据集文件夹白名单
|
||||
VALID_DATASET_FOLDERS = [
|
||||
'1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
|
||||
'3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
|
||||
'3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
|
||||
'4_Dresden-11Type-512x512_outputs-MMSeg'
|
||||
]
|
||||
|
||||
# 2. 查找存在的、有效的数据集目录
|
||||
existing_dataset_dirs = [
|
||||
os.path.join(input_root, d) for d in VALID_DATASET_FOLDERS
|
||||
if os.path.isdir(os.path.join(input_root, d))
|
||||
]
|
||||
|
||||
if not existing_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
|
||||
# 3. 第一级菜单:选择数据集
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = [] # 初始化最终要处理的目录列表
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
# 4. 查找选定数据集下的所有算法子目录
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。程序退出。")
|
||||
else:
|
||||
# 5. 第二级菜单:选择算法
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
|
||||
# 6. 根据第二级选择,最终确定 model_dirs
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]] # 将单个路径放入列表中
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
|
||||
# --- 交互式选择修改结束 ---
|
||||
|
||||
# 修改后的循环,将遍历经过用户筛选后的 model_dirs 列表
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"--- 开始处理模型: {model_name} ---")
|
||||
|
||||
files = find_model_files(model_dir)
|
||||
if not files:
|
||||
logging.warning(f"跳过目录 {model_dir},因为缺少必要文件。")
|
||||
continue
|
||||
|
||||
# 构建输出目录
|
||||
# 从模型名中提取数据集标识作为Key
|
||||
dataset_key = model_name.split('-')[0]
|
||||
dataset_map = {
|
||||
'1_cholecseg8k': '1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_autolaparo': '2_AutoLaparo-10Type-1280x1024_outputs-MMSeg',
|
||||
'3_1_endovis_2017': '3_1_EndoVis_2017-7Type-1280x1024_outputs-MMSeg',
|
||||
'3_2_endovis_2018': '3_2_EndoVis_2018-11Type-1280x1024_outputs-MMSeg',
|
||||
'4_dresden': '4_Dresden-6Type-1920x1080_outputs-MMSeg'
|
||||
}
|
||||
# 使用提取的Key(字符串)进行查询,并为默认值也使用该Key
|
||||
output_dataset_folder = dataset_map.get(dataset_key, f"{dataset_key}_outputs-MMSeg")
|
||||
|
||||
final_output_dir = os.path.join(output_root, output_dataset_folder, model_name)
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
logging.info(f"结果将保存至: {final_output_dir}")
|
||||
|
||||
try:
|
||||
# 加载配置
|
||||
cfg = Config.fromfile(files['config'])
|
||||
|
||||
# 1. 执行推理、评估和可视化
|
||||
metrics_data = run_inference_and_evaluation(cfg, files['checkpoint'], final_output_dir)
|
||||
|
||||
# 2. 分析训练日志
|
||||
analyze_training_log(files['log'], os.path.join(final_output_dir, 'prediction_analysis'))
|
||||
|
||||
# 3. 计算模型复杂度
|
||||
complexity_data = calculate_model_complexity(cfg)
|
||||
|
||||
# 4. 整合并保存所有量化结果
|
||||
val_dataloader = Runner.build_dataloader(cfg.val_dataloader)
|
||||
metainfo = val_dataloader.dataset.metainfo
|
||||
|
||||
consolidate_and_save_results(metrics_data, complexity_data, metainfo, final_output_dir)
|
||||
|
||||
logging.info(f"--- 模型 {model_name} 处理完成 ---")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"处理模型 {model_name} 时发生严重错误: {e}", exc_info=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="MMSegmentation 自动化评估脚本")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含已训练模型文件夹的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果的根目录。"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
613
Seg_All_In_One_MMSeg/My_All_In_One/x4_Predict_V2-.py
Normal file
613
Seg_All_In_One_MMSeg/My_All_In_One/x4_Predict_V2-.py
Normal file
@@ -0,0 +1,613 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MMSegmentation 自动化推理、评估与分析脚本
|
||||
|
||||
本脚本旨在为基于 MMSegmentation 训练的语义分割模型提供一个全面、自动化的后处理流水线。
|
||||
它能够系统地发现指定目录下的训练产物,并为每个模型执行一系列详尽的分析任务,最终生成结构化的报告。
|
||||
|
||||
核心功能:
|
||||
1. 模型发现与检查点选择:
|
||||
- 自动扫描指定输入目录(默认为 './Outputs_mmseg')下的所有子文件夹。
|
||||
- 智能选择用于推理的权重文件:优先选择 'best.pth',若不存在,则选择周期数(epoch)最大的检查点文件。
|
||||
- 定位与模型对应的配置文件(.py)和训练日志文件(.json)。
|
||||
|
||||
2. 验证集推理与结果生成:
|
||||
- 对每个模型的验证集进行推理,并使用 tqdm 显示进度条。
|
||||
- 生成并保存原始的、未经着色的预测掩码图(predicted_raw_masks)。
|
||||
- 生成并保存包含“原始图像-预测图(着色后)-真值图标注(着色后)”的三图对比分析图像(prediction_analysis)。
|
||||
|
||||
3. 综合性能评估:
|
||||
- 计算并记录多项关键评估指标,包括 IoU (mIoU), F1-Score (mFscore/mDice), Accuracy (aAcc/mAcc), Recall (mRecall), 和 Precision (mPrecision)。
|
||||
- 深入计算每个类别的混淆矩阵基本元素:真正例(TP)、真负例(TN)、假正例(FP)、假负例(FN)。这一功能通过对 IoUMetric 内部数据的再处理实现,提供了标准评估流程之外的精细化分析能力。
|
||||
|
||||
4. 模型复杂度分析:
|
||||
- 计算模型的参数量(Parameters)和计算量(FLOPs),以评估其资源消耗。
|
||||
- 该计算基于 MMEngine 提供的底层分析工具,确保了与 MMSegmentation 官方工具的一致性。
|
||||
|
||||
5. 训练过程可视化:
|
||||
- 解析训练日志文件(.json),提取损失(loss)、准确率(accuracy)、交并比(mIoU)等关键指标在训练过程中的变化。
|
||||
- 使用 Matplotlib 和 Seaborn 绘制指标变化曲线图,并保存为图像文件,便于直观分析模型的收敛情况。
|
||||
|
||||
6. 结构化结果输出:
|
||||
- 所有生成的分析结果(指标、图像、日志图表)均被保存在一个结构化的输出目录中(默认为 './BestMode_Predict_Results_DataSet_Public')。
|
||||
- 每个模型的输出都存放在以其原始文件夹命名的子目录中,保持清晰的对应关系。
|
||||
- 最终的量化指标(包括总体指标、模型复杂度、逐类TP/TN/FP/FN)被整合并保存到一个名为 'test_set_metrics.csv' 的文件中,采用长格式(long format)存储,便于后续的数据处理和比较。
|
||||
|
||||
使用方法:
|
||||
直接在终端运行此脚本。可以通过命令行参数指定输入和输出的根目录。
|
||||
python predict.py --input_dir./Outputs_mmseg --output_dir./BestMode_Predict_Results_DataSet_Public
|
||||
|
||||
依赖:
|
||||
- torch, torchvision
|
||||
- mmengine
|
||||
- mmcv
|
||||
- mmsegmentation>=1.0.0
|
||||
- numpy
|
||||
- pandas
|
||||
- matplotlib
|
||||
- seaborn
|
||||
- tqdm
|
||||
- opencv-python
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import glob
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import mmcv
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from tqdm import tqdm
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.analysis import get_model_complexity_info
|
||||
from mmseg.apis import init_model, inference_model
|
||||
from mmseg.evaluation.metrics import IoUMetric
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
# --- 全局配置 ---
|
||||
# 设置日志记录
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# 自动选择设备
|
||||
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
logging.info(f"使用设备: {DEVICE}")
|
||||
|
||||
# --- 辅助函数 ---
|
||||
def find_model_files(model_dir: str) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
在给定的模型目录中查找配置文件、最佳检查点和日志文件。
|
||||
|
||||
Args:
|
||||
model_dir (str): 模型的根目录。
|
||||
|
||||
Returns:
|
||||
Optional]: 包含 'config', 'checkpoint', 'log' 路径的字典,
|
||||
如果缺少任何必要文件,则返回 None。
|
||||
"""
|
||||
config_files = glob.glob(os.path.join(model_dir, '*.py'))
|
||||
if not config_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到配置文件 (.py)。")
|
||||
return None
|
||||
config_path = config_files[0]
|
||||
|
||||
checkpoint_path = os.path.join(model_dir, 'best.pth')
|
||||
if not os.path.exists(checkpoint_path):
|
||||
epoch_files = glob.glob(os.path.join(model_dir, 'epoch_*.pth'))
|
||||
if not epoch_files:
|
||||
logging.warning(f"在目录 {model_dir} 中未找到 'best.pth' 或 'epoch_*.pth' 检查点文件。")
|
||||
return None
|
||||
|
||||
# 通过正则表达式从文件名中提取周期数并找到最大的
|
||||
latest_epoch = -1
|
||||
latest_file = None
|
||||
for f in epoch_files:
|
||||
match = re.search(r'epoch_(\d+)\.pth', os.path.basename(f))
|
||||
if match:
|
||||
epoch_num = int(match.group(1))
|
||||
if epoch_num > latest_epoch:
|
||||
latest_epoch = epoch_num
|
||||
latest_file = f
|
||||
|
||||
if latest_file:
|
||||
checkpoint_path = latest_file
|
||||
else:
|
||||
logging.warning(f"在目录 {model_dir} 中无法确定最新的检查点文件。")
|
||||
return None
|
||||
|
||||
|
||||
# log_files = glob.glob(os.path.join(model_dir, 'vis_data', '*.json'))
|
||||
# if not log_files:
|
||||
# logging.warning(f"在目录 {model_dir}/vis_data/ 中未找到日志文件 (.json)。")
|
||||
# return None
|
||||
# # 通常只有一个json日志文件,取第一个
|
||||
# log_path = log_files
|
||||
|
||||
# return {'config': config_path, 'checkpoint': checkpoint_path, 'log': log_path}
|
||||
return {'config': config_path, 'checkpoint': checkpoint_path}
|
||||
|
||||
def calculate_model_complexity(cfg: Config) -> Dict[str, float]:
|
||||
"""
|
||||
计算模型的 FLOPs 和参数量。
|
||||
|
||||
Args:
|
||||
cfg (Config): 加载的模型配置对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: 包含 'flops_G' 和 'params_M' 的字典。
|
||||
"""
|
||||
try:
|
||||
model = MODELS.build(cfg.model)
|
||||
model.eval()
|
||||
|
||||
# 从配置中获取输入尺寸,如果不存在则使用默认值
|
||||
if hasattr(cfg, 'crop_size'):
|
||||
input_shape = (3, *cfg.crop_size)
|
||||
else:
|
||||
# 尝试从test_pipeline中查找
|
||||
input_shape = None
|
||||
if hasattr(cfg, 'test_pipeline'):
|
||||
for transform in cfg.test_pipeline:
|
||||
if transform['type'] == 'Resize':
|
||||
input_shape = (3, transform['scale'][1], transform['scale'])
|
||||
break
|
||||
if input_shape is None:
|
||||
input_shape = (3, 512, 512)
|
||||
logging.info(f"配置中未找到 'crop_size' 或 'Resize',使用默认输入尺寸 {input_shape} 进行 FLOPs 计算。")
|
||||
|
||||
# 使用 MMEngine 的内置工具进行分析
|
||||
result = get_model_complexity_info(model, input_shape)
|
||||
|
||||
flops_str = result.get('flops_str', '0.0G')
|
||||
params_str = result.get('params_str', '0.0M')
|
||||
|
||||
# 解析字符串以获取数值
|
||||
flops_g = float(re.search(r'(\d+\.?\d*)', flops_str).group(1))
|
||||
params_m = float(re.search(r'(\d+\.?\d*)', params_str).group(1))
|
||||
|
||||
# 单位转换
|
||||
if 'G' not in flops_str.upper():
|
||||
flops_g /= 1e3 # 假设是 MMac
|
||||
if 'M' not in params_str.upper():
|
||||
params_m /= 1e6 # 假设是 K
|
||||
|
||||
return {'flops_G': flops_g, 'params_M': params_m}
|
||||
except Exception as e:
|
||||
logging.error(f"计算模型复杂度时出错: {e}")
|
||||
return {'flops_G': 0.0, 'params_M': 0.0}
|
||||
|
||||
def analyze_training_log(log_path: str, output_dir: str):
|
||||
"""
|
||||
解析训练日志文件并绘制指标曲线。
|
||||
|
||||
Args:
|
||||
log_path (str): 训练日志文件 (.json) 的路径。
|
||||
output_dir (str): 保存绘图的目录。
|
||||
"""
|
||||
try:
|
||||
log_data = []
|
||||
with open(log_path, 'r') as f:
|
||||
for line in f:
|
||||
log_data.append(json.loads(line))
|
||||
|
||||
df = pd.DataFrame(log_data)
|
||||
|
||||
# 提取关键指标
|
||||
metrics_to_plot = {
|
||||
'mIoU': 'mIoU',
|
||||
'mAcc': 'mAcc',
|
||||
'aAcc': 'aAcc',
|
||||
'loss': 'loss'
|
||||
}
|
||||
|
||||
# 筛选出验证和训练周期的行
|
||||
val_df = df[df['mode'] == 'val'].dropna(subset=['step'])
|
||||
train_df = df[df['mode'] == 'train'].dropna(subset=['step'])
|
||||
|
||||
sns.set_theme(style="whitegrid")
|
||||
|
||||
for key, name in metrics_to_plot.items():
|
||||
plt.figure(figsize=(12, 6))
|
||||
|
||||
if key in val_df.columns:
|
||||
sns.lineplot(data=val_df, x='step', y=key, label=f'Validation {name}')
|
||||
|
||||
if key in train_df.columns and key == 'loss':
|
||||
# 训练 loss 通常波动较大,可以进行平滑处理
|
||||
train_df[f'{key}_smooth'] = train_df[key].rolling(window=50, min_periods=1).mean()
|
||||
sns.lineplot(data=train_df, x='step', y=f'{key}_smooth', label=f'Training {name} (Smoothed)')
|
||||
|
||||
plt.title(f'{name} Curve during Training')
|
||||
plt.xlabel('Training Steps')
|
||||
plt.ylabel(name)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
save_path = os.path.join(output_dir, f'{name}_curve.png')
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
logging.info(f"已保存日志图表: {save_path}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"分析训练日志 {log_path} 时出错: {e}")
|
||||
|
||||
def run_inference_and_evaluation(cfg: Config, checkpoint_path: str, output_dir: str) -> Dict:
|
||||
"""
|
||||
执行模型推理、评估和可视化。
|
||||
|
||||
该函数通过一次性遍历验证数据集来高效地完成所有任务,避免了重复的数据加载和模型前向传播。
|
||||
|
||||
Args:
|
||||
cfg (Config): 加载的模型配置对象。
|
||||
checkpoint_path (str): 模型检查点文件的路径。
|
||||
output_dir (str): 保存所有输出的根目录。
|
||||
|
||||
Returns:
|
||||
Dict: 包含评估指标和逐类 TP/TN/FP/FN 计数的字典。
|
||||
"""
|
||||
# --- 1. 初始化 ---
|
||||
model = init_model(cfg, checkpoint_path, device=DEVICE)
|
||||
|
||||
# 以编程方式构建数据加载器,确保与训练时的数据预处理流程一致
|
||||
# V1.
|
||||
# train_dataloader = dict(
|
||||
# batch_size=1,
|
||||
# dataset=dict(
|
||||
# data_prefix=dict(img_path='images/train', seg_map_path='labels_GT/train'),
|
||||
# data_root=
|
||||
# '/home/wkmgc/Desktop/Seg/DataSet_Public/3_1_Endovis_2017-8Type-512x512',
|
||||
# pipeline=[
|
||||
# dict(type='LoadImageFromFile'),
|
||||
# dict(keep_ratio=True, scale=(
|
||||
# 512,
|
||||
# 512,
|
||||
# ), type='Resize'),
|
||||
# dict(type='LoadAnnotations'),
|
||||
# dict(type='PackSegInputs'),
|
||||
# ],
|
||||
# type='PublicDataSet_Endovis_2017'),
|
||||
# num_workers=4,
|
||||
# persistent_workers=True,
|
||||
# sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
# val_dataloader = Runner.build_dataloader(train_dataloader) # TODO cfg.val_dataloader 修改为 cfg.train_dataloader
|
||||
|
||||
# V2.
|
||||
val_dataloader = Runner.build_dataloader(cfg.val_dataloader)
|
||||
|
||||
# 实例化评估器和可视化器
|
||||
metric = IoUMetric(iou_metrics=['mIoU'])
|
||||
visualizer = SegLocalVisualizer(save_dir=os.path.join(output_dir, 'prediction_analysis'))
|
||||
|
||||
# 关键步骤:为可视化器设置数据集元信息(类别名和调色板)
|
||||
visualizer.dataset_meta = val_dataloader.dataset.metainfo
|
||||
|
||||
metric.dataset_meta = val_dataloader.dataset.metainfo
|
||||
|
||||
# --- 2. 创建输出子目录 ---
|
||||
raw_mask_dir = os.path.join(output_dir, 'predicted_raw_masks')
|
||||
viz_dir = os.path.join(output_dir, 'prediction_analysis')
|
||||
os.makedirs(raw_mask_dir, exist_ok=True)
|
||||
os.makedirs(viz_dir, exist_ok=True)
|
||||
|
||||
# --- 3. 推理循环 ---
|
||||
model.eval()
|
||||
metric.results = [] # 清空历史结果
|
||||
with torch.no_grad():
|
||||
for data in tqdm(val_dataloader, desc=f"推理与评估: {os.path.basename(output_dir)}"):
|
||||
# 将数据移动到指定设备
|
||||
# inputs = data['inputs'].to(DEVICE) # 有误
|
||||
inputs = data['inputs'][0].to(DEVICE).float().unsqueeze(0)
|
||||
data_samples =data['data_samples']
|
||||
|
||||
# 执行推理
|
||||
result = model(inputs, data_samples=data_samples, mode='predict')
|
||||
|
||||
# 将预测结果合并到 GT 样本中,以便于评估和可视化
|
||||
for i in range(len(data_samples)):
|
||||
data_samples[i].pred_sem_seg = result[i].pred_sem_seg
|
||||
|
||||
# 使用评估器的 process 方法处理一个批次的结果
|
||||
predictions_as_dicts = [r.to_dict() for r in result]
|
||||
metric.process(data, predictions_as_dicts)
|
||||
|
||||
# 保存原始预测掩码和可视化结果
|
||||
for sample in data_samples:
|
||||
pred_mask = sample.pred_sem_seg.data.squeeze().cpu().numpy().astype(np.uint8)
|
||||
img_filename = os.path.basename(sample.img_path)
|
||||
cv2.imwrite(os.path.join(raw_mask_dir, img_filename), pred_mask)
|
||||
|
||||
# --- 开始修改:手动生成并拼接三图对比的可视化结果 ---
|
||||
# 1. 读取原始图像
|
||||
original_img = mmcv.imread(sample.img_path)
|
||||
# 2. 生成着色后的预测图
|
||||
# 从 sample 中获取预测结果的张量
|
||||
pred_mask_tensor = sample.pred_sem_seg.data
|
||||
# 调用可视化器的内部方法进行绘制
|
||||
predicted_img_colored = visualizer._draw_sem_seg(
|
||||
image=original_img.copy(),
|
||||
sem_seg=pred_mask_tensor,
|
||||
classes=visualizer.dataset_meta.get('classes'),
|
||||
palette=visualizer.dataset_meta.get('palette')
|
||||
)
|
||||
# 3. 生成着色后的真值图
|
||||
# 从 sample 中获取真值标签的张量
|
||||
gt_mask_tensor = sample.gt_sem_seg.data
|
||||
# 调用可视化器的内部方法进行绘制
|
||||
gt_img_colored = visualizer._draw_sem_seg(
|
||||
image=original_img.copy(),
|
||||
sem_seg=gt_mask_tensor,
|
||||
classes=visualizer.dataset_meta.get('classes'),
|
||||
palette=visualizer.dataset_meta.get('palette')
|
||||
)
|
||||
# 4. 将三张图水平拼接
|
||||
# 确保所有图像的数据类型一致,以便拼接
|
||||
original_img = original_img.astype(np.uint8)
|
||||
predicted_img_colored = predicted_img_colored.astype(np.uint8)
|
||||
gt_img_colored = gt_img_colored.astype(np.uint8)
|
||||
# 使用 numpy.hstack 进行水平拼接
|
||||
comparison_image = np.hstack([original_img, predicted_img_colored, gt_img_colored])
|
||||
# 5. 保存拼接后的图像
|
||||
save_path = os.path.join(viz_dir, img_filename)
|
||||
cv2.imwrite(save_path, comparison_image)
|
||||
# --- 修改结束 ---
|
||||
|
||||
# --- 4. 计算最终指标 ---
|
||||
metrics_summary = metric.compute_metrics(metric.results)
|
||||
|
||||
# --- 5. 派生 TP, TN, FP, FN ---
|
||||
if not metric.results:
|
||||
logging.error("评估结果为空,无法计算 TP/TN/FP/FN。")
|
||||
return {'summary': metrics_summary, 'per_class': {}}
|
||||
|
||||
# 聚合所有批次的结果
|
||||
total_area_intersect = torch.stack([res for res in metric.results]).sum(0)
|
||||
total_area_union = torch.stack([res[1] for res in metric.results]).sum(0)
|
||||
total_area_pred_label = torch.stack([res[2] for res in metric.results]).sum(0)
|
||||
total_area_label = torch.stack([res[3] for res in metric.results]).sum(0)
|
||||
|
||||
# TP: 预测为正类,实际也为正类 (交集) [4]
|
||||
tp = total_area_intersect.numpy()
|
||||
# FP: 预测为正类,实际为负类 (预测区域 - 交集) [4]
|
||||
fp = (total_area_pred_label - total_area_intersect).numpy()
|
||||
# FN: 预测为负类,实际为正类 (真值区域 - 交集) [4]
|
||||
fn = (total_area_label - total_area_intersect).numpy()
|
||||
|
||||
# TN: 预测为负类,实际也为负类
|
||||
# TN_i = total_valid_pixels - (TP_i + FP_i + FN_i)
|
||||
# total_valid_pixels = sum of all elements in confusion matrix = sum of all ground truth pixels
|
||||
total_valid_pixels = total_area_label.sum().item() * len(val_dataloader.dataset) / sum(len(b['inputs']) for b in val_dataloader) # 估算
|
||||
union = total_area_pred_label + total_area_label - total_area_intersect
|
||||
tn = total_valid_pixels - union.numpy()
|
||||
|
||||
per_class_metrics = {'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn}
|
||||
|
||||
return {'summary': metrics_summary, 'per_class': per_class_metrics}
|
||||
|
||||
def consolidate_and_save_results(metrics_data: Dict, complexity_data: Dict, metainfo: Dict, output_dir: str):
|
||||
"""
|
||||
整合所有量化指标并保存为 CSV 文件。
|
||||
|
||||
Args:
|
||||
metrics_data (Dict): 来自 run_inference_and_evaluation 的指标数据。
|
||||
complexity_data (Dict): 来自 calculate_model_complexity 的复杂度数据。
|
||||
metainfo (Dict): 数据集的元信息,包含类别名。
|
||||
output_dir (str): 保存 CSV 文件的目录。
|
||||
"""
|
||||
num_classes = len(metainfo['classes'])
|
||||
class_names = metainfo['classes']
|
||||
|
||||
rows = []
|
||||
|
||||
# 添加总体指标
|
||||
summary = metrics_data.get('summary', {})
|
||||
rows.append({'metric': 'iou_score', 'class_id': -1, 'class_name': 'Overall', 'value': summary.get('mIoU', 0)})
|
||||
rows.append({'metric': 'f1_score', 'class_id': -1, 'class_name': 'Overall', 'value': summary.get('mFscore', summary.get('mDice', 0))})
|
||||
rows.append({'metric': 'accuracy', 'class_id': -1, 'class_name': 'Overall', 'value': summary.get('aAcc', 0)})
|
||||
rows.append({'metric': 'recall', 'class_id': -1, 'class_name': 'Overall', 'value': summary.get('mRecall', 0)})
|
||||
rows.append({'metric': 'precision', 'class_id': -1, 'class_name': 'Overall', 'value': summary.get('mPrecision', 0)})
|
||||
|
||||
# 添加模型复杂度
|
||||
rows.append({'metric': 'params_M', 'class_id': -1, 'class_name': 'N/A', 'value': complexity_data.get('params_M', 0)})
|
||||
rows.append({'metric': 'flops_G', 'class_id': -1, 'class_name': 'N/A', 'value': complexity_data.get('flops_G', 0)})
|
||||
|
||||
# 添加逐类 TP/TN/FP/FN
|
||||
per_class = metrics_data.get('per_class', {})
|
||||
for metric_name, values in per_class.items():
|
||||
if len(values) == num_classes:
|
||||
for i in range(num_classes):
|
||||
rows.append({
|
||||
'metric': metric_name,
|
||||
'class_id': i,
|
||||
'class_name': class_names[i],
|
||||
'value': values[i]
|
||||
})
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
save_path = os.path.join(output_dir, 'prediction_analysis', 'test_set_metrics.csv')
|
||||
df.to_csv(save_path, index=False)
|
||||
logging.info(f"所有量化指标已保存至: {save_path}")
|
||||
|
||||
# --- 主函数 ---
|
||||
def main(args):
|
||||
"""
|
||||
脚本主入口,负责编排整个自动化分析流程。
|
||||
"""
|
||||
input_root = args.input_dir
|
||||
output_root = args.output_dir
|
||||
# --- 开始交互式选择修改 (V2 - 两级菜单) ---
|
||||
if not os.path.isdir(input_root):
|
||||
logging.error(f"输入目录不存在: {input_root}")
|
||||
return
|
||||
|
||||
# 1. 定义有效的数据集文件夹白名单
|
||||
VALID_DATASET_FOLDERS = [
|
||||
'1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_AutoLaparo-10Type-1920x1080_outputs-MMSeg',
|
||||
'3_1_Endovis_2017-8Type-512x512_outputs-MMSeg',
|
||||
'3_2_Endovis_2018-8Type-512x512_outputs-MMSeg',
|
||||
'4_Dresden-11Type-512x512_outputs-MMSeg'
|
||||
]
|
||||
|
||||
# 2. 查找存在的、有效的数据集目录
|
||||
existing_dataset_dirs = [
|
||||
os.path.join(input_root, d) for d in VALID_DATASET_FOLDERS
|
||||
if os.path.isdir(os.path.join(input_root, d))
|
||||
]
|
||||
|
||||
if not existing_dataset_dirs:
|
||||
logging.error(f"在输入目录 {input_root} 中未找到任何有效的数据集文件夹。")
|
||||
return
|
||||
|
||||
# 3. 第一级菜单:选择数据集
|
||||
dataset_map = {str(i + 1): path for i, path in enumerate(existing_dataset_dirs)}
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 1: 请选择要处理的数据集 ---")
|
||||
for key, path in dataset_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice1 = input("请输入数据集编号并按回车键: ").strip()
|
||||
|
||||
model_dirs = [] # 初始化最终要处理的目录列表
|
||||
|
||||
if choice1 in dataset_map:
|
||||
selected_dataset_dir = dataset_map[choice1]
|
||||
logging.info(f"您已选择数据集: [{os.path.basename(selected_dataset_dir)}]")
|
||||
|
||||
# 4. 查找选定数据集下的所有算法子目录
|
||||
alg_dirs = sorted([
|
||||
d for d in glob.glob(os.path.join(selected_dataset_dir, '*')) if os.path.isdir(d)
|
||||
])
|
||||
|
||||
if not alg_dirs:
|
||||
logging.warning(f"在 {os.path.basename(selected_dataset_dir)} 中未发现任何算法子文件夹。程序退出。")
|
||||
else:
|
||||
# 5. 第二级菜单:选择算法
|
||||
alg_map = {str(i + 1): path for i, path in enumerate(alg_dirs)}
|
||||
print("\n" + "="*50)
|
||||
print("--- 步骤 2: 请选择要处理的算法 ---")
|
||||
print("0: 批量处理当前数据集下的【全部】算法")
|
||||
for key, path in alg_map.items():
|
||||
print(f"{key}: {os.path.basename(path)}")
|
||||
print("="*50)
|
||||
|
||||
choice2 = input("请输入算法编号 (或输入 '0' 处理全部) 并按回车键: ").strip()
|
||||
|
||||
# 6. 根据第二级选择,最终确定 model_dirs
|
||||
if choice2 == '0':
|
||||
model_dirs = alg_dirs
|
||||
logging.info(f"您选择了批量处理全部 {len(model_dirs)} 个算法。")
|
||||
elif choice2 in alg_map:
|
||||
model_dirs = [alg_map[choice2]] # 将单个路径放入列表中
|
||||
logging.info(f"您选择了处理单个算法: {os.path.basename(model_dirs[0])}")
|
||||
else:
|
||||
logging.error("无效的算法选择,程序已退出。")
|
||||
else:
|
||||
logging.error("无效的数据集选择,程序已退出。")
|
||||
|
||||
# --- 交互式选择修改结束 ---
|
||||
|
||||
# 修改后的循环,将遍历经过用户筛选后的 model_dirs 列表
|
||||
for model_dir in model_dirs:
|
||||
model_name = os.path.basename(model_dir)
|
||||
logging.info(f"--- 开始处理模型: {model_name} ---")
|
||||
|
||||
files = find_model_files(model_dir)
|
||||
if not files:
|
||||
logging.warning(f"跳过目录 {model_dir},因为缺少必要文件。")
|
||||
continue
|
||||
|
||||
# 构建输出目录
|
||||
# 从模型名中提取数据集标识作为Key
|
||||
dataset_key = model_name.split('-')[0]
|
||||
dataset_map = {
|
||||
'1_cholecseg8k': '1_CholecSeg8k-13Type-1920x1080_outputs-MMSeg',
|
||||
'2_autolaparo': '2_AutoLaparo-10Type-1280x1024_outputs-MMSeg',
|
||||
'3_1_endovis_2017': '3_1_EndoVis_2017-7Type-1280x1024_outputs-MMSeg',
|
||||
'3_2_endovis_2018': '3_2_EndoVis_2018-11Type-1280x1024_outputs-MMSeg',
|
||||
'4_dresden': '4_Dresden-6Type-1920x1080_outputs-MMSeg'
|
||||
}
|
||||
# 使用提取的Key(字符串)进行查询,并为默认值也使用该Key
|
||||
output_dataset_folder = dataset_map.get(dataset_key, f"{dataset_key}_outputs-MMSeg")
|
||||
|
||||
final_output_dir = os.path.join(output_root, output_dataset_folder, model_name)
|
||||
os.makedirs(final_output_dir, exist_ok=True)
|
||||
logging.info(f"结果将保存至: {final_output_dir}")
|
||||
|
||||
try:
|
||||
# 加载配置
|
||||
cfg = Config.fromfile(files['config'])
|
||||
|
||||
# 1. 执行推理、评估和可视化
|
||||
metrics_data = run_inference_and_evaluation(cfg, files['checkpoint'], final_output_dir)
|
||||
|
||||
# 2. 分析训练日志
|
||||
# analyze_training_log(files['log'], os.path.join(final_output_dir, 'prediction_analysis'))
|
||||
|
||||
# 3. 计算模型复杂度
|
||||
complexity_data = calculate_model_complexity(cfg)
|
||||
|
||||
# 4. 整合并保存所有量化结果
|
||||
# V1.
|
||||
# train_dataloader = dict(
|
||||
# batch_size=1,
|
||||
# dataset=dict(
|
||||
# data_prefix=dict(img_path='images/train', seg_map_path='labels_GT/train'),
|
||||
# data_root=
|
||||
# '/home/wkmgc/Desktop/Seg/DataSet_Public/3_1_Endovis_2017-8Type-512x512',
|
||||
# pipeline=[
|
||||
# dict(type='LoadImageFromFile'),
|
||||
# dict(keep_ratio=True, scale=(
|
||||
# 512,
|
||||
# 512,
|
||||
# ), type='Resize'),
|
||||
# dict(type='LoadAnnotations'),
|
||||
# dict(type='PackSegInputs'),
|
||||
# ],
|
||||
# type='PublicDataSet_Endovis_2017'),
|
||||
# num_workers=4,
|
||||
# persistent_workers=True,
|
||||
# sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
# val_dataloader = Runner.build_dataloader(train_dataloader) # TODO cfg.val_dataloader 修改为 cfg.train_dataloader
|
||||
|
||||
# V2.
|
||||
val_dataloader = Runner.build_dataloader(cfg.val_dataloader)
|
||||
|
||||
metainfo = val_dataloader.dataset.metainfo
|
||||
|
||||
consolidate_and_save_results(metrics_data, complexity_data, metainfo, final_output_dir)
|
||||
|
||||
logging.info(f"--- 模型 {model_name} 处理完成 ---")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"处理模型 {model_name} 时发生严重错误: {e}", exc_info=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="MMSegmentation 自动化评估脚本")
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='../Hardisk',
|
||||
help="包含已训练模型文件夹的根目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='../BestMode_Predict_Results_DataSet_Public',
|
||||
help="用于存储所有分析结果的根目录。"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
420
Seg_All_In_One_MMSeg/README.md
Normal file
420
Seg_All_In_One_MMSeg/README.md
Normal file
@@ -0,0 +1,420 @@
|
||||
<div align="center">
|
||||
<img src="resources/mmseg-logo.png" width="600"/>
|
||||
<div> </div>
|
||||
<div align="center">
|
||||
<b><font size="5">OpenMMLab website</font></b>
|
||||
<sup>
|
||||
<a href="https://openmmlab.com">
|
||||
<i><font size="4">HOT</font></i>
|
||||
</a>
|
||||
</sup>
|
||||
|
||||
<b><font size="5">OpenMMLab platform</font></b>
|
||||
<sup>
|
||||
<a href="https://platform.openmmlab.com">
|
||||
<i><font size="4">TRY IT OUT</font></i>
|
||||
</a>
|
||||
</sup>
|
||||
</div>
|
||||
<div> </div>
|
||||
|
||||
[](https://pypi.org/project/mmsegmentation/)
|
||||
[](https://pypi.org/project/mmsegmentation)
|
||||
[](https://mmsegmentation.readthedocs.io/en/latest/)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/actions)
|
||||
[](https://codecov.io/gh/open-mmlab/mmsegmentation)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/blob/main/LICENSE)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/issues)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/issues)
|
||||
[](https://openxlab.org.cn/apps?search=mmseg)
|
||||
|
||||
Documentation: <https://mmsegmentation.readthedocs.io/en/latest/>
|
||||
|
||||
English | [简体中文](README_zh-CN.md)
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<a href="https://openmmlab.medium.com/" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219255827-67c1a27f-f8c5-46a9-811d-5e57448c61d1.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://discord.gg/raweFPmdzG" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://twitter.com/OpenMMLab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346637-d30c8a0f-3eba-4699-8131-512fb06d46db.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://www.youtube.com/openmmlab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346691-ceb2116a-465a-40af-8424-9f30d2348ca9.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://space.bilibili.com/1293512903" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219026751-d7d14cce-a7c9-4e82-9942-8375fca65b99.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://www.zhihu.com/people/openmmlab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219026120-ba71e48b-6e94-4bd4-b4e9-b7d175b5e362.png" width="3%" alt="" /></a>
|
||||
</div>
|
||||
|
||||
## Introduction
|
||||
|
||||
MMSegmentation is an open source semantic segmentation toolbox based on PyTorch.
|
||||
It is a part of the OpenMMLab project.
|
||||
|
||||
The [main](https://github.com/open-mmlab/mmsegmentation/tree/main) branch works with PyTorch 1.6+.
|
||||
|
||||
### 🎉 Introducing MMSegmentation v1.0.0 🎉
|
||||
|
||||
We are thrilled to announce the official release of MMSegmentation's latest version! For this new release, the [main](https://github.com/open-mmlab/mmsegmentation/tree/main) branch serves as the primary branch, while the development branch is [dev-1.x](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x). The stable branch for the previous release remains as the [0.x](https://github.com/open-mmlab/mmsegmentation/tree/0.x) branch. Please note that the [master](https://github.com/open-mmlab/mmsegmentation/tree/master) branch will only be maintained for a limited time before being removed. We encourage you to be mindful of branch selection and updates during use. Thank you for your unwavering support and enthusiasm, and let's work together to make MMSegmentation even more robust and powerful! 💪
|
||||
|
||||
MMSegmentation v1.x brings remarkable improvements over the 0.x release, offering a more flexible and feature-packed experience. To utilize the new features in v1.x, we kindly invite you to consult our detailed [📚 migration guide](https://mmsegmentation.readthedocs.io/en/latest/migration/interface.html), which will help you seamlessly transition your projects. Your support is invaluable, and we eagerly await your feedback!
|
||||
|
||||

|
||||
|
||||
### Major features
|
||||
|
||||
- **Unified Benchmark**
|
||||
|
||||
We provide a unified benchmark toolbox for various semantic segmentation methods.
|
||||
|
||||
- **Modular Design**
|
||||
|
||||
We decompose the semantic segmentation framework into different components and one can easily construct a customized semantic segmentation framework by combining different modules.
|
||||
|
||||
- **Support of multiple methods out of box**
|
||||
|
||||
The toolbox directly supports popular and contemporary semantic segmentation frameworks, *e.g.* PSPNet, DeepLabV3, PSANet, DeepLabV3+, etc.
|
||||
|
||||
- **High efficiency**
|
||||
|
||||
The training speed is faster than or comparable to other codebases.
|
||||
|
||||
## What's New
|
||||
|
||||
v1.2.0 was released on 10/12/2023, from 1.1.0 to 1.2.0, we have added or updated the following features:
|
||||
|
||||
### Highlights
|
||||
|
||||
- Support for the open-vocabulary semantic segmentation algorithm [SAN](configs/san/README.md)
|
||||
|
||||
- Support monocular depth estimation task, please refer to [VPD](configs/vpd/README.md) and [Adabins](projects/Adabins/README.md) for more details.
|
||||
|
||||

|
||||
|
||||
- Add new projects: open-vocabulary semantic segmentation algorithm [CAT-Seg](projects/CAT-Seg/README.md), real-time semantic segmentation algofithm [PP-MobileSeg](projects/pp_mobileseg/README.md)
|
||||
|
||||
## Installation
|
||||
|
||||
Please refer to [get_started.md](docs/en/get_started.md#installation) for installation and [dataset_prepare.md](docs/en/user_guides/2_dataset_prepare.md#prepare-datasets) for dataset preparation.
|
||||
|
||||
## Get Started
|
||||
|
||||
Please see [Overview](docs/en/overview.md) for the general introduction of MMSegmentation.
|
||||
|
||||
Please see [user guides](https://mmsegmentation.readthedocs.io/en/latest/user_guides/index.html#) for the basic usage of MMSegmentation.
|
||||
There are also [advanced tutorials](https://mmsegmentation.readthedocs.io/en/latest/advanced_guides/index.html) for in-depth understanding of mmseg design and implementation .
|
||||
|
||||
A Colab tutorial is also provided. You may preview the notebook [here](demo/MMSegmentation_Tutorial.ipynb) or directly [run](https://colab.research.google.com/github/open-mmlab/mmsegmentation/blob/main/demo/MMSegmentation_Tutorial.ipynb) on Colab.
|
||||
|
||||
To migrate from MMSegmentation 0.x, please refer to [migration](docs/en/migration).
|
||||
|
||||
## Tutorial
|
||||
|
||||
<div align="center">
|
||||
<b>MMSegmentation Tutorials</b>
|
||||
</div>
|
||||
<table align="center">
|
||||
<tbody>
|
||||
<tr align="center" valign="center">
|
||||
<td>
|
||||
<b>Get Started</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>MMSeg Basic Tutorial</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>MMSeg Detail Tutorial</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>MMSeg Development Tutorial</b>
|
||||
</td>
|
||||
</tr>
|
||||
<tr valign="top">
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="docs/en/overview.md">MMSeg overview</a></li>
|
||||
<li><a href="docs/en/get_started.md">MMSeg Installation</a></li>
|
||||
<li><a href="docs/en/notes/faq.md">FAQ</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="docs/en/user_guides/1_config.md">Tutorial 1: Learn about Configs</a></li>
|
||||
<li><a href="docs/en/user_guides/2_dataset_prepare.md">Tutorial 2: Prepare datasets</a></li>
|
||||
<li><a href="docs/en/user_guides/3_inference.md">Tutorial 3: Inference with existing models</a></li>
|
||||
<li><a href="docs/en/user_guides/4_train_test.md">Tutorial 4: Train and test with existing models</a></li>
|
||||
<li><a href="docs/en/user_guides/5_deployment.md">Tutorial 5: Model deployment</a></li>
|
||||
<li><a href="docs/zh_cn/user_guides/deploy_jetson.md">Deploy mmsegmentation on Jetson platform</a></li>
|
||||
<li><a href="docs/en/user_guides/useful_tools.md">Useful Tools</a></li>
|
||||
<li><a href="docs/en/user_guides/visualization_feature_map.md">Feature Map Visualization</a></li>
|
||||
<li><a href="docs/en/user_guides/visualization.md">Visualization</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="docs/en/advanced_guides/datasets.md">MMSeg Dataset</a></li>
|
||||
<li><a href="docs/en/advanced_guides/models.md">MMSeg Models</a></li>
|
||||
<li><a href="docs/en/advanced_guides/structures.md">MMSeg Dataset Structures</a></li>
|
||||
<li><a href="docs/en/advanced_guides/transforms.md">MMSeg Data Transforms</a></li>
|
||||
<li><a href="docs/en/advanced_guides/data_flow.md">MMSeg Dataflow</a></li>
|
||||
<li><a href="docs/en/advanced_guides/engine.md">MMSeg Training Engine</a></li>
|
||||
<li><a href="docs/en/advanced_guides/evaluation.md">MMSeg Evaluation</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="docs/en/advanced_guides/add_datasets.md">Add New Datasets</a></li>
|
||||
<li><a href="docs/en/advanced_guides/add_metrics.md">Add New Metrics</a></li>
|
||||
<li><a href="docs/en/advanced_guides/add_models.md">Add New Modules</a></li>
|
||||
<li><a href="docs/en/advanced_guides/add_transforms.md">Add New Data Transforms</a></li>
|
||||
<li><a href="docs/en/advanced_guides/customize_runtime.md">Customize Runtime Settings</a></li>
|
||||
<li><a href="docs/en/advanced_guides/training_tricks.md">Training Tricks</a></li>
|
||||
<li><a href=".github/CONTRIBUTING.md">Contribute code to MMSeg</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/contribute_dataset.md">Contribute a standard dataset in projects</a></li>
|
||||
<li><a href="docs/en/device/npu.md">NPU (HUAWEI Ascend)</a></li>
|
||||
<li><a href="docs/en/migration/interface.md">0.x → 1.x migration</a></li>
|
||||
<li><a href="docs/en/migration/package.md">0.x → 1.x package</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## Benchmark and model zoo
|
||||
|
||||
Results and models are available in the [model zoo](docs/en/model_zoo.md).
|
||||
|
||||
<div align="center">
|
||||
<b>Overview</b>
|
||||
</div>
|
||||
<table align="center">
|
||||
<tbody>
|
||||
<tr align="center" valign="center">
|
||||
<td>
|
||||
<b>Supported backbones</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>Supported methods</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>Supported Head</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>Supported datasets</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>Other</b>
|
||||
</td>
|
||||
</tr>
|
||||
<tr valign="top">
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="mmseg/models/backbones/resnet.py">ResNet(CVPR'2016)</a></li>
|
||||
<li><a href="mmseg/models/backbones/resnext.py">ResNeXt (CVPR'2017)</a></li>
|
||||
<li><a href="configs/hrnet">HRNet (CVPR'2019)</a></li>
|
||||
<li><a href="configs/resnest">ResNeSt (ArXiv'2020)</a></li>
|
||||
<li><a href="configs/mobilenet_v2">MobileNetV2 (CVPR'2018)</a></li>
|
||||
<li><a href="configs/mobilenet_v3">MobileNetV3 (ICCV'2019)</a></li>
|
||||
<li><a href="configs/vit">Vision Transformer (ICLR'2021)</a></li>
|
||||
<li><a href="configs/swin">Swin Transformer (ICCV'2021)</a></li>
|
||||
<li><a href="configs/twins">Twins (NeurIPS'2021)</a></li>
|
||||
<li><a href="configs/beit">BEiT (ICLR'2022)</a></li>
|
||||
<li><a href="configs/convnext">ConvNeXt (CVPR'2022)</a></li>
|
||||
<li><a href="configs/mae">MAE (CVPR'2022)</a></li>
|
||||
<li><a href="configs/poolformer">PoolFormer (CVPR'2022)</a></li>
|
||||
<li><a href="configs/segnext">SegNeXt (NeurIPS'2022)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="configs/san/">SAN (CVPR'2023)</a></li>
|
||||
<li><a href="configs/vpd">VPD (ICCV'2023)</a></li>
|
||||
<li><a href="configs/ddrnet">DDRNet (T-ITS'2022)</a></li>
|
||||
<li><a href="configs/pidnet">PIDNet (ArXiv'2022)</a></li>
|
||||
<li><a href="configs/mask2former">Mask2Former (CVPR'2022)</a></li>
|
||||
<li><a href="configs/maskformer">MaskFormer (NeurIPS'2021)</a></li>
|
||||
<li><a href="configs/knet">K-Net (NeurIPS'2021)</a></li>
|
||||
<li><a href="configs/segformer">SegFormer (NeurIPS'2021)</a></li>
|
||||
<li><a href="configs/segmenter">Segmenter (ICCV'2021)</a></li>
|
||||
<li><a href="configs/dpt">DPT (ArXiv'2021)</a></li>
|
||||
<li><a href="configs/setr">SETR (CVPR'2021)</a></li>
|
||||
<li><a href="configs/stdc">STDC (CVPR'2021)</a></li>
|
||||
<li><a href="configs/bisenetv2">BiSeNetV2 (IJCV'2021)</a></li>
|
||||
<li><a href="configs/cgnet">CGNet (TIP'2020)</a></li>
|
||||
<li><a href="configs/point_rend">PointRend (CVPR'2020)</a></li>
|
||||
<li><a href="configs/dnlnet">DNLNet (ECCV'2020)</a></li>
|
||||
<li><a href="configs/ocrnet">OCRNet (ECCV'2020)</a></li>
|
||||
<li><a href="configs/isanet">ISANet (ArXiv'2019/IJCV'2021)</a></li>
|
||||
<li><a href="configs/fastscnn">Fast-SCNN (ArXiv'2019)</a></li>
|
||||
<li><a href="configs/fastfcn">FastFCN (ArXiv'2019)</a></li>
|
||||
<li><a href="configs/gcnet">GCNet (ICCVW'2019/TPAMI'2020)</a></li>
|
||||
<li><a href="configs/ann">ANN (ICCV'2019)</a></li>
|
||||
<li><a href="configs/emanet">EMANet (ICCV'2019)</a></li>
|
||||
<li><a href="configs/ccnet">CCNet (ICCV'2019)</a></li>
|
||||
<li><a href="configs/dmnet">DMNet (ICCV'2019)</a></li>
|
||||
<li><a href="configs/sem_fpn">Semantic FPN (CVPR'2019)</a></li>
|
||||
<li><a href="configs/danet">DANet (CVPR'2019)</a></li>
|
||||
<li><a href="configs/apcnet">APCNet (CVPR'2019)</a></li>
|
||||
<li><a href="configs/nonlocal_net">NonLocal Net (CVPR'2018)</a></li>
|
||||
<li><a href="configs/encnet">EncNet (CVPR'2018)</a></li>
|
||||
<li><a href="configs/deeplabv3plus">DeepLabV3+ (CVPR'2018)</a></li>
|
||||
<li><a href="configs/upernet">UPerNet (ECCV'2018)</a></li>
|
||||
<li><a href="configs/icnet">ICNet (ECCV'2018)</a></li>
|
||||
<li><a href="configs/psanet">PSANet (ECCV'2018)</a></li>
|
||||
<li><a href="configs/bisenetv1">BiSeNetV1 (ECCV'2018)</a></li>
|
||||
<li><a href="configs/deeplabv3">DeepLabV3 (ArXiv'2017)</a></li>
|
||||
<li><a href="configs/pspnet">PSPNet (CVPR'2017)</a></li>
|
||||
<li><a href="configs/erfnet">ERFNet (T-ITS'2017)</a></li>
|
||||
<li><a href="configs/unet">UNet (MICCAI'2016/Nat. Methods'2019)</a></li>
|
||||
<li><a href="configs/fcn">FCN (CVPR'2015/TPAMI'2017)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="mmseg/models/decode_heads/ann_head.py">ANN_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/apc_head.py">APC_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/aspp_head.py">ASPP_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/cc_head.py">CC_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/da_head.py">DA_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/ddr_head.py">DDR_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/dm_head.py">DM_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/dnl_head.py">DNL_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/dpt_head.py">DPT_HEAD</li>
|
||||
<li><a href="mmseg/models/decode_heads/ema_head.py">EMA_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/enc_head.py">ENC_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/fcn_head.py">FCN_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/fpn_head.py">FPN_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/gc_head.py">GC_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/ham_head.py">LightHam_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/isa_head.py">ISA_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/knet_head.py">Knet_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/lraspp_head.py">LRASPP_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/mask2former_head.py">mask2former_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/maskformer_head.py">maskformer_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/nl_head.py">NL_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/ocr_head.py">OCR_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/pid_head.py">PID_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/point_head.py">point_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/psa_head.py">PSA_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/psp_head.py">PSP_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/san_head.py">SAN_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/segformer_head.py">segformer_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/segmenter_mask_head.py">segmenter_mask_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/sep_aspp_head.py">SepASPP_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/sep_fcn_head.py">SepFCN_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/setr_mla_head.py">SETRMLAHead_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/setr_up_head.py">SETRUP_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/stdc_head.py">STDC_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/uper_head.py">Uper_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/vpd_depth_head.py">VPDDepth_Head</li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#cityscapes">Cityscapes</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#pascal-voc">PASCAL VOC</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#ade20k">ADE20K</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#pascal-context">Pascal Context</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#coco-stuff-10k">COCO-Stuff 10k</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#coco-stuff-164k">COCO-Stuff 164k</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#chase-db1">CHASE_DB1</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#drive">DRIVE</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#hrf">HRF</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#stare">STARE</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#dark-zurich">Dark Zurich</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#nighttime-driving">Nighttime Driving</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#loveda">LoveDA</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#isprs-potsdam">Potsdam</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#isprs-vaihingen">Vaihingen</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#isaid">iSAID</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#mapillary-vistas-datasets">Mapillary Vistas</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#levir-cd">LEVIR-CD</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#bdd100K">BDD100K</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#nyu">NYU</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#hsi-drive-2.0">HSIDrive20</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><b>Supported loss</b></li>
|
||||
<ul>
|
||||
<li><a href="mmseg/models/losses/boundary_loss.py">boundary_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/cross_entropy_loss.py">cross_entropy_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/dice_loss.py">dice_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/focal_loss.py">focal_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/huasdorff_distance_loss.py">huasdorff_distance_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/kldiv_loss.py">kldiv_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/lovasz_loss.py">lovasz_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/ohem_cross_entropy_loss.py">ohem_cross_entropy_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/silog_loss.py">silog_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/tversky_loss.py">tversky_loss</a></li>
|
||||
</ul>
|
||||
</ul>
|
||||
</td>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
Please refer to [FAQ](docs/en/notes/faq.md) for frequently asked questions.
|
||||
|
||||
## Projects
|
||||
|
||||
[Here](projects/README.md) are some implementations of SOTA models and solutions built on MMSegmentation, which are supported and maintained by community users. These projects demonstrate the best practices based on MMSegmentation for research and product development. We welcome and appreciate all the contributions to OpenMMLab ecosystem.
|
||||
|
||||
## Contributing
|
||||
|
||||
We appreciate all contributions to improve MMSegmentation. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
MMSegmentation is an open source project that welcome any contribution and feedback.
|
||||
We wish that the toolbox and benchmark could serve the growing research
|
||||
community by providing a flexible as well as standardized toolkit to reimplement existing methods
|
||||
and develop their own new semantic segmentation methods.
|
||||
|
||||
## Citation
|
||||
|
||||
If you find this project useful in your research, please consider cite:
|
||||
|
||||
```bibtex
|
||||
@misc{mmseg2020,
|
||||
title={{MMSegmentation}: OpenMMLab Semantic Segmentation Toolbox and Benchmark},
|
||||
author={MMSegmentation Contributors},
|
||||
howpublished = {\url{https://github.com/open-mmlab/mmsegmentation}},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is released under the [Apache 2.0 license](LICENSE).
|
||||
|
||||
## OpenMMLab Family
|
||||
|
||||
- [MMEngine](https://github.com/open-mmlab/mmengine): OpenMMLab foundational library for training deep learning models.
|
||||
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision.
|
||||
- [MMPreTrain](https://github.com/open-mmlab/mmpretrain): OpenMMLab pre-training toolbox and benchmark.
|
||||
- [MMagic](https://github.com/open-mmlab/mmagic): Open**MM**Lab **A**dvanced, **G**enerative and **I**ntelligent **C**reation toolbox.
|
||||
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
|
||||
- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO series toolbox and benchmark.
|
||||
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
|
||||
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
|
||||
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
|
||||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
|
||||
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
|
||||
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark.
|
||||
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark.
|
||||
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
|
||||
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
|
||||
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab Model Deployment Framework.
|
||||
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark.
|
||||
- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages.
|
||||
- [Playground](https://github.com/open-mmlab/playground): A central hub for gathering and showcasing amazing projects built upon OpenMMLab.
|
||||
426
Seg_All_In_One_MMSeg/README_zh-CN.md
Normal file
426
Seg_All_In_One_MMSeg/README_zh-CN.md
Normal file
@@ -0,0 +1,426 @@
|
||||
<div align="center">
|
||||
<img src="resources/mmseg-logo.png" width="600"/>
|
||||
<div> </div>
|
||||
<div align="center">
|
||||
<b><font size="5">OpenMMLab 官网</font></b>
|
||||
<sup>
|
||||
<a href="https://openmmlab.com">
|
||||
<i><font size="4">HOT</font></i>
|
||||
</a>
|
||||
</sup>
|
||||
|
||||
<b><font size="5">OpenMMLab 开放平台</font></b>
|
||||
<sup>
|
||||
<a href="https://platform.openmmlab.com">
|
||||
<i><font size="4">TRY IT OUT</font></i>
|
||||
</a>
|
||||
</sup>
|
||||
</div>
|
||||
<div> </div>
|
||||
|
||||
[](https://pypi.org/project/mmsegmentation/)
|
||||
[](https://pypi.org/project/mmsegmentation)
|
||||
[](https://mmsegmentation.readthedocs.io/zh_CN/latest/)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/actions)
|
||||
[](https://codecov.io/gh/open-mmlab/mmsegmentation)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/blob/main/LICENSE)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/issues)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/issues)
|
||||
[](https://openxlab.org.cn/apps?search=mmseg)
|
||||
|
||||
文档: <https://mmsegmentation.readthedocs.io/zh_CN/latest>
|
||||
|
||||
[English](README.md) | 简体中文
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<a href="https://openmmlab.medium.com/" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219255827-67c1a27f-f8c5-46a9-811d-5e57448c61d1.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://discord.gg/raweFPmdzG" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://twitter.com/OpenMMLab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346637-d30c8a0f-3eba-4699-8131-512fb06d46db.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://www.youtube.com/openmmlab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346691-ceb2116a-465a-40af-8424-9f30d2348ca9.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://space.bilibili.com/1293512903" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219026751-d7d14cce-a7c9-4e82-9942-8375fca65b99.png" width="3%" alt="" /></a>
|
||||
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
||||
<a href="https://www.zhihu.com/people/openmmlab" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/25839884/219026120-ba71e48b-6e94-4bd4-b4e9-b7d175b5e362.png" width="3%" alt="" /></a>
|
||||
</div>
|
||||
|
||||
## 简介
|
||||
|
||||
MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 OpenMMLab 项目的一部分。
|
||||
|
||||
[main](https://github.com/open-mmlab/mmsegmentation/tree/main) 分支代码目前支持 PyTorch 1.6 以上的版本。
|
||||
|
||||
### 🎉 MMSegmentation v1.0.0 简介 🎉
|
||||
|
||||
我们非常高兴地宣布 MMSegmentation 最新版本的正式发布!在这个新版本中,主要分支是 [main](https://github.com/open-mmlab/mmsegmentation/tree/main) 分支,开发分支是 [dev-1.x](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x)。而之前版本的稳定分支保留为 [0.x](https://github.com/open-mmlab/mmsegmentation/tree/0.x) 分支。请注意,[master](https://github.com/open-mmlab/mmsegmentation/tree/master) 分支将只在有限的时间内维护,然后将被删除。我们鼓励您在使用过程中注意分支选择和更新。感谢您一如既往的支持和热情,让我们共同努力,使 MMSegmentation 变得更加健壮和强大!💪
|
||||
|
||||
MMSegmentation v1.x 在 0.x 版本的基础上有了显著的提升,提供了更加灵活和功能丰富的体验。为了更好使用 v1.x 中的新功能,我们诚挚邀请您查阅我们详细的 [📚 迁移指南](https://mmsegmentation.readthedocs.io/zh_CN/latest/migration/interface.html),以帮助您无缝地过渡您的项目。您的支持对我们来说非常宝贵,我们热切期待您的反馈!
|
||||
|
||||

|
||||
|
||||
### 主要特性
|
||||
|
||||
- **统一的基准平台**
|
||||
|
||||
我们将各种各样的语义分割算法集成到了一个统一的工具箱,进行基准测试。
|
||||
|
||||
- **模块化设计**
|
||||
|
||||
MMSegmentation 将分割框架解耦成不同的模块组件,通过组合不同的模块组件,用户可以便捷地构建自定义的分割模型。
|
||||
|
||||
- **丰富的即插即用的算法和模型**
|
||||
|
||||
MMSegmentation 支持了众多主流的和最新的检测算法,例如 PSPNet,DeepLabV3,PSANet,DeepLabV3+ 等.
|
||||
|
||||
- **速度快**
|
||||
|
||||
训练速度比其他语义分割代码库更快或者相当。
|
||||
|
||||
## 更新日志
|
||||
|
||||
最新版本 v1.2.0 在 2023.10.12 发布。
|
||||
如果想了解更多版本更新细节和历史信息,请阅读[更新日志](docs/en/notes/changelog.md)。
|
||||
|
||||
## 安装
|
||||
|
||||
请参考[快速入门文档](docs/zh_cn/get_started.md#installation)进行安装,参考[数据集准备](docs/zh_cn/user_guides/2_dataset_prepare.md)处理数据。
|
||||
|
||||
## 快速入门
|
||||
|
||||
请参考[概述](docs/zh_cn/overview.md)对 MMSegmetation 进行初步了解
|
||||
|
||||
请参考[用户指南](https://mmsegmentation.readthedocs.io/zh_CN/latest/user_guides/index.html)了解 mmseg 的基本使用,以及[进阶指南](https://mmsegmentation.readthedocs.io/zh_CN/latest/advanced_guides/index.html)深入了解 mmseg 设计和代码实现。
|
||||
|
||||
同时,我们提供了 Colab 教程。你可以在[这里](demo/MMSegmentation_Tutorial.ipynb)浏览教程,或者直接在 Colab 上[运行](https://colab.research.google.com/github/open-mmlab/mmsegmentation/blob/main/demo/MMSegmentation_Tutorial.ipynb)。
|
||||
|
||||
若需要将 0.x 版本的代码迁移至新版,请参考[迁移文档](docs/zh_cn/migration)。
|
||||
|
||||
## 教程文档
|
||||
|
||||
<div align="center">
|
||||
<b>mmsegmentation 教程文档</b>
|
||||
</div>
|
||||
<table align="center">
|
||||
<tbody>
|
||||
<tr align="center" valign="center">
|
||||
<td>
|
||||
<b>开启 MMSeg 之旅</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>MMSeg 快速入门教程</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>MMSeg 细节介绍</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>MMSeg 开发教程</b>
|
||||
</td>
|
||||
</tr>
|
||||
<tr valign="top">
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="docs/zh_cn/overview.md">MMSeg 概述</a></li>
|
||||
<li><a href="docs/zh_cn/get_started.md">安装和验证</a></li>
|
||||
<li><a href="docs/zh_cn/notes/faq.md">常见问题解答</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="docs/zh_cn/user_guides/1_config.md">教程1:了解配置文件</a></li>
|
||||
<li><a href="docs/zh_cn/user_guides/2_dataset_prepare.md">教程2:准备数据集</a></li>
|
||||
<li><a href="docs/zh_cn/user_guides/3_inference.md">教程3:使用预训练模型推理</a></li>
|
||||
<li><a href="docs/zh_cn/user_guides/4_train_test.md">教程4:模型训练和测试</a></li>
|
||||
<li><a href="docs/zh_cn/user_guides/5_deployment.md">教程5:模型部署</a></li>
|
||||
<li><a href="docs/zh_cn/user_guides/deploy_jetson.md">在 Jetson 平台部署 MMSeg</a></li>
|
||||
<li><a href="docs/zh_cn/user_guides/useful_tools.md">常用工具</a></li>
|
||||
<li><a href="docs/zh_cn/user_guides/visualization_feature_map.md">特征图可视化</a></li>
|
||||
<li><a href="docs/zh_cn/user_guides/visualization.md">可视化</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="docs/zh_cn/advanced_guides/datasets.md">MMSeg 数据集介绍</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/models.md">MMSeg 模型介绍</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/structures.md">MMSeg 数据结构介绍</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/transforms.md">MMSeg 数据增强介绍</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/data_flow.md">MMSeg 数据流介绍</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/engine.md">MMSeg 训练引擎介绍</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/evaluation.md">MMSeg 模型评测介绍</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="docs/zh_cn/advanced_guides/add_datasets.md">新增自定义数据集</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/add_metrics.md">新增评测指标</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/add_models.md">新增自定义模型</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/add_transforms.md">新增自定义数据增强</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/customize_runtime.md">自定义运行设定</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/training_tricks.md">训练技巧</a></li>
|
||||
<li><a href=".github/CONTRIBUTING.md">如何给 MMSeg 贡献代码</a></li>
|
||||
<li><a href="docs/zh_cn/advanced_guides/contribute_dataset.md">给 MMSeg 贡献数据集教程</a></li>
|
||||
<li><a href="docs/zh_cn/device/npu.md">NPU (华为 昇腾)</a></li>
|
||||
<li><a href="docs/zh_cn/migration/interface.md">0.x → 1.x 迁移文档</a></li>
|
||||
<li><a href="docs/zh_cn/migration/package.md">0.x → 1.x 库变更文档</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## 基准测试和模型库
|
||||
|
||||
测试结果和模型可以在[模型库](docs/zh_cn/model_zoo.md)中找到。
|
||||
|
||||
<div align="center">
|
||||
<b>概览</b>
|
||||
</div>
|
||||
<table align="center">
|
||||
<tbody>
|
||||
<tr align="center" valign="bottom">
|
||||
<td>
|
||||
<b>已支持的主干网络</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>已支持的算法架构</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>已支持的分割头</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>已支持的数据集</b>
|
||||
</td>
|
||||
<td>
|
||||
<b>其他</b>
|
||||
</td>
|
||||
</tr>
|
||||
<tr valign="top">
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="mmseg/models/backbones/resnet.py">ResNet(CVPR'2016)</a></li>
|
||||
<li><a href="mmseg/models/backbones/resnext.py">ResNeXt (CVPR'2017)</a></li>
|
||||
<li><a href="configs/hrnet">HRNet (CVPR'2019)</a></li>
|
||||
<li><a href="configs/resnest">ResNeSt (ArXiv'2020)</a></li>
|
||||
<li><a href="configs/mobilenet_v2">MobileNetV2 (CVPR'2018)</a></li>
|
||||
<li><a href="configs/mobilenet_v3">MobileNetV3 (ICCV'2019)</a></li>
|
||||
<li><a href="configs/vit">Vision Transformer (ICLR'2021)</a></li>
|
||||
<li><a href="configs/swin">Swin Transformer (ICCV'2021)</a></li>
|
||||
<li><a href="configs/twins">Twins (NeurIPS'2021)</a></li>
|
||||
<li><a href="configs/beit">BEiT (ICLR'2022)</a></li>
|
||||
<li><a href="configs/convnext">ConvNeXt (CVPR'2022)</a></li>
|
||||
<li><a href="configs/mae">MAE (CVPR'2022)</a></li>
|
||||
<li><a href="configs/poolformer">PoolFormer (CVPR'2022)</a></li>
|
||||
<li><a href="configs/segnext">SegNeXt (NeurIPS'2022)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="configs/san/">SAN (CVPR'2023)</a></li>
|
||||
<li><a href="configs/vpd">VPD (ICCV'2023)</a></li>
|
||||
<li><a href="configs/ddrnet">DDRNet (T-ITS'2022)</a></li>
|
||||
<li><a href="configs/pidnet">PIDNet (ArXiv'2022)</a></li>
|
||||
<li><a href="configs/mask2former">Mask2Former (CVPR'2022)</a></li>
|
||||
<li><a href="configs/maskformer">MaskFormer (NeurIPS'2021)</a></li>
|
||||
<li><a href="configs/knet">K-Net (NeurIPS'2021)</a></li>
|
||||
<li><a href="configs/segformer">SegFormer (NeurIPS'2021)</a></li>
|
||||
<li><a href="configs/segmenter">Segmenter (ICCV'2021)</a></li>
|
||||
<li><a href="configs/dpt">DPT (ArXiv'2021)</a></li>
|
||||
<li><a href="configs/setr">SETR (CVPR'2021)</a></li>
|
||||
<li><a href="configs/stdc">STDC (CVPR'2021)</a></li>
|
||||
<li><a href="configs/bisenetv2">BiSeNetV2 (IJCV'2021)</a></li>
|
||||
<li><a href="configs/cgnet">CGNet (TIP'2020)</a></li>
|
||||
<li><a href="configs/point_rend">PointRend (CVPR'2020)</a></li>
|
||||
<li><a href="configs/dnlnet">DNLNet (ECCV'2020)</a></li>
|
||||
<li><a href="configs/ocrnet">OCRNet (ECCV'2020)</a></li>
|
||||
<li><a href="configs/isanet">ISANet (ArXiv'2019/IJCV'2021)</a></li>
|
||||
<li><a href="configs/fastscnn">Fast-SCNN (ArXiv'2019)</a></li>
|
||||
<li><a href="configs/fastfcn">FastFCN (ArXiv'2019)</a></li>
|
||||
<li><a href="configs/gcnet">GCNet (ICCVW'2019/TPAMI'2020)</a></li>
|
||||
<li><a href="configs/ann">ANN (ICCV'2019)</a></li>
|
||||
<li><a href="configs/emanet">EMANet (ICCV'2019)</a></li>
|
||||
<li><a href="configs/ccnet">CCNet (ICCV'2019)</a></li>
|
||||
<li><a href="configs/dmnet">DMNet (ICCV'2019)</a></li>
|
||||
<li><a href="configs/sem_fpn">Semantic FPN (CVPR'2019)</a></li>
|
||||
<li><a href="configs/danet">DANet (CVPR'2019)</a></li>
|
||||
<li><a href="configs/apcnet">APCNet (CVPR'2019)</a></li>
|
||||
<li><a href="configs/nonlocal_net">NonLocal Net (CVPR'2018)</a></li>
|
||||
<li><a href="configs/encnet">EncNet (CVPR'2018)</a></li>
|
||||
<li><a href="configs/deeplabv3plus">DeepLabV3+ (CVPR'2018)</a></li>
|
||||
<li><a href="configs/upernet">UPerNet (ECCV'2018)</a></li>
|
||||
<li><a href="configs/icnet">ICNet (ECCV'2018)</a></li>
|
||||
<li><a href="configs/psanet">PSANet (ECCV'2018)</a></li>
|
||||
<li><a href="configs/bisenetv1">BiSeNetV1 (ECCV'2018)</a></li>
|
||||
<li><a href="configs/deeplabv3">DeepLabV3 (ArXiv'2017)</a></li>
|
||||
<li><a href="configs/pspnet">PSPNet (CVPR'2017)</a></li>
|
||||
<li><a href="configs/erfnet">ERFNet (T-ITS'2017)</a></li>
|
||||
<li><a href="configs/unet">UNet (MICCAI'2016/Nat. Methods'2019)</a></li>
|
||||
<li><a href="configs/fcn">FCN (CVPR'2015/TPAMI'2017)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="mmseg/models/decode_heads/ann_head.py">ANN_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/apc_head.py">APC_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/aspp_head.py">ASPP_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/cc_head.py">CC_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/da_head.py">DA_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/ddr_head.py">DDR_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/dm_head.py">DM_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/dnl_head.py">DNL_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/dpt_head.py">DPT_HEAD</li>
|
||||
<li><a href="mmseg/models/decode_heads/ema_head.py">EMA_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/enc_head.py">ENC_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/fcn_head.py">FCN_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/fpn_head.py">FPN_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/gc_head.py">GC_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/ham_head.py">LightHam_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/isa_head.py">ISA_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/knet_head.py">Knet_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/lraspp_head.py">LRASPP_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/mask2former_head.py">mask2former_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/maskformer_head.py">maskformer_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/nl_head.py">NL_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/ocr_head.py">OCR_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/pid_head.py">PID_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/point_head.py">point_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/psa_head.py">PSA_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/psp_head.py">PSP_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/san_head.py">SAN_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/segformer_head.py">segformer_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/segmenter_mask_head.py">segmenter_mask_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/sep_aspp_head.py">SepASPP_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/sep_fcn_head.py">SepFCN_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/setr_mla_head.py">SETRMLAHead_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/setr_up_head.py">SETRUP_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/stdc_head.py">STDC_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/uper_head.py">Uper_Head</li>
|
||||
<li><a href="mmseg/models/decode_heads/vpd_depth_head.py">VPDDepth_Head</li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#cityscapes">Cityscapes</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#pascal-voc">PASCAL VOC</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#ade20k">ADE20K</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#pascal-context">Pascal Context</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#coco-stuff-10k">COCO-Stuff 10k</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#coco-stuff-164k">COCO-Stuff 164k</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#chase-db1">CHASE_DB1</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#drive">DRIVE</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#hrf">HRF</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#stare">STARE</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#dark-zurich">Dark Zurich</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#nighttime-driving">Nighttime Driving</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#loveda">LoveDA</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#isprs-potsdam">Potsdam</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#isprs-vaihingen">Vaihingen</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#isaid">iSAID</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#mapillary-vistas-datasets">Mapillary Vistas</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#levir-cd">LEVIR-CD</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#bdd100K">BDD100K</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#nyu">NYU</a></li>
|
||||
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#hsi-drive-2.0">HSIDrive20</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li><b>已支持的 loss</b></li>
|
||||
<ul>
|
||||
<li><a href="mmseg/models/losses/boundary_loss.py">boundary_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/cross_entropy_loss.py">cross_entropy_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/dice_loss.py">dice_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/focal_loss.py">focal_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/huasdorff_distance_loss.py">huasdorff_distance_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/kldiv_loss.py">kldiv_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/lovasz_loss.py">lovasz_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/ohem_cross_entropy_loss.py">ohem_cross_entropy_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/silog_loss.py">silog_loss</a></li>
|
||||
<li><a href="mmseg/models/losses/tversky_loss.py">tversky_loss</a></li>
|
||||
</ul>
|
||||
</ul>
|
||||
</td>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
如果遇到问题,请参考 [常见问题解答](docs/zh_cn/notes/faq.md)。
|
||||
|
||||
## 社区项目
|
||||
|
||||
[这里](projects/README.md)有一些由社区用户支持和维护的基于 MMSegmentation 的 SOTA 模型和解决方案的实现。这些项目展示了基于 MMSegmentation 的研究和产品开发的最佳实践。
|
||||
我们欢迎并感谢对 OpenMMLab 生态系统的所有贡献。
|
||||
|
||||
## 贡献指南
|
||||
|
||||
我们感谢所有的贡献者为改进和提升 MMSegmentation 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。
|
||||
|
||||
## 致谢
|
||||
|
||||
MMSegmentation 是一个由来自不同高校和企业的研发人员共同参与贡献的开源项目。我们感谢所有为项目提供算法复现和新功能支持的贡献者,以及提供宝贵反馈的用户。我们希望这个工具箱和基准测试可以为社区提供灵活的代码工具,供用户复现已有算法并开发自己的新模型,从而不断为开源社区提供贡献。
|
||||
|
||||
## 引用
|
||||
|
||||
如果你觉得本项目对你的研究工作有所帮助,请参考如下 bibtex 引用 MMSegmentation。
|
||||
|
||||
```bibtex
|
||||
@misc{mmseg2020,
|
||||
title={{MMSegmentation}: OpenMMLab Semantic Segmentation Toolbox and Benchmark},
|
||||
author={MMSegmentation Contributors},
|
||||
howpublished = {\url{https://github.com/open-mmlab/mmsegmentation}},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## 开源许可证
|
||||
|
||||
该项目采用 [Apache 2.0 开源许可证](LICENSE)。
|
||||
|
||||
## OpenMMLab 的其他项目
|
||||
|
||||
- [MMEngine](https://github.com/open-mmlab/mmengine): OpenMMLab 深度学习模型训练基础库
|
||||
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab 计算机视觉基础库
|
||||
- [MMPreTrain](https://github.com/open-mmlab/mmpretrain): OpenMMLab 深度学习预训练工具箱
|
||||
- [MMagic](https://github.com/open-mmlab/mmagic): OpenMMLab 新一代人工智能内容生成(AIGC)工具箱
|
||||
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱
|
||||
- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO 系列工具箱与测试基准
|
||||
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台
|
||||
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准
|
||||
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台
|
||||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱
|
||||
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具包
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱
|
||||
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 人体参数化模型工具箱与测试基准
|
||||
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab 少样本学习工具箱与测试基准
|
||||
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab 新一代视频理解工具箱
|
||||
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准
|
||||
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab 模型部署框架
|
||||
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab 模型压缩工具箱与测试基准
|
||||
- [MIM](https://github.com/open-mmlab/mim): OpenMMLab 项目、算法、模型的统一入口
|
||||
- [Playground](https://github.com/open-mmlab/playground): 收集和展示 OpenMMLab 相关的前沿、有趣的社区项目
|
||||
|
||||
## 欢迎加入 OpenMMLab 社区
|
||||
|
||||
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),扫描下方微信二维码添加喵喵好友,进入 MMSegmentation 微信交流社群。【加好友申请格式:研究方向+地区+学校/公司+姓名】
|
||||
|
||||
<div align="center">
|
||||
<img src="docs/zh_cn/imgs/zhihu_qrcode.jpg" height="400" /> <img src="resources/miaomiao_qrcode.jpg" height="400" />
|
||||
</div>
|
||||
|
||||
我们会在 OpenMMLab 社区为大家
|
||||
|
||||
- 📢 分享 AI 框架的前沿核心技术
|
||||
- 💻 解读 PyTorch 常用模块源码
|
||||
- 📰 发布 OpenMMLab 的相关新闻
|
||||
- 🚀 介绍 OpenMMLab 开发的前沿算法
|
||||
- 🏃 获取更高效的问题答疑和意见反馈
|
||||
- 🔥 提供与各行各业开发者充分交流的平台
|
||||
|
||||
干货满满 📘,等你来撩 💗,OpenMMLab 社区期待您的加入 👬
|
||||
68
Seg_All_In_One_MMSeg/configs/_base_/datasets/ade20k.py
Normal file
68
Seg_All_In_One_MMSeg/configs/_base_/datasets/ade20k.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# dataset settings
|
||||
dataset_type = 'ADE20KDataset'
|
||||
data_root = 'data/ade/ADEChallengeData2016'
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 512),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', backend_args=None),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training', seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/validation',
|
||||
seg_map_path='annotations/validation'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,68 @@
|
||||
# dataset settings
|
||||
dataset_type = 'ADE20KDataset'
|
||||
data_root = 'data/ade/ADEChallengeData2016'
|
||||
crop_size = (640, 640)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2560, 640),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2560, 640), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', backend_args=None),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training', seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/validation',
|
||||
seg_map_path='annotations/validation'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
70
Seg_All_In_One_MMSeg/configs/_base_/datasets/bdd100k.py
Normal file
70
Seg_All_In_One_MMSeg/configs/_base_/datasets/bdd100k.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# dataset settings
|
||||
dataset_type = 'BDD100KDataset'
|
||||
data_root = 'data/bdd100k/'
|
||||
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 1024),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', backend_args=None),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/10k/train',
|
||||
seg_map_path='labels/sem_seg/masks/train'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/10k/val',
|
||||
seg_map_path='labels/sem_seg/masks/val'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
75
Seg_All_In_One_MMSeg/configs/_base_/datasets/chase_db1.py
Normal file
75
Seg_All_In_One_MMSeg/configs/_base_/datasets/chase_db1.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# dataset settings
|
||||
dataset_type = 'ChaseDB1Dataset'
|
||||
data_root = 'data/CHASE_DB1'
|
||||
img_scale = (960, 999)
|
||||
crop_size = (128, 128)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=img_scale,
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=img_scale, keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', backend_args=None),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training',
|
||||
seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline)))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/validation',
|
||||
seg_map_path='annotations/validation'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mDice'])
|
||||
test_evaluator = val_evaluator
|
||||
67
Seg_All_In_One_MMSeg/configs/_base_/datasets/cityscapes.py
Normal file
67
Seg_All_In_One_MMSeg/configs/_base_/datasets/cityscapes.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# dataset settings
|
||||
dataset_type = 'CityscapesDataset'
|
||||
data_root = 'data/cityscapes/'
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 1024),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', backend_args=None),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='leftImg8bit/train', seg_map_path='gtFine/train'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='leftImg8bit/val', seg_map_path='gtFine/val'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,29 @@
|
||||
_base_ = './cityscapes.py'
|
||||
crop_size = (1024, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 1024),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,29 @@
|
||||
_base_ = './cityscapes.py'
|
||||
crop_size = (768, 768)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2049, 1025),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2049, 1025), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,29 @@
|
||||
_base_ = './cityscapes.py'
|
||||
crop_size = (769, 769)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2049, 1025),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2049, 1025), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,29 @@
|
||||
_base_ = './cityscapes.py'
|
||||
crop_size = (832, 832)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 1024),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,69 @@
|
||||
# dataset settings
|
||||
dataset_type = 'COCOStuffDataset'
|
||||
data_root = 'data/coco_stuff10k'
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 512),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', backend_args=None),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
reduce_zero_label=True,
|
||||
data_prefix=dict(
|
||||
img_path='images/train2014', seg_map_path='annotations/train2014'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
reduce_zero_label=True,
|
||||
data_prefix=dict(
|
||||
img_path='images/test2014', seg_map_path='annotations/test2014'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,67 @@
|
||||
# dataset settings
|
||||
dataset_type = 'COCOStuffDataset'
|
||||
data_root = 'data/coco_stuff164k'
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 512),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', backend_args=None),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/train2017', seg_map_path='annotations/train2017'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/val2017', seg_map_path='annotations/val2017'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
73
Seg_All_In_One_MMSeg/configs/_base_/datasets/drive.py
Normal file
73
Seg_All_In_One_MMSeg/configs/_base_/datasets/drive.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# dataset settings
|
||||
dataset_type = 'DRIVEDataset'
|
||||
data_root = 'data/DRIVE'
|
||||
img_scale = (584, 565)
|
||||
crop_size = (64, 64)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=img_scale,
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=img_scale, keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', backend_args=None),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training',
|
||||
seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline)))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/validation',
|
||||
seg_map_path='annotations/validation'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mDice'])
|
||||
test_evaluator = val_evaluator
|
||||
73
Seg_All_In_One_MMSeg/configs/_base_/datasets/hrf.py
Normal file
73
Seg_All_In_One_MMSeg/configs/_base_/datasets/hrf.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# dataset settings
|
||||
dataset_type = 'HRFDataset'
|
||||
data_root = 'data/HRF'
|
||||
img_scale = (2336, 3504)
|
||||
crop_size = (256, 256)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=img_scale,
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=img_scale, keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', backend_args=None),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training',
|
||||
seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline)))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/validation',
|
||||
seg_map_path='annotations/validation'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mDice'])
|
||||
test_evaluator = val_evaluator
|
||||
53
Seg_All_In_One_MMSeg/configs/_base_/datasets/hsi_drive.py
Normal file
53
Seg_All_In_One_MMSeg/configs/_base_/datasets/hsi_drive.py
Normal file
@@ -0,0 +1,53 @@
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromNpyFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='RandomCrop', crop_size=(192, 384)),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromNpyFile'),
|
||||
dict(type='RandomCrop', crop_size=(192, 384)),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=1,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type='HSIDrive20Dataset',
|
||||
data_root='data/HSIDrive20',
|
||||
data_prefix=dict(
|
||||
img_path='images/training', seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=1,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type='HSIDrive20Dataset',
|
||||
data_root='data/HSIDrive20',
|
||||
data_prefix=dict(
|
||||
img_path='images/validation',
|
||||
seg_map_path='annotations/validation'),
|
||||
pipeline=test_pipeline))
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=1,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type='HSIDrive20Dataset',
|
||||
data_root='data/HSIDrive20',
|
||||
data_prefix=dict(
|
||||
img_path='images/test', seg_map_path='annotations/test'),
|
||||
pipeline=test_pipeline))
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], ignore_index=0)
|
||||
test_evaluator = val_evaluator
|
||||
73
Seg_All_In_One_MMSeg/configs/_base_/datasets/isaid.py
Normal file
73
Seg_All_In_One_MMSeg/configs/_base_/datasets/isaid.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# dataset settings
|
||||
dataset_type = 'iSAIDDataset'
|
||||
data_root = 'data/iSAID'
|
||||
"""
|
||||
This crop_size setting is followed by the implementation of
|
||||
`PointFlow: Flowing Semantics Through Points for Aerial Image
|
||||
Segmentation <https://arxiv.org/pdf/2103.06564.pdf>`_.
|
||||
"""
|
||||
|
||||
crop_size = (896, 896)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(896, 896),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(896, 896), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', backend_args=None),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='img_dir/train', seg_map_path='ann_dir/train'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,68 @@
|
||||
# dataset settings
|
||||
dataset_type = 'LEVIRCDDataset'
|
||||
data_root = r'data/LEVIRCD'
|
||||
|
||||
albu_train_transforms = [
|
||||
dict(type='RandomBrightnessContrast', p=0.2),
|
||||
dict(type='HorizontalFlip', p=0.5),
|
||||
dict(type='VerticalFlip', p=0.5)
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadMultipleRSImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='Albu',
|
||||
keymap={
|
||||
'img': 'image',
|
||||
'img2': 'image2',
|
||||
'gt_seg_map': 'mask'
|
||||
},
|
||||
transforms=albu_train_transforms,
|
||||
additional_targets={'image2': 'image'},
|
||||
bgr_to_rgb=False),
|
||||
dict(type='ConcatCDInput'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadMultipleRSImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='ConcatCDInput'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(type='LoadMultipleRSImageFromFile'),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[[dict(type='LoadAnnotations')],
|
||||
[dict(type='ConcatCDInput')],
|
||||
[dict(type='PackSegInputs')]])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='train/A',
|
||||
img_path2='train/B',
|
||||
seg_map_path='train/label'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='test/A', img_path2='test/B', seg_map_path='test/label'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
66
Seg_All_In_One_MMSeg/configs/_base_/datasets/loveda.py
Normal file
66
Seg_All_In_One_MMSeg/configs/_base_/datasets/loveda.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# dataset settings
|
||||
dataset_type = 'LoveDADataset'
|
||||
data_root = 'data/loveDA'
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 512),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(1024, 1024), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', backend_args=None),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='img_dir/train', seg_map_path='ann_dir/train'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
68
Seg_All_In_One_MMSeg/configs/_base_/datasets/mapillary_v1.py
Normal file
68
Seg_All_In_One_MMSeg/configs/_base_/datasets/mapillary_v1.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# dataset settings
|
||||
dataset_type = 'MapillaryDataset_v1'
|
||||
data_root = 'data/mapillary/'
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 1024),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='training/images', seg_map_path='training/v1.2/labels'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='validation/images',
|
||||
seg_map_path='validation/v1.2/labels'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,37 @@
|
||||
# dataset settings
|
||||
_base_ = './mapillary_v1.py'
|
||||
metainfo = dict(
|
||||
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier',
|
||||
'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking',
|
||||
'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane',
|
||||
'Sidewalk', 'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist',
|
||||
'Motorcyclist', 'Other Rider', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
|
||||
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack',
|
||||
'Billboard', 'Catch Basin', 'CCTV Camera', 'Fire Hydrant',
|
||||
'Junction Box', 'Mailbox', 'Manhole', 'Phone Booth', 'Pothole',
|
||||
'Street Light', 'Pole', 'Traffic Sign Frame', 'Utility Pole',
|
||||
'Traffic Light', 'Traffic Sign (Back)', 'Traffic Sign (Front)',
|
||||
'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan',
|
||||
'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
|
||||
'Wheeled Slow', 'Car Mount', 'Ego Vehicle'),
|
||||
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
|
||||
[180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255],
|
||||
[140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232],
|
||||
[150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60],
|
||||
[255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128],
|
||||
[255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180],
|
||||
[190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30],
|
||||
[255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220],
|
||||
[220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40],
|
||||
[33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150],
|
||||
[210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80],
|
||||
[250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20],
|
||||
[119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142],
|
||||
[0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10]])
|
||||
|
||||
train_dataloader = dict(dataset=dict(metainfo=metainfo))
|
||||
val_dataloader = dict(dataset=dict(metainfo=metainfo))
|
||||
test_dataloader = val_dataloader
|
||||
68
Seg_All_In_One_MMSeg/configs/_base_/datasets/mapillary_v2.py
Normal file
68
Seg_All_In_One_MMSeg/configs/_base_/datasets/mapillary_v2.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# dataset settings
|
||||
dataset_type = 'MapillaryDataset_v2'
|
||||
data_root = 'data/mapillary/'
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 1024),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='training/images', seg_map_path='training/v2.0/labels'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='validation/images',
|
||||
seg_map_path='validation/v2.0/labels'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
64
Seg_All_In_One_MMSeg/configs/_base_/datasets/my_dataset.py
Normal file
64
Seg_All_In_One_MMSeg/configs/_base_/datasets/my_dataset.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# dataset settings
|
||||
dataset_class_name = 'MyDataset' # TODO 上一步中你定义的数据集的名字
|
||||
data_root = '/home/audience/Desktop/Seg_data/Data' # TODO 数据集存储路径
|
||||
# img_norm_cfg = dict(
|
||||
# mean=[33.30, 35.03, 47.23], std=[48.00, 50.4, 60.51], to_rgb=True) # TODO 数据集的均值和标准差,空引用默认的,也可以网上搜代码计算
|
||||
img_scale = (1920, 1080) # img_scale图像尺寸 TODO (1920,1080)
|
||||
crop_size = (512, 512) # 数据增强时裁剪的大小 TODO 之后可以修改
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'), # ", reduce_zero_label=False" TODO 是否忽略0直选项
|
||||
dict(type='RandomResize', scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
# dict(type='GenerateEdge', edge_width=4), # For pidnet
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=img_scale, keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
train_dataloader = dict( # Train dataloader config
|
||||
batch_size=4, # Batch size of a single GPU TODO
|
||||
num_workers=4, # Worker to pre-fetch data for each single GPU
|
||||
persistent_workers=True, # Shut down the worker processes after an epoch end, which can accelerate training speed.
|
||||
sampler=dict(type='DefaultSampler', shuffle=True), # Randomly shuffle during training.
|
||||
dataset=dict( # Train dataset config
|
||||
type=dataset_class_name, # Type of dataset, refer to mmseg/datasets/ for details.
|
||||
data_root=data_root, # The root of dataset.
|
||||
data_prefix=dict(
|
||||
img_path='A_Ori',
|
||||
seg_map_path='A_Label_GT_label_fold'),
|
||||
pipeline=train_pipeline)) # Processing pipeline. This is passed by the train_pipeline created before.
|
||||
val_dataloader = dict(
|
||||
batch_size=1, # Batch size of a single GPU
|
||||
num_workers=4, # Worker to pre-fetch data for each single GPU
|
||||
persistent_workers=True, # Shut down the worker processes after an epoch end, which can accelerate testing speed.
|
||||
sampler=dict(type='DefaultSampler', shuffle=False), # Not shuffle during validation and testing.
|
||||
dataset=dict( # Test dataset config
|
||||
type=dataset_class_name, # Type of dataset, refer to mmseg/datasets/ for details.
|
||||
data_root=data_root, # The root of dataset.
|
||||
data_prefix=dict(
|
||||
img_path='A_Ori',
|
||||
seg_map_path='A_Label_GT_label_fold'),
|
||||
pipeline=test_pipeline)) # Processing pipeline. This is passed by the test_pipeline created before.
|
||||
test_dataloader = dict(
|
||||
batch_size=1, # Batch size of a single GPU
|
||||
num_workers=4, # Worker to pre-fetch data for each single GPU
|
||||
persistent_workers=True, # Shut down the worker processes after an epoch end, which can accelerate testing speed.
|
||||
sampler=dict(type='DefaultSampler', shuffle=False), # Not shuffle during validation and testing.
|
||||
dataset=dict( # Test dataset config
|
||||
type=dataset_class_name, # Type of dataset, refer to mmseg/datasets/ for details.
|
||||
data_root=data_root, # The root of dataset.
|
||||
data_prefix=dict(
|
||||
img_path='A_Ori',
|
||||
seg_map_path='A_Label_GT_label_fold'),
|
||||
pipeline=test_pipeline)) # Processing pipeline. This is passed by the test_pipeline created before.
|
||||
# The metric to measure the accuracy. Here, we use IoUMetric.
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,64 @@
|
||||
# dataset settings
|
||||
dataset_class_name = 'MyDataset_model' # TODO 上一步中你定义的数据集的名字
|
||||
data_root = '/home/wkmgc/Desktop/Seg/Seg_All_In_One_MMSeg/My_Data' # TODO 数据集存储路径
|
||||
# img_norm_cfg = dict(
|
||||
# mean=[33.30, 35.03, 47.23], std=[48.00, 50.4, 60.51], to_rgb=True) # TODO 数据集的均值和标准差,空引用默认的,也可以网上搜代码计算
|
||||
img_scale = (1920, 1080) # img_scale图像尺寸 TODO (1920,1080)
|
||||
crop_size = (256, 256) # 数据增强时裁剪的大小 TODO 之后可以修改
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'), # ", reduce_zero_label=False" TODO 是否忽略0直选项
|
||||
dict(type='RandomResize', scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
# dict(type='GenerateEdge', edge_width=4), # For pidnet
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=img_scale, keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
train_dataloader = dict( # Train dataloader config
|
||||
batch_size=16, # Batch size of a single GPU TODO
|
||||
num_workers=4, # Worker to pre-fetch data for each single GPU
|
||||
persistent_workers=True, # Shut down the worker processes after an epoch end, which can accelerate training speed.
|
||||
sampler=dict(type='DefaultSampler', shuffle=True), # Randomly shuffle during training.
|
||||
dataset=dict( # Train dataset config
|
||||
type=dataset_class_name, # Type of dataset, refer to mmseg/datasets/ for details.
|
||||
data_root=data_root, # The root of dataset.
|
||||
data_prefix=dict(
|
||||
img_path='A_Ori',
|
||||
seg_map_path='A_Label_GT_label_fold'),
|
||||
pipeline=train_pipeline)) # Processing pipeline. This is passed by the train_pipeline created before.
|
||||
val_dataloader = dict(
|
||||
batch_size=1, # Batch size of a single GPU
|
||||
num_workers=4, # Worker to pre-fetch data for each single GPU
|
||||
persistent_workers=True, # Shut down the worker processes after an epoch end, which can accelerate testing speed.
|
||||
sampler=dict(type='DefaultSampler', shuffle=False), # Not shuffle during validation and testing.
|
||||
dataset=dict( # Test dataset config
|
||||
type=dataset_class_name, # Type of dataset, refer to mmseg/datasets/ for details.
|
||||
data_root=data_root, # The root of dataset.
|
||||
data_prefix=dict(
|
||||
img_path='A_Ori',
|
||||
seg_map_path='A_Label_GT_label_fold'),
|
||||
pipeline=test_pipeline)) # Processing pipeline. This is passed by the test_pipeline created before.
|
||||
test_dataloader = dict(
|
||||
batch_size=1, # Batch size of a single GPU
|
||||
num_workers=4, # Worker to pre-fetch data for each single GPU
|
||||
persistent_workers=True, # Shut down the worker processes after an epoch end, which can accelerate testing speed.
|
||||
sampler=dict(type='DefaultSampler', shuffle=False), # Not shuffle during validation and testing.
|
||||
dataset=dict( # Test dataset config
|
||||
type=dataset_class_name, # Type of dataset, refer to mmseg/datasets/ for details.
|
||||
data_root=data_root, # The root of dataset.
|
||||
data_prefix=dict(
|
||||
img_path='A_Ori',
|
||||
seg_map_path='A_Label_GT_label_fold'),
|
||||
pipeline=test_pipeline)) # Processing pipeline. This is passed by the test_pipeline created before.
|
||||
# The metric to measure the accuracy. Here, we use IoUMetric.
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
67
Seg_All_In_One_MMSeg/configs/_base_/datasets/nyu.py
Normal file
67
Seg_All_In_One_MMSeg/configs/_base_/datasets/nyu.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# dataset settings
|
||||
dataset_type = 'NYUDataset'
|
||||
data_root = 'data/nyu'
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3),
|
||||
dict(type='RandomDepthMix', prob=0.25),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='RandomCrop', crop_size=(480, 480)),
|
||||
dict(
|
||||
type='Albu',
|
||||
transforms=[
|
||||
dict(type='RandomBrightnessContrast'),
|
||||
dict(type='RandomGamma'),
|
||||
dict(type='HueSaturationValue'),
|
||||
]),
|
||||
dict(
|
||||
type='PackSegInputs',
|
||||
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
|
||||
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
|
||||
'category_id')),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2000, 480), keep_ratio=True),
|
||||
dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)),
|
||||
dict(
|
||||
type='PackSegInputs',
|
||||
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
|
||||
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
|
||||
'category_id'))
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/train', depth_map_path='annotations/train'),
|
||||
pipeline=train_pipeline))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
test_mode=True,
|
||||
data_prefix=dict(
|
||||
img_path='images/test', depth_map_path='annotations/test'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(
|
||||
type='DepthMetric',
|
||||
min_depth_eval=0.001,
|
||||
max_depth_eval=10.0,
|
||||
crop_type='nyu_crop')
|
||||
test_evaluator = val_evaluator
|
||||
72
Seg_All_In_One_MMSeg/configs/_base_/datasets/nyu_512x512.py
Normal file
72
Seg_All_In_One_MMSeg/configs/_base_/datasets/nyu_512x512.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# dataset settings
|
||||
dataset_type = 'NYUDataset'
|
||||
data_root = 'data/nyu'
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3),
|
||||
dict(type='RandomDepthMix', prob=0.25),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(768, 512),
|
||||
ratio_range=(0.8, 1.5),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=(512, 512)),
|
||||
dict(
|
||||
type='Albu',
|
||||
transforms=[
|
||||
dict(type='RandomBrightnessContrast'),
|
||||
dict(type='RandomGamma'),
|
||||
dict(type='HueSaturationValue'),
|
||||
]),
|
||||
dict(
|
||||
type='PackSegInputs',
|
||||
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
|
||||
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
|
||||
'category_id')),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
|
||||
dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)),
|
||||
dict(
|
||||
type='PackSegInputs',
|
||||
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
|
||||
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
|
||||
'category_id'))
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/train', depth_map_path='annotations/train'),
|
||||
pipeline=train_pipeline))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
test_mode=True,
|
||||
data_prefix=dict(
|
||||
img_path='images/test', depth_map_path='annotations/test'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(
|
||||
type='DepthMetric',
|
||||
min_depth_eval=0.001,
|
||||
max_depth_eval=10.0,
|
||||
crop_type='nyu_crop')
|
||||
test_evaluator = val_evaluator
|
||||
@@ -0,0 +1,56 @@
|
||||
# dataset settings
|
||||
dataset_type = 'PascalContextDataset'
|
||||
data_root = 'data/VOCdevkit/VOC2010/'
|
||||
|
||||
img_scale = (520, 520)
|
||||
crop_size = (480, 480)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=img_scale,
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=img_scale, keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
|
||||
ann_file='ImageSets/SegmentationContext/train.txt',
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
|
||||
ann_file='ImageSets/SegmentationContext/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user