CMLL commited on
Commit
bbe7b0a
·
verified ·
1 Parent(s): a5064fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -49
app.py CHANGED
@@ -1,64 +1,132 @@
1
- import spaces # Import spaces at the top
2
  import gradio as gr
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Import the GPU decorator
7
- from spaces import GPU
 
 
 
 
8
 
9
- # Set the device to use GPU
10
- device = "cuda" # Use CUDA for GPU
11
 
12
- # Initialize model and tokenizer
13
- peft_model_id = "CMLM/ZhongJing-2-1_8b"
14
  base_model_id = "Qwen/Qwen1.5-1.8B-Chat"
15
- model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map={"cuda": 0})
16
- model.load_adapter(peft_model_id)
17
- tokenizer = AutoTokenizer.from_pretrained(
18
- "CMLM/ZhongJing-2-1_8b",
19
- padding_side="right",
20
- trust_remote_code=True,
21
- pad_token=''
22
- )
23
 
24
- @GPU(duration=120) # Decorate with GPU usage and specify the duration
25
- def get_model_response(question):
26
- # Create the prompt without context
27
- prompt = f"Question: {question}"
28
- messages = [
29
- {"role": "system", "content": "You are a helpful TCM medical assistant named 仲景中医大语言模型, created by 医哲未来 of Fudan University."},
30
- {"role": "user", "content": prompt}
31
- ]
 
 
 
 
 
 
 
 
 
32
 
33
- # Prepare the input
34
- text = tokenizer.apply_chat_template(
35
- messages,
36
- tokenize=False,
37
- add_generation_prompt=True
38
- )
39
- model_inputs = tokenizer([text], return_tensors="pt").to(device)
40
 
41
- # Generate the response
42
  generated_ids = model.generate(
43
- model_inputs.input_ids,
44
- max_new_tokens=512
 
 
 
 
 
 
45
  )
46
- generated_ids = [
47
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
48
- ]
49
 
50
- # Decode the response
51
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
52
- return response
53
-
54
- iface = gr.Interface(
55
- fn=get_model_response, # Directly use the decorated function
56
- inputs=["text"],
57
- outputs="text",
58
- title="仲景GPT-V2-1.8B",
59
- description="博极医源,精勤不倦。Unlocking the Wisdom of Traditional Chinese Medicine with AI."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
 
62
- # Launch the interface with sharing enabled
63
- iface.launch(share=True)
 
 
 
 
 
 
64
 
 
1
+ import os
2
  import gradio as gr
 
3
  import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from threading import Thread
6
+ from typing import Iterator
7
+
8
+ # Constants
9
+ MAX_MAX_NEW_TOKENS = 2048
10
+ DEFAULT_MAX_NEW_TOKENS = 1024
11
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
12
+
13
+ DESCRIPTION = """\
14
+ # Llama-2 7B Chat
15
+ This Space demonstrates model [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, a Llama 2 model with 7B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
16
+ 🔎 For more details about the Llama 2 family of models and how to use them with `transformers`, take a look [at our blog post](https://huggingface.co/blog/llama2).
17
+ 🔨 Looking for an even more powerful model? Check out the [13B version](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat) or the large [70B model demo](https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI).
18
+ """
19
 
20
+ LICENSE = """
21
+ <p/>
22
+ ---
23
+ As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
24
+ this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
25
+ """
26
 
27
+ # Set the device
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
+ # Model loading with the replacement setup
 
31
  base_model_id = "Qwen/Qwen1.5-1.8B-Chat"
32
+ model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map="auto")
33
+ model.load_adapter("CMLM/ZhongJing-2-1_8b")
34
+ tokenizer = AutoTokenizer.from_pretrained("CMLM/ZhongJing-2-1_8b", padding_side="right", trust_remote_code=True, pad_token='')
 
 
 
 
 
35
 
36
+ @spaces.GPU
37
+ def generate(
38
+ message: str,
39
+ chat_history: list[tuple[str, str]],
40
+ system_prompt: str,
41
+ max_new_tokens: int = 1024,
42
+ temperature: float = 0.6,
43
+ top_p: float = 0.9,
44
+ top_k: int = 50,
45
+ repetition_penalty: float = 1.2,
46
+ ) -> Iterator[str]:
47
+ conversation = []
48
+ if system_prompt:
49
+ conversation.append({"role": "system", "content": system_prompt})
50
+ for user, assistant in chat_history:
51
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
52
+ conversation.append({"role": "user", "content": message})
53
 
54
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
55
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
56
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
57
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
58
+ input_ids = input_ids.to(model.device)
 
 
59
 
60
+ outputs = []
61
  generated_ids = model.generate(
62
+ input_ids,
63
+ max_new_tokens=max_new_tokens,
64
+ do_sample=True,
65
+ top_p=top_p,
66
+ top_k=top_k,
67
+ temperature=temperature,
68
+ num_beams=1,
69
+ repetition_penalty=repetition_penalty
70
  )
71
+ outputs.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
72
+ return "".join(outputs)
 
73
 
74
+ chat_interface = gr.ChatInterface(
75
+ fn=generate,
76
+ additional_inputs=[
77
+ gr.Textbox(label="System prompt", lines=6),
78
+ gr.Slider(
79
+ label="Max new tokens",
80
+ minimum=1,
81
+ maximum=MAX_MAX_NEW_TOKENS,
82
+ step=1,
83
+ value=DEFAULT_MAX_NEW_TOKENS,
84
+ ),
85
+ gr.Slider(
86
+ label="Temperature",
87
+ minimum=0.1,
88
+ maximum=4.0,
89
+ step=0.1,
90
+ value=0.6,
91
+ ),
92
+ gr.Slider(
93
+ label="Top-p (nucleus sampling)",
94
+ minimum=0.05,
95
+ maximum=1.0,
96
+ step=0.05,
97
+ value=0.9,
98
+ ),
99
+ gr.Slider(
100
+ label="Top-k",
101
+ minimum=1,
102
+ maximum=1000,
103
+ step=1,
104
+ value=50,
105
+ ),
106
+ gr.Slider(
107
+ label="Repetition penalty",
108
+ minimum=1.0,
109
+ maximum=2.0,
110
+ step=0.05,
111
+ value=1.2,
112
+ ),
113
+ ],
114
+ stop_btn=None,
115
+ examples=[
116
+ ["Hello there! How are you doing?"],
117
+ ["Can you explain briefly to me what is the Python programming language?"],
118
+ ["Explain the plot of Cinderella in a sentence."],
119
+ ["How many hours does it take a man to eat a Helicopter?"],
120
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
121
+ ],
122
  )
123
 
124
+ with gr.Blocks(css="style.css") as demo:
125
+ gr.Markdown(DESCRIPTION)
126
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
127
+ chat_interface.render()
128
+ gr.Markdown(LICENSE)
129
+
130
+ if __name__ == "__main__":
131
+ demo.queue(max_size=20).launch()
132