ag0 commited on
Commit
f2e665e
·
1 Parent(s): 4ec6edc

Correct the output dtype of rmsnorm_func

Browse files

Currently the output dtype of `rmsnorm_func` is not the same as the input dtype, I'm not sure if this is the intended behaviour but this looks like a bug.

How to reproduce:
```
import torch

hidden_size = 8

hidden_states = torch.rand((4, hidden_size), dtype=torch.float16)
weight = torch.ones(hidden_size, dtype=torch.float32)
variance_epsilon = torch.tensor(1e-6)

def rmsnorm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return weight * hidden_states.to(input_dtype)

print('input', hidden_states.dtype)
print('output', rmsnorm_func(hidden_states, weight, variance_epsilon).dtype)
```

Result:
```
input torch.float16
output torch.float32
```

With this PR:
```
input torch.float16
output torch.float16
```

Files changed (1) hide show
  1. modeling_flash_llama.py +1 -1
modeling_flash_llama.py CHANGED
@@ -68,7 +68,7 @@ def rmsnorm_func(hidden_states, weight, variance_epsilon):
68
  hidden_states = hidden_states.to(torch.float32)
69
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
70
  hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
71
- return weight * hidden_states.to(input_dtype)
72
 
73
 
74
  class LlamaRMSNorm(nn.Module):
 
68
  hidden_states = hidden_states.to(torch.float32)
69
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
70
  hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
71
+ return (weight * hidden_states).to(input_dtype)
72
 
73
 
74
  class LlamaRMSNorm(nn.Module):