QLWD commited on
Commit
b4223e6
·
verified ·
1 Parent(s): 5026de8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -17
app.py CHANGED
@@ -1,14 +1,13 @@
1
  import spaces
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
- from peft import PeftModel
5
  import gradio as gr
6
  import os
7
 
8
  # 从环境变量中获取 Hugging Face 模型信息
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
  BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct" # 替换为基础模型
11
- LORA_MODEL_PATH = "QLWD/test-7b" # 替换为 LoRA 模型仓库路径
12
 
13
  # 定义界面标题和描述
14
  TITLE = "<h1><center>漏洞检测 微调模型测试</center></h1>"
@@ -33,14 +32,30 @@ text-align: center;
33
  """
34
 
35
  # 加载基础模型和 LoRA 微调权重
36
- base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.float16, device_map="auto", use_auth_token=HF_TOKEN)
37
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_auth_token=HF_TOKEN)
38
-
39
- # 加载 LoRA 微调权重
40
- model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH, use_auth_token=HF_TOKEN)
41
- model = model.to("cuda" if torch.cuda.is_available() else "cpu")
42
-
43
- # 定义推理函数(同步方式,无线程)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  @spaces.GPU(duration=50)
45
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
46
  conversation = []
@@ -85,16 +100,12 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
85
  eos_token_id=[151645, 151643],
86
  )
87
 
88
- # 直接调用模型生成,不使用线程
89
- model.generate(**generate_kwargs)
90
-
91
- # 收集生成的文本并逐步返回
92
  buffer = ""
93
- for new_text in streamer:
94
  buffer += new_text
95
  yield buffer
96
 
97
-
98
  # 定义 Gradio 界面
99
  chatbot = gr.Chatbot(height=450)
100
 
@@ -120,4 +131,4 @@ with gr.Blocks(css=CSS) as demo:
120
 
121
  # 启动 Gradio 应用
122
  if __name__ == "__main__":
123
- demo.launch()
 
1
  import spaces
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
4
  import gradio as gr
5
  import os
6
 
7
  # 从环境变量中获取 Hugging Face 模型信息
8
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
9
  BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct" # 替换为基础模型
10
+ LORA_MODEL_PATH = "QLWD/test-7b" # 替换为 LoRA 微调模型路径
11
 
12
  # 定义界面标题和描述
13
  TITLE = "<h1><center>漏洞检测 微调模型测试</center></h1>"
 
32
  """
33
 
34
  # 加载基础模型和 LoRA 微调权重
35
+ model_name = BASE_MODEL_ID
36
+ lora_model_name = LORA_MODEL_PATH
37
+
38
+ # 加载基础模型
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_name,
41
+ torch_dtype=torch.bfloat16, # 使用 bfloat16 提高性能
42
+ device_map="auto", # 自动分配设备
43
+ use_auth_token=HF_TOKEN
44
+ )
45
+
46
+ # 加载微调权重
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ lora_model_name,
49
+ torch_dtype=torch.bfloat16, # 同样使用 bfloat16 提高性能
50
+ device_map="auto", # 自动分配设备
51
+ use_auth_token=HF_TOKEN,
52
+ trust_remote_code=True # 如果远程代码需要自定义加载逻辑
53
+ )
54
+
55
+ # 加载分词器
56
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=HF_TOKEN)
57
+
58
+ # 定义推理函数
59
  @spaces.GPU(duration=50)
60
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
61
  conversation = []
 
100
  eos_token_id=[151645, 151643],
101
  )
102
 
103
+ # 流式生成输出
 
 
 
104
  buffer = ""
105
+ for new_text in model.generate(**generate_kwargs):
106
  buffer += new_text
107
  yield buffer
108
 
 
109
  # 定义 Gradio 界面
110
  chatbot = gr.Chatbot(height=450)
111
 
 
131
 
132
  # 启动 Gradio 应用
133
  if __name__ == "__main__":
134
+ demo.launch()