From 59ba43bb4d56d45fe862844e810b0c9424ade3aa Mon Sep 17 00:00:00 2001 From: char-1ee Date: Fri, 26 Apr 2024 09:06:33 +0000 Subject: [PATCH] Rebase upstream commits and refactor Signed-off-by: char-1ee --- colossalai/inference/core/engine.py | 15 +- .../inference/kv_cache/kvcache_manager.py | 2 +- .../modeling/models/nopadding_baichuan.py | 18 +- .../modeling/models/nopadding_bloom.py | 98 ++++-- colossalai/inference/utils.py | 28 +- colossalai/kernel/triton/alibi_embedding.py | 327 ------------------ examples/inference/test_bloom_generation.py | 82 +++++ tests/test_infer/test_models/test_baichuan.py | 3 +- tests/test_infer/test_models/test_bloom.py | 41 ++- .../triton/test_context_attn_unpad.py | 2 +- .../test_ops/triton/test_decoding_attn.py | 2 +- 11 files changed, 218 insertions(+), 400 deletions(-) delete mode 100644 colossalai/kernel/triton/alibi_embedding.py create mode 100644 examples/inference/test_bloom_generation.py diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index b42c21a5175b..3ae392c18677 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -13,8 +13,8 @@ PreTrainedTokenizer, PreTrainedTokenizerFast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.bloom.modeling_bloom import BloomForCausalLM +from transformers.models.llama.modeling_llama import LlamaForCausalLM from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh @@ -43,7 +43,6 @@ "BloomForCausalLM": BloomForCausalLM, } -_alibi_models = ["bloom", "baichuan"] _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] @@ -83,7 +82,7 @@ def __init__( self.tokenizer = tokenizer self.tokenizer.pad_token = self.tokenizer.eos_token - self.request_handler = RequestHandler(self.inference_config, self.model_config, alibi_attn=self.alibi_attn) + self.request_handler = RequestHandler(self.inference_config, self.model_config) self.k_cache, self.v_cache = self.request_handler.get_kvcache() # DISCUSS maybe move this into batch info? @@ -164,14 +163,6 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) tp_group = pg_mesh.get_group_along_axis(TP_AXIS) - self.alibi_attn = False - if self.model_config.model_type in _alibi_models: - # Used for bloom, baichuan 13b and baichuan2 13b. - self.alibi_attn = True - # Hardcode used to distinguish between baichuan 7b and baichuan 13b.(There might be a better way to handle this.) - if self.model_config.model_type == "baichuan" and self.model_config.hidden_size == 4096: - self.alibi_attn = False - self.model = self._shardformer( model, model_policy, @@ -747,4 +738,4 @@ def step(self) -> List[str]: finished_sequences = self.request_handler.update() - return finished_sequences \ No newline at end of file + return finished_sequences diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 734b79ac60e3..94c79dd412be 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -18,7 +18,7 @@ def get_model_config_attr(config: PretrainedConfig, attr_name: str, alter_attr: Any = None): if hasattr(config, attr_name): return getattr(config, attr_name) - if alter_attr is not None: # TODO, rebase caidi changes + if alter_attr is not None: return alter_attr elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]): return getattr(config, config.attribute_map[attr_name]) diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index e6b39ccfa20d..b802379e2e1a 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -8,7 +8,7 @@ from torch.distributed import ProcessGroup from colossalai.inference.flash_decoding_utils import FDIntermTensors -from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import ( context_attention_unpadded, @@ -47,22 +47,6 @@ logger = get_dist_logger(__name__) -# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 -def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: - closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) - base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device) - powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device) - slopes = torch.pow(base, powers) - if closest_power_of_2 != num_heads: - extra_base = torch.tensor( - 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device - ) - num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) - slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - return slopes - - def baichuan_rmsnorm_forward( self, hidden_states: torch.Tensor, diff --git a/colossalai/inference/modeling/models/nopadding_bloom.py b/colossalai/inference/modeling/models/nopadding_bloom.py index d0297dbf5367..dd6b821648c5 100644 --- a/colossalai/inference/modeling/models/nopadding_bloom.py +++ b/colossalai/inference/modeling/models/nopadding_bloom.py @@ -6,24 +6,27 @@ from colossalai.inference.config import InputMetaData from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.jit.bias_dropout_add import bias_dropout_add_fused_inference from colossalai.kernel.jit.bias_gelu import GeLUFunction from colossalai.kernel.kernel_loader import InferenceOpsLoader -from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention +from colossalai.kernel.triton import context_attention_unpadded, copy_k_to_blocked_cache, flash_decoding_attention from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) -inference_ops = InferenceOpsLoader.load() - try: - pass + from flash_attn import flash_attn_varlen_func use_flash_attn2 = True except ImportError: use_flash_attn2 = False logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") +inference_ops = InferenceOpsLoader().load() + +logger = get_dist_logger(__name__) + def bloom_causal_lm_forward( self: BloomForCausalLM, @@ -107,6 +110,7 @@ def bloom_model_forward( hidden_states = layer( hidden_states, block_tables=block_tables, + is_prompts=inputmetadata.is_prompts, k_cache=k_caches[layer_id], v_cache=v_caches[layer_id], sequence_lengths=sequence_lengths, @@ -144,7 +148,7 @@ def bloom_block_forward( use_cuda_kernel: bool = True, cu_seqlens: torch.Tensor = None, high_precision: bool = False, -) -> torch.Tensor: +) -> torch.FloatTensor: """ Replacement of forward function in the BloomBlock module. @@ -234,6 +238,7 @@ def __init__( self.hidden_size = hidden_size self.num_heads = n_heads + self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device) self.head_dim = self.hidden_size // self.num_heads self.o_proj_w = attn_oproj_w @@ -289,7 +294,7 @@ def forward( high_precision: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ - Forward function of the NopadBloomAttention. + Forward function of the NopadBloomAttention. Current attention does not support speculative decoding. Args: hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. @@ -318,28 +323,73 @@ def forward( block_size = k_cache.size(-2) - # TODO: flash attention - if is_prompts: # Prefilling phase - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_size=block_size, - block_tables=block_tables, - output=output_tensor, - alibi_slopes=fd_inter_tensor.alibi_slopes, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, - ) - else: # Decoding phase + if is_prompts: # Context stage (prefilling phase) + if ( + use_cuda_kernel + and query_states.dtype != torch.float32 + and use_flash_attn2 # flash attn 2 currently only supports FP16/BF16 + ): + # Copy the GPU memory of kvcache during context stage + inference_ops.context_kv_cache_memcpy( + key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len + ) + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=kv_seq_len, + max_seqlen_k=kv_seq_len, + dropout_p=0.0, + softmax_scale=sm_scale, + causal=True, + alibi_slopes=self.alibi_slopes, + ) + attn_output = attn_output.view(token_nums, -1) + + else: + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_size=block_size, + block_tables=block_tables, + output=output_tensor, + alibi_slopes=self.alibi_slopes, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + ) + + else: # Decode stage + if use_cuda_kernel: + # Copy the GPU memory of kvcache during decode stage + inference_ops.decode_kv_cache_memcpy( + key_states, value_states, k_cache, v_cache, sequence_lengths, block_size, block_tables + ) + else: + copy_k_to_blocked_cache( + key_states, + k_cache, + kv_lengths=sequence_lengths, + block_tables=block_tables, + ) + copy_k_to_blocked_cache( + value_states, + v_cache, + kv_lengths=sequence_lengths, + block_tables=block_tables, + ) + attn_output = flash_decoding_attention( q=query_states, k_cache=k_cache, v_cache=v_cache, - alibi_slopes=fd_inter_tensor.alibi_slopes, + alibi_slopes=self.alibi_slopes, kv_seq_len=sequence_lengths, block_tables=block_tables, block_size=block_size, diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 9e0d72586e37..266052ab7247 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -1,6 +1,7 @@ """ -Utils for model inference +Utilities for model inference """ +import math import os import re from pathlib import Path @@ -55,6 +56,31 @@ def init_to_get_rotary(self, base=10000, use_elem=False): self._sin_cached = torch.sin(freqs).to(self.dtype).cuda() +def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: + """ + Calculate the slopes for the Alibi positional encoding. The calculation is adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 + + Args: + num_heads (int): The number of heads. + device (torch.device): The device to perform the calculations on. + + Returns: + torch.Tensor: The calculated slopes tensor of (nheads,) or (batch_size, nheads). + """ + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device) + slopes = torch.pow(base, powers) + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: """ Check whether the checkpoint has an index file. diff --git a/colossalai/kernel/triton/alibi_embedding.py b/colossalai/kernel/triton/alibi_embedding.py deleted file mode 100644 index 99745d166b41..000000000000 --- a/colossalai/kernel/triton/alibi_embedding.py +++ /dev/null @@ -1,327 +0,0 @@ -import torch -import triton -import triton.language as tl - - -# Triton 2.1.0 -@triton.jit -def _flash_decoding_fwd_kernel( - Q, # [batch_size, head_num, head_dim] - KCache, # [num_blocks, num_kv_heads, block_size, head_dim] - VCache, # [num_blocks, num_kv_heads, block_size, head_dim] - block_tables, # [batch_size, max_blocks_per_sequence] - mid_output, # [batch_size, head_num, kv_split_num, head_dim] - mid_output_lse, # [batch_size, head_num, kv_split_num] - kv_seq_len, # [batch_size] - batch_size, - alibi, - stride_qt, - stride_qh, - stride_qd, - stride_cacheb, - stride_cacheh, - stride_cachebs, - stride_cached, - stride_bts, - stride_btb, - stride_mid_ot, - stride_mid_oh, - stride_mid_ob, - stride_mid_od, - stride_mid_o_lset, - stride_mid_o_lseh, - stride_mid_o_lseb, - sm_scale, - KV_GROUPS: tl.constexpr, - BLOCK_KV: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - HEAD_DIM: tl.constexpr, -): - cur_seq_idx = tl.program_id(0) - if cur_seq_idx >= batch_size: - return - cur_head_idx = tl.program_id(1) - block_start_kv = tl.program_id(2) # for splitting k/v - - # NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same - # TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE) - # and then support calculating multiple kv cache blocks on an instance - tl.static_assert(BLOCK_KV == BLOCK_SIZE) - - # get the current (kv) sequence length - cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) - if block_start_kv * BLOCK_KV >= cur_kv_seq_len: - return - - cur_kv_head_idx = cur_head_idx // KV_GROUPS - offsets_dmodel = tl.arange(0, HEAD_DIM) - offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd - offsets_n = block_start_kv * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - - alibi_mask = tl.load(alibi + offsets_q) - q = tl.load(Q + offsets_q) - - # block table for the current sequence - block_table_ptr = block_tables + cur_seq_idx * stride_bts - - cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb) - cur_occupied_size = tl.where( - (block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE - ) - tl.device_assert(cur_occupied_size >= 0) - - cur_kv_head_idx = cur_head_idx // KV_GROUPS - offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh - - K_block_ptr = tl.make_block_ptr( - base=KCache + offset_kvcache, - shape=(cur_occupied_size, HEAD_DIM), - strides=(stride_cachebs, stride_cached), - offsets=(0, 0), - block_shape=(BLOCK_SIZE, HEAD_DIM), - order=(0, 1), - ) - V_block_ptr = tl.make_block_ptr( - base=VCache + offset_kvcache, - shape=(cur_occupied_size, HEAD_DIM), - strides=(stride_cachebs, stride_cached), - offsets=(0, 0), - block_shape=(BLOCK_SIZE, HEAD_DIM), - order=(0, 1), - ) - k_cur_block = tl.load(K_block_ptr) - v_cur_block = tl.load(V_block_ptr) - acc = tl.zeros([HEAD_DIM], dtype=tl.float32) - # use block size of the paged/blocked kv cache - S_ij = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - - # NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16, - # Multiplying two tensors with shapes [1, d] * [d, block_size] will fail. - # Refer to https://github.com/openai/triton/discussions/895 - S_ij += tl.sum(q[None, :] * k_cur_block, 1) - S_ij *= sm_scale - - S_ij -= alibi_mask * (cur_kv_seq_len - 1 - offsets_n) - S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf")) - - m = tl.max(S_ij, 0) - S_ij -= m - p_ij_hat = tl.exp(S_ij) - l = tl.sum(p_ij_hat, 0) - p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty) - acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0) - acc = acc / l - - offsets_mid_o = ( - cur_seq_idx * stride_mid_ot - + cur_head_idx * stride_mid_oh - + block_start_kv * stride_mid_ob - + offsets_dmodel * stride_mid_od - ) - tl.store(mid_output + offsets_mid_o, acc) - offsets_mid_o_lse = ( - cur_seq_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb - ) - - # logsumexp L^(j) = m^(j) + log(l^(j)) - tl.store(mid_output_lse + offsets_mid_o_lse, m + tl.log(l)) - - -# Triton 2.1.0 -@triton.jit -def _flash_decoding_fwd_reduce_kernel( - mid_output, # [batch_size, head_num, kv_split_num, head_dim] - mid_output_lse, # [batch_size, head_num, kv_split_num] - O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim] - kv_seq_len, - batch_size, - stride_mid_ot, - stride_mid_oh, - stride_mid_ob, - stride_mid_od, - stride_o_lset, - stride_o_lseh, - stride_o_lseb, - stride_ot, - stride_oh, - stride_od, - BLOCK_KV: tl.constexpr, - HEAD_DIM: tl.constexpr, -): - cur_seq_idx = tl.program_id(0) - if cur_seq_idx >= batch_size: - return - cur_head_idx = tl.program_id(1) - - cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) - offsets_dmodel = tl.arange(0, HEAD_DIM) - - # NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have - # BLOCK_KV == BLOCK_SIZE for now. We might want to decrease the number of blocks of kv splitted. - kv_split_num = (cur_kv_seq_len + BLOCK_KV - 1) // BLOCK_KV - m_i = float("-inf") # max logic - l = 0.0 # sum exp - acc = tl.zeros([HEAD_DIM], dtype=tl.float32) - - offsets_mid_o = cur_seq_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel - offset_mid_lse = cur_seq_idx * stride_o_lset + cur_head_idx * stride_o_lseh - for block_i in range(0, kv_split_num, 1): - mid_o_block = tl.load(mid_output + offsets_mid_o + block_i * stride_mid_ob) - lse = tl.load(mid_output_lse + offset_mid_lse + block_i * stride_o_lseb) - m_ij = tl.maximum(m_i, lse) - scale = tl.exp(m_i - m_ij) - acc = acc * scale - lse -= m_ij - exp_logic = tl.exp(lse) - acc += exp_logic * mid_o_block - l = scale * l + exp_logic - m_i = m_ij - - acc = acc / l - offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel - tl.store(O + offsets_O, acc.to(O.type.element_ty)) - return - - -# Decoding Stage -# Used with blocked KV Cache (PagedAttention) -def flash_decoding_attention_with_alibi( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - alibi: torch.Tensor, - kv_seq_len: torch.Tensor, - block_tables: torch.Tensor, - block_size: int, - max_seq_len_in_batch: int = None, - output: torch.Tensor = None, - mid_output: torch.Tensor = None, - mid_output_lse: torch.Tensor = None, - sm_scale: int = None, - kv_group_num: int = 1, -): - """ - Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage. - Args: - q (torch.Tensor): [bsz, num_heads, head_dim] - k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - kv_seq_len (torch.Tensor): [batch_size] - records the (kv) sequence lengths incorporating past kv sequence lengths. - block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] - max_seq_len_in_batch (int): Maximum sequence length in the batch. - output (torch.Tensor): [bsz, num_heads * head_dim] - mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim] - Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`. - mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num] - Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`. - block_size (int): Size of each block in the blocked key/value cache. - num_kv_group (int, optional): Number of key/value groups. Defaults to 1. - Returns: - Output tensor with shape [bsz, num_heads * head_dim] - """ - - q = q.squeeze() if q.dim() == 4 else q - assert q.dim() == 3, f"Incompatible q dim: {q.dim()}" - bsz, num_heads, head_dim = q.shape - - assert head_dim in {32, 64, 128, 256} - assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, ( - f"Got incompatible batch size (number of seqs):\n" - f" KV seq lengths bsz {kv_seq_len.size(0)}, Block tables bsz {block_tables.size(0)}, " - f"batch size {bsz}" - ) - assert k_cache.size(-2) == v_cache.size(-2) == block_size, ( - f"Got incompatible block size on kv caches:\n" - f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-2)}, " - f"v_cache block_size {v_cache.size(-2)}" - ) - - # NOTE BLOCK_KV could be considered as block splitting the sequence on k/v - # For now, BLOCK_KV is supposed to be equivalent with the size of physical cache block (i.e.`block_size`) - assert block_size in {16, 32, 64, 128} - BLOCK_KV = block_size - - sm_scale = 1.0 / (head_dim**0.5) if sm_scale is None else sm_scale - max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch - # For compatibility (TODO revise modeling in future) - kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV - - if mid_output is None: - mid_output = torch.empty( - (bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device - ) - - if mid_output_lse is None: - mid_output_lse = torch.empty((bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device) - - if output is None: - output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) - - assert ( - mid_output.size(2) == mid_output_lse.size(2) >= kv_max_split_num - ), "Incompatible kv split number of intermediate output tensors" - assert ( - mid_output.size(0) == mid_output_lse.size(0) >= output.size(0) == bsz - ), f"Incompatible first dimension of output tensors" - - grid = ( - triton.next_power_of_2(bsz), - num_heads, - triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV), - ) - _flash_decoding_fwd_kernel[grid]( - q, - k_cache, - v_cache, - block_tables, - mid_output, - mid_output_lse, - kv_seq_len, - bsz, - alibi, - q.stride(0), - q.stride(1), - q.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - block_tables.stride(0), - block_tables.stride(1), - mid_output.stride(0), - mid_output.stride(1), - mid_output.stride(2), - mid_output.stride(3), - mid_output_lse.stride(0), - mid_output_lse.stride(1), - mid_output_lse.stride(2), - sm_scale, - KV_GROUPS=kv_group_num, - BLOCK_KV=block_size, - BLOCK_SIZE=block_size, - HEAD_DIM=head_dim, - ) - - grid = (triton.next_power_of_2(bsz), num_heads) - _flash_decoding_fwd_reduce_kernel[grid]( - mid_output, - mid_output_lse, - output, - kv_seq_len, - bsz, - mid_output.stride(0), - mid_output.stride(1), - mid_output.stride(2), - mid_output.stride(3), - mid_output_lse.stride(0), - mid_output_lse.stride(1), - mid_output_lse.stride(2), - output.stride(0), - head_dim, - 1, - BLOCK_KV=block_size, - HEAD_DIM=head_dim, - ) - - return output diff --git a/examples/inference/test_bloom_generation.py b/examples/inference/test_bloom_generation.py new file mode 100644 index 000000000000..fcabe6200c94 --- /dev/null +++ b/examples/inference/test_bloom_generation.py @@ -0,0 +1,82 @@ +import argparse + +from transformers import AutoModelForCausalLM, BloomTokenizerFast, GenerationConfig + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.inference.modeling.policy.nopadding_bloom import NoPaddingBloomModelInferPolicy + +# For Llama 3, we'll use the following configuration +MODEL_CLS = AutoModelForCausalLM +POLICY_CLS = NoPaddingBloomModelInferPolicy + + +def infer(args): + # ============================== + # Launch colossalai, setup distributed environment + # ============================== + colossalai.launch_from_torch(config={}) + coordinator = DistCoordinator() + + # ============================== + # Load model and tokenizer + # ============================== + # model_path_or_name = "/home/lixingjian/models/bloom-7b1" + model_path_or_name = "/home/lixingjian/models/bloom-560m" + model = MODEL_CLS.from_pretrained(model_path_or_name).cuda() + tokenizer = BloomTokenizerFast.from_pretrained(model_path_or_name) + tokenizer.pad_token = tokenizer.eos_token + coordinator.print_on_master(f"Model Config:\n{model.config}") + + # ============================== + # Initialize InferenceEngine + # ============================== + inference_config = InferenceConfig( + dtype=args.dtype, + max_batch_size=args.max_batch_size, + max_input_len=args.max_input_len, + max_output_len=args.max_output_len, + prefill_ratio=1.2, + block_size=16, + tp_size=args.tp_size, + use_cuda_kernel=False, + ) + coordinator.print_on_master(f"Initializing Inference Engine...") + engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True) + + # ============================== + # Generation + # ============================== + generation_config = GenerationConfig( + pad_token_id=tokenizer.eos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_length=args.max_length, + do_sample=True, + ) + coordinator.print_on_master(f"Generating...") + out = engine.generate(prompts=[args.prompt], generation_config=generation_config) + coordinator.print_on_master(out[0]) + + +# colossalai run --nproc_per_node 1 llama_gen.py -m MODEL_PATH +if __name__ == "__main__": + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + # parser.add_argument("-m", "--model", type=str, help="Path to the model or model name") + parser.add_argument( + "-p", "--prompt", type=str, default="Introduce some landmarks in the United Kingdom, such as", help="Prompt" + ) + parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size") + parser.add_argument("-i", "--max_input_len", type=int, default=128, help="Max input length") + parser.add_argument("-o", "--max_output_len", type=int, default=128, help="Max output length") + parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size") + parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"]) + parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default") + parser.add_argument("--max_length", type=int, default=32, help="Max length for generation") + args = parser.parse_args() + + infer(args) diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py index 5d6be5cb1982..6789e669191a 100644 --- a/tests/test_infer/test_models/test_baichuan.py +++ b/tests/test_infer/test_models/test_baichuan.py @@ -14,8 +14,7 @@ from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base" -BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base" +BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base" def setup_seed(seed): diff --git a/tests/test_infer/test_models/test_bloom.py b/tests/test_infer/test_models/test_bloom.py index 2448843aeb3d..b64060bd9718 100644 --- a/tests/test_infer/test_models/test_bloom.py +++ b/tests/test_infer/test_models/test_bloom.py @@ -4,7 +4,7 @@ import numpy as np import pytest import torch -from transformers import AutoModelForCausalLM, BloomTokenizerFast, GenerationConfig +from transformers import BloomForCausalLM, BloomTokenizerFast, GenerationConfig import colossalai from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig @@ -23,30 +23,35 @@ def setup_seed(seed): random.seed(seed) -def check_inference_engine(use_engine=False, prompt_template=None): +def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None): setup_seed(20) tokenizer = BloomTokenizerFast.from_pretrained(BLOOM_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True) - model = AutoModelForCausalLM.from_pretrained( - BLOOM_MODEL_NAME_OR_PATH, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True - ).cuda() + model = BloomForCausalLM.from_pretrained(BLOOM_MODEL_NAME_OR_PATH, trust_remote_code=True).half().cuda() model = model.eval() inputs = [ "Please introduce some landmarks in the United Kingdom. ", ] - output_len = 50 - do_sample = False + output_len = 38 + do_sample = do_sample + + if do_sample: + top_p = 0.5 + top_k = 50 + else: + top_p = None + top_k = None if use_engine: inference_config = InferenceConfig( - max_output_len=output_len, prompt_template=prompt_template, dtype="fp32", use_cuda_kernel=True + max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=use_cuda_kernel ) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - generation_config = GenerationConfig(do_sample=do_sample) + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) outputs = inference_engine.generate(generation_config=generation_config) else: if prompt_template: @@ -58,6 +63,8 @@ def check_inference_engine(use_engine=False, prompt_template=None): inputs = inputs.cuda() generation_config = GenerationConfig( do_sample=do_sample, + top_p=top_p, + top_k=top_k, pad_token_id=tokenizer.pad_token_id, max_new_tokens=output_len, ) @@ -68,11 +75,17 @@ def check_inference_engine(use_engine=False, prompt_template=None): @parameterize("prompt_template", [None, "bloom"]) -def check_output_consistency(prompt_template): - outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template) - transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template) - - for s1, s2 in zip(outputs, transformer_outputs): +@parameterize("do_sample", [True, False]) +@parameterize("use_cuda_kernel", [True, False]) +def check_output_consistency(prompt_template, do_sample, use_cuda_kernel): + cai_outputs = check_inference_engine( + use_engine=True, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template + ) + transformer_outputs = check_inference_engine( + use_engine=False, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template + ) + + for s1, s2 in zip(cai_outputs, transformer_outputs): assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}" # clear singleton flash decoding tensors diff --git a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py b/tests/test_infer/test_ops/triton/test_context_attn_unpad.py index 76785d53095a..675bb5b22873 100644 --- a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py +++ b/tests/test_infer/test_ops/triton/test_context_attn_unpad.py @@ -2,7 +2,7 @@ import torch from packaging import version -from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device from tests.test_infer.test_ops.triton.kernel_utils import ( diff --git a/tests/test_infer/test_ops/triton/test_decoding_attn.py b/tests/test_infer/test_ops/triton/test_decoding_attn.py index 616d7868beb0..94e996893bcb 100644 --- a/tests/test_infer/test_ops/triton/test_decoding_attn.py +++ b/tests/test_infer/test_ops/triton/test_decoding_attn.py @@ -3,7 +3,7 @@ import torch from packaging import version -from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device from tests.test_infer.test_ops.triton.kernel_utils import (