stillerman commited on
Commit
738a057
·
unverified ·
1 Parent(s): cdc71f7

Feat: Added Gradio support (#812)

Browse files

* Added gradio support

* queuing and title

* pre-commit run

README.md CHANGED
@@ -97,6 +97,10 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
97
  # inference
98
  accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
99
  --lora_model_dir="./lora-out"
 
 
 
 
100
  ```
101
 
102
  ## Installation
@@ -919,6 +923,10 @@ Pass the appropriate flag to the train command:
919
  cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
920
  --base_model="./completed-model" --prompter=None --load_in_8bit=True
921
  ```
 
 
 
 
922
 
923
  Please use `--sample_packing False` if you have it on and receive the error similar to below:
924
 
 
97
  # inference
98
  accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
99
  --lora_model_dir="./lora-out"
100
+
101
+ # gradio
102
+ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
103
+ --lora_model_dir="./lora-out" --gradio
104
  ```
105
 
106
  ## Installation
 
923
  cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
924
  --base_model="./completed-model" --prompter=None --load_in_8bit=True
925
  ```
926
+ -- With gradio hosting
927
+ ```bash
928
+ python -m axolotl.cli.inference examples/your_config.yml --gradio
929
+ ```
930
 
931
  Please use `--sample_packing False` if you have it on and receive the error similar to below:
932
 
requirements.txt CHANGED
@@ -31,3 +31,4 @@ scikit-learn==1.2.2
31
  pynvml
32
  art
33
  fschat==0.2.29
 
 
31
  pynvml
32
  art
33
  fschat==0.2.29
34
+ gradio
src/axolotl/cli/__init__.py CHANGED
@@ -6,8 +6,10 @@ import os
6
  import random
7
  import sys
8
  from pathlib import Path
 
9
  from typing import Any, Dict, List, Optional, Union
10
 
 
11
  import torch
12
  import yaml
13
 
@@ -16,7 +18,7 @@ from accelerate.commands.config import config_args
16
  from art import text2art
17
  from huggingface_hub import HfApi
18
  from huggingface_hub.utils import LocalTokenNotFoundError
19
- from transformers import GenerationConfig, TextStreamer
20
 
21
  from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
22
  from axolotl.logging_config import configure_logging
@@ -153,6 +155,91 @@ def do_inference(
153
  print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def choose_config(path: Path):
157
  yaml_files = list(path.glob("*.yml"))
158
 
 
6
  import random
7
  import sys
8
  from pathlib import Path
9
+ from threading import Thread
10
  from typing import Any, Dict, List, Optional, Union
11
 
12
+ import gradio as gr
13
  import torch
14
  import yaml
15
 
 
18
  from art import text2art
19
  from huggingface_hub import HfApi
20
  from huggingface_hub.utils import LocalTokenNotFoundError
21
+ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
22
 
23
  from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
24
  from axolotl.logging_config import configure_logging
 
155
  print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
156
 
157
 
158
+ def do_inference_gradio(
159
+ *,
160
+ cfg: DictDefault,
161
+ cli_args: TrainerCliArgs,
162
+ ):
163
+ model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
164
+ prompter = cli_args.prompter
165
+ default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
166
+
167
+ for token, symbol in default_tokens.items():
168
+ # If the token isn't already specified in the config, add it
169
+ if not (cfg.special_tokens and token in cfg.special_tokens):
170
+ tokenizer.add_special_tokens({token: symbol})
171
+
172
+ prompter_module = None
173
+ if prompter:
174
+ prompter_module = getattr(
175
+ importlib.import_module("axolotl.prompters"), prompter
176
+ )
177
+
178
+ if cfg.landmark_attention:
179
+ from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
180
+
181
+ set_model_mem_id(model, tokenizer)
182
+ model.set_mem_cache_args(
183
+ max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
184
+ )
185
+
186
+ model = model.to(cfg.device)
187
+
188
+ def generate(instruction):
189
+ if not instruction:
190
+ return
191
+ if prompter_module:
192
+ # pylint: disable=stop-iteration-return
193
+ prompt: str = next(
194
+ prompter_module().build_prompt(instruction=instruction.strip("\n"))
195
+ )
196
+ else:
197
+ prompt = instruction.strip()
198
+ batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
199
+
200
+ model.eval()
201
+ with torch.no_grad():
202
+ generation_config = GenerationConfig(
203
+ repetition_penalty=1.1,
204
+ max_new_tokens=1024,
205
+ temperature=0.9,
206
+ top_p=0.95,
207
+ top_k=40,
208
+ bos_token_id=tokenizer.bos_token_id,
209
+ eos_token_id=tokenizer.eos_token_id,
210
+ pad_token_id=tokenizer.pad_token_id,
211
+ do_sample=True,
212
+ use_cache=True,
213
+ return_dict_in_generate=True,
214
+ output_attentions=False,
215
+ output_hidden_states=False,
216
+ output_scores=False,
217
+ )
218
+ streamer = TextIteratorStreamer(tokenizer)
219
+ generation_kwargs = {
220
+ "inputs": batch["input_ids"].to(cfg.device),
221
+ "generation_config": generation_config,
222
+ "streamer": streamer,
223
+ }
224
+
225
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
226
+ thread.start()
227
+
228
+ all_text = ""
229
+
230
+ for new_text in streamer:
231
+ all_text += new_text
232
+ yield all_text
233
+
234
+ demo = gr.Interface(
235
+ fn=generate,
236
+ inputs="textbox",
237
+ outputs="text",
238
+ title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
239
+ )
240
+ demo.queue().launch(show_api=False, share=True)
241
+
242
+
243
  def choose_config(path: Path):
244
  yaml_files = list(path.glob("*.yml"))
245
 
src/axolotl/cli/inference.py CHANGED
@@ -6,11 +6,16 @@ from pathlib import Path
6
  import fire
7
  import transformers
8
 
9
- from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
 
 
 
 
 
10
  from axolotl.common.cli import TrainerCliArgs
11
 
12
 
13
- def do_cli(config: Path = Path("examples/"), **kwargs):
14
  # pylint: disable=duplicate-code
15
  print_axolotl_text_art()
16
  parsed_cfg = load_cfg(config, **kwargs)
@@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
21
  )
22
  parsed_cli_args.inference = True
23
 
24
- do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
 
 
 
25
 
26
 
27
  if __name__ == "__main__":
 
6
  import fire
7
  import transformers
8
 
9
+ from axolotl.cli import (
10
+ do_inference,
11
+ do_inference_gradio,
12
+ load_cfg,
13
+ print_axolotl_text_art,
14
+ )
15
  from axolotl.common.cli import TrainerCliArgs
16
 
17
 
18
+ def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
19
  # pylint: disable=duplicate-code
20
  print_axolotl_text_art()
21
  parsed_cfg = load_cfg(config, **kwargs)
 
26
  )
27
  parsed_cli_args.inference = True
28
 
29
+ if gradio:
30
+ do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
31
+ else:
32
+ do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
33
 
34
 
35
  if __name__ == "__main__":