Global CLS attention
#13
by
Markus28
- opened
- configuration_bert.py +2 -0
- modeling_bert.py +8 -3
configuration_bert.py
CHANGED
@@ -129,6 +129,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
129 |
feed_forward_type="original",
|
130 |
emb_pooler=None,
|
131 |
attn_implementation='torch',
|
|
|
132 |
**kwargs,
|
133 |
):
|
134 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
@@ -151,6 +152,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
151 |
self.feed_forward_type = feed_forward_type
|
152 |
self.emb_pooler = emb_pooler
|
153 |
self.attn_implementation = attn_implementation
|
|
|
154 |
|
155 |
class JinaBertOnnxConfig(OnnxConfig):
|
156 |
@property
|
|
|
129 |
feed_forward_type="original",
|
130 |
emb_pooler=None,
|
131 |
attn_implementation='torch',
|
132 |
+
cls_bias=None,
|
133 |
**kwargs,
|
134 |
):
|
135 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
|
|
152 |
self.feed_forward_type = feed_forward_type
|
153 |
self.emb_pooler = emb_pooler
|
154 |
self.attn_implementation = attn_implementation
|
155 |
+
self.cls_bias = cls_bias
|
156 |
|
157 |
class JinaBertOnnxConfig(OnnxConfig):
|
158 |
@property
|
modeling_bert.py
CHANGED
@@ -701,12 +701,12 @@ class JinaBertEncoder(nn.Module):
|
|
701 |
self.num_attention_heads = config.num_attention_heads
|
702 |
self.register_buffer(
|
703 |
"alibi",
|
704 |
-
self.rebuild_alibi_tensor(size=config.max_position_embeddings),
|
705 |
persistent=False,
|
706 |
)
|
707 |
|
708 |
def rebuild_alibi_tensor(
|
709 |
-
self, size: int, device: Optional[Union[torch.device, str]] = None
|
710 |
):
|
711 |
# Alibi
|
712 |
# Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
|
@@ -747,6 +747,10 @@ class JinaBertEncoder(nn.Module):
|
|
747 |
alibi = alibi.unsqueeze(0)
|
748 |
assert alibi.shape == torch.Size([1, n_heads, size, size])
|
749 |
|
|
|
|
|
|
|
|
|
750 |
self._current_alibi_size = size
|
751 |
return alibi
|
752 |
|
@@ -778,7 +782,8 @@ class JinaBertEncoder(nn.Module):
|
|
778 |
)
|
779 |
self.register_buffer(
|
780 |
"alibi",
|
781 |
-
self.rebuild_alibi_tensor(size=seqlen,
|
|
|
782 |
hidden_states.dtype
|
783 |
),
|
784 |
persistent=False,
|
|
|
701 |
self.num_attention_heads = config.num_attention_heads
|
702 |
self.register_buffer(
|
703 |
"alibi",
|
704 |
+
self.rebuild_alibi_tensor(size=config.max_position_embeddings, cls_bias=config.cls_bias),
|
705 |
persistent=False,
|
706 |
)
|
707 |
|
708 |
def rebuild_alibi_tensor(
|
709 |
+
self, size: int, device: Optional[Union[torch.device, str]] = None, cls_bias=None
|
710 |
):
|
711 |
# Alibi
|
712 |
# Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
|
|
|
747 |
alibi = alibi.unsqueeze(0)
|
748 |
assert alibi.shape == torch.Size([1, n_heads, size, size])
|
749 |
|
750 |
+
if cls_bias is not None:
|
751 |
+
alibi[:, :, 0, :] = cls_bias
|
752 |
+
alibi[:, :, :, 0] = cls_bias
|
753 |
+
|
754 |
self._current_alibi_size = size
|
755 |
return alibi
|
756 |
|
|
|
782 |
)
|
783 |
self.register_buffer(
|
784 |
"alibi",
|
785 |
+
self.rebuild_alibi_tensor(size=seqlen, cls_bias=self.config.cls_bias,
|
786 |
+
device=hidden_states.device).to(
|
787 |
hidden_states.dtype
|
788 |
),
|
789 |
persistent=False,
|