Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Compatible with paddle.where #9534

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llm/experimental/ernie-3.5-se/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class BFloatFInfo:

def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
return paddle.where(mask.to("bool"), y, x)


def scaled_dot_product_attention(
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def paddle_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = N

def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
return paddle.where(mask.to("bool"), y, x)

# probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
probability_matrix = masked_fill(probability_matrix, special_tokens_mask, value=0.0)
Expand Down Expand Up @@ -789,6 +789,7 @@ def paddle_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
]

def masked_fill(x, mask, value):
mask = mask.astype("bool")
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def _prepare_attn_mask(
# Attention score will be cast to float32 in the following calculation, therefore we set attention_mask dtype as float32
zero = paddle.zeros(expanded_attn_mask.shape, dtype=paddle.float32)
neg_inf = paddle.full(expanded_attn_mask.shape, paddle.finfo(paddle.float32).min, dtype=paddle.float32)
expanded_attn_mask = paddle.where(expanded_attn_mask, zero, neg_inf)
expanded_attn_mask = paddle.where(expanded_attn_mask.to("bool"), zero, neg_inf)
batch_size, num_heads, sq_len, kv_len = expanded_attn_mask.shape
return expanded_attn_mask.reshape([batch_size * num_heads, sq_len, kv_len])

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/codegen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _attn(self, query, key, value, attention_mask=None):
attn_weights = attn_weights / self.scale_attn
mask_value = paddle.to_tensor(-1e4, dtype=attn_weights.dtype)
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
attn_weights = paddle.where(causal_mask, attn_weights, mask_value)
attn_weights = paddle.where(causal_mask.to("bool"), attn_weights, mask_value)

if attention_mask is not None:
# Apply the attention mask
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/gemma/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
else:
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask.to("bool"), 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask

@paddle.jit.not_to_static
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/gptj/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _attn(
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = paddle.to_tensor(mask_value, dtype=attn_weights.dtype, place=attn_weights.place)
attn_weights = paddle.where(causal_mask, attn_weights, mask_value)
attn_weights = paddle.where(causal_mask.to("bool"), attn_weights, mask_value)

attn_weights = attn_weights / self.scale_attn

Expand Down
6 changes: 3 additions & 3 deletions paddlenlp/transformers/mixtral/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def scaled_dot_product_attention(

def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
return paddle.where(mask.to("bool"), y, x)


def is_casual_mask(attention_mask):
Expand Down Expand Up @@ -519,7 +519,7 @@ def forward(self, hidden_states):
# this will be used to easily index which expert is going to be sollicitated.
# shape: [num_experts, top_k, batch_size * seq_len]
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).transpose([2, 1, 0])

expert_mask = expert_mask.to("bool")
# Loop over all available experts in the model and perform the computation on each expert.
for expert_id in range(self.num_experts):
expert_layer = self.experts[expert_id]
Expand Down Expand Up @@ -1098,7 +1098,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
past_key_values_length=past_key_values_length,
)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask.to("bool"), 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask

@paddle.jit.not_to_static
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def scaled_dot_product_attention(

def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
return paddle.where(mask.to("bool"), y, x)


def is_casual_mask(attention_mask):
Expand Down Expand Up @@ -1020,7 +1020,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
past_key_values_length=past_key_values_length,
)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask.to("bool"), 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask

@paddle.jit.not_to_static
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/qwen2_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def scaled_dot_product_attention(

def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
return paddle.where(mask.to("bool"), y, x)


def is_casual_mask(attention_mask):
Expand Down Expand Up @@ -1124,7 +1124,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
past_key_values_length=past_key_values_length,
)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask.to("bool"), 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask

@paddle.jit.not_to_static
Expand Down