Skip to content

Commit c7ac97d

Browse files
haruishi43xiexinch
andauthored
[Feature] add -with-labels arg to inferencer for visualization without labels (#3466)
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation It is difficult to visualize without "labels" when using the inferencer. - While using the `MMSegInferencer`, the visualized prediction contains labels on the mask, but it is difficult to pass `withLabels=False` without rewriting the config (which is harder to do when you initialize the inferencer with a model name rather than the config). - I thought it would be easier to just pass `withLabels=False` to `inferencer.__call__()` since you can also pass `opacity` and other parameters anyway. ## Modification Please briefly describe what modification is made in this PR. - Added `with_labels` to `visualize_kwargs` inside `MMSegInferencer`. - Modified to `visualize()` function. ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 4. The documentation has been modified accordingly, like docstring or example tutorials. --------- Co-authored-by: xiexinch <[email protected]>
1 parent 7451459 commit c7ac97d

File tree

5 files changed

+32
-14
lines changed

5 files changed

+32
-14
lines changed

demo/image_demo.py

+6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ def main():
1919
type=float,
2020
default=0.5,
2121
help='Opacity of painted segmentation map. In (0, 1] range.')
22+
parser.add_argument(
23+
'--with-labels',
24+
action='store_true',
25+
default=False,
26+
help='Whether to display the class labels.')
2227
parser.add_argument(
2328
'--title', default='result', help='The image identifier.')
2429
args = parser.parse_args()
@@ -36,6 +41,7 @@ def main():
3641
result,
3742
title=args.title,
3843
opacity=args.opacity,
44+
with_labels=args.with_labels,
3945
draw_gt=False,
4046
show=False if args.out_file is not None else True,
4147
out_file=args.out_file)

demo/image_demo_with_inferencer.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ def main():
2727
type=float,
2828
default=0.5,
2929
help='Opacity of painted segmentation map. In (0, 1] range.')
30+
parser.add_argument(
31+
'--with-labels',
32+
action='store_true',
33+
default=False,
34+
help='Whether to display the class labels.')
3035
args = parser.parse_args()
3136

3237
# build the model from a config file and a checkpoint file
@@ -38,7 +43,11 @@ def main():
3843

3944
# test a single image
4045
mmseg_inferencer(
41-
args.img, show=args.show, out_dir=args.out_dir, opacity=args.opacity)
46+
args.img,
47+
show=args.show,
48+
out_dir=args.out_dir,
49+
opacity=args.opacity,
50+
with_labels=args.with_labels)
4251

4352

4453
if __name__ == '__main__':

mmseg/apis/inference.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def show_result_pyplot(model: BaseSegmentor,
127127
draw_pred: bool = True,
128128
wait_time: float = 0,
129129
show: bool = True,
130-
withLabels: Optional[bool] = True,
130+
with_labels: Optional[bool] = True,
131131
save_dir=None,
132132
out_file=None):
133133
"""Visualize the segmentation results on the image.
@@ -147,7 +147,7 @@ def show_result_pyplot(model: BaseSegmentor,
147147
that means "forever". Defaults to 0.
148148
show (bool): Whether to display the drawn image.
149149
Default to True.
150-
withLabels(bool, optional): Add semantic labels in visualization
150+
with_labels(bool, optional): Add semantic labels in visualization
151151
result, Default to True.
152152
save_dir (str, optional): Save file dir for all storage backends.
153153
If it is None, the backend storage will not save any data.
@@ -183,7 +183,7 @@ def show_result_pyplot(model: BaseSegmentor,
183183
wait_time=wait_time,
184184
out_file=out_file,
185185
show=show,
186-
withLabels=withLabels)
186+
with_labels=with_labels)
187187
vis_img = visualizer.get_image()
188188

189189
return vis_img

mmseg/apis/mmseg_inferencer.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ class MMSegInferencer(BaseInferencer):
6060
preprocess_kwargs: set = set()
6161
forward_kwargs: set = {'mode', 'out_dir'}
6262
visualize_kwargs: set = {
63-
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis'
63+
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis',
64+
'with_labels'
6465
}
6566
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
6667

@@ -201,7 +202,8 @@ def visualize(self,
201202
show: bool = False,
202203
wait_time: int = 0,
203204
img_out_dir: str = '',
204-
opacity: float = 0.8) -> List[np.ndarray]:
205+
opacity: float = 0.8,
206+
with_labels: Optional[bool] = True) -> List[np.ndarray]:
205207
"""Visualize predictions.
206208
207209
Args:
@@ -254,7 +256,8 @@ def visualize(self,
254256
wait_time=wait_time,
255257
draw_gt=False,
256258
draw_pred=True,
257-
out_file=out_file)
259+
out_file=out_file,
260+
with_labels=with_labels)
258261
if return_vis:
259262
results.append(self.visualizer.get_image())
260263
self.num_visualized_imgs += 1

mmseg/visualization/local_visualizer.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _draw_sem_seg(self,
103103
sem_seg: PixelData,
104104
classes: Optional[List],
105105
palette: Optional[List],
106-
withLabels: Optional[bool] = True) -> np.ndarray:
106+
with_labels: Optional[bool] = True) -> np.ndarray:
107107
"""Draw semantic seg of GT or prediction.
108108
109109
Args:
@@ -119,7 +119,7 @@ def _draw_sem_seg(self,
119119
palette (list, optional): Input palette for result rendering, which
120120
is a list of color palette responding to the classes.
121121
Defaults to None.
122-
withLabels(bool, optional): Add semantic labels in visualization
122+
with_labels(bool, optional): Add semantic labels in visualization
123123
result, Default to True.
124124
125125
Returns:
@@ -139,7 +139,7 @@ def _draw_sem_seg(self,
139139
for label, color in zip(labels, colors):
140140
mask[sem_seg[0] == label, :] = color
141141

142-
if withLabels:
142+
if with_labels:
143143
font = cv2.FONT_HERSHEY_SIMPLEX
144144
# (0,1] to change the size of the text relative to the image
145145
scale = 0.05
@@ -265,7 +265,7 @@ def add_datasample(
265265
# TODO: Supported in mmengine's Viusalizer.
266266
out_file: Optional[str] = None,
267267
step: int = 0,
268-
withLabels: Optional[bool] = True) -> None:
268+
with_labels: Optional[bool] = True) -> None:
269269
"""Draw datasample and save to all backends.
270270
271271
- If GT and prediction are plotted at the same time, they are
@@ -291,7 +291,7 @@ def add_datasample(
291291
wait_time (float): The interval of show (s). Defaults to 0.
292292
out_file (str): Path to output file. Defaults to None.
293293
step (int): Global step value to record. Defaults to 0.
294-
withLabels(bool, optional): Add semantic labels in visualization
294+
with_labels(bool, optional): Add semantic labels in visualization
295295
result, Defaults to True.
296296
"""
297297
classes = self.dataset_meta.get('classes', None)
@@ -307,7 +307,7 @@ def add_datasample(
307307
'visualizing semantic ' \
308308
'segmentation results.'
309309
gt_img_data = self._draw_sem_seg(image, data_sample.gt_sem_seg,
310-
classes, palette, withLabels)
310+
classes, palette, with_labels)
311311

312312
if 'gt_depth_map' in data_sample:
313313
gt_img_data = gt_img_data if gt_img_data is not None else image
@@ -325,7 +325,7 @@ def add_datasample(
325325
pred_img_data = self._draw_sem_seg(image,
326326
data_sample.pred_sem_seg,
327327
classes, palette,
328-
withLabels)
328+
with_labels)
329329

330330
if 'pred_depth_map' in data_sample:
331331
pred_img_data = pred_img_data if pred_img_data is not None \

0 commit comments

Comments
 (0)