"""RMSnorm.""" import torch import torch.nn as nn class RMSNorm(torch.nn.Module): """RMSNorm: https://arxiv.org/abs/1910.07467 Args: hidden_size (int): layer hidden_sizeension. """ def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(hidden_size)) def forward(self, hidden_states): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) hidden_states = hidden_states.to(self.weight.dtype) return hidden_states * self.weight