configure log level, add llama 7b config
Browse files- scripts/finetune.py +10 -3
scripts/finetune.py
CHANGED
@@ -39,6 +39,7 @@ from axolotl.prompt_tokenizers import (
|
|
39 |
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
40 |
|
41 |
logger = logging.getLogger(__name__)
|
|
|
42 |
DEFAULT_DATASET_PREPARED_PATH = "data/last_run"
|
43 |
|
44 |
|
@@ -353,9 +354,15 @@ def train(
|
|
353 |
else:
|
354 |
datasets = []
|
355 |
for d in cfg.datasets:
|
356 |
-
|
357 |
-
|
358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
if d.type == "alpaca":
|
361 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
|
|
39 |
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
40 |
|
41 |
logger = logging.getLogger(__name__)
|
42 |
+
logger.setLevel(os.getenv("LOG_LEVEL", "INFO"))
|
43 |
DEFAULT_DATASET_PREPARED_PATH = "data/last_run"
|
44 |
|
45 |
|
|
|
354 |
else:
|
355 |
datasets = []
|
356 |
for d in cfg.datasets:
|
357 |
+
if Path(d.path).exists():
|
358 |
+
ds: IterableDataset = load_dataset(
|
359 |
+
"json", data_files=d.path, streaming=True, split=None
|
360 |
+
)
|
361 |
+
# elif d.name and d.path:
|
362 |
+
# # TODO load from huggingface hub, but it only seems to support arrow or parquet atm
|
363 |
+
# ds = load_dataset(d.path, split=None, data_files=d.name)
|
364 |
+
else:
|
365 |
+
raise Exception("unhandled dataset load")
|
366 |
|
367 |
if d.type == "alpaca":
|
368 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|