Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 26, 2024
1 parent ce12073 commit 0ad0d12
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 122 deletions.
203 changes: 98 additions & 105 deletions colossalai/inference/modeling/models/baichuan_13b.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.generation.utils import GenerationConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging
from transformers.generation.utils import GenerationConfig

from .configuration_baichuan import BaichuanConfig

Expand All @@ -19,42 +19,42 @@

def _get_interleave(n):
def _get_interleave_power_of_2(n):
start = (2 ** (-2 ** -(math.log2(n) - 3)))
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
return [start * ratio**i for i in range(n)]

if math.log2(n).is_integer():
return _get_interleave_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return _get_interleave_power_of_2(closest_power_of_2) + \
_get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
return (
_get_interleave_power_of_2(closest_power_of_2)
+ _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)


def _fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float("-inf")).type_as(t)


def _gen_alibi_mask(n_head, max_pos):
"""used in inference only"""
slopes = torch.Tensor(_get_interleave(n_head))
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(
n_head, -1, -1)
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
alibi = alibi.view(n_head, 1, max_pos)
alibi_mask = torch.triu(
_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
)
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
alibi_mask = alibi_mask.unsqueeze(0) + alibi
return alibi_mask


def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
"""used in training only"""
dim = tensor.size(1)
_future_mask = torch.triu(
_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1
)
tensor.size(1)
_future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
_future_mask = _future_mask.unsqueeze(0) + alibi
_future_mask = _future_mask.to(tensor)
return _future_mask[:tensor.shape[0] * attn_heads, :maxpos, :maxpos]
return _future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]


class RMSNorm(torch.nn.Module):
Expand All @@ -76,10 +76,10 @@ def forward(self, hidden_states):

class MLP(torch.nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
Expand All @@ -101,24 +101,21 @@ def __init__(self, config: BaichuanConfig):
self.max_position_embeddings = config.model_max_length

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
)
raise ValueError(f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}")
self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

bsz, q_len, _ = hidden_states.size()

proj = self.W_pack(hidden_states)
Expand All @@ -141,11 +138,11 @@ def forward(
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None:
if q_len == 1: # inference with cache
if q_len == 1: # inference with cache
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, -1:, :]
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
attention_mask = attention_mask[:, -1:, :]
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

Expand Down Expand Up @@ -177,14 +174,13 @@ def __init__(self, config: BaichuanConfig):
self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -261,33 +257,36 @@ def set_input_embeddings(self, value):
def get_alibi_mask(self, tensor, seq_length_with_past):
if self.training:
slopes = torch.Tensor(_get_interleave(self.n_head))
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(0).expand(
self.n_head,
-1, -1)
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(
0
).expand(self.n_head, -1, -1)
alibi = alibi.view(self.n_head, 1, seq_length_with_past)
mask = _buffered_future_mask(tensor, seq_length_with_past, alibi, self.n_head)
else:
if self.first_run:
self.first_run = False
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
self.register_buffer(
"future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False
)
if seq_length_with_past > self.max_cache_pos:
self.max_cache_pos = seq_length_with_past
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
self.register_buffer(
"future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False
)
mask = self.future_mask[: self.n_head, :seq_length_with_past, :seq_length_with_past]
return mask

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
elif input_ids is not None:
Expand Down Expand Up @@ -318,10 +317,11 @@ def forward(
if attention_mask is not None:
if len(attention_mask.shape) == 2:
expanded_mask = attention_mask.to(alibi_mask.dtype)
expanded_mask = torch.tril(torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
expanded_mask = torch.tril(
torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
else:
expanded_mask = attention_mask
expanded_mask = attention_mask
bsz = inputs_embeds.size(0)
src_len, tgt_len = alibi_mask.size()[-2:]
expanded_mask = expanded_mask.unsqueeze(1).expand(bsz, 1, src_len, tgt_len).to(alibi_mask.dtype)
Expand Down Expand Up @@ -428,21 +428,20 @@ def get_decoder(self):
return self.model

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
**kwargs
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
Expand Down Expand Up @@ -484,12 +483,12 @@ def forward(
)

def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
Expand All @@ -501,65 +500,58 @@ def prepare_inputs_for_generation(
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask
}
{"past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask}
)
return model_inputs

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
return tuple(
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
for layer_past in past_key_values
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) for layer_past in past_key_values
)

def quantize(self, bits: int):
try:
from .quantizer import QLinear
except ImportError:
raise ImportError(
f"Needs QLinear to run quantize."
)
raise ImportError(f"Needs QLinear to run quantize.")

for layer in self.model.layers:
layer.self_attn.W_pack = QLinear(
bits=bits,
weight=layer.self_attn.W_pack.weight,
bias = None,
bias=None,
)
layer.self_attn.o_proj = QLinear(
bits=bits,
weight=layer.self_attn.o_proj.weight,
bias = None,
bias=None,
)
layer.mlp.gate_proj = QLinear(
bits=bits,
weight=layer.mlp.gate_proj.weight,
bias = None,
bias=None,
)
layer.mlp.down_proj = QLinear(
bits=bits,
weight=layer.mlp.down_proj.weight,
bias = None,
bias=None,
)
layer.mlp.up_proj = QLinear(
bits=bits,
weight=layer.mlp.up_proj.weight,
bias = None,
bias=None,
)
return self

def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int = 0):
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
max_input_tokens = self.config.model_max_length - max_new_tokens
max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)
total_input, round_input = [], []
for i, message in enumerate(messages[::-1]):
content_tokens = tokenizer.encode(message['content'])
if message['role'] == 'user':
content_tokens = tokenizer.encode(message["content"])
if message["role"] == "user":
round_input = [self.generation_config.user_token_id] + content_tokens + round_input
if total_input and len(total_input) + len(round_input) > max_input_tokens:
break
Expand All @@ -569,12 +561,13 @@ def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int
break
else:
round_input = []
elif message['role'] == 'assistant':
round_input = [
self.generation_config.assistant_token_id
] + content_tokens + [
self.generation_config.eos_token_id
] + round_input
elif message["role"] == "assistant":
round_input = (
[self.generation_config.assistant_token_id]
+ content_tokens
+ [self.generation_config.eos_token_id]
+ round_input
)
else:
raise ValueError(f"message role not supported yet: {message['role']}")
total_input = total_input[-max_input_tokens:] # truncate left
Expand All @@ -583,12 +576,12 @@ def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int
return total_input

@torch.no_grad()
def chat(self, tokenizer, messages: List[dict], stream=False,
generation_config: Optional[GenerationConfig]=None):
def chat(self, tokenizer, messages: List[dict], stream=False, generation_config: Optional[GenerationConfig] = None):
generation_config = generation_config or self.generation_config
input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens)
if stream:
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig

self.__class__.generate = NewGenerationMixin.generate
self.__class__.sample_stream = NewGenerationMixin.sample_stream
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
Expand All @@ -603,5 +596,5 @@ def stream_generator():
else:
self.__class__.generate = PreTrainedModel.generate # disable stream
outputs = self.generate(input_ids, generation_config=generation_config)
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
return response
response = tokenizer.decode(outputs[0][len(input_ids[0]) :], skip_special_tokens=True)
return response
Loading

0 comments on commit 0ad0d12

Please sign in to comment.