From 7d850dfadd6f3bec676e9095d9b6d62c13c50756 Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Mon, 19 Jun 2023 20:16:13 +0800 Subject: [PATCH] [Improve] Update Otter and LLaVA docs and config. (#1653) --- configs/llava/README.md | 1 + configs/otter/README.md | 15 ++++++++------- configs/otter/otter-9b_caption.py | 10 +++------- mmpretrain/models/multimodal/llava/llava.py | 9 +++++---- mmpretrain/models/multimodal/otter/otter.py | 6 ++++-- 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/configs/llava/README.md b/configs/llava/README.md index 8683a795845..31275f7748c 100644 --- a/configs/llava/README.md +++ b/configs/llava/README.md @@ -34,6 +34,7 @@ from mmpretrain import get_model, inference_model model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda') out = inference_model(model, 'demo/cat-dog.png') print(out) +# {'pred_caption': 'In the image, there are two cats sitting on a blanket.'} ``` **Test Command** diff --git a/configs/otter/README.md b/configs/otter/README.md index b96779a6a4d..e0cafde546e 100644 --- a/configs/otter/README.md +++ b/configs/otter/README.md @@ -22,9 +22,10 @@ Large language models (LLMs) have demonstrated significant universal capabilitie import torch from mmpretrain import get_model, inference_model -model = get_model('otter-9b_3rdparty_caption', pretrained=True, device='cuda') +model = get_model('otter-9b_3rdparty_caption', pretrained=True, device='cuda', generation_cfg=dict(max_new_tokens=50)) out = inference_model(model, 'demo/cat-dog.png') print(out) +# {'pred_caption': 'The image features two adorable small puppies sitting next to each other on the grass. One puppy is brown and white, while the other is tan and white. They appear to be relaxing outdoors, enjoying each other'} ``` **Test Command** @@ -43,17 +44,17 @@ python tools/test.py configs/otter/otter-9b_caption.py https://download.openmmla ### Image Caption on COCO -| Model | Pretrain | Params (M) | BLEU-4 | CIDER | Config | Download | -| :---------------------------- | :----------: | :--------: | :------: | :------: | :---------------------------: | :------------------------------------------------------------------------------------------------------: | -| `otter-9b_3rdparty_caption`\* | From scratch | 8220.45 | Upcoming | Upcoming | [config](otter-9b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) | +| Model | Params (M) | BLEU-4 | CIDER | Config | Download | +| :---------------------------- | :--------: | :------: | :------: | :---------------------------: | :------------------------------------------------------------------------------------------------------: | +| `otter-9b_3rdparty_caption`\* | 8220.45 | Upcoming | Upcoming | [config](otter-9b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) | *Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.* ### Visual Question Answering on VQAv2 -| Model | Pretrain | Params (M) | Accuracy | Config | Download | -| :------------------------ | :----------: | :--------: | :------: | :-----------------------: | :------------------------------------------------------------------------------------------------------: | -| `otter-9b_3rdparty_vqa`\* | From scratch | 8220.45 | Upcoming | [config](otter-9b_vqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) | +| Model | Params (M) | Accuracy | Config | Download | +| :------------------------ | :--------: | :------: | :-----------------------: | :------------------------------------------------------------------------------------------------------: | +| `otter-9b_3rdparty_vqa`\* | 8220.45 | Upcoming | [config](otter-9b_vqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) | *Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.* diff --git a/configs/otter/otter-9b_caption.py b/configs/otter/otter-9b_caption.py index 3985e9534c8..e35e92ef40c 100644 --- a/configs/otter/otter-9b_caption.py +++ b/configs/otter/otter-9b_caption.py @@ -65,14 +65,10 @@ batch_size=8, num_workers=8, dataset=dict( - type='FlamingoEvalCOCOCaption', + type='COCOCaption', data_root='data/coco', - ann_file='annotations/captions_train2014.json', - data_prefix=dict(img_path='train2014'), + ann_file='annotations/coco_karpathy_val.json', pipeline=test_pipeline, - num_shots=0, - num_support_examples=2048, - num_query_examples=5000, ), sampler=dict(type='DefaultSampler', shuffle=False), persistent_workers=True, @@ -80,7 +76,7 @@ val_evaluator = dict( type='COCOCaption', - ann_file='data/coco/annotations/captions_train2014.json') + ann_file='data/coco/annotations/coco_karpathy_val_gt.json') # If you want standard test, please manually configure the test dataset test_dataloader = val_dataloader diff --git a/mmpretrain/models/multimodal/llava/llava.py b/mmpretrain/models/multimodal/llava/llava.py index 966c9462854..1c300fdcd05 100644 --- a/mmpretrain/models/multimodal/llava/llava.py +++ b/mmpretrain/models/multimodal/llava/llava.py @@ -129,8 +129,9 @@ def forward( mode: str = 'loss', ): """The unified entry for a forward process in both training and test. - The method should accept only one mode "loss": + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`DataSample`. - "loss": Forward and return a dict of losses according to the given inputs and data samples. @@ -150,10 +151,10 @@ def forward( - If ``mode="loss"``, return a dict of tensor. """ - if mode == 'loss': - return self.loss(images, data_samples) - elif mode == 'predict': + if mode == 'predict': return self.predict(images, data_samples) + elif mode == 'loss': + raise NotImplementedError else: raise RuntimeError(f'Invalid mode "{mode}".') diff --git a/mmpretrain/models/multimodal/otter/otter.py b/mmpretrain/models/multimodal/otter/otter.py index 189b6619ace..2fed1a4d27c 100644 --- a/mmpretrain/models/multimodal/otter/otter.py +++ b/mmpretrain/models/multimodal/otter/otter.py @@ -10,13 +10,15 @@ @MODELS.register_module() class Otter(Flamingo): - """The Open Flamingo model for multiple tasks. + """The Otter model for multiple tasks. Args: vision_encoder (dict): The config of the vision encoder. lang_encoder (dict): The config of the language encoder. tokenizer (dict): The tokenizer to encode the text. task (int): The task to perform prediction. + zeroshot_prompt (str): Prompt used for zero-shot inference. + Defaults to an. shot_prompt_tmpl (str): Prompt used for few-shot inference. Defaults to 'User:Please describe the image. GPT:{caption}<|endofchunk|>'. @@ -69,7 +71,7 @@ def __init__( # init tokenizer self.tokenizer = TOKENIZER.build(tokenizer) - # add Flamingo special tokens to the tokenizer + # add Otter special tokens to the tokenizer self.tokenizer.add_special_tokens({ 'additional_special_tokens': ['<|endofchunk|>', '', '']