hussain2030
commited on
Commit
·
87bede3
1
Parent(s):
5fc6a9a
Update modeling_jais.py
Browse filestorch.zeros to torch.empty
- modeling_jais.py +1 -1
modeling_jais.py
CHANGED
@@ -268,7 +268,7 @@ class JAISAttention(nn.Module):
|
|
268 |
_, _, k_seq_len, _ = key.size()
|
269 |
|
270 |
# Preallocate attn_weights for `baddbmm`
|
271 |
-
attn_weights = torch.
|
272 |
|
273 |
# Compute Scale Factor
|
274 |
scale_factor = 1.0
|
|
|
268 |
_, _, k_seq_len, _ = key.size()
|
269 |
|
270 |
# Preallocate attn_weights for `baddbmm`
|
271 |
+
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
|
272 |
|
273 |
# Compute Scale Factor
|
274 |
scale_factor = 1.0
|