Skip to content

Commit

Permalink
[Improve] Update Otter and LLaVA docs and config. (#1653)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 authored Jun 19, 2023
1 parent dbef2b4 commit 7d850df
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 20 deletions.
1 change: 1 addition & 0 deletions configs/llava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ from mmpretrain import get_model, inference_model
model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda')
out = inference_model(model, 'demo/cat-dog.png')
print(out)
# {'pred_caption': 'In the image, there are two cats sitting on a blanket.'}
```

**Test Command**
Expand Down
15 changes: 8 additions & 7 deletions configs/otter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ Large language models (LLMs) have demonstrated significant universal capabilitie
import torch
from mmpretrain import get_model, inference_model

model = get_model('otter-9b_3rdparty_caption', pretrained=True, device='cuda')
model = get_model('otter-9b_3rdparty_caption', pretrained=True, device='cuda', generation_cfg=dict(max_new_tokens=50))
out = inference_model(model, 'demo/cat-dog.png')
print(out)
# {'pred_caption': 'The image features two adorable small puppies sitting next to each other on the grass. One puppy is brown and white, while the other is tan and white. They appear to be relaxing outdoors, enjoying each other'}
```

**Test Command**
Expand All @@ -43,17 +44,17 @@ python tools/test.py configs/otter/otter-9b_caption.py https://download.openmmla

### Image Caption on COCO

| Model | Pretrain | Params (M) | BLEU-4 | CIDER | Config | Download |
| :---------------------------- | :----------: | :--------: | :------: | :------: | :---------------------------: | :------------------------------------------------------------------------------------------------------: |
| `otter-9b_3rdparty_caption`\* | From scratch | 8220.45 | Upcoming | Upcoming | [config](otter-9b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |
| Model | Params (M) | BLEU-4 | CIDER | Config | Download |
| :---------------------------- | :--------: | :------: | :------: | :---------------------------: | :------------------------------------------------------------------------------------------------------: |
| `otter-9b_3rdparty_caption`\* | 8220.45 | Upcoming | Upcoming | [config](otter-9b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |

*Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.*

### Visual Question Answering on VQAv2

| Model | Pretrain | Params (M) | Accuracy | Config | Download |
| :------------------------ | :----------: | :--------: | :------: | :-----------------------: | :------------------------------------------------------------------------------------------------------: |
| `otter-9b_3rdparty_vqa`\* | From scratch | 8220.45 | Upcoming | [config](otter-9b_vqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |
| Model | Params (M) | Accuracy | Config | Download |
| :------------------------ | :--------: | :------: | :-----------------------: | :------------------------------------------------------------------------------------------------------: |
| `otter-9b_3rdparty_vqa`\* | 8220.45 | Upcoming | [config](otter-9b_vqa.py) | [model](https://download.openmmlab.com/mmclassification/v1/otter/otter-9b-adapter_20230613-51c5be8d.pth) |

*Models with * are converted from the [official repo](https://github.com/Luodian/Otter/tree/main). The config files of these models are only for inference. We haven't reprodcue the training results.*

Expand Down
10 changes: 3 additions & 7 deletions configs/otter/otter-9b_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,18 @@
batch_size=8,
num_workers=8,
dataset=dict(
type='FlamingoEvalCOCOCaption',
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/captions_train2014.json',
data_prefix=dict(img_path='train2014'),
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
num_shots=0,
num_support_examples=2048,
num_query_examples=5000,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)

val_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/captions_train2014.json')
ann_file='data/coco/annotations/coco_karpathy_val_gt.json')

# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
Expand Down
9 changes: 5 additions & 4 deletions mmpretrain/models/multimodal/llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def forward(
mode: str = 'loss',
):
"""The unified entry for a forward process in both training and test.
The method should accept only one mode "loss":
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`DataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Expand All @@ -150,10 +151,10 @@ def forward(
- If ``mode="loss"``, return a dict of tensor.
"""

if mode == 'loss':
return self.loss(images, data_samples)
elif mode == 'predict':
if mode == 'predict':
return self.predict(images, data_samples)
elif mode == 'loss':
raise NotImplementedError
else:
raise RuntimeError(f'Invalid mode "{mode}".')

Expand Down
6 changes: 4 additions & 2 deletions mmpretrain/models/multimodal/otter/otter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

@MODELS.register_module()
class Otter(Flamingo):
"""The Open Flamingo model for multiple tasks.
"""The Otter model for multiple tasks.
Args:
vision_encoder (dict): The config of the vision encoder.
lang_encoder (dict): The config of the language encoder.
tokenizer (dict): The tokenizer to encode the text.
task (int): The task to perform prediction.
zeroshot_prompt (str): Prompt used for zero-shot inference.
Defaults to an.
shot_prompt_tmpl (str): Prompt used for few-shot inference.
Defaults to '<image>User:Please describe the image.
GPT:<answer>{caption}<|endofchunk|>'.
Expand Down Expand Up @@ -69,7 +71,7 @@ def __init__(

# init tokenizer
self.tokenizer = TOKENIZER.build(tokenizer)
# add Flamingo special tokens to the tokenizer
# add Otter special tokens to the tokenizer
self.tokenizer.add_special_tokens({
'additional_special_tokens':
['<|endofchunk|>', '<image>', '<answer>']
Expand Down

0 comments on commit 7d850df

Please sign in to comment.