Skip to content

Commit

Permalink
Refactor bloom modeling and add tests
Browse files Browse the repository at this point in the history
Signed-off-by: char-1ee <[email protected]>
  • Loading branch information
char-1ee committed May 3, 2024
1 parent 7f9f667 commit 67d67fb
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 178 deletions.
42 changes: 28 additions & 14 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Any, List, Tuple

import torch
from transformers.configuration_utils import PretrainedConfig
Expand All @@ -15,9 +15,11 @@
GIGABYTE = 1024**3


def get_model_config_attr(config: PretrainedConfig, attr_name: str):
def get_model_config_attr(config: PretrainedConfig, attr_name: str, alter_attr: Any = None):
if hasattr(config, attr_name):
return getattr(config, attr_name)
if alter_attr is not None: # TODO, rebase caidi changes
return alter_attr
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]):
return getattr(config, config.attribute_map[attr_name])
raise AttributeError(f"{attr_name} is not found in config")
Expand Down Expand Up @@ -53,7 +55,12 @@ class KVCacheManager:
And it's possible to have a batch of sequences with different lengths of block tables.
"""

def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
def __init__(
self,
config: InferenceConfig,
model_config: PretrainedConfig,
verbose: bool = False,
) -> None:
self.logger = get_dist_logger(__name__)
self.device = get_current_device()

Expand All @@ -64,14 +71,15 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads", alter_attr=self.head_num)
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num

if hasattr(config, "num_key_value_heads"):
self.kv_head_num = getattr(config, "num_key_value_heads")
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
else:
self.kv_head_num = self.head_num
# if hasattr(config, "num_key_value_heads"):
# self.kv_head_num = getattr(config, "num_key_value_heads")
# elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
# self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
# else:
# self.kv_head_num = self.head_num

assert (
self.kv_head_num % self.tp_size == 0
Expand Down Expand Up @@ -211,7 +219,8 @@ def allocate_context_from_block_table(self, block_table: torch.Tensor, context_l
block.add_ref()
if block_id == block_indexes[-1].item():
self._allocate_on_block(
block, block.block_size if context_len % block.block_size == 0 else context_len % block.block_size
block,
(block.block_size if context_len % block.block_size == 0 else context_len % block.block_size),
)
else:
self._allocate_on_block(block, block.block_size)
Expand Down Expand Up @@ -278,9 +287,11 @@ def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context
block.add_ref()
self._allocate_on_block(
block,
block.block_size
if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size,
(
block.block_size
if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size
),
)
for block_id in alloc_block_ids:
if block_id in alloc_block_ids[last_block_locs]:
Expand Down Expand Up @@ -453,7 +464,10 @@ def clear_all(self) -> None:

def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
return (
self._kv_caches[0][layer_id][block_idx],
self._kv_caches[1][layer_id][block_idx],
)

def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> int:
"""Allocate a specific size of space on a provided cache block.
Expand Down
Loading

0 comments on commit 67d67fb

Please sign in to comment.