Skip to content

Commit

Permalink
Pull Master
Browse files Browse the repository at this point in the history
  • Loading branch information
Helw150 committed Nov 21, 2024
1 parent e26412d commit 69f29b4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
18 changes: 10 additions & 8 deletions src/levanter/models/diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class DivaConfig(HFCompatConfig, ASRConfig):
)
prefix = property(lambda self: hax.named(self.pre_audio_prompt, axis="position"))
suffix = property(lambda self: hax.named(self.pre_text_prompt, axis="position"))
Embed = property(lambda self: self.dec_config.Embed)
Pos = property(lambda self: Axis(name="position", size=self.max_length))
AudioPos = property(lambda self: self.enc_config.AudioPos)
KeyPos = property(lambda self: self.Pos.alias("key_position"))
Expand Down Expand Up @@ -204,12 +205,14 @@ def init(
)

if init_from_submodels:
llm: Union[LlamaLMHeadModel | MistralLMHeadModel | GemmaLMHeadModel] = HFCheckpointConverter(
type(config.dec_config), config.reference_decoder
).load_pretrained(
config.dec_config.model_type,
config.reference_decoder,
config.dec_config,
llm: Union[LlamaLMHeadModel | MistralLMHeadModel | GemmaLMHeadModel] = (
HFCheckpointConverter(type(config.dec_config), config.reference_decoder)
.load_pretrained(
config.dec_config.model_type,
config.reference_decoder,
config.dec_config,
)
.resize_vocab(Vocab.size)
) # type: ignore[assignment]
whisper: WhisperModel = HFCheckpointConverter(
WhisperConfig, config.reference_encoder, ignore_prefix="model"
Expand All @@ -219,8 +222,7 @@ def init(
config.enc_config,
) # type: ignore[assignment]
encoder = whisper.encoder
# connector = whisper.decoder
connector = WhisperDecoder.init(config.enc_config, key=k_connector)
connector = whisper.decoder
decoder = llm
mean_embedding = hax.mean(llm.embeddings.token_embeddings.weight, llm.embeddings.Vocab)
projection = dataclasses.replace(
Expand Down
11 changes: 9 additions & 2 deletions src/levanter/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
from levanter.compat.hf_checkpoints import HFCheckpointConverter
from levanter.logging import silence_transformer_nag
from levanter.models.attention import AttentionMask, dot_product_attention
from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaMlp, LlamaRMSNorm, LlamaTransformer
from levanter.models.llama import (
LlamaConfig,
LlamaEmbedding,
LlamaMlp,
LlamaRMSNorm,
LlamaTransformer,
LlamaLMHeadModel,
)
from levanter.models.lm_model import LmConfig, LmHeadModel
from levanter.models.rotary import RotaryEmbeddingsConfig
from levanter.types import BlockFoldable
Expand Down Expand Up @@ -264,7 +271,7 @@ def init(config: QwenConfig, *, key) -> "QwenTransformer":


# Modified LM head model for Qwen
class QwenLMHeadModel(LmHeadModel[QwenConfig], ModuleWithStateDictSerialization):
class QwenLMHeadModel(LlamaLMHeadModel, ModuleWithStateDictSerialization):
transformer: QwenTransformer
embeddings: LlamaEmbedding # Can reuse Llama embeddings
lm_head: Optional[hnn.Linear]
Expand Down
32 changes: 26 additions & 6 deletions src/levanter/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ class WhisperMlp(eqx.Module):
@staticmethod
def init(Embed: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = True) -> "WhisperMlp":
k_fc, k_proj = haliax.jax_utils.maybe_rng_split(key, 2)
fc1 = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias, out_first=False)
fc2 = hnn.Linear.init(Out=Embed, In=Mlp, key=k_proj, use_bias=use_bias, out_first=False)
fc1 = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias)
fc2 = hnn.Linear.init(Out=Embed, In=Mlp, key=k_proj, use_bias=use_bias)
if isinstance(activation_fn, str):
activation_fn = ACT2FN[activation_fn]
act = activation_fn # type: ignore
Expand Down Expand Up @@ -164,10 +164,30 @@ def init(Heads: Axis, HeadSize: Axis, config: WhisperConfig, *, key) -> "Whisper
Embed = config.Embed

k_q, k_k, k_v, k_out = haliax.jax_utils.maybe_rng_split(key, 4)
q_proj = hnn.Linear.init(In=Embed, Out=(Heads, HeadSize), key=k_q, use_bias=use_bias, out_first=False)
k_proj = hnn.Linear.init(In=Embed, Out=(Heads, HeadSize), key=k_k, use_bias=False, out_first=False)
v_proj = hnn.Linear.init(In=Embed, Out=(Heads, HeadSize), key=k_v, use_bias=use_bias, out_first=False)
out_proj = hnn.Linear.init(In=(Heads, HeadSize), Out=Embed, key=k_out, use_bias=use_bias, out_first=False)
q_proj = hnn.Linear.init(
In=Embed,
Out=(Heads, HeadSize),
key=k_q,
use_bias=use_bias,
)
k_proj = hnn.Linear.init(
In=Embed,
Out=(Heads, HeadSize),
key=k_k,
use_bias=False,
)
v_proj = hnn.Linear.init(
In=Embed,
Out=(Heads, HeadSize),
key=k_v,
use_bias=use_bias,
)
out_proj = hnn.Linear.init(
In=(Heads, HeadSize),
Out=Embed,
key=k_out,
use_bias=use_bias,
)

return WhisperAttention(config, q_proj, k_proj, v_proj, out_proj, inference=False)

Expand Down

0 comments on commit 69f29b4

Please sign in to comment.