Skip to content

Commit

Permalink
Rebase upstream commits and refactor
Browse files Browse the repository at this point in the history
Signed-off-by: char-1ee <[email protected]>
  • Loading branch information
char-1ee committed Apr 26, 2024
1 parent 0de7a6b commit ce12073
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 400 deletions.
15 changes: 3 additions & 12 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
Expand Down Expand Up @@ -43,7 +43,6 @@
"BloomForCausalLM": BloomForCausalLM,
}

_alibi_models = ["bloom", "baichuan"]

_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]

Expand Down Expand Up @@ -83,7 +82,7 @@ def __init__(
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token

self.request_handler = RequestHandler(self.inference_config, self.model_config, alibi_attn=self.alibi_attn)
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?

Expand Down Expand Up @@ -154,14 +153,6 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)

self.alibi_attn = False
if self.model_config.model_type in _alibi_models:
# Used for bloom, baichuan 13b and baichuan2 13b.
self.alibi_attn = True
# Hardcode used to distinguish between baichuan 7b and baichuan 13b.(There might be a better way to handle this.)
if self.model_config.model_type == "baichuan" and self.model_config.hidden_size == 4096:
self.alibi_attn = False

self.model = self._shardformer(
model,
model_policy,
Expand Down Expand Up @@ -737,4 +728,4 @@ def step(self) -> List[str]:

finished_sequences = self.request_handler.update()

return finished_sequences
return finished_sequences
2 changes: 1 addition & 1 deletion colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def get_model_config_attr(config: PretrainedConfig, attr_name: str, alter_attr: Any = None):
if hasattr(config, attr_name):
return getattr(config, attr_name)
if alter_attr is not None: # TODO, rebase caidi changes
if alter_attr is not None:
return alter_attr
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map[attr_name]):
return getattr(config, config.attribute_map[attr_name])
Expand Down
18 changes: 1 addition & 17 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import math
from typing import Optional, Tuple

import torch
import torch.nn as nn

from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
Expand All @@ -32,22 +32,6 @@
logger = get_dist_logger(__name__)


# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes


def baichuan_rmsnorm_forward(
self,
hidden_states: torch.Tensor,
Expand Down
98 changes: 74 additions & 24 deletions colossalai/inference/modeling/models/nopadding_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,27 @@

from colossalai.inference.config import InputMetaData
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import get_alibi_slopes
from colossalai.kernel.jit.bias_dropout_add import bias_dropout_add_fused_inference
from colossalai.kernel.jit.bias_gelu import GeLUFunction
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention
from colossalai.kernel.triton import context_attention_unpadded, copy_k_to_blocked_cache, flash_decoding_attention
from colossalai.logging import get_dist_logger

logger = get_dist_logger(__name__)

inference_ops = InferenceOpsLoader.load()

try:
pass
from flash_attn import flash_attn_varlen_func

use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")

inference_ops = InferenceOpsLoader().load()

logger = get_dist_logger(__name__)


def bloom_causal_lm_forward(
self: BloomForCausalLM,
Expand Down Expand Up @@ -107,6 +110,7 @@ def bloom_model_forward(
hidden_states = layer(
hidden_states,
block_tables=block_tables,
is_prompts=inputmetadata.is_prompts,
k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id],
sequence_lengths=sequence_lengths,
Expand Down Expand Up @@ -144,7 +148,7 @@ def bloom_block_forward(
use_cuda_kernel: bool = True,
cu_seqlens: torch.Tensor = None,
high_precision: bool = False,
) -> torch.Tensor:
) -> torch.FloatTensor:
"""
Replacement of forward function in the BloomBlock module.
Expand Down Expand Up @@ -234,6 +238,7 @@ def __init__(

self.hidden_size = hidden_size
self.num_heads = n_heads
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)
self.head_dim = self.hidden_size // self.num_heads
self.o_proj_w = attn_oproj_w

Expand Down Expand Up @@ -289,7 +294,7 @@ def forward(
high_precision: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Forward function of the NopadBloomAttention.
Forward function of the NopadBloomAttention. Current attention does not support speculative decoding.
Args:
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
Expand Down Expand Up @@ -318,28 +323,73 @@ def forward(

block_size = k_cache.size(-2)

# TODO: flash attention
if is_prompts: # Prefilling phase
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_size=block_size,
block_tables=block_tables,
output=output_tensor,
alibi_slopes=fd_inter_tensor.alibi_slopes,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
else: # Decoding phase
if is_prompts: # Context stage (prefilling phase)
if (
use_cuda_kernel
and query_states.dtype != torch.float32
and use_flash_attn2 # flash attn 2 currently only supports FP16/BF16
):
# Copy the GPU memory of kvcache during context stage
inference_ops.context_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
)

attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=kv_seq_len,
max_seqlen_k=kv_seq_len,
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
alibi_slopes=self.alibi_slopes,
)
attn_output = attn_output.view(token_nums, -1)

else:
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_size=block_size,
block_tables=block_tables,
output=output_tensor,
alibi_slopes=self.alibi_slopes,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)

else: # Decode stage
if use_cuda_kernel:
# Copy the GPU memory of kvcache during decode stage
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_size, block_tables
)
else:
copy_k_to_blocked_cache(
key_states,
k_cache,
kv_lengths=sequence_lengths,
block_tables=block_tables,
)
copy_k_to_blocked_cache(
value_states,
v_cache,
kv_lengths=sequence_lengths,
block_tables=block_tables,
)

attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
alibi_slopes=fd_inter_tensor.alibi_slopes,
alibi_slopes=self.alibi_slopes,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
Expand Down
28 changes: 27 additions & 1 deletion colossalai/inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Utils for model inference
Utilities for model inference
"""
import math
import os
import re
from pathlib import Path
Expand Down Expand Up @@ -55,6 +56,31 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()


def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
"""
Calculate the slopes for the Alibi positional encoding. The calculation is adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
Args:
num_heads (int): The number of heads.
device (torch.device): The device to perform the calculations on.
Returns:
torch.Tensor: The calculated slopes tensor of (nheads,) or (batch_size, nheads).
"""
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes


def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
"""
Check whether the checkpoint has an index file.
Expand Down
Loading

0 comments on commit ce12073

Please sign in to comment.