-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merging DiVA to Levanter Main #779
base: main
Are you sure you want to change the base?
Conversation
I am currently annoyed by how we initialize models and this seems fine enough (cf #780 ) so I don't have a super strong feeling right now on it. You could look to how we do Lora if you want, but that's a bit of a different case. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Few minor comments. I don't understand everything but overall seems good to me!
lev_model: DivaModel = super().load_pretrained( | ||
DivaModel, ref, config, axis_mapping, resize_vocab_to_match_tokenizer, dtype | ||
) # type: ignore[assignment] | ||
llm: Union[LlamaLMHeadModel | MistralLMHeadModel | GemmaLMHeadModel] = HFCheckpointConverter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need this Union around your 3.10 union?
elif "gemma" in model_id: | ||
config = GemmaConfig.from_hf_config(hf_config) | ||
elif "mistral" in model_id: | ||
config = MistralConfig.from_hf_config(hf_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should you raise a better error if it's none of these?
return config | ||
|
||
|
||
def get_prefix(tokenizer_ref): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doccomment maybe?
return prefix_tok, suffix_tok | ||
|
||
|
||
@LmConfig.register_subclass("diva") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AsrConfig?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good catch.
init_from_submodel: bool = True | ||
|
||
# Connector Config | ||
pre_audio_prompt = property(lambda self: get_prefix(self.reference_decoder)[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cache_property or who cares?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely cache property.
) | ||
|
||
|
||
class DivaModel(eqx.Module, ModelWithHfSerializationMixin[DivaConfig]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe link to paper or something?
|
||
# Convert to Virtual LLM Tokens | ||
virt_whisper_tokens = self.connector.transformer( | ||
(self.query_tokens + self.query_position_embeds).broadcast_axis(OtherAxes), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hrm i wouldn't think this should be necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll double check. I've somewhat forgotten why I added these explicit broadcasts.
text[ | ||
{ | ||
"batch": hax.arange(Batch), | ||
"position": (hax.sum(text_tokens == pad_token_id, "position") * -1) - 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you just broke my brain
kl_proxy_loss = hax.dot(diff_distill, diff_distill, axis="embed") ** 0.5 | ||
|
||
# Compute Contrastive Loss on Input | ||
# Correct for Normal Autoregressive Loss Mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to check that attn mask is causal or eh?
Still investigating a bit of stuff here - used this code to reproduce the original DiVA model with Llama 3 8B - but hitting some weirdness with Llama 3.1 8B where the resulting model has a lot of repetitions. Hypotheses:
|
hrm happy to pair if that would be helpful. We can definitely investigate the rope thing. It's a constant pain |
it looks like rope is exactly the same for llama3 and 3.1 so it's probably not that, unless you haven't merged main in month or two. I did fix a bug in #740 |
2e6ca68
to
6c9f6f0
Compare
6c9f6f0
to
69f29b4
Compare
Cleaned up version of my code for the Distilled Voice Assistant models that I trained using a fork of Levanter!
@dlwh Main thing I want to check in with you here is what the appropriate design pattern you think would make sense for initializing the model weights from multiple other pretrained models would be! What I've done here is much cleaner than what I did originally for the paper, but still feels a bit messy.
Testing Procedure for the correctness of this training code:
I trained a new DiVA model with this updated code and Llama 3.2 1B using the config in
diva_flash.yaml
.Training Log is here: https://wandb.ai/i18nlp/levanter/runs/jnxp463y?nw=nwuserheld
Resulting model is on HF in PyTorch form here: https://huggingface.co/WillHeld/DiVA-llama-3.2-1b
Demo which confirmed the result is ~reasonable here for now: https://b3f161194b514a990f.gradio.live/