Skip to content

Commit

Permalink
Fix Token Shuffling
Browse files Browse the repository at this point in the history
  • Loading branch information
Helw150 committed Nov 14, 2024
1 parent 8f04255 commit 2e6ca68
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
2 changes: 2 additions & 0 deletions config/diva_flash.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ optimizer:
#learning_rate: 5E-5
learning_rate: 5e-4
weight_decay: 0.1
weight_decay_modules: None
default_weight_decay_mask: False
warmup: 0.01
hf_save_path: gs://diva-flash/librispeech-hf-checkpoints
diva_training: true
10 changes: 5 additions & 5 deletions src/levanter/main/train_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ def compute_loss(
model_init=lambda: config.model.build_asr(Vocab, key=model_key),
)

if config.diva_training and config.model.asr_model_type == DivaASRModel:
state = dataclasses.replace(state, model=None)
model = DivaASRModel.init(Vocab, config.model, key=model_key, init_from_submodels=True)
model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model)
state = dataclasses.replace(state, model=model, is_trainable=diva_connector_only(model))
if int(state.step) == 0:
if config.diva_training and config.model.asr_model_type == DivaASRModel:
state = dataclasses.replace(state, model=None)
model = DivaASRModel.init(Vocab, config.model, key=model_key, init_from_submodels=True)
model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model)
state = dataclasses.replace(state, model=model, is_trainable=diva_connector_only(model))
# TODO: I don't love that we init the model twice, but it's not a big deal i think?
if config.initialize_from_hf:
# initialize from an hf pretrained model
Expand Down
32 changes: 20 additions & 12 deletions src/levanter/models/diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class DivaConfig(HFCompatConfig, ASRConfig):
reference_encoder: str = "openai/whisper-large-v3-turbo"
reference_decoder: str = "meta-llama/Llama-3.1-8B-Instruct"
reference_checkpoint: str = "WillHeld/DiVA-llama-3-v0-8b"
max_length: int = 448
init_from_submodel: bool = True

# Connector Config
Expand All @@ -118,7 +119,7 @@ class DivaConfig(HFCompatConfig, ASRConfig):
)
prefix = property(lambda self: hax.named(self.pre_audio_prompt, axis="position"))
suffix = property(lambda self: hax.named(self.pre_text_prompt, axis="position"))
Pos = property(lambda self: Axis(name="position", size=448))
Pos = property(lambda self: Axis(name="position", size=self.max_length))
AudioPos = property(lambda self: self.enc_config.AudioPos)
KeyPos = property(lambda self: self.Pos.alias("key_position"))

Expand All @@ -138,6 +139,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[dict] = None)
"vocab_size": vocab_size,
"reference_encoder": self.reference_encoder,
"reference_decoder": self.reference_decoder,
"max_length": self.max_length,
}
return HfConfig.from_dict(merged_config)

Expand All @@ -146,7 +148,7 @@ def from_hf_config(cls, hf_config: HfConfig):
config_dict = hf_config.to_dict()
reference_encoder = config_dict["encoder_reference"]
reference_decoder = config_dict["decoder_reference"]
return DivaConfig(reference_encoder, reference_decoder)
return DivaConfig(reference_encoder, reference_decoder, max_length=config_dict["max_length"])

def hf_checkpoint_converter(cls) -> HFCheckpointConverter["DivaModel"]: # type: ignore
return DivaHFCheckpointer(cls, cls.reference_checkpoint, trust_remote_code=True)
Expand Down Expand Up @@ -198,7 +200,7 @@ def init(

query_tokens = hax.random.normal(k_query, (config.Pos, config.enc_config.Embed)) * 0.02
projection = hnn.Linear.init(
In=config.enc_config.Embed.alias("whisp_embed"), Out=config.dec_config.Embed, key=key
In=config.enc_config.Embed.alias("whisp_embed"), Out=config.dec_config.Embed, init_scale=0.01, key=key
)

if init_from_submodels:
Expand All @@ -217,8 +219,15 @@ def init(
config.enc_config,
) # type: ignore[assignment]
encoder = whisper.encoder
connector = whisper.decoder
# connector = whisper.decoder
connector = WhisperDecoder.init(config.enc_config, key=k_connector)
decoder = llm
mean_embedding = hax.mean(llm.embeddings.token_embeddings.weight, llm.embeddings.Vocab)
projection = dataclasses.replace(
projection,
weight=hax.rearrange(mean_embedding.broadcast_axis(projection.In), (projection.Out, projection.In)),
)

else:
encoder = WhisperEncoder.init(config.enc_config, key=k_enc)
connector = WhisperDecoder.init(config.enc_config, key=k_connector)
Expand All @@ -235,7 +244,7 @@ def __call__(
mel: NamedArray,
input_ids: NamedArray,
attn_mask: Optional[AttentionMask | NamedArray] = None,
pad_token_id: int = 128002,
pad_token_id: int = 128255,
*,
key=None,
) -> NamedArray:
Expand Down Expand Up @@ -273,7 +282,6 @@ def __call__(
suffix.broadcast_axis(OtherAxes),
],
)

text_tokens = hax.concatenate(
"position",
[
Expand All @@ -283,7 +291,6 @@ def __call__(
],
)
push_back_padding = hax.argsort(text_tokens == pad_token_id, "position")

text_tokens_left_pad = text_tokens[{"batch": hax.arange(Batch), "position": push_back_padding}]

text_embeds = self.decoder.embeddings.embed(text_tokens_left_pad)
Expand All @@ -292,7 +299,7 @@ def __call__(
text = self.decoder.transformer(text_embeds, attn_mask=causal_mask, key=k_decoder)

push_forward_padding = hax.argsort(input_ids != pad_token_id, "position")
input_ids_right_pad = text_tokens[{"batch": hax.arange(Batch), "position": push_forward_padding}]
input_ids_right_pad = input_ids[{"batch": hax.arange(Batch), "position": push_forward_padding}]
return (
audio["position", -1],
text[
Expand Down Expand Up @@ -332,11 +339,12 @@ def compute_loss(
# Mask Final Tokens So That Initial Tokens can be used for extra computation
reversed_loss_mask = corrected_loss_mask["position", -1::-1]
diff_contrast = virtual_embeds - real_embeds
loss2 = hax.dot(diff_contrast, diff_contrast, axis="embed") ** 0.5
tal_loss = hax.dot(diff_contrast, diff_contrast, axis="embed") ** 0.5

if reduction is None:
return kl_proxy_loss
else:
return reduction(kl_proxy_loss, axis=reduction_axis) + reduction(
loss2, axis=reduction_axis, where=reversed_loss_mask
)
loss1 = reduction(kl_proxy_loss, axis=reduction_axis)
loss2 = reduction(tal_loss, axis=reduction_axis, where=reversed_loss_mask)
loss = loss1 + loss2
return loss

0 comments on commit 2e6ca68

Please sign in to comment.