directly use bool instead of torch.float16 to avoid crash in ASIC like HPU which does not support float16
Browse files- attention.py +2 -3
attention.py
CHANGED
@@ -46,9 +46,8 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_
|
|
46 |
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
47 |
if is_causal and (not q.size(2) == 1):
|
48 |
s = max(s_q, s_k)
|
49 |
-
causal_mask = attn_weight.new_ones(s, s, dtype=torch.
|
50 |
causal_mask = causal_mask.tril()
|
51 |
-
causal_mask = causal_mask.to(torch.bool)
|
52 |
causal_mask = ~causal_mask
|
53 |
causal_mask = causal_mask[-s_q:, -s_k:]
|
54 |
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
@@ -297,4 +296,4 @@ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None
|
|
297 |
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
298 |
alibi_bias = alibi_bias * slopes
|
299 |
return alibi_bias.to(dtype=dtype)
|
300 |
-
ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
|
|
|
46 |
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
47 |
if is_causal and (not q.size(2) == 1):
|
48 |
s = max(s_q, s_k)
|
49 |
+
causal_mask = attn_weight.new_ones(s, s, dtype=torch.bool)
|
50 |
causal_mask = causal_mask.tril()
|
|
|
51 |
causal_mask = ~causal_mask
|
52 |
causal_mask = causal_mask[-s_q:, -s_k:]
|
53 |
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
|
|
296 |
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
297 |
alibi_bias = alibi_bias * slopes
|
298 |
return alibi_bias.to(dtype=dtype)
|
299 |
+
ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
|