codellm_1b_rotary / configuration_custom_t5.py
kazemnejad's picture
Upload CustomDecoderOnlyT5
ac0b14a verified
raw
history blame
1.43 kB
from transformers import T5Config
POSITION_ENCODING_REL_T5_BIAS = "t5_relative_bias"
POSITION_ENCODING_REL_TRANSFORMER_XL = "transformer_xl_relative_encoding"
POSITION_ENCODING_ROTARY = "rotary"
POSITION_ENCODING_ROTARY_RERUN = "rotary_rerun"
POSITION_ENCODING_ROTARY_NEW = "new_rotary"
POSITION_ENCODING_ABS_LEARNED = "abs_learned"
POSITION_ENCODING_ABS_SINUSOID = "abs_sinusoid"
POSITION_ENCODING_ALiBi = "alibi"
POSITION_ENCODING_ALiBi_LEARNED = "alibi_learned"
POSITION_ENCODING_NONE = "none"
POSITION_ENCODING_NONE_WINDOW = "none_window"
class CustomT5Config(T5Config):
model_type = "custom_decoder_only_t5"
def __init__(
self,
position_encoding_type=POSITION_ENCODING_REL_T5_BIAS,
**kwargs,
):
if position_encoding_type not in [
POSITION_ENCODING_ALiBi,
POSITION_ENCODING_ALiBi_LEARNED,
POSITION_ENCODING_ABS_LEARNED,
POSITION_ENCODING_ABS_SINUSOID,
POSITION_ENCODING_REL_T5_BIAS,
POSITION_ENCODING_REL_TRANSFORMER_XL,
POSITION_ENCODING_ROTARY,
POSITION_ENCODING_ROTARY_NEW,
POSITION_ENCODING_NONE,
POSITION_ENCODING_NONE_WINDOW,
]:
raise ValueError(
f"Invalid position_encoding_type: {position_encoding_type}"
)
self.position_encoding_type = position_encoding_type
super().__init__(**kwargs)