Skip to content

Commit

Permalink
Make cache_size optional
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701019861
  • Loading branch information
Conchylicultor authored and The gemma Authors committed Nov 28, 2024
1 parent 65a8858 commit af38d6e
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class TransformerConfig:
use_post_attn_norm: bool
use_post_ffw_norm: bool
attention_types: Iterable[modules.AttentionType]
max_cache_length: int = 1024
max_cache_length: int | None = 1024
query_pre_attn_norm: QueryPreAttentionNormalisation = (
QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM
)
Expand All @@ -83,7 +83,7 @@ def query_pre_attn_scalar(self) -> float:

@classmethod
def from_params(
cls, params: params_lib.Params, cache_size: int = 1024
cls, params: params_lib.Params, cache_size: int | None = 1024
) -> 'TransformerConfig':
"""Creates a TransformerConfig from loaded parameters.
Expand Down Expand Up @@ -116,7 +116,7 @@ def from_params(
)

@classmethod
def gemma_2b(cls, cache_size: int):
def gemma_2b(cls, cache_size: int | None):
return cls(
num_layers=_NUM_LAYERS_GEMMA_2B,
num_embed=256128,
Expand All @@ -133,7 +133,7 @@ def gemma_2b(cls, cache_size: int):
)

@classmethod
def gemma_7b(cls, cache_size: int):
def gemma_7b(cls, cache_size: int | None):
return cls(
num_layers=_NUM_LAYERS_GEMMA_7B,
num_embed=256128,
Expand All @@ -150,7 +150,7 @@ def gemma_7b(cls, cache_size: int):
)

@classmethod
def gemma2_2b(cls, cache_size: int):
def gemma2_2b(cls, cache_size: int | None):
return cls(
num_layers=_NUM_LAYERS_GEMMA2_2B,
num_embed=256128,
Expand All @@ -174,7 +174,7 @@ def gemma2_2b(cls, cache_size: int):
)

@classmethod
def gemma2_9b(cls, cache_size: int):
def gemma2_9b(cls, cache_size: int | None):
return cls(
num_layers=_NUM_LAYERS_GEMMA2_9B,
num_embed=256128,
Expand All @@ -199,7 +199,7 @@ def gemma2_9b(cls, cache_size: int):
)

@classmethod
def gemma2_27b(cls, cache_size: int):
def gemma2_27b(cls, cache_size: int | None):
return cls(
num_layers=_NUM_LAYERS_GEMMA2_27B,
num_embed=256128,
Expand Down Expand Up @@ -229,6 +229,8 @@ def init_cache(
dtype: jnp.dtype = jnp.bfloat16,
) -> Cache:
"""Initializes a new Transformer cache."""
if self.max_cache_length is None:
raise ValueError('max_cache_length must be set to initialize cache.')
cache = {
f'layer_{i}': modules.Attention.init_cache(
self.max_cache_length,
Expand Down

0 comments on commit af38d6e

Please sign in to comment.