winglian commited on
Commit
d653859
·
1 Parent(s): 5749eb0

improve inference

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +25 -25
  2. src/axolotl/utils/models.py +18 -15
scripts/finetune.py CHANGED
@@ -79,31 +79,31 @@ def do_inference(cfg, model, tokenizer):
79
 
80
  from axolotl.prompters import ReflectAlpacaPrompter
81
 
82
- instruction = str(input("Give me an instruction: "))
83
- instruction = (
84
- instruction if not instruction else "Tell me a joke about dromedaries."
85
- )
86
- prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
87
- batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
88
-
89
- model.eval()
90
- with torch.no_grad():
91
- # gc = GenerationConfig() # TODO swap out and use this
92
- generated = model.generate(
93
- inputs=batch["input_ids"].to("cuda"),
94
- do_sample=True,
95
- use_cache=True,
96
- repetition_penalty=1.1,
97
- max_new_tokens=100,
98
- temperature=0.9,
99
- top_p=0.95,
100
- top_k=40,
101
- return_dict_in_generate=True,
102
- output_attentions=False,
103
- output_hidden_states=False,
104
- output_scores=False,
105
- )
106
- print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
107
 
108
 
109
  def choose_config(path: Path):
 
79
 
80
  from axolotl.prompters import ReflectAlpacaPrompter
81
 
82
+ while True:
83
+ instruction = str(input("Give me an instruction: "))
84
+ if not instruction:
85
+ return
86
+ prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
87
+ batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
88
+
89
+ model.eval()
90
+ with torch.no_grad():
91
+ # gc = GenerationConfig() # TODO swap out and use this
92
+ generated = model.generate(
93
+ inputs=batch["input_ids"].to("cuda"),
94
+ do_sample=True,
95
+ use_cache=True,
96
+ repetition_penalty=1.1,
97
+ max_new_tokens=100,
98
+ temperature=0.9,
99
+ top_p=0.95,
100
+ top_k=40,
101
+ return_dict_in_generate=True,
102
+ output_attentions=False,
103
+ output_hidden_states=False,
104
+ output_scores=False,
105
+ )
106
+ print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
107
 
108
 
109
  def choose_config(path: Path):
src/axolotl/utils/models.py CHANGED
@@ -66,22 +66,25 @@ def load_model(
66
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
67
  from huggingface_hub import snapshot_download
68
 
69
- snapshot_download_kwargs = {}
70
- if cfg.base_model_ignore_patterns:
71
- snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
72
- cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs))
73
- files = (
74
- list(cache_model_path.glob("*.pt"))
75
- + list(cache_model_path.glob("*.safetensors"))
76
- + list(cache_model_path.glob("*.bin"))
77
- )
78
- if len(files) > 0:
79
- model_path = str(files[0])
80
- else:
81
- logging.warning(
82
- "unable to find a cached model file, this will likely fail..."
83
  )
84
- model_path = str(cache_model_path)
 
 
 
 
 
 
 
 
85
  model, tokenizer = load_llama_model_4bit_low_ram(
86
  base_model_config if base_model_config else base_model,
87
  model_path,
 
66
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
67
  from huggingface_hub import snapshot_download
68
 
69
+ try:
70
+ snapshot_download_kwargs = {}
71
+ if cfg.base_model_ignore_patterns:
72
+ snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
73
+ cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs))
74
+ files = (
75
+ list(cache_model_path.glob("*.pt"))
76
+ + list(cache_model_path.glob("*.safetensors"))
77
+ + list(cache_model_path.glob("*.bin"))
 
 
 
 
 
78
  )
79
+ if len(files) > 0:
80
+ model_path = str(files[0])
81
+ else:
82
+ logging.warning(
83
+ "unable to find a cached model file, this will likely fail..."
84
+ )
85
+ model_path = str(cache_model_path)
86
+ except:
87
+ model_path = cfg.base_model
88
  model, tokenizer = load_llama_model_4bit_low_ram(
89
  base_model_config if base_model_config else base_model,
90
  model_path,