fix tokenizer loading, got openllama 3b working
Browse files
examples/{lora-alpaca-7b → lora-openllama-3b}/config.yml
RENAMED
@@ -1,5 +1,5 @@
|
|
1 |
-
base_model:
|
2 |
-
base_model_config:
|
3 |
model_type: LlamaForCausalLM
|
4 |
tokenizer_type: LlamaTokenizer
|
5 |
load_in_8bit: true
|
@@ -32,9 +32,9 @@ wandb_watch:
|
|
32 |
wandb_run_id:
|
33 |
wandb_log_model:
|
34 |
output_dir: ./lora-out
|
35 |
-
batch_size:
|
36 |
-
micro_batch_size:
|
37 |
-
num_epochs:
|
38 |
optimizer: adamw_bnb_8bit
|
39 |
torchdistx_path:
|
40 |
lr_scheduler: cosine
|
|
|
1 |
+
base_model: openlm-research/open_llama_3b_600bt_preview
|
2 |
+
base_model_config: openlm-research/open_llama_3b_600bt_preview
|
3 |
model_type: LlamaForCausalLM
|
4 |
tokenizer_type: LlamaTokenizer
|
5 |
load_in_8bit: true
|
|
|
32 |
wandb_run_id:
|
33 |
wandb_log_model:
|
34 |
output_dir: ./lora-out
|
35 |
+
batch_size: 16
|
36 |
+
micro_batch_size: 4
|
37 |
+
num_epochs: 3
|
38 |
optimizer: adamw_bnb_8bit
|
39 |
torchdistx_path:
|
40 |
lr_scheduler: cosine
|
src/axolotl/utils/models.py
CHANGED
@@ -211,12 +211,12 @@ def load_model(
|
|
211 |
try:
|
212 |
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
213 |
tokenizer = LlamaTokenizer.from_pretrained(
|
214 |
-
|
215 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
216 |
)
|
217 |
else:
|
218 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
219 |
-
|
220 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
221 |
)
|
222 |
except:
|
|
|
211 |
try:
|
212 |
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
213 |
tokenizer = LlamaTokenizer.from_pretrained(
|
214 |
+
base_model_config,
|
215 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
216 |
)
|
217 |
else:
|
218 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
219 |
+
base_model_config,
|
220 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
221 |
)
|
222 |
except:
|