Update modeling_intern_vit.py
Browse files- modeling_intern_vit.py +9 -2
modeling_intern_vit.py
CHANGED
@@ -129,6 +129,12 @@ except Exception:
|
|
129 |
pass
|
130 |
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
class InternVisionEmbeddings(nn.Module):
|
133 |
def __init__(self, config: InternVisionConfig):
|
134 |
super().__init__()
|
@@ -267,11 +273,12 @@ class InternVisionEncoderLayer(nn.Module):
|
|
267 |
super().__init__()
|
268 |
self.embed_dim = config.hidden_size
|
269 |
self.intermediate_size = config.intermediate_size
|
|
|
270 |
|
271 |
self.attn = InternAttention(config)
|
272 |
self.mlp = InternMLP(config)
|
273 |
-
self.norm1 =
|
274 |
-
self.norm2 =
|
275 |
|
276 |
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
277 |
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
|
|
129 |
pass
|
130 |
|
131 |
|
132 |
+
NORM2FN = {
|
133 |
+
'rms_norm': InternRMSNorm,
|
134 |
+
'layer_norm': nn.LayerNorm,
|
135 |
+
}
|
136 |
+
|
137 |
+
|
138 |
class InternVisionEmbeddings(nn.Module):
|
139 |
def __init__(self, config: InternVisionConfig):
|
140 |
super().__init__()
|
|
|
273 |
super().__init__()
|
274 |
self.embed_dim = config.hidden_size
|
275 |
self.intermediate_size = config.intermediate_size
|
276 |
+
self.norm_type = config.norm_type
|
277 |
|
278 |
self.attn = InternAttention(config)
|
279 |
self.mlp = InternMLP(config)
|
280 |
+
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
281 |
+
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
282 |
|
283 |
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
284 |
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|