theobjectivedad
commited on
Commit
·
553a86b
1
Parent(s):
ef17e15
Adding logging enhancement
Browse files- scripts/alpaca_json_to_jsonl.py +3 -0
- scripts/finetune.py +17 -14
- src/axolotl/datasets.py +2 -1
- src/axolotl/logging_config.py +27 -0
- src/axolotl/monkeypatch/llama_landmark_attn.py +1 -2
- src/axolotl/prompt_strategies/pygmalion.py +1 -1
- src/axolotl/prompt_tokenizers.py +3 -1
- src/axolotl/prompters.py +1 -1
- src/axolotl/utils/data.py +20 -18
- src/axolotl/utils/models.py +23 -21
- src/axolotl/utils/tokenization.py +4 -2
- src/axolotl/utils/trainer.py +3 -1
- src/axolotl/utils/validation.py +10 -10
- tests/test_prompt_tokenizers.py +4 -1
scripts/alpaca_json_to_jsonl.py
CHANGED
@@ -15,6 +15,9 @@ from axolotl.convert import (
|
|
15 |
JsonToJsonlConverter,
|
16 |
StdoutWriter,
|
17 |
)
|
|
|
|
|
|
|
18 |
|
19 |
# add src to the pythonpath so we don't need to pip install this
|
20 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
|
15 |
JsonToJsonlConverter,
|
16 |
StdoutWriter,
|
17 |
)
|
18 |
+
from axolotl.logging_config import configure_logging
|
19 |
+
|
20 |
+
configure_logging()
|
21 |
|
22 |
# add src to the pythonpath so we don't need to pip install this
|
23 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
scripts/finetune.py
CHANGED
@@ -24,13 +24,16 @@ from axolotl.utils.tokenization import check_dataset_labels
|
|
24 |
from axolotl.utils.trainer import setup_trainer
|
25 |
from axolotl.utils.validation import validate_config
|
26 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
|
|
27 |
|
28 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
29 |
src_dir = os.path.join(project_root, "src")
|
30 |
sys.path.insert(0, src_dir)
|
31 |
|
|
|
|
|
|
|
32 |
|
33 |
-
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
34 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
35 |
|
36 |
|
@@ -212,7 +215,7 @@ def train(
|
|
212 |
|
213 |
# load the tokenizer first
|
214 |
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
215 |
-
|
216 |
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
217 |
|
218 |
if (
|
@@ -234,7 +237,7 @@ def train(
|
|
234 |
eval_dataset = None
|
235 |
|
236 |
if cfg.debug or "debug" in kwargs:
|
237 |
-
|
238 |
check_dataset_labels(
|
239 |
train_dataset.select(
|
240 |
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
|
@@ -243,11 +246,11 @@ def train(
|
|
243 |
)
|
244 |
|
245 |
if prepare_ds_only:
|
246 |
-
|
247 |
return
|
248 |
|
249 |
# Load the model and tokenizer
|
250 |
-
|
251 |
model, peft_config = load_model(
|
252 |
cfg.base_model,
|
253 |
cfg.base_model_config,
|
@@ -258,17 +261,17 @@ def train(
|
|
258 |
)
|
259 |
|
260 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
261 |
-
|
262 |
model = model.merge_and_unload()
|
263 |
model.to(dtype=torch.float16)
|
264 |
|
265 |
if cfg.local_rank == 0:
|
266 |
-
|
267 |
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
268 |
return
|
269 |
|
270 |
if cfg.inference:
|
271 |
-
|
272 |
prompter: Optional[str] = "AlpacaPrompter"
|
273 |
if "prompter" in kwargs:
|
274 |
if kwargs["prompter"] == "None":
|
@@ -287,12 +290,12 @@ def train(
|
|
287 |
model.config.use_cache = False
|
288 |
|
289 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
290 |
-
|
291 |
model = torch.compile(model)
|
292 |
|
293 |
# go ahead and presave, so we have the adapter config available to inspect
|
294 |
if peft_config:
|
295 |
-
|
296 |
peft_config.save_pretrained(cfg.output_dir)
|
297 |
|
298 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
@@ -308,9 +311,9 @@ def train(
|
|
308 |
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
309 |
)
|
310 |
|
311 |
-
|
312 |
if cfg.group_by_length:
|
313 |
-
|
314 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
315 |
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
316 |
possible_checkpoints = [
|
@@ -322,7 +325,7 @@ def train(
|
|
322 |
key=lambda path: int(path.split("-")[-1]),
|
323 |
)
|
324 |
resume_from_checkpoint = sorted_paths[-1]
|
325 |
-
|
326 |
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
327 |
)
|
328 |
|
@@ -336,7 +339,7 @@ def train(
|
|
336 |
else:
|
337 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
338 |
|
339 |
-
|
340 |
|
341 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
342 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
|
|
24 |
from axolotl.utils.trainer import setup_trainer
|
25 |
from axolotl.utils.validation import validate_config
|
26 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
27 |
+
from axolotl.logging_config import configure_logging
|
28 |
|
29 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
30 |
src_dir = os.path.join(project_root, "src")
|
31 |
sys.path.insert(0, src_dir)
|
32 |
|
33 |
+
configure_logging()
|
34 |
+
LOG = logging.getLogger("axolotl.scripts")
|
35 |
+
|
36 |
|
|
|
37 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
38 |
|
39 |
|
|
|
215 |
|
216 |
# load the tokenizer first
|
217 |
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
218 |
+
LOG.info(f"loading tokenizer... {tokenizer_config}")
|
219 |
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
220 |
|
221 |
if (
|
|
|
237 |
eval_dataset = None
|
238 |
|
239 |
if cfg.debug or "debug" in kwargs:
|
240 |
+
LOG.info("check_dataset_labels...")
|
241 |
check_dataset_labels(
|
242 |
train_dataset.select(
|
243 |
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
|
|
|
246 |
)
|
247 |
|
248 |
if prepare_ds_only:
|
249 |
+
LOG.info("Finished preparing dataset. Exiting...")
|
250 |
return
|
251 |
|
252 |
# Load the model and tokenizer
|
253 |
+
LOG.info("loading model and peft_config...")
|
254 |
model, peft_config = load_model(
|
255 |
cfg.base_model,
|
256 |
cfg.base_model_config,
|
|
|
261 |
)
|
262 |
|
263 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
264 |
+
LOG.info("running merge of LoRA with base model")
|
265 |
model = model.merge_and_unload()
|
266 |
model.to(dtype=torch.float16)
|
267 |
|
268 |
if cfg.local_rank == 0:
|
269 |
+
LOG.info("saving merged model")
|
270 |
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
271 |
return
|
272 |
|
273 |
if cfg.inference:
|
274 |
+
LOG.info("calling do_inference function")
|
275 |
prompter: Optional[str] = "AlpacaPrompter"
|
276 |
if "prompter" in kwargs:
|
277 |
if kwargs["prompter"] == "None":
|
|
|
290 |
model.config.use_cache = False
|
291 |
|
292 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
293 |
+
LOG.info("Compiling torch model")
|
294 |
model = torch.compile(model)
|
295 |
|
296 |
# go ahead and presave, so we have the adapter config available to inspect
|
297 |
if peft_config:
|
298 |
+
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
299 |
peft_config.save_pretrained(cfg.output_dir)
|
300 |
|
301 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
|
|
311 |
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
312 |
)
|
313 |
|
314 |
+
LOG.info("Starting trainer...")
|
315 |
if cfg.group_by_length:
|
316 |
+
LOG.info("hang tight... sorting dataset for group_by_length")
|
317 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
318 |
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
319 |
possible_checkpoints = [
|
|
|
325 |
key=lambda path: int(path.split("-")[-1]),
|
326 |
)
|
327 |
resume_from_checkpoint = sorted_paths[-1]
|
328 |
+
LOG.info(
|
329 |
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
330 |
)
|
331 |
|
|
|
339 |
else:
|
340 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
341 |
|
342 |
+
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
343 |
|
344 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
345 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
src/axolotl/datasets.py
CHANGED
@@ -14,6 +14,7 @@ from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
|
|
14 |
# let's check to ensure we don't truncate an item in the middle, we'll use
|
15 |
# the collators later on to pad the datasets
|
16 |
|
|
|
17 |
|
18 |
class TokenizedPromptDataset(IterableDataset):
|
19 |
"""
|
@@ -115,7 +116,7 @@ class ConstantLengthDataset(IterableDataset):
|
|
115 |
"attention_mask": attention_mask,
|
116 |
}
|
117 |
else:
|
118 |
-
|
119 |
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
120 |
)
|
121 |
buffer = {
|
|
|
14 |
# let's check to ensure we don't truncate an item in the middle, we'll use
|
15 |
# the collators later on to pad the datasets
|
16 |
|
17 |
+
LOG = logging.getLogger("axolotl")
|
18 |
|
19 |
class TokenizedPromptDataset(IterableDataset):
|
20 |
"""
|
|
|
116 |
"attention_mask": attention_mask,
|
117 |
}
|
118 |
else:
|
119 |
+
LOG.warning(
|
120 |
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
121 |
)
|
122 |
buffer = {
|
src/axolotl/logging_config.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from logging.config import dictConfig
|
3 |
+
from typing import Any, Dict
|
4 |
+
|
5 |
+
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
6 |
+
"version": 1,
|
7 |
+
"formatters": {
|
8 |
+
"simple": {
|
9 |
+
"format": "[%(asctime)s] [%(levelname)s] [PID:%(process)d] [%(name)s.%(funcName)s:%(lineno)d] %(message)s",
|
10 |
+
},
|
11 |
+
},
|
12 |
+
"filters": {},
|
13 |
+
"handlers": {
|
14 |
+
"console": {
|
15 |
+
"class": "logging.StreamHandler",
|
16 |
+
"formatter": "simple",
|
17 |
+
"filters": [],
|
18 |
+
"stream": sys.stdout,
|
19 |
+
},
|
20 |
+
},
|
21 |
+
"root": {"handlers": ["console"], "level": "INFO"},
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
def configure_logging():
|
26 |
+
"""Configure with default logging"""
|
27 |
+
dictConfig(DEFAULT_LOGGING_CONFIG)
|
src/axolotl/monkeypatch/llama_landmark_attn.py
CHANGED
@@ -52,8 +52,7 @@ from transformers.utils import (
|
|
52 |
logging,
|
53 |
replace_return_docstrings,
|
54 |
)
|
55 |
-
|
56 |
-
logger = logging.get_logger(__name__)
|
57 |
|
58 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
59 |
|
|
|
52 |
logging,
|
53 |
replace_return_docstrings,
|
54 |
)
|
55 |
+
LOG = logging.getLogger("axolotl")
|
|
|
56 |
|
57 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
58 |
|
src/axolotl/prompt_strategies/pygmalion.py
CHANGED
@@ -64,7 +64,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
64 |
*copy.deepcopy(res["input_ids"])
|
65 |
][len(self.bot_prefix_token_ids) :]
|
66 |
else:
|
67 |
-
|
68 |
res = defaultdict(lambda: [])
|
69 |
|
70 |
# pylint: disable=duplicate-code
|
|
|
64 |
*copy.deepcopy(res["input_ids"])
|
65 |
][len(self.bot_prefix_token_ids) :]
|
66 |
else:
|
67 |
+
LOG.warning(f"unknown role in conversation: {role}")
|
68 |
res = defaultdict(lambda: [])
|
69 |
|
70 |
# pylint: disable=duplicate-code
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -10,6 +10,8 @@ from transformers import PreTrainedTokenizer
|
|
10 |
|
11 |
from axolotl.prompters import IGNORE_TOKEN_ID
|
12 |
|
|
|
|
|
13 |
IGNORE_INDEX = -100
|
14 |
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
|
15 |
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
@@ -384,7 +386,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
384 |
# everything from this is masked out from the labels
|
385 |
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
386 |
else:
|
387 |
-
|
388 |
|
389 |
# pylint: disable=duplicate-code
|
390 |
result, current_len = parse_tokenized_to_result(
|
|
|
10 |
|
11 |
from axolotl.prompters import IGNORE_TOKEN_ID
|
12 |
|
13 |
+
LOG = logging.getLogger("axolotl")
|
14 |
+
|
15 |
IGNORE_INDEX = -100
|
16 |
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
|
17 |
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
|
|
386 |
# everything from this is masked out from the labels
|
387 |
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
388 |
else:
|
389 |
+
LOG.warning(f"unhandled role: {part[0]}")
|
390 |
|
391 |
# pylint: disable=duplicate-code
|
392 |
result, current_len = parse_tokenized_to_result(
|
src/axolotl/prompters.py
CHANGED
@@ -241,7 +241,7 @@ class Conversation:
|
|
241 |
if message:
|
242 |
yield (role + ":", " " + message)
|
243 |
else:
|
244 |
-
|
245 |
yield (role + ":", "")
|
246 |
|
247 |
def copy(self):
|
|
|
241 |
if message:
|
242 |
yield (role + ":", " " + message)
|
243 |
else:
|
244 |
+
LOG.warning(f"role with empty message: {role}")
|
245 |
yield (role + ":", "")
|
246 |
|
247 |
def copy(self):
|
src/axolotl/utils/data.py
CHANGED
@@ -35,6 +35,8 @@ from axolotl.prompters import (
|
|
35 |
SummarizeTLDRPrompter,
|
36 |
)
|
37 |
|
|
|
|
|
38 |
|
39 |
def load_tokenized_prepared_datasets(
|
40 |
tokenizer, cfg, default_dataset_prepared_path
|
@@ -73,17 +75,17 @@ def load_tokenized_prepared_datasets(
|
|
73 |
if dataset:
|
74 |
...
|
75 |
elif any(prepared_ds_path.glob("*")):
|
76 |
-
|
77 |
dataset = load_from_disk(str(prepared_ds_path))
|
78 |
-
|
79 |
else:
|
80 |
-
|
81 |
-
|
82 |
|
83 |
if cfg.seed:
|
84 |
seed = cfg.seed
|
85 |
else:
|
86 |
-
|
87 |
seed = 42
|
88 |
|
89 |
datasets = []
|
@@ -256,25 +258,25 @@ def load_tokenized_prepared_datasets(
|
|
256 |
suffix = ""
|
257 |
if ":load_" in d.type:
|
258 |
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
259 |
-
|
260 |
f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
|
261 |
)
|
262 |
raise ValueError(
|
263 |
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
264 |
)
|
265 |
-
|
266 |
|
267 |
samples: List[int] = []
|
268 |
for d in datasets:
|
269 |
samples = samples + list(d)
|
270 |
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
271 |
if cfg.local_rank == 0:
|
272 |
-
|
273 |
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
274 |
)
|
275 |
dataset.save_to_disk(prepared_ds_path)
|
276 |
if cfg.push_dataset_to_hub:
|
277 |
-
|
278 |
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
279 |
)
|
280 |
dataset.push_to_hub(
|
@@ -325,7 +327,7 @@ def load_prepare_datasets(
|
|
325 |
use_auth_token = cfg.hf_use_auth_token
|
326 |
try:
|
327 |
if cfg.push_dataset_to_hub:
|
328 |
-
|
329 |
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
330 |
)
|
331 |
dataset = load_dataset(
|
@@ -339,13 +341,13 @@ def load_prepare_datasets(
|
|
339 |
if dataset:
|
340 |
...
|
341 |
elif any(prepared_ds_path.glob("*")):
|
342 |
-
|
343 |
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
344 |
)
|
345 |
dataset = load_from_disk(str(prepared_ds_path))
|
346 |
-
|
347 |
if cfg.push_dataset_to_hub:
|
348 |
-
|
349 |
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
350 |
)
|
351 |
dataset.push_to_hub(
|
@@ -364,7 +366,7 @@ def load_prepare_datasets(
|
|
364 |
[dataset],
|
365 |
seq_length=max_packed_sequence_len,
|
366 |
)
|
367 |
-
|
368 |
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
369 |
)
|
370 |
dataset = Dataset.from_list(list(constant_len_dataset))
|
@@ -382,12 +384,12 @@ def load_prepare_datasets(
|
|
382 |
)
|
383 |
|
384 |
if cfg.local_rank == 0:
|
385 |
-
|
386 |
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
387 |
)
|
388 |
dataset.save_to_disk(prepared_ds_path)
|
389 |
if cfg.push_dataset_to_hub:
|
390 |
-
|
391 |
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
392 |
)
|
393 |
dataset.push_to_hub(
|
@@ -400,7 +402,7 @@ def load_prepare_datasets(
|
|
400 |
)
|
401 |
|
402 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
403 |
-
|
404 |
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
405 |
)
|
406 |
dataset = dataset.shard(
|
@@ -521,7 +523,7 @@ def encode_pretraining(tokenizer, max_tokens, examples):
|
|
521 |
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
522 |
}
|
523 |
|
524 |
-
|
525 |
return ret
|
526 |
|
527 |
|
|
|
35 |
SummarizeTLDRPrompter,
|
36 |
)
|
37 |
|
38 |
+
LOG = logging.getLogger("axolotl")
|
39 |
+
|
40 |
|
41 |
def load_tokenized_prepared_datasets(
|
42 |
tokenizer, cfg, default_dataset_prepared_path
|
|
|
75 |
if dataset:
|
76 |
...
|
77 |
elif any(prepared_ds_path.glob("*")):
|
78 |
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
79 |
dataset = load_from_disk(str(prepared_ds_path))
|
80 |
+
LOG.info("Prepared dataset loaded from disk...")
|
81 |
else:
|
82 |
+
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
83 |
+
LOG.info("Loading raw datasets...")
|
84 |
|
85 |
if cfg.seed:
|
86 |
seed = cfg.seed
|
87 |
else:
|
88 |
+
LOG.info("No seed provided, using default seed of 42")
|
89 |
seed = 42
|
90 |
|
91 |
datasets = []
|
|
|
258 |
suffix = ""
|
259 |
if ":load_" in d.type:
|
260 |
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
261 |
+
LOG.error(
|
262 |
f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
|
263 |
)
|
264 |
raise ValueError(
|
265 |
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
266 |
)
|
267 |
+
LOG.info("tokenizing, merging, and shuffling master dataset")
|
268 |
|
269 |
samples: List[int] = []
|
270 |
for d in datasets:
|
271 |
samples = samples + list(d)
|
272 |
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
273 |
if cfg.local_rank == 0:
|
274 |
+
LOG.info(
|
275 |
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
276 |
)
|
277 |
dataset.save_to_disk(prepared_ds_path)
|
278 |
if cfg.push_dataset_to_hub:
|
279 |
+
LOG.info(
|
280 |
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
281 |
)
|
282 |
dataset.push_to_hub(
|
|
|
327 |
use_auth_token = cfg.hf_use_auth_token
|
328 |
try:
|
329 |
if cfg.push_dataset_to_hub:
|
330 |
+
LOG.info(
|
331 |
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
332 |
)
|
333 |
dataset = load_dataset(
|
|
|
341 |
if dataset:
|
342 |
...
|
343 |
elif any(prepared_ds_path.glob("*")):
|
344 |
+
LOG.info(
|
345 |
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
346 |
)
|
347 |
dataset = load_from_disk(str(prepared_ds_path))
|
348 |
+
LOG.info("Prepared packed dataset loaded from disk...")
|
349 |
if cfg.push_dataset_to_hub:
|
350 |
+
LOG.info(
|
351 |
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
352 |
)
|
353 |
dataset.push_to_hub(
|
|
|
366 |
[dataset],
|
367 |
seq_length=max_packed_sequence_len,
|
368 |
)
|
369 |
+
LOG.info(
|
370 |
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
371 |
)
|
372 |
dataset = Dataset.from_list(list(constant_len_dataset))
|
|
|
384 |
)
|
385 |
|
386 |
if cfg.local_rank == 0:
|
387 |
+
LOG.info(
|
388 |
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
389 |
)
|
390 |
dataset.save_to_disk(prepared_ds_path)
|
391 |
if cfg.push_dataset_to_hub:
|
392 |
+
LOG.info(
|
393 |
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
394 |
)
|
395 |
dataset.push_to_hub(
|
|
|
402 |
)
|
403 |
|
404 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
405 |
+
LOG.info(
|
406 |
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
407 |
)
|
408 |
dataset = dataset.shard(
|
|
|
523 |
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
524 |
}
|
525 |
|
526 |
+
LOG.debug(len(ret["input_ids"]))
|
527 |
return ret
|
528 |
|
529 |
|
src/axolotl/utils/models.py
CHANGED
@@ -23,6 +23,8 @@ from transformers import ( # noqa: F401
|
|
23 |
|
24 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
25 |
|
|
|
|
|
26 |
if TYPE_CHECKING:
|
27 |
from peft import PeftConfig # noqa: F401
|
28 |
|
@@ -50,10 +52,10 @@ def load_tokenizer(
|
|
50 |
use_fast=use_fast,
|
51 |
)
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
58 |
if tokenizer.__class__.__name__ in [
|
59 |
"LlamaTokenizer",
|
@@ -92,21 +94,21 @@ def load_model(
|
|
92 |
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
93 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
94 |
|
95 |
-
|
96 |
replace_llama_attn_with_flash_attn()
|
97 |
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
98 |
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
99 |
hijack_llama_attention,
|
100 |
)
|
101 |
|
102 |
-
|
103 |
hijack_llama_attention()
|
104 |
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
105 |
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
106 |
hijack_llama_sdp_attention,
|
107 |
)
|
108 |
|
109 |
-
|
110 |
hijack_llama_sdp_attention()
|
111 |
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
112 |
from axolotl.monkeypatch.llama_landmark_attn import (
|
@@ -114,7 +116,7 @@ def load_model(
|
|
114 |
patch_llama_with_landmark_attn,
|
115 |
)
|
116 |
|
117 |
-
|
118 |
patch_llama_with_landmark_attn()
|
119 |
|
120 |
# Note: This might overwrite previous additional_special_tokens
|
@@ -125,7 +127,7 @@ def load_model(
|
|
125 |
replace_llama_rope_with_xpos_rope,
|
126 |
)
|
127 |
|
128 |
-
|
129 |
replace_llama_rope_with_xpos_rope()
|
130 |
|
131 |
if cfg.bf16 or cfg.bfloat16:
|
@@ -142,7 +144,7 @@ def load_model(
|
|
142 |
|
143 |
replace_peft_model_with_int4_lora_model()
|
144 |
except Exception as err:
|
145 |
-
|
146 |
raise err
|
147 |
|
148 |
try:
|
@@ -187,7 +189,7 @@ def load_model(
|
|
187 |
if len(files) > 0:
|
188 |
model_path = str(files[0])
|
189 |
else:
|
190 |
-
|
191 |
"unable to find a cached model file, this will likely fail..."
|
192 |
)
|
193 |
model_path = str(cache_model_path)
|
@@ -266,14 +268,14 @@ def load_model(
|
|
266 |
and cfg.sequence_len > config.max_seq_len
|
267 |
):
|
268 |
config.max_seq_len = cfg.sequence_len
|
269 |
-
|
270 |
elif (
|
271 |
hasattr(config, "max_sequence_length")
|
272 |
and config.max_sequence_length
|
273 |
and cfg.sequence_len > config.max_sequence_length
|
274 |
):
|
275 |
config.max_sequence_length = cfg.sequence_len
|
276 |
-
|
277 |
model = AutoModelForCausalLM.from_pretrained(
|
278 |
base_model,
|
279 |
config=config,
|
@@ -285,10 +287,10 @@ def load_model(
|
|
285 |
**model_kwargs,
|
286 |
)
|
287 |
except Exception as err: # pylint: disable=broad-exception-caught
|
288 |
-
|
289 |
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
290 |
)
|
291 |
-
|
292 |
model = AutoModelForCausalLM.from_pretrained(
|
293 |
base_model,
|
294 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
@@ -307,7 +309,7 @@ def load_model(
|
|
307 |
and model.config.max_position_embeddings
|
308 |
and cfg.sequence_len >= model.config.max_position_embeddings
|
309 |
):
|
310 |
-
|
311 |
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
312 |
)
|
313 |
model.config.max_position_embeddings = cfg.sequence_len
|
@@ -316,7 +318,7 @@ def load_model(
|
|
316 |
(cfg.adapter == "lora" and load_in_8bit)
|
317 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
318 |
):
|
319 |
-
|
320 |
model = prepare_model_for_kbit_training(
|
321 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
322 |
)
|
@@ -328,7 +330,7 @@ def load_model(
|
|
328 |
|
329 |
if cfg.gptq:
|
330 |
# Scales to half
|
331 |
-
|
332 |
for _, module in model.named_modules():
|
333 |
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
|
334 |
type(module)
|
@@ -354,7 +356,7 @@ def load_model(
|
|
354 |
if param.requires_grad:
|
355 |
requires_grad.append(f"{name}: {param.requires_grad}")
|
356 |
if len(requires_grad) == 0:
|
357 |
-
|
358 |
model.config.use_cache = False
|
359 |
|
360 |
if cfg.flash_optimum:
|
@@ -388,7 +390,7 @@ def load_llama_adapter(model, cfg):
|
|
388 |
)
|
389 |
|
390 |
if cfg.lora_model_dir:
|
391 |
-
|
392 |
model = PeftModel.from_pretrained(
|
393 |
model,
|
394 |
cfg.lora_model_dir,
|
@@ -435,7 +437,7 @@ def load_lora(model, cfg):
|
|
435 |
bits = 8
|
436 |
|
437 |
linear_names = find_all_linear_names(bits, model)
|
438 |
-
|
439 |
lora_target_modules = list(set(lora_target_modules + linear_names))
|
440 |
|
441 |
lora_config = LoraConfig(
|
|
|
23 |
|
24 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
25 |
|
26 |
+
LOG = logging.getLogger("axolotl")
|
27 |
+
|
28 |
if TYPE_CHECKING:
|
29 |
from peft import PeftConfig # noqa: F401
|
30 |
|
|
|
52 |
use_fast=use_fast,
|
53 |
)
|
54 |
|
55 |
+
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
56 |
+
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
57 |
+
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
58 |
+
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
59 |
|
60 |
if tokenizer.__class__.__name__ in [
|
61 |
"LlamaTokenizer",
|
|
|
94 |
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
95 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
96 |
|
97 |
+
LOG.info("patching with flash attention")
|
98 |
replace_llama_attn_with_flash_attn()
|
99 |
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
100 |
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
101 |
hijack_llama_attention,
|
102 |
)
|
103 |
|
104 |
+
LOG.info("patching with xformers attention")
|
105 |
hijack_llama_attention()
|
106 |
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
107 |
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
108 |
hijack_llama_sdp_attention,
|
109 |
)
|
110 |
|
111 |
+
LOG.info("patching with sdp attention")
|
112 |
hijack_llama_sdp_attention()
|
113 |
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
114 |
from axolotl.monkeypatch.llama_landmark_attn import (
|
|
|
116 |
patch_llama_with_landmark_attn,
|
117 |
)
|
118 |
|
119 |
+
LOG.info("patching with landmark attention")
|
120 |
patch_llama_with_landmark_attn()
|
121 |
|
122 |
# Note: This might overwrite previous additional_special_tokens
|
|
|
127 |
replace_llama_rope_with_xpos_rope,
|
128 |
)
|
129 |
|
130 |
+
LOG.info("patching with xpos rope")
|
131 |
replace_llama_rope_with_xpos_rope()
|
132 |
|
133 |
if cfg.bf16 or cfg.bfloat16:
|
|
|
144 |
|
145 |
replace_peft_model_with_int4_lora_model()
|
146 |
except Exception as err:
|
147 |
+
LOG.exception(err)
|
148 |
raise err
|
149 |
|
150 |
try:
|
|
|
189 |
if len(files) > 0:
|
190 |
model_path = str(files[0])
|
191 |
else:
|
192 |
+
LOG.warning(
|
193 |
"unable to find a cached model file, this will likely fail..."
|
194 |
)
|
195 |
model_path = str(cache_model_path)
|
|
|
268 |
and cfg.sequence_len > config.max_seq_len
|
269 |
):
|
270 |
config.max_seq_len = cfg.sequence_len
|
271 |
+
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
272 |
elif (
|
273 |
hasattr(config, "max_sequence_length")
|
274 |
and config.max_sequence_length
|
275 |
and cfg.sequence_len > config.max_sequence_length
|
276 |
):
|
277 |
config.max_sequence_length = cfg.sequence_len
|
278 |
+
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
279 |
model = AutoModelForCausalLM.from_pretrained(
|
280 |
base_model,
|
281 |
config=config,
|
|
|
287 |
**model_kwargs,
|
288 |
)
|
289 |
except Exception as err: # pylint: disable=broad-exception-caught
|
290 |
+
LOG.error(
|
291 |
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
292 |
)
|
293 |
+
LOG.exception(err)
|
294 |
model = AutoModelForCausalLM.from_pretrained(
|
295 |
base_model,
|
296 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
|
309 |
and model.config.max_position_embeddings
|
310 |
and cfg.sequence_len >= model.config.max_position_embeddings
|
311 |
):
|
312 |
+
LOG.warning(
|
313 |
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
314 |
)
|
315 |
model.config.max_position_embeddings = cfg.sequence_len
|
|
|
318 |
(cfg.adapter == "lora" and load_in_8bit)
|
319 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
320 |
):
|
321 |
+
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
322 |
model = prepare_model_for_kbit_training(
|
323 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
324 |
)
|
|
|
330 |
|
331 |
if cfg.gptq:
|
332 |
# Scales to half
|
333 |
+
LOG.info("Fitting 4bit scales and zeros to half")
|
334 |
for _, module in model.named_modules():
|
335 |
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
|
336 |
type(module)
|
|
|
356 |
if param.requires_grad:
|
357 |
requires_grad.append(f"{name}: {param.requires_grad}")
|
358 |
if len(requires_grad) == 0:
|
359 |
+
LOG.warning("there are no parameters that require gradient updates")
|
360 |
model.config.use_cache = False
|
361 |
|
362 |
if cfg.flash_optimum:
|
|
|
390 |
)
|
391 |
|
392 |
if cfg.lora_model_dir:
|
393 |
+
LOG.info("Loading pretained LORA")
|
394 |
model = PeftModel.from_pretrained(
|
395 |
model,
|
396 |
cfg.lora_model_dir,
|
|
|
437 |
bits = 8
|
438 |
|
439 |
linear_names = find_all_linear_names(bits, model)
|
440 |
+
LOG.info(f"found linear modules: {repr(linear_names)}")
|
441 |
lora_target_modules = list(set(lora_target_modules + linear_names))
|
442 |
|
443 |
lora_config = LoraConfig(
|
src/axolotl/utils/tokenization.py
CHANGED
@@ -5,6 +5,8 @@ import logging
|
|
5 |
|
6 |
from termcolor import colored
|
7 |
|
|
|
|
|
8 |
|
9 |
def check_dataset_labels(dataset, tokenizer):
|
10 |
# the dataset is already shuffled, so let's just check the first 5 elements
|
@@ -32,7 +34,7 @@ def check_example_labels(example, tokenizer):
|
|
32 |
)
|
33 |
colored_tokens.append(colored_token)
|
34 |
|
35 |
-
|
36 |
-
|
37 |
|
38 |
return " ".join(colored_tokens)
|
|
|
5 |
|
6 |
from termcolor import colored
|
7 |
|
8 |
+
LOG = logging.getLogger("axolotl")
|
9 |
+
|
10 |
|
11 |
def check_dataset_labels(dataset, tokenizer):
|
12 |
# the dataset is already shuffled, so let's just check the first 5 elements
|
|
|
34 |
)
|
35 |
colored_tokens.append(colored_token)
|
36 |
|
37 |
+
LOG.info(" ".join(colored_tokens))
|
38 |
+
LOG.info("\n\n\n")
|
39 |
|
40 |
return " ".join(colored_tokens)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -26,6 +26,8 @@ from axolotl.utils.schedulers import (
|
|
26 |
get_cosine_schedule_with_quadratic_warmup,
|
27 |
)
|
28 |
|
|
|
|
|
29 |
|
30 |
class AxolotlTrainingArguments(TrainingArguments):
|
31 |
"""
|
@@ -320,7 +322,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
320 |
|
321 |
set_model_mem_id(model, tokenizer)
|
322 |
|
323 |
-
|
324 |
|
325 |
for dataset in [train_dataset, eval_dataset]:
|
326 |
dataset = dataset.map(
|
|
|
26 |
get_cosine_schedule_with_quadratic_warmup,
|
27 |
)
|
28 |
|
29 |
+
LOG = logging.getLogger("axolotl")
|
30 |
+
|
31 |
|
32 |
class AxolotlTrainingArguments(TrainingArguments):
|
33 |
"""
|
|
|
322 |
|
323 |
set_model_mem_id(model, tokenizer)
|
324 |
|
325 |
+
LOG.info("Adding landmark attention tokens to dataset")
|
326 |
|
327 |
for dataset in [train_dataset, eval_dataset]:
|
328 |
dataset = dataset.map(
|
src/axolotl/utils/validation.py
CHANGED
@@ -4,6 +4,8 @@ import logging
|
|
4 |
|
5 |
import torch
|
6 |
|
|
|
|
|
7 |
|
8 |
def validate_config(cfg):
|
9 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
@@ -11,7 +13,7 @@ def validate_config(cfg):
|
|
11 |
"please set only one of gradient_accumulation_steps or batch_size"
|
12 |
)
|
13 |
if cfg.batch_size:
|
14 |
-
|
15 |
"%s\n%s",
|
16 |
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
17 |
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
@@ -44,10 +46,10 @@ def validate_config(cfg):
|
|
44 |
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
45 |
|
46 |
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
47 |
-
|
48 |
|
49 |
if cfg.trust_remote_code:
|
50 |
-
|
51 |
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
52 |
)
|
53 |
|
@@ -66,31 +68,29 @@ def validate_config(cfg):
|
|
66 |
|
67 |
if cfg.flash_optimum is True:
|
68 |
if cfg.adapter:
|
69 |
-
|
70 |
-
"BetterTransformers probably doesn't work with PEFT adapters"
|
71 |
-
)
|
72 |
if cfg.fp16 or cfg.bf16:
|
73 |
raise ValueError("AMP is not supported with BetterTransformer")
|
74 |
if cfg.float16 is not True and cfg.bloat16 is not True:
|
75 |
-
|
76 |
"You should probably set bfloat16 or float16 to true to "
|
77 |
"load the model in float16 for BetterTransformers"
|
78 |
)
|
79 |
if int(torch.__version__.split(".")[0]) < 2:
|
80 |
-
|
81 |
raise ValueError(
|
82 |
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
83 |
)
|
84 |
|
85 |
if cfg.pretraining_dataset and cfg.group_by_length:
|
86 |
-
|
87 |
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
88 |
)
|
89 |
|
90 |
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
91 |
not cfg.optimizer or "adamw" not in cfg.optimizer
|
92 |
):
|
93 |
-
|
94 |
|
95 |
if cfg.push_to_hub_model_id:
|
96 |
raise ValueError(
|
|
|
4 |
|
5 |
import torch
|
6 |
|
7 |
+
LOG = logging.getLogger("axolotl")
|
8 |
+
|
9 |
|
10 |
def validate_config(cfg):
|
11 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
|
|
13 |
"please set only one of gradient_accumulation_steps or batch_size"
|
14 |
)
|
15 |
if cfg.batch_size:
|
16 |
+
LOG.warning(
|
17 |
"%s\n%s",
|
18 |
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
19 |
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
|
|
46 |
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
47 |
|
48 |
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
49 |
+
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
50 |
|
51 |
if cfg.trust_remote_code:
|
52 |
+
LOG.warning(
|
53 |
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
54 |
)
|
55 |
|
|
|
68 |
|
69 |
if cfg.flash_optimum is True:
|
70 |
if cfg.adapter:
|
71 |
+
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
|
|
|
|
72 |
if cfg.fp16 or cfg.bf16:
|
73 |
raise ValueError("AMP is not supported with BetterTransformer")
|
74 |
if cfg.float16 is not True and cfg.bloat16 is not True:
|
75 |
+
LOG.warning(
|
76 |
"You should probably set bfloat16 or float16 to true to "
|
77 |
"load the model in float16 for BetterTransformers"
|
78 |
)
|
79 |
if int(torch.__version__.split(".")[0]) < 2:
|
80 |
+
LOG.warning("torch>=2.0.0 required")
|
81 |
raise ValueError(
|
82 |
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
83 |
)
|
84 |
|
85 |
if cfg.pretraining_dataset and cfg.group_by_length:
|
86 |
+
LOG.warning(
|
87 |
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
88 |
)
|
89 |
|
90 |
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
91 |
not cfg.optimizer or "adamw" not in cfg.optimizer
|
92 |
):
|
93 |
+
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
94 |
|
95 |
if cfg.push_to_hub_model_id:
|
96 |
raise ValueError(
|
tests/test_prompt_tokenizers.py
CHANGED
@@ -16,8 +16,11 @@ from axolotl.prompt_tokenizers import (
|
|
16 |
ShareGPTPromptTokenizingStrategy,
|
17 |
)
|
18 |
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
|
|
19 |
|
20 |
-
|
|
|
|
|
21 |
|
22 |
|
23 |
class TestPromptTokenizationStrategies(unittest.TestCase):
|
|
|
16 |
ShareGPTPromptTokenizingStrategy,
|
17 |
)
|
18 |
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
19 |
+
from axolotl.logging_config import configure_logging
|
20 |
|
21 |
+
configure_logging()
|
22 |
+
|
23 |
+
LOG = logging.getLogger("axolotl")
|
24 |
|
25 |
|
26 |
class TestPromptTokenizationStrategies(unittest.TestCase):
|