From af38d6eb413cb98446b78a906c77cf5ba28be149 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Thu, 28 Nov 2024 07:23:54 -0800 Subject: [PATCH] Make cache_size optional PiperOrigin-RevId: 701019861 --- gemma/transformer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/gemma/transformer.py b/gemma/transformer.py index f75bf9f..630bf82 100644 --- a/gemma/transformer.py +++ b/gemma/transformer.py @@ -63,7 +63,7 @@ class TransformerConfig: use_post_attn_norm: bool use_post_ffw_norm: bool attention_types: Iterable[modules.AttentionType] - max_cache_length: int = 1024 + max_cache_length: int | None = 1024 query_pre_attn_norm: QueryPreAttentionNormalisation = ( QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM ) @@ -83,7 +83,7 @@ def query_pre_attn_scalar(self) -> float: @classmethod def from_params( - cls, params: params_lib.Params, cache_size: int = 1024 + cls, params: params_lib.Params, cache_size: int | None = 1024 ) -> 'TransformerConfig': """Creates a TransformerConfig from loaded parameters. @@ -116,7 +116,7 @@ def from_params( ) @classmethod - def gemma_2b(cls, cache_size: int): + def gemma_2b(cls, cache_size: int | None): return cls( num_layers=_NUM_LAYERS_GEMMA_2B, num_embed=256128, @@ -133,7 +133,7 @@ def gemma_2b(cls, cache_size: int): ) @classmethod - def gemma_7b(cls, cache_size: int): + def gemma_7b(cls, cache_size: int | None): return cls( num_layers=_NUM_LAYERS_GEMMA_7B, num_embed=256128, @@ -150,7 +150,7 @@ def gemma_7b(cls, cache_size: int): ) @classmethod - def gemma2_2b(cls, cache_size: int): + def gemma2_2b(cls, cache_size: int | None): return cls( num_layers=_NUM_LAYERS_GEMMA2_2B, num_embed=256128, @@ -174,7 +174,7 @@ def gemma2_2b(cls, cache_size: int): ) @classmethod - def gemma2_9b(cls, cache_size: int): + def gemma2_9b(cls, cache_size: int | None): return cls( num_layers=_NUM_LAYERS_GEMMA2_9B, num_embed=256128, @@ -199,7 +199,7 @@ def gemma2_9b(cls, cache_size: int): ) @classmethod - def gemma2_27b(cls, cache_size: int): + def gemma2_27b(cls, cache_size: int | None): return cls( num_layers=_NUM_LAYERS_GEMMA2_27B, num_embed=256128, @@ -229,6 +229,8 @@ def init_cache( dtype: jnp.dtype = jnp.bfloat16, ) -> Cache: """Initializes a new Transformer cache.""" + if self.max_cache_length is None: + raise ValueError('max_cache_length must be set to initialize cache.') cache = { f'layer_{i}': modules.Attention.init_cache( self.max_cache_length,