Skip to content

Commit

Permalink
[Feature] add -with-labels arg to inferencer for visualization withou…
Browse files Browse the repository at this point in the history
…t labels (#3466)

Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

It is difficult to visualize without "labels" when using the inferencer.

- While using the `MMSegInferencer`, the visualized prediction contains
labels on the mask, but it is difficult to pass `withLabels=False`
without rewriting the config (which is harder to do when you initialize
the inferencer with a model name rather than the config).
- I thought it would be easier to just pass `withLabels=False` to
`inferencer.__call__()` since you can also pass `opacity` and other
parameters anyway.

## Modification

Please briefly describe what modification is made in this PR.

- Added `with_labels` to `visualize_kwargs` inside `MMSegInferencer`.
- Modified to `visualize()` function.

## BC-breaking (Optional)

Does the modification introduce changes that break the
backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the
downstream projects should modify their code to keep compatibility with
this PR.

## Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases
here, and update the documentation.

## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
2. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
3. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
4. The documentation has been modified accordingly, like docstring or
example tutorials.

---------

Co-authored-by: xiexinch <[email protected]>
  • Loading branch information
haruishi43 and xiexinch authored Dec 14, 2023
1 parent 7451459 commit c7ac97d
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 14 deletions.
6 changes: 6 additions & 0 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def main():
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
parser.add_argument(
'--with-labels',
action='store_true',
default=False,
help='Whether to display the class labels.')
parser.add_argument(
'--title', default='result', help='The image identifier.')
args = parser.parse_args()
Expand All @@ -36,6 +41,7 @@ def main():
result,
title=args.title,
opacity=args.opacity,
with_labels=args.with_labels,
draw_gt=False,
show=False if args.out_file is not None else True,
out_file=args.out_file)
Expand Down
11 changes: 10 additions & 1 deletion demo/image_demo_with_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def main():
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
parser.add_argument(
'--with-labels',
action='store_true',
default=False,
help='Whether to display the class labels.')
args = parser.parse_args()

# build the model from a config file and a checkpoint file
Expand All @@ -38,7 +43,11 @@ def main():

# test a single image
mmseg_inferencer(
args.img, show=args.show, out_dir=args.out_dir, opacity=args.opacity)
args.img,
show=args.show,
out_dir=args.out_dir,
opacity=args.opacity,
with_labels=args.with_labels)


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions mmseg/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def show_result_pyplot(model: BaseSegmentor,
draw_pred: bool = True,
wait_time: float = 0,
show: bool = True,
withLabels: Optional[bool] = True,
with_labels: Optional[bool] = True,
save_dir=None,
out_file=None):
"""Visualize the segmentation results on the image.
Expand All @@ -147,7 +147,7 @@ def show_result_pyplot(model: BaseSegmentor,
that means "forever". Defaults to 0.
show (bool): Whether to display the drawn image.
Default to True.
withLabels(bool, optional): Add semantic labels in visualization
with_labels(bool, optional): Add semantic labels in visualization
result, Default to True.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
Expand Down Expand Up @@ -183,7 +183,7 @@ def show_result_pyplot(model: BaseSegmentor,
wait_time=wait_time,
out_file=out_file,
show=show,
withLabels=withLabels)
with_labels=with_labels)
vis_img = visualizer.get_image()

return vis_img
9 changes: 6 additions & 3 deletions mmseg/apis/mmseg_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class MMSegInferencer(BaseInferencer):
preprocess_kwargs: set = set()
forward_kwargs: set = {'mode', 'out_dir'}
visualize_kwargs: set = {
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis'
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis',
'with_labels'
}
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}

Expand Down Expand Up @@ -201,7 +202,8 @@ def visualize(self,
show: bool = False,
wait_time: int = 0,
img_out_dir: str = '',
opacity: float = 0.8) -> List[np.ndarray]:
opacity: float = 0.8,
with_labels: Optional[bool] = True) -> List[np.ndarray]:
"""Visualize predictions.
Args:
Expand Down Expand Up @@ -254,7 +256,8 @@ def visualize(self,
wait_time=wait_time,
draw_gt=False,
draw_pred=True,
out_file=out_file)
out_file=out_file,
with_labels=with_labels)
if return_vis:
results.append(self.visualizer.get_image())
self.num_visualized_imgs += 1
Expand Down
14 changes: 7 additions & 7 deletions mmseg/visualization/local_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _draw_sem_seg(self,
sem_seg: PixelData,
classes: Optional[List],
palette: Optional[List],
withLabels: Optional[bool] = True) -> np.ndarray:
with_labels: Optional[bool] = True) -> np.ndarray:
"""Draw semantic seg of GT or prediction.
Args:
Expand All @@ -119,7 +119,7 @@ def _draw_sem_seg(self,
palette (list, optional): Input palette for result rendering, which
is a list of color palette responding to the classes.
Defaults to None.
withLabels(bool, optional): Add semantic labels in visualization
with_labels(bool, optional): Add semantic labels in visualization
result, Default to True.
Returns:
Expand All @@ -139,7 +139,7 @@ def _draw_sem_seg(self,
for label, color in zip(labels, colors):
mask[sem_seg[0] == label, :] = color

if withLabels:
if with_labels:
font = cv2.FONT_HERSHEY_SIMPLEX
# (0,1] to change the size of the text relative to the image
scale = 0.05
Expand Down Expand Up @@ -265,7 +265,7 @@ def add_datasample(
# TODO: Supported in mmengine's Viusalizer.
out_file: Optional[str] = None,
step: int = 0,
withLabels: Optional[bool] = True) -> None:
with_labels: Optional[bool] = True) -> None:
"""Draw datasample and save to all backends.
- If GT and prediction are plotted at the same time, they are
Expand All @@ -291,7 +291,7 @@ def add_datasample(
wait_time (float): The interval of show (s). Defaults to 0.
out_file (str): Path to output file. Defaults to None.
step (int): Global step value to record. Defaults to 0.
withLabels(bool, optional): Add semantic labels in visualization
with_labels(bool, optional): Add semantic labels in visualization
result, Defaults to True.
"""
classes = self.dataset_meta.get('classes', None)
Expand All @@ -307,7 +307,7 @@ def add_datasample(
'visualizing semantic ' \
'segmentation results.'
gt_img_data = self._draw_sem_seg(image, data_sample.gt_sem_seg,
classes, palette, withLabels)
classes, palette, with_labels)

if 'gt_depth_map' in data_sample:
gt_img_data = gt_img_data if gt_img_data is not None else image
Expand All @@ -325,7 +325,7 @@ def add_datasample(
pred_img_data = self._draw_sem_seg(image,
data_sample.pred_sem_seg,
classes, palette,
withLabels)
with_labels)

if 'pred_depth_map' in data_sample:
pred_img_data = pred_img_data if pred_img_data is not None \
Expand Down

0 comments on commit c7ac97d

Please sign in to comment.