From afc3d6f05025f4fbd91a9ab7e561753bd8aa0213 Mon Sep 17 00:00:00 2001 From: John Son <852172305@qq.com> Date: Mon, 10 Jul 2023 08:46:47 +0800 Subject: [PATCH] [feature] Visualizer compatible with MultiTaskDataSample --- mmpretrain/visualization/visualizer.py | 89 ++++++++++++++------- tests/test_visualization/test_visualizer.py | 42 +++++++++- 2 files changed, 103 insertions(+), 28 deletions(-) diff --git a/mmpretrain/visualization/visualizer.py b/mmpretrain/visualization/visualizer.py index 5d18ca87f6b..4a0dea2b804 100644 --- a/mmpretrain/visualization/visualizer.py +++ b/mmpretrain/visualization/visualizer.py @@ -11,7 +11,7 @@ from mmengine.visualization.utils import img_from_canvas from mmpretrain.registry import VISUALIZERS -from mmpretrain.structures import DataSample +from mmpretrain.structures import DataSample, MultiTaskDataSample from .utils import create_figure, get_adaptive_scale @@ -114,33 +114,9 @@ def visualize_cls(self, texts = [] self.set_image(image) - if draw_gt and 'gt_label' in data_sample: - idx = data_sample.gt_label.tolist() - class_labels = [''] * len(idx) - if classes is not None: - class_labels = [f' ({classes[i]})' for i in idx] - labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))] - prefix = 'Ground truth: ' - texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) - - if draw_pred and 'pred_label' in data_sample: - idx = data_sample.pred_label.tolist() - score_labels = [''] * len(idx) - class_labels = [''] * len(idx) - if draw_score and 'pred_score' in data_sample: - score_labels = [ - f', {data_sample.pred_score[i].item():.2f}' for i in idx - ] - - if classes is not None: - class_labels = [f' ({classes[i]})' for i in idx] + self.draw_gt(data_sample, classes, draw_gt, texts) - labels = [ - str(idx[i]) + score_labels[i] + class_labels[i] - for i in range(len(idx)) - ] - prefix = 'Prediction: ' - texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) + self.draw_pred(data_sample, classes, draw_pred, draw_score, texts) img_scale = get_adaptive_scale(image.shape[:2]) text_cfg = { @@ -167,6 +143,65 @@ def visualize_cls(self, return drawn_img + def draw_pred(self, + data_sample: DataSample, + classes: Optional[Sequence[str]], + draw_pred: bool, + draw_score: bool, + texts: Sequence[str], + parent_task: str = ''): + if isinstance(data_sample, MultiTaskDataSample): + for task in data_sample.tasks: + sub_task = f'{parent_task}_{task}' if parent_task else task + self.draw_pred( + data_sample.get(task), classes, draw_pred, draw_score, + texts, sub_task) + else: + if draw_pred and 'pred_label' in data_sample: + idx = data_sample.pred_label.tolist() + score_labels = [''] * len(idx) + class_labels = [''] * len(idx) + if draw_score and 'pred_score' in data_sample: + score_labels = [ + f', {data_sample.pred_score[i].item():.2f}' + for i in idx + ] + + if classes is not None: + class_labels = [f' ({classes[i]})' for i in idx] + + labels = [ + str(idx[i]) + score_labels[i] + class_labels[i] + for i in range(len(idx)) + ] + prefix = f'{parent_task} Prediction: ' if parent_task \ + else 'Prediction: ' + texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) + + def draw_gt(self, + data_sample: DataSample, + classes: Optional[Sequence[str]], + draw_gt: bool, + texts: Sequence[str], + parent_task: str = ''): + if isinstance(data_sample, MultiTaskDataSample): + for task in data_sample.tasks: + sub_task = f'{parent_task}_{task}' if parent_task else task + self.draw_gt( + data_sample.get(task), classes, draw_gt, texts, sub_task) + else: + if draw_gt and 'gt_label' in data_sample: + idx = data_sample.gt_label.tolist() + class_labels = [''] * len(idx) + if classes is not None: + class_labels = [f' ({classes[i]})' for i in idx] + labels = [ + str(idx[i]) + class_labels[i] for i in range(len(idx)) + ] + prefix = f'{parent_task} Ground truth: ' if parent_task \ + else 'Ground truth: ' + texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) + @master_only def visualize_image_retrieval(self, image: np.ndarray, diff --git a/tests/test_visualization/test_visualizer.py b/tests/test_visualization/test_visualizer.py index 900e495cf34..967a2f45abc 100644 --- a/tests/test_visualization/test_visualizer.py +++ b/tests/test_visualization/test_visualizer.py @@ -7,7 +7,7 @@ import numpy as np import torch -from mmpretrain.structures import DataSample +from mmpretrain.structures import DataSample, MultiTaskDataSample from mmpretrain.visualization import UniversalVisualizer @@ -123,6 +123,46 @@ def draw_texts(text, font_sizes, *_, **__): data_sample, rescale_factor=2.) + def test_visualize_multitask_cls(self): + image = np.ones((1000, 1000, 3), np.uint8) + gt_label = {'task0': {'task00': 2, 'task01': 1}, 'task1': 1} + data_sample = MultiTaskDataSample() + task_sample = DataSample().set_gt_label( + gt_label['task1']).set_pred_label(1).set_pred_score( + torch.tensor([0.1, 0.8, 0.1])) + data_sample.set_field(task_sample, 'task1') + data_sample.set_field(MultiTaskDataSample(), 'task0') + for task_name in gt_label['task0']: + task_sample = DataSample().set_gt_label( + gt_label['task0'][task_name]).set_pred_label(2).set_pred_score( + torch.tensor([0.1, 0.4, 0.5])) + data_sample.task0.set_field(task_sample, task_name) + + # Test show + def mock_show(drawn_img, win_name, wait_time): + self.assertFalse((image == drawn_img).all()) + self.assertEqual(win_name, 'test_cls') + self.assertEqual(wait_time, 0) + + with patch.object(self.vis, 'show', mock_show): + self.vis.visualize_cls( + image=image, + data_sample=data_sample, + show=True, + name='test_cls', + step=2) + + # Test storage backend. + save_file = osp.join(self.tmpdir.name, + 'vis_data/vis_image/test_cls_2.png') + self.assertTrue(osp.exists(save_file)) + + # Test out_file + out_file = osp.join(self.tmpdir.name, 'results_2.png') + self.vis.visualize_cls( + image=image, data_sample=data_sample, out_file=out_file) + self.assertTrue(osp.exists(out_file)) + def test_visualize_image_retrieval(self): image = np.ones((10, 10, 3), np.uint8) data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1])