File size: 1,431 Bytes
ac0b14a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from transformers import T5Config
POSITION_ENCODING_REL_T5_BIAS = "t5_relative_bias"
POSITION_ENCODING_REL_TRANSFORMER_XL = "transformer_xl_relative_encoding"
POSITION_ENCODING_ROTARY = "rotary"
POSITION_ENCODING_ROTARY_RERUN = "rotary_rerun"
POSITION_ENCODING_ROTARY_NEW = "new_rotary"
POSITION_ENCODING_ABS_LEARNED = "abs_learned"
POSITION_ENCODING_ABS_SINUSOID = "abs_sinusoid"
POSITION_ENCODING_ALiBi = "alibi"
POSITION_ENCODING_ALiBi_LEARNED = "alibi_learned"
POSITION_ENCODING_NONE = "none"
POSITION_ENCODING_NONE_WINDOW = "none_window"
class CustomT5Config(T5Config):
model_type = "custom_decoder_only_t5"
def __init__(
self,
position_encoding_type=POSITION_ENCODING_REL_T5_BIAS,
**kwargs,
):
if position_encoding_type not in [
POSITION_ENCODING_ALiBi,
POSITION_ENCODING_ALiBi_LEARNED,
POSITION_ENCODING_ABS_LEARNED,
POSITION_ENCODING_ABS_SINUSOID,
POSITION_ENCODING_REL_T5_BIAS,
POSITION_ENCODING_REL_TRANSFORMER_XL,
POSITION_ENCODING_ROTARY,
POSITION_ENCODING_ROTARY_NEW,
POSITION_ENCODING_NONE,
POSITION_ENCODING_NONE_WINDOW,
]:
raise ValueError(
f"Invalid position_encoding_type: {position_encoding_type}"
)
self.position_encoding_type = position_encoding_type
super().__init__(**kwargs)
|