File size: 1,954 Bytes
f243c21 ff939d8 f243c21 7d1d22f ff939d8 f243c21 ff939d8 f243c21 78c5b19 f243c21 ff939d8 f243c21 7d1d22f f243c21 7d1d22f f243c21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
"""
unit tests for axolotl.core.trainer_builder
"""
import pytest
from axolotl.core.trainer_builder import HFRLTrainerBuilder
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@pytest.fixture(name="cfg")
def fixture_cfg():
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00005,
"save_steps": 100,
"output_dir": "./model-out",
"warmup_steps": 10,
"gradient_checkpointing": False,
"optimizer": "adamw_torch",
"sequence_len": 2048,
"rl": True,
"adam_beta1": 0.998,
"adam_beta2": 0.9,
"adam_epsilon": 0.00001,
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"model_config_type": "llama",
}
)
normalize_config(cfg)
return cfg
@pytest.fixture(name="tokenizer")
def fixture_tokenizer(cfg):
return load_tokenizer(cfg)
@pytest.fixture(name="model")
def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer)
class TestHFRLTrainerBuilder:
"""
TestCase class for DPO trainer builder
"""
def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
|