Skip to content

Commit 6e22a31

Browse files
committed
Release nyu & pascal context models
Release nyu & pascal context models
1 parent 9d4b561 commit 6e22a31

File tree

6 files changed

+383
-3
lines changed

6 files changed

+383
-3
lines changed

segmentation/README.md

+28-2
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,40 @@ Prepare datasets according to the [guidelines](https://github.com/open-mmlab/mms
151151
<br>
152152
<div>
153153

154-
| method | backbone | resolution | mIoU (ss) | #params | FLOPs | Config | Download |
155-
| :---------: | :-----------: | :--------: | :---------: | :-----: | :---: | :------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
154+
| method | backbone | resolution | mIoU (ss) | #params | FLOPs | Config | Download |
155+
| :---------: | :-----------: | :--------: | :---------: | :-----: | :---: | :-----------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
156156
| Mask2Former | InternImage-H | 512x512 | 59.2 / 59.6 | 1.28B | 1528G | [config](./configs/coco_stuff10k/mask2former_internimage_h_512_40k_cocostuff164k_to_10k.py) | [ckpt](https://huggingface.co/OpenGVLab/InternImage/resolve/main/mask2former_internimage_h_512_40k_cocostuff164k_to_10k.pth) \| [log](https://huggingface.co/OpenGVLab/InternImage/raw/main/mask2former_internimage_h_512_40k_cocostuff164k_to_10k.log.json) |
157157

158158
</div>
159159

160160
</details>
161161

162+
<details>
163+
<summary> Dataset: Pascal-Context-59 </summary>
164+
<br>
165+
<div>
166+
167+
| method | backbone | resolution | mIoU (ss/ms) | #param | FLOPs | Config | Download |
168+
| :---------: | :-----------: | :--------: | :----------: | :----: | :---: | :--------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
169+
| Mask2Former | InternImage-H | 480x480 | 69.7 / 70.3 | 1.07B | 867G | [config](./configs/coco_stuff10k/mask2former_internimage_h_480_40k_pascal_context_59.py) | [ckpt](https://huggingface.co/OpenGVLab/InternImage/resolve/main/mask2former_internimage_h_480_40k_pascal_context_59.pth) \| [log](https://huggingface.co/OpenGVLab/InternImage/raw/main/mask2former_internimage_h_480_40k_pascal_context_59.log.json) |
170+
171+
</div>
172+
173+
</details>
174+
175+
<details>
176+
<summary> Dataset: NYU-Depth-V2 </summary>
177+
<br>
178+
<div>
179+
180+
| method | backbone | resolution | mIoU (ss/ms) | #param | FLOPs | Config | Download |
181+
| :---------: | :-----------: | :--------: | :----------: | :----: | :---: | :-----------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
182+
| Mask2Former | InternImage-H | 480x480 | 67.1 / 68.1 | 1.07B | 867G | [config](./configs/nyu_depth_v2/mask2former_internimage_h_480_40k_nyu.py) | [ckpt](https://huggingface.co/OpenGVLab/InternImage/resolve/main/mask2former_internimage_h_480_40k_nyu.pth) \| [log](https://huggingface.co/OpenGVLab/InternImage/raw/main/mask2former_internimage_h_480_40k_nyu.log.json) |
183+
184+
</div>
185+
186+
</details>
187+
162188
## Evaluation
163189

164190
To evaluate our `InternImage` on ADE20K val, run:

segmentation/configs/coco_stuff164k/mask2former_internimage_h_896_80k_cocostuff164k.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
center_feature_scale=True, # for InternImage-H/G
3232
with_cp=False,
3333
out_indices=(0, 1, 2, 3),
34-
init_cfg=None),
34+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
3535
decode_head=dict(
3636
in_channels=[320, 640, 1280, 2560],
3737
feat_channels=1024,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# NYU-Depth-V2
2+
3+
<!-- [ALGORITHM] -->
4+
5+
## Introduction
6+
7+
The NYU Depth V2 dataset is a comprehensive collection of indoor scene data captured using a Microsoft Kinect device. It is widely utilized in computer vision research, particularly for tasks such as depth estimation and semantic segmentation.
8+
9+
## Model Zoo
10+
11+
### Mask2Former + InternImage
12+
13+
| backbone | resolution | mIoU (ss/ms) | #param | FLOPs | Config | Download |
14+
| :-----------: | :--------: | :----------: | :----: | :---: | :--------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
15+
| InternImage-H | 480x480 | 67.1 / 68.1 | 1.07B | 867G | [config](./mask2former_internimage_h_480_40k_nyu.py) | [ckpt](https://huggingface.co/OpenGVLab/InternImage/resolve/main/mask2former_internimage_h_480_40k_nyu.pth) \| [log](https://huggingface.co/OpenGVLab/InternImage/raw/main/mask2former_internimage_h_480_40k_nyu.log.json) |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# --------------------------------------------------------
2+
# InternImage
3+
# Copyright (c) 2022 OpenGVLab
4+
# Licensed under The MIT License [see LICENSE for details]
5+
# --------------------------------------------------------
6+
_base_ = [
7+
'../_base_/models/mask2former_beit.py', '../_base_/datasets/nyu_depth_v2.py',
8+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
9+
]
10+
num_classes = 40
11+
crop_size = (480, 480)
12+
pretrained = 'https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_h_jointto22k_384.pth'
13+
model = dict(
14+
type='EncoderDecoderMask2Former',
15+
backbone=dict(
16+
_delete_=True,
17+
type='InternImage',
18+
core_op='DCNv3',
19+
channels=320,
20+
depths=[6, 6, 32, 6],
21+
groups=[10, 20, 40, 80],
22+
mlp_ratio=4.,
23+
drop_path_rate=0.5,
24+
norm_layer='LN',
25+
layer_scale=None,
26+
offset_scale=1.0,
27+
post_norm=False,
28+
dw_kernel_size=5, # for InternImage-H/G
29+
res_post_norm=True, # for InternImage-H/G
30+
level2_post_norm=True, # for InternImage-H/G
31+
level2_post_norm_block_ids=[5, 11, 17, 23, 29], # for InternImage-H/G
32+
center_feature_scale=True, # for InternImage-H/G
33+
with_cp=False,
34+
out_indices=(0, 1, 2, 3),
35+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
36+
decode_head=dict(
37+
in_channels=[320, 640, 1280, 2560],
38+
feat_channels=256,
39+
out_channels=256,
40+
num_classes=num_classes,
41+
num_queries=100,
42+
pixel_decoder=dict(
43+
type='MSDeformAttnPixelDecoder',
44+
num_outs=3,
45+
norm_cfg=dict(type='GN', num_groups=32),
46+
act_cfg=dict(type='ReLU'),
47+
encoder=dict(
48+
type='DetrTransformerEncoder',
49+
num_layers=6,
50+
transformerlayers=dict(
51+
type='BaseTransformerLayer',
52+
attn_cfgs=dict(
53+
type='MultiScaleDeformableAttention',
54+
embed_dims=256,
55+
num_heads=8,
56+
num_levels=3,
57+
num_points=4,
58+
im2col_step=64,
59+
dropout=0.0,
60+
batch_first=False,
61+
norm_cfg=None,
62+
init_cfg=None),
63+
ffn_cfgs=dict(
64+
type='FFN',
65+
embed_dims=256,
66+
feedforward_channels=1024,
67+
num_fcs=2,
68+
ffn_drop=0.0,
69+
with_cp=False, # set with_cp=True to save memory
70+
act_cfg=dict(type='ReLU', inplace=True)),
71+
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
72+
init_cfg=None),
73+
positional_encoding=dict(
74+
type='SinePositionalEncoding', num_feats=128, normalize=True),
75+
init_cfg=None),
76+
positional_encoding=dict(
77+
type='SinePositionalEncoding', num_feats=128, normalize=True),
78+
transformer_decoder=dict(
79+
type='DetrTransformerDecoder',
80+
return_intermediate=True,
81+
num_layers=9,
82+
transformerlayers=dict(
83+
type='DetrTransformerDecoderLayer',
84+
attn_cfgs=dict(
85+
type='MultiheadAttention',
86+
embed_dims=256,
87+
num_heads=8,
88+
attn_drop=0.0,
89+
proj_drop=0.0,
90+
dropout_layer=None,
91+
batch_first=False),
92+
ffn_cfgs=dict(
93+
embed_dims=256,
94+
feedforward_channels=2048,
95+
num_fcs=2,
96+
act_cfg=dict(type='ReLU', inplace=True),
97+
ffn_drop=0.0,
98+
dropout_layer=None,
99+
with_cp=False, # set with_cp=True to save memory
100+
add_identity=True),
101+
feedforward_channels=2048,
102+
operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
103+
'ffn', 'norm')),
104+
init_cfg=None),
105+
loss_cls=dict(
106+
type='CrossEntropyLoss',
107+
use_sigmoid=False,
108+
loss_weight=2.0,
109+
reduction='mean',
110+
class_weight=[1.0] * num_classes + [0.1])
111+
),
112+
test_cfg=dict(mode='whole'))
113+
img_norm_cfg = dict(
114+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
115+
train_pipeline = [
116+
dict(type='LoadImageFromFile'),
117+
dict(type='LoadAnnotations', reduce_zero_label=True),
118+
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
119+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
120+
dict(type='RandomFlip', prob=0.5),
121+
dict(type='PhotoMetricDistortion'),
122+
dict(type='Normalize', **img_norm_cfg),
123+
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
124+
dict(type='ToMask'),
125+
dict(type='DefaultFormatBundle'),
126+
dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_masks', 'gt_labels'])
127+
]
128+
test_pipeline = [
129+
dict(type='LoadImageFromFile'),
130+
dict(
131+
type='MultiScaleFlipAug',
132+
img_scale=(640, 480),
133+
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
134+
flip=False,
135+
transforms=[
136+
dict(type='Resize', keep_ratio=True),
137+
dict(type='ResizeToMultiple', size_divisor=32),
138+
dict(type='RandomFlip'),
139+
dict(type='Normalize', **img_norm_cfg),
140+
dict(type='ImageToTensor', keys=['img']),
141+
dict(type='Collect', keys=['img']),
142+
])
143+
]
144+
optimizer = dict(
145+
_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05,
146+
constructor='CustomLayerDecayOptimizerConstructor',
147+
paramwise_cfg=dict(num_layers=50, layer_decay_rate=0.95,
148+
depths=[6, 6, 32, 6], offset_lr_scale=1.0))
149+
lr_config = dict(_delete_=True, policy='poly',
150+
warmup='linear',
151+
warmup_iters=1500,
152+
warmup_ratio=1e-6,
153+
power=1.0, min_lr=0.0, by_epoch=False)
154+
# By default, models are trained on 16 GPUs with 1 images per GPU
155+
data = dict(samples_per_gpu=1,
156+
train=dict(pipeline=train_pipeline),
157+
val=dict(pipeline=test_pipeline),
158+
test=dict(pipeline=test_pipeline))
159+
runner = dict(type='IterBasedRunner')
160+
checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1)
161+
evaluation = dict(interval=2000, metric='mIoU', save_best='mIoU')
162+
# fp16 = dict(loss_scale=dict(init_scale=512))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Pascal Context 59
2+
3+
<!-- [ALGORITHM] -->
4+
5+
## Introduction
6+
7+
The PASCAL Context dataset is an extension of the PASCAL VOC 2010 dataset, providing comprehensive pixel-wise annotations for over 400 classes, including the original 20 object categories and additional background classes. Due to the sparsity of many object categories, a subset of the 59 most frequent classes is commonly used for tasks like semantic segmentation.
8+
9+
## Model Zoo
10+
11+
### Mask2Former + InternImage
12+
13+
| backbone | resolution | mIoU (ss/ms) | #param | FLOPs | Config | Download |
14+
| :-----------: | :--------: | :----------: | :----: | :---: | :----------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
15+
| InternImage-H | 480x480 | 69.7 / 70.3 | 1.07B | 867G | [config](./mask2former_internimage_h_480_40k_pascal_context_59.py) | [ckpt](https://huggingface.co/OpenGVLab/InternImage/resolve/main/mask2former_internimage_h_480_40k_pascal_context_59.pth) \| [log](https://huggingface.co/OpenGVLab/InternImage/raw/main/mask2former_internimage_h_480_40k_pascal_context_59.log.json) |

0 commit comments

Comments
 (0)