diff --git a/config/diva_flash.yaml b/config/diva_flash.yaml new file mode 100644 index 000000000..7aa49b1b4 --- /dev/null +++ b/config/diva_flash.yaml @@ -0,0 +1,38 @@ +data: + cache_dir: "gs://diva-flash/processed/llama" + tokenizer: "meta-llama/Llama-3.2-1B-Instruct" + processor: "openai/whisper-large-v3" + configs: + cv: + id: mozilla-foundation/common_voice_17_0 + cache_dir: "gs://diva-flash/processed/llama/cv" + name: "en" + text_key: "sentence" + train_split: "train" + validation_split: "validation" + train_weights: + cv: 1.0 +model: + type: diva + reference_encoder: "openai/whisper-large-v3" + reference_decoder: "meta-llama/Llama-3.2-1B-Instruct" +use_hf_model_config: true +trainer: + steps_per_eval: 500 + mp: p=f32,c=bf16 + model_axis_size: 1 + per_device_parallelism: -1 + train_batch_size: 512 + num_train_steps: 4300 + checkpointer: + base_path: gs://diva-flash/cv-checkpoints + save_interval: 60m +optimizer: + #learning_rate: 5E-5 + learning_rate: 5e-4 + weight_decay: 0.1 + weight_decay_modules: None + default_weight_decay_mask: False + warmup: 0.01 +hf_save_path: gs://diva-flash/librispeech-hf-checkpoints +diva_training: true diff --git a/src/levanter/data/audio.py b/src/levanter/data/audio.py index 9bfc1e142..3a298203a 100644 --- a/src/levanter/data/audio.py +++ b/src/levanter/data/audio.py @@ -81,6 +81,10 @@ def __init__( padding=True, ): self.feature_extractor: SequenceFeatureExtractor = processor.feature_extractor + if tokenizer.pad_token_id is None: + override_token = list(tokenizer.added_tokens_decoder.items())[-1] + tokenizer.pad_token_id = override_token[0] + tokenizer.pad_tokn = str(override_token[1]) self.bt = BatchTokenizer( tokenizer, enforce_bos=enforce_bos, @@ -272,6 +276,7 @@ class ProcessedAudioCache(AsyncDataset[AudioTextDict]): def __init__(self, cache: TreeCache[AudioTextDict]): super().__init__() self.cache = cache + self._cached_len: Optional[int] = None async def async_len(self) -> int: return await self.cache.async_len() @@ -285,6 +290,15 @@ def is_finite(self) -> bool: async def current_len(self) -> Optional[int]: return await self.cache.current_len() + async def wait_until_len_at_least(self, length: int) -> int: + # length is brutally slow to compute, so we cache it + if self._cached_len is not None and self._cached_len >= length: + return self._cached_len + + length = await super().wait_until_len_at_least(length) + self._cached_len = length + return length + async def get_batch(self, indices: Sequence[int]) -> Sequence[AudioTextDict]: return await self.cache.get_batch(indices) diff --git a/src/levanter/main/train_asr.py b/src/levanter/main/train_asr.py index c0d52eea2..9a8b481c1 100644 --- a/src/levanter/main/train_asr.py +++ b/src/levanter/main/train_asr.py @@ -16,6 +16,7 @@ from levanter.compat.hf_checkpoints import HFCompatConfig, ModelWithHfSerializationMixin, save_hf_checkpoint_callback from levanter.data.audio import AudioIODatasetConfig, AudioMixtureDatasetConfig, AudioTextDataset from levanter.models.asr_model import ASRConfig, AudioTextExample +from levanter.models.diva import DivaASRModel, diva_connector_only from levanter.models.whisper import WhisperConfig from levanter.optim import AdamConfig, OptimizerConfig from levanter.trainer import Trainer, TrainerConfig @@ -45,6 +46,7 @@ class TrainASRConfig: hf_save_path: Optional[str] = None hf_upload: Optional[str] = None hf_save_steps: int = 10000 + diva_training: bool = False def main(config: TrainASRConfig): @@ -122,7 +124,7 @@ def compute_loss( train_dataset = AudioTextDataset( config.data.train_set(key=data_key), Pos, - [config.model.Mels, config.model.MelPos], + config.model.AudioPos, KeyPos, ignore_index=config.data.pad_token_id, ) @@ -136,8 +138,16 @@ def compute_loss( if vocab_size != Vocab.size: logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") - state = trainer.initial_state(training_key, model_init=lambda: config.model.build_asr(Vocab, key=model_key)) + state = trainer.initial_state( + training_key, + model_init=lambda: config.model.build_asr(Vocab, key=model_key), + ) + if config.diva_training and config.model.asr_model_type == DivaASRModel: + state = dataclasses.replace(state, model=None) + model = DivaASRModel.init(Vocab, config.model, key=model_key, init_from_submodels=True) + model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model) + state = dataclasses.replace(state, model=model, is_trainable=diva_connector_only(model)) if int(state.step) == 0: # TODO: I don't love that we init the model twice, but it's not a big deal i think? if config.initialize_from_hf: @@ -164,7 +174,7 @@ def compute_loss( hax_eval_dataset = AudioTextDataset( eval_dataset, Pos, - [config.model.Mels, config.model.MelPos], + config.model.AudioPos, KeyPos, ignore_index=config.data.pad_token_id, ) diff --git a/src/levanter/models/diva.py b/src/levanter/models/diva.py new file mode 100644 index 000000000..323a82ce2 --- /dev/null +++ b/src/levanter/models/diva.py @@ -0,0 +1,352 @@ +import dataclasses +from dataclasses import dataclass +from typing import Optional, Type, Union + +import equinox as eqx +import jax +import jax.numpy as jnp +from jaxtyping import PRNGKeyArray + +import haliax as hax +import haliax.nn as hnn +from haliax import Axis, NamedArray +from haliax.jax_utils import maybe_rng_split +from haliax.partitioning import ResourceMapping + +from levanter.compat.hf_checkpoints import ( + HFCheckpointConverter, + HFCompatConfig, + ModelWithHfSerializationMixin, + RepoRef, +) +from levanter.logging import silence_transformer_nag +from levanter.models.asr_model import ASRConfig, ASRMixin, AudioTextExample +from levanter.models.attention import AttentionMask +from levanter.models.gemma import GemmaConfig, GemmaLMHeadModel +from levanter.models.llama import LlamaConfig, LlamaLMHeadModel +from levanter.models.lm_model import LmConfig +from levanter.models.mistral import MistralConfig, MistralLMHeadModel +from levanter.models.whisper import WhisperConfig, WhisperDecoder, WhisperEncoder, WhisperModel + + +silence_transformer_nag() +from transformers import AutoTokenizer # noqa: E402 +from transformers import PretrainedConfig as HfConfig # noqa: E402 + + +class DivaHFCheckpointer(HFCheckpointConverter["DivaModel"]): + def load_pretrained( + self, + lm_model_cls: Type[ModelWithHfSerializationMixin], + ref: Optional[Union[str, RepoRef]] = None, + config: Optional[HFCompatConfig] = None, + axis_mapping: Optional[ResourceMapping] = None, + resize_vocab_to_match_tokenizer: bool = True, + dtype: Optional[jnp.dtype] = None, + ) -> ModelWithHfSerializationMixin: + lev_model: DivaModel = super().load_pretrained( + DivaModel, ref, config, axis_mapping, resize_vocab_to_match_tokenizer, dtype + ) # type: ignore[assignment] + llm: Union[LlamaLMHeadModel | MistralLMHeadModel | GemmaLMHeadModel] = HFCheckpointConverter( + type(lev_model.config.dec_config), + lev_model.config.reference_decoder, + ).load_pretrained( + lev_model.config.dec_config.model_type, + lev_model.config.reference_decoder, + lev_model.config.dec_config, + axis_mapping, + resize_vocab_to_match_tokenizer, + dtype, + ) # type: ignore[assignment] + whisper: WhisperModel = HFCheckpointConverter( + WhisperConfig, lev_model.config.reference_encoder + ).load_pretrained( + WhisperModel, + lev_model.config.reference_encoder, + lev_model.config.enc_config, + axis_mapping, + resize_vocab_to_match_tokenizer, + dtype, + ) # type: ignore[assignment] + lev_model.encoder = whisper.encoder + lev_model.decoder = llm + + return lev_model + + +def load_correct_config(reference_decoder): + model_id = reference_decoder.lower() + hf_config = HfConfig.from_pretrained(reference_decoder) + if "llama" in model_id: + config = LlamaConfig.from_hf_config(hf_config) + elif "gemma" in model_id: + config = GemmaConfig.from_hf_config(hf_config) + elif "mistral" in model_id: + config = MistralConfig.from_hf_config(hf_config) + return config + + +def get_prefix(tokenizer_ref): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_ref) + prefix, suffix = tokenizer.apply_chat_template( + [{"role": "user", "content": "PLACEHOLDER"}], tokenize=False, add_generation_prompt=True + ).split("PLACEHOLDER") + prefix_tok = tokenizer.encode(prefix, add_special_tokens=False) + suffix_tok = tokenizer.encode(suffix, add_special_tokens=False) + return prefix_tok, suffix_tok + + +@LmConfig.register_subclass("diva") +@dataclass(frozen=True) +class DivaConfig(HFCompatConfig, ASRConfig): + # Reference Models + reference_encoder: str = "openai/whisper-large-v3-turbo" + reference_decoder: str = "meta-llama/Llama-3.1-8B-Instruct" + reference_checkpoint: str = "WillHeld/DiVA-llama-3-v0-8b" + max_length: int = 448 + init_from_submodel: bool = True + + # Connector Config + pre_audio_prompt = property(lambda self: get_prefix(self.reference_decoder)[0]) + pre_text_prompt = property(lambda self: get_prefix(self.reference_decoder)[1]) + + # SubConfigs + enc_config = property( + lambda self: WhisperConfig.from_hf_config(HfConfig.from_pretrained(self.reference_encoder)), + ) + dec_config = property( + lambda self: load_correct_config(self.reference_decoder), + ) + prefix = property(lambda self: hax.named(self.pre_audio_prompt, axis="position")) + suffix = property(lambda self: hax.named(self.pre_text_prompt, axis="position")) + Embed = property(lambda self: self.dec_config.Embed) + Pos = property(lambda self: Axis(name="position", size=self.max_length)) + AudioPos = property(lambda self: self.enc_config.AudioPos) + KeyPos = property(lambda self: self.Pos.alias("key_position")) + + @property + def model_type(self) -> Type["DivaModel"]: + return DivaModel + + @property + def asr_model_type(self) -> Type["DivaASRModel"]: + return DivaASRModel + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[dict] = None) -> HfConfig: + merged_config = { + "model_type": "diva", + "architectures": ["DiVAModel"], + "auto_map": {"AutoConfig": "configuring_diva.DiVAConfig", "AutoModel": "modeling_diva.DiVAModel"}, + "vocab_size": vocab_size, + "reference_encoder": self.reference_encoder, + "reference_decoder": self.reference_decoder, + "max_length": self.max_length, + } + return HfConfig.from_dict(merged_config) + + @classmethod + def from_hf_config(cls, hf_config: HfConfig): + config_dict = hf_config.to_dict() + reference_encoder = config_dict["encoder_reference"] + reference_decoder = config_dict["decoder_reference"] + return DivaConfig(reference_encoder, reference_decoder, max_length=config_dict["max_length"]) + + def hf_checkpoint_converter(cls) -> HFCheckpointConverter["DivaModel"]: # type: ignore + return DivaHFCheckpointer(cls, cls.reference_checkpoint, trust_remote_code=True) + + +def diva_connector_only(model): + frozen_tree = jax.tree_util.tree_map(lambda _: False, model) + return eqx.tree_at( + lambda tree: (tree.query_tokens, tree.projection.weight, tree.projection.bias, tree.connector), + frozen_tree, + (True, True, True, True), + ) + + +class DivaModel(eqx.Module, ModelWithHfSerializationMixin[DivaConfig]): + query_tokens: NamedArray + projection: hnn.Linear + encoder: WhisperEncoder + connector: WhisperDecoder + decoder: Union[LlamaLMHeadModel | MistralLMHeadModel | GemmaLMHeadModel] + _config: DivaConfig = eqx.static_field() + + @property + def config(self): + return self._config + + @property + def Pos(self) -> Axis: + return self.config.Pos + + @property + def Vocab(self) -> Axis: + return self.decoder.embeddings.Vocab + + def resize_vocab(self, new_size: int, key: Optional[PRNGKeyArray] = None) -> "DivaModel": + new_decoder = self.decoder.resize_vocab(new_size, key) + return dataclasses.replace(self, decoder=new_decoder) + + @classmethod + def init( + cls, + Vocab: Axis, + config: DivaConfig, + *, + key, + init_from_submodels: bool = False, + ) -> "DivaModel": + k_query, k_projection, k_enc, k_connector, k_dec = maybe_rng_split(key, 5) + + query_tokens = hax.random.normal(k_query, (config.Pos, config.enc_config.Embed)) * 0.02 + projection = hnn.Linear.init( + In=config.enc_config.Embed.alias("whisp_embed"), Out=config.dec_config.Embed, init_scale=0.01, key=key + ) + + if init_from_submodels: + llm: Union[LlamaLMHeadModel | MistralLMHeadModel | GemmaLMHeadModel] = ( + HFCheckpointConverter(type(config.dec_config), config.reference_decoder) + .load_pretrained( + config.dec_config.model_type, + config.reference_decoder, + config.dec_config, + ) + .resize_vocab(Vocab.size) + ) # type: ignore[assignment] + whisper: WhisperModel = HFCheckpointConverter( + WhisperConfig, config.reference_encoder, ignore_prefix="model" + ).load_pretrained( + WhisperModel, + config.reference_encoder, + config.enc_config, + ) # type: ignore[assignment] + encoder = whisper.encoder + connector = whisper.decoder + decoder = llm + mean_embedding = hax.mean(llm.embeddings.token_embeddings.weight, llm.embeddings.Vocab) + projection = dataclasses.replace( + projection, + weight=hax.rearrange(mean_embedding.broadcast_axis(projection.In), (projection.Out, projection.In)), + ) + + else: + encoder = WhisperEncoder.init(config.enc_config, key=k_enc) + connector = WhisperDecoder.init(config.enc_config, key=k_connector) + decoder = config.dec_config.model_type.init(Vocab, config.dec_config, key=k_dec) + + return cls(query_tokens, projection, encoder, connector, decoder, config) + + @property + def query_position_embeds(self) -> NamedArray: + return self.connector.embeddings.position_embeddings.embed(hax.arange(self.config.Pos)) + + def __call__( + self, + mel: NamedArray, + input_ids: NamedArray, + attn_mask: Optional[AttentionMask | NamedArray] = None, + pad_token_id: int = 128255, + *, + key=None, + ) -> NamedArray: + # Setup + Batch = input_ids.resolve_axis("batch") + OtherAxes = hax.axis.eliminate_axes(input_ids.axes, "position") + causal_mask = AttentionMask.causal() + if attn_mask is not None: + causal_mask = causal_mask & attn_mask + k_encoder, k_connector, k_decoder, k_head = maybe_rng_split(key, 4) + + # Encode Audio With Whisper Encoder + audio_features = self.encoder(mel, key=k_encoder) + + # Convert to Virtual LLM Tokens + virt_whisper_tokens = self.connector.transformer( + (self.query_tokens + self.query_position_embeds).broadcast_axis(OtherAxes), + audio_features, + causal_mask, + key=k_connector, + ) + + virtual_tokens = self.projection(virt_whisper_tokens.rename({"embed_dim": "whisp_embed"})) + + # Embed Real LLM Tokens + prefix = self.decoder.embeddings.embed(self.config.prefix) + suffix = self.decoder.embeddings.embed(self.config.suffix) + + # Create Mixed Virtual and Real Input + audio_embeds = hax.concatenate( + "position", + [ + prefix.broadcast_axis(OtherAxes), + virtual_tokens, + suffix.broadcast_axis(OtherAxes), + ], + ) + text_tokens = hax.concatenate( + "position", + [ + self.config.prefix.broadcast_axis(OtherAxes), + input_ids, + self.config.suffix.broadcast_axis(OtherAxes), + ], + ) + push_back_padding = hax.argsort(text_tokens == pad_token_id, "position") + text_tokens_left_pad = text_tokens[{"batch": hax.arange(Batch), "position": push_back_padding}] + + text_embeds = self.decoder.embeddings.embed(text_tokens_left_pad) + # Create LLM Response + audio = self.decoder.transformer(audio_embeds, attn_mask=causal_mask, key=k_decoder) + text = self.decoder.transformer(text_embeds, attn_mask=causal_mask, key=k_decoder) + + push_forward_padding = hax.argsort(input_ids != pad_token_id, "position") + input_ids_right_pad = input_ids[{"batch": hax.arange(Batch), "position": push_forward_padding}] + return ( + audio["position", -1], + text[ + { + "batch": hax.arange(Batch), + "position": (hax.sum(text_tokens == pad_token_id, "position") * -1) - 1, + } + ], + virtual_tokens, + self.decoder.embeddings.embed(input_ids_right_pad), + ) + + +class DivaASRModel(DivaModel, ASRMixin): + def compute_loss( + self, + example: AudioTextExample, + *, + key=None, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + ) -> NamedArray: + LocalPos = example.tokens.resolve_axis("position") + # Compute Distillation Contrastive Loss + # Since the weight matrix is frozen, equivalent but faster than KL Div + audio_pred, text_pred, virtual_embeds, real_embeds = self( + example.audio, example.tokens, example.attn_mask, key=key + ) + diff_distill = audio_pred - text_pred + kl_proxy_loss = hax.dot(diff_distill, diff_distill, axis="embed") ** 0.5 + + # Compute Contrastive Loss on Input + # Correct for Normal Autoregressive Loss Mask + corrected_loss_mask = hax.roll(example.loss_mask, 1, LocalPos) + hax.nn.one_hot( + 0, LocalPos, dtype=jax.numpy.float32 + ) + # Mask Final Tokens So That Initial Tokens can be used for extra computation + reversed_loss_mask = corrected_loss_mask["position", -1::-1] + diff_contrast = virtual_embeds - real_embeds + tal_loss = hax.dot(diff_contrast, diff_contrast, axis="embed") ** 0.5 + + if reduction is None: + return kl_proxy_loss + else: + loss1 = reduction(kl_proxy_loss, axis=reduction_axis) + loss2 = reduction(tal_loss, axis=reduction_axis, where=reversed_loss_mask) + loss = loss1 + loss2 + return loss diff --git a/src/levanter/models/qwen.py b/src/levanter/models/qwen.py index 807a768ad..604d270e4 100644 --- a/src/levanter/models/qwen.py +++ b/src/levanter/models/qwen.py @@ -15,7 +15,14 @@ from levanter.compat.hf_checkpoints import HFCheckpointConverter from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionMask, dot_product_attention -from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaMlp, LlamaRMSNorm, LlamaTransformer +from levanter.models.llama import ( + LlamaConfig, + LlamaEmbedding, + LlamaMlp, + LlamaRMSNorm, + LlamaTransformer, + LlamaLMHeadModel, +) from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.models.rotary import RotaryEmbeddingsConfig from levanter.types import BlockFoldable @@ -264,7 +271,7 @@ def init(config: QwenConfig, *, key) -> "QwenTransformer": # Modified LM head model for Qwen -class QwenLMHeadModel(LmHeadModel[QwenConfig], ModuleWithStateDictSerialization): +class QwenLMHeadModel(LlamaLMHeadModel, ModuleWithStateDictSerialization): transformer: QwenTransformer embeddings: LlamaEmbedding # Can reuse Llama embeddings lm_head: Optional[hnn.Linear] diff --git a/src/levanter/models/whisper.py b/src/levanter/models/whisper.py index 7239626f7..a3d1ff5c0 100644 --- a/src/levanter/models/whisper.py +++ b/src/levanter/models/whisper.py @@ -50,6 +50,7 @@ class WhisperConfig(HFCompatConfig, ASRConfig): initializer_range: float = 0.02 gradient_checkpointing: bool = True + reference_checkpoint: str = "openai/whisper-base" # Attention-related config upcast_attn: bool = True @@ -66,7 +67,7 @@ def asr_model_type(self) -> Type["WhisperASRModel"]: return WhisperASRModel def hf_checkpoint_converter(self) -> HFCheckpointConverter["WhisperModel"]: # type: ignore - return HFCheckpointConverter(self, "openai/whisper-base", ignore_prefix="model") + return HFCheckpointConverter(self, self.reference_checkpoint, ignore_prefix="model") # Axis MelPos = property(lambda self: Axis(name="position", size=self.max_source_positions * 2)) @@ -84,6 +85,7 @@ def hf_checkpoint_converter(self) -> HFCheckpointConverter["WhisperModel"]: # t DecoderHeadSize = property(lambda self: Axis(name="head_size", size=self.d_model // self.decoder_attention_heads)) DecoderLayer = property(lambda self: Axis(name="decoder_layers", size=self.decoder_layers)) Mels = property(lambda self: Axis(name="n_mels", size=self.num_mel_bins)) + AudioPos = property(lambda self: [self.Mels, self.MelPos]) def to_hf_config(self, vocab_size, config_overrides=None): if config_overrides is None: @@ -104,8 +106,9 @@ def to_hf_config(self, vocab_size, config_overrides=None): ) @classmethod - def from_hf_config(cls, hf_config: HfConfig): + def from_hf_config(cls, hf_config: HfConfig, reference_checkpoint: str = "openai/whisper-base"): return cls( + reference_checkpoint=reference_checkpoint, vocab_size=hf_config.vocab_size, num_mel_bins=hf_config.num_mel_bins, encoder_layers=hf_config.encoder_layers, @@ -128,8 +131,8 @@ class WhisperMlp(eqx.Module): @staticmethod def init(Embed: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = True) -> "WhisperMlp": k_fc, k_proj = haliax.jax_utils.maybe_rng_split(key, 2) - fc1 = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias, out_first=False) - fc2 = hnn.Linear.init(Out=Embed, In=Mlp, key=k_proj, use_bias=use_bias, out_first=False) + fc1 = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias) + fc2 = hnn.Linear.init(Out=Embed, In=Mlp, key=k_proj, use_bias=use_bias) if isinstance(activation_fn, str): activation_fn = ACT2FN[activation_fn] act = activation_fn # type: ignore @@ -161,10 +164,30 @@ def init(Heads: Axis, HeadSize: Axis, config: WhisperConfig, *, key) -> "Whisper Embed = config.Embed k_q, k_k, k_v, k_out = haliax.jax_utils.maybe_rng_split(key, 4) - q_proj = hnn.Linear.init(In=Embed, Out=(Heads, HeadSize), key=k_q, use_bias=use_bias, out_first=False) - k_proj = hnn.Linear.init(In=Embed, Out=(Heads, HeadSize), key=k_k, use_bias=False, out_first=False) - v_proj = hnn.Linear.init(In=Embed, Out=(Heads, HeadSize), key=k_v, use_bias=use_bias, out_first=False) - out_proj = hnn.Linear.init(In=(Heads, HeadSize), Out=Embed, key=k_out, use_bias=use_bias, out_first=False) + q_proj = hnn.Linear.init( + In=Embed, + Out=(Heads, HeadSize), + key=k_q, + use_bias=use_bias, + ) + k_proj = hnn.Linear.init( + In=Embed, + Out=(Heads, HeadSize), + key=k_k, + use_bias=False, + ) + v_proj = hnn.Linear.init( + In=Embed, + Out=(Heads, HeadSize), + key=k_v, + use_bias=use_bias, + ) + out_proj = hnn.Linear.init( + In=(Heads, HeadSize), + Out=Embed, + key=k_out, + use_bias=use_bias, + ) return WhisperAttention(config, q_proj, k_proj, v_proj, out_proj, inference=False) @@ -307,10 +330,9 @@ class WhisperEncoder(ModuleWithStateDictSerialization): def init(cls, config: WhisperConfig, *, key) -> "WhisperEncoder": k_conv1, k_conv2, k_t = haliax.jax_utils.maybe_rng_split(key, 3) - Len = hax.Axis("position", size=config.SourcePos.size * 2) Mid = hax.Axis("mid", config.Embed.size) - conv1 = hnn.Conv.init(Len, config.Mels, Mid, kernel_size=3, padding=1, key=k_conv1) - conv2 = hnn.Conv.init(Len, Mid, config.Embed, kernel_size=3, stride=2, padding=1, key=k_conv2) + conv1 = hnn.Conv.init(config.MelPos, config.Mels, Mid, kernel_size=3, padding=1, key=k_conv1) + conv2 = hnn.Conv.init(config.MelPos, Mid, config.Embed, kernel_size=3, stride=2, padding=1, key=k_conv2) if isinstance(config.activation_function, str): act = ACT2FN[config.activation_function] # type: ignore else: