winglian commited on
Commit
215d775
·
unverified ·
2 Parent(s): 931e606 f36e227

Merge pull request #180 from Glavin001/feat/stream-inference

Browse files
Files changed (1) hide show
  1. scripts/finetune.py +12 -5
scripts/finetune.py CHANGED
@@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
12
  import fire
13
  import torch
14
  import yaml
15
- from transformers import GenerationConfig
16
 
17
  from axolotl.utils.data import load_prepare_datasets
18
  from axolotl.utils.dict import DictDefault
@@ -64,13 +64,17 @@ def get_multi_line_input() -> Optional[str]:
64
 
65
 
66
  def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
67
- tokenizer.add_special_tokens({"unk_token": "<unk>"})
68
- tokenizer.add_special_tokens({"bos_token": "<s>"})
69
- tokenizer.add_special_tokens({"eos_token": "</s>"})
 
 
 
70
 
71
  prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
72
 
73
  while True:
 
74
  # support for multiline inputs
75
  instruction = get_multi_line_input()
76
  if not instruction:
@@ -79,7 +83,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
79
  prompter_module().build_prompt(instruction=instruction.strip("\n"))
80
  )
81
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
82
-
83
  model.eval()
84
  with torch.no_grad():
85
  generation_config = GenerationConfig(
@@ -98,10 +102,13 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
98
  output_hidden_states=False,
99
  output_scores=False,
100
  )
 
101
  generated = model.generate(
102
  inputs=batch["input_ids"].to(cfg.device),
103
  generation_config=generation_config,
 
104
  )
 
105
  print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
106
 
107
 
 
12
  import fire
13
  import torch
14
  import yaml
15
+ from transformers import GenerationConfig, TextStreamer
16
 
17
  from axolotl.utils.data import load_prepare_datasets
18
  from axolotl.utils.dict import DictDefault
 
64
 
65
 
66
  def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
67
+ default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
68
+
69
+ for token, symbol in default_tokens.items():
70
+ # If the token isn't already specified in the config, add it
71
+ if not (cfg.special_tokens and token in cfg.special_tokens):
72
+ tokenizer.add_special_tokens({token: symbol})
73
 
74
  prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
75
 
76
  while True:
77
+ print("=" * 80)
78
  # support for multiline inputs
79
  instruction = get_multi_line_input()
80
  if not instruction:
 
83
  prompter_module().build_prompt(instruction=instruction.strip("\n"))
84
  )
85
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
86
+ print("=" * 40)
87
  model.eval()
88
  with torch.no_grad():
89
  generation_config = GenerationConfig(
 
102
  output_hidden_states=False,
103
  output_scores=False,
104
  )
105
+ streamer = TextStreamer(tokenizer)
106
  generated = model.generate(
107
  inputs=batch["input_ids"].to(cfg.device),
108
  generation_config=generation_config,
109
+ streamer=streamer,
110
  )
111
+ print("=" * 40)
112
  print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
113
 
114