QLWD commited on
Commit
4d1724f
·
verified ·
1 Parent(s): b4223e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -26
app.py CHANGED
@@ -1,13 +1,15 @@
1
  import spaces
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
4
  import gradio as gr
 
5
  import os
6
 
7
  # 从环境变量中获取 Hugging Face 模型信息
8
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
9
  BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct" # 替换为基础模型
10
- LORA_MODEL_PATH = "QLWD/test-7b" # 替换为 LoRA 微调模型路径
11
 
12
  # 定义界面标题和描述
13
  TITLE = "<h1><center>漏洞检测 微调模型测试</center></h1>"
@@ -32,28 +34,12 @@ text-align: center;
32
  """
33
 
34
  # 加载基础模型和 LoRA 微调权重
35
- model_name = BASE_MODEL_ID
36
- lora_model_name = LORA_MODEL_PATH
37
 
38
- # 加载基础模型
39
- model = AutoModelForCausalLM.from_pretrained(
40
- model_name,
41
- torch_dtype=torch.bfloat16, # 使用 bfloat16 提高性能
42
- device_map="auto", # 自动分配设备
43
- use_auth_token=HF_TOKEN
44
- )
45
-
46
- # 加载微调权重
47
- model = AutoModelForCausalLM.from_pretrained(
48
- lora_model_name,
49
- torch_dtype=torch.bfloat16, # 同样使用 bfloat16 提高性能
50
- device_map="auto", # 自动分配设备
51
- use_auth_token=HF_TOKEN,
52
- trust_remote_code=True # 如果远程代码需要自定义加载逻辑
53
- )
54
-
55
- # 加载分词器
56
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=HF_TOKEN)
57
 
58
  # 定义推理函数
59
  @spaces.GPU(duration=50)
@@ -76,15 +62,16 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
76
 
77
  # 将历史对话内容添加到会话中
78
  for prompt, answer in history:
79
- conversation.extend([{"role": "user", "content": prompt}, {"role": "漏洞zhushou", "content": answer}])
80
 
81
  # 添加当前用户的输入到对话中
82
  conversation.append({"role": "user", "content": message})
83
 
84
  # 使用自定义对话模板生成 input_ids
85
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
 
86
  inputs = tokenizer(input_ids, return_tensors="pt").to("cuda")
87
-
88
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
89
 
90
  # 设置生成参数
@@ -100,9 +87,12 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
100
  eos_token_id=[151645, 151643],
101
  )
102
 
103
- # 流式生成输出
 
 
 
104
  buffer = ""
105
- for new_text in model.generate(**generate_kwargs):
106
  buffer += new_text
107
  yield buffer
108
 
@@ -132,3 +122,4 @@ with gr.Blocks(css=CSS) as demo:
132
  # 启动 Gradio 应用
133
  if __name__ == "__main__":
134
  demo.launch()
 
 
1
  import spaces
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ from peft import PeftModel
5
  import gradio as gr
6
+ from threading import Thread
7
  import os
8
 
9
  # 从环境变量中获取 Hugging Face 模型信息
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct" # 替换为基础模型
12
+ LORA_MODEL_PATH = "QLWD/test-7b" # 替换为 LoRA 模型仓库路径
13
 
14
  # 定义界面标题和描述
15
  TITLE = "<h1><center>漏洞检测 微调模型测试</center></h1>"
 
34
  """
35
 
36
  # 加载基础模型和 LoRA 微调权重
37
+ base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.float16, device_map="auto", use_auth_token=HF_TOKEN)
38
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_auth_token=HF_TOKEN)
39
 
40
+ # 加载 LoRA 微调权重
41
+ model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH, use_auth_token=HF_TOKEN)
42
+ model = model.to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # 定义推理函数
45
  @spaces.GPU(duration=50)
 
62
 
63
  # 将历史对话内容添加到会话中
64
  for prompt, answer in history:
65
+ conversation.extend([{"role": "user", "content": prompt}, {"role": "漏洞助手", "content": answer}])
66
 
67
  # 添加当前用户的输入到对话中
68
  conversation.append({"role": "user", "content": message})
69
 
70
  # 使用自定义对话模板生成 input_ids
71
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
72
+ print("123")
73
  inputs = tokenizer(input_ids, return_tensors="pt").to("cuda")
74
+ print("321")
75
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
76
 
77
  # 设置生成参数
 
87
  eos_token_id=[151645, 151643],
88
  )
89
 
90
+ # 启动生成线程
91
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
92
+ thread.start()
93
+
94
  buffer = ""
95
+ for new_text in streamer:
96
  buffer += new_text
97
  yield buffer
98
 
 
122
  # 启动 Gradio 应用
123
  if __name__ == "__main__":
124
  demo.launch()
125
+