File size: 4,052 Bytes
51a7d9e 232a3c7 51a7d9e 232a3c7 51a7d9e 232a3c7 51a7d9e de898ce cad536c 51a7d9e 232a3c7 68ede8f 51a7d9e bd34f0b 77a172b bd34f0b 232a3c7 bd34f0b 51a7d9e 5169fea 51a7d9e 232a3c7 77e3cd7 232a3c7 77e3cd7 232a3c7 51a7d9e 232a3c7 248fc73 bd34f0b 51a7d9e 248fc73 849fb4e 408edc2 248fc73 51a7d9e 1e84b32 248fc73 51a7d9e 232a3c7 bd34f0b 232a3c7 639e063 232a3c7 edb9e8a 232a3c7 edb9e8a bd34f0b 232a3c7 51a7d9e 232a3c7 51a7d9e 232a3c7 edb9e8a 51a7d9e edb9e8a 51a7d9e 232a3c7 51a7d9e 232a3c7 51a7d9e 232a3c7 51a7d9e 232a3c7 51a7d9e 232a3c7 51a7d9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftModel
import gradio as gr
from threading import Thread
import spaces
import os
# 从环境变量中获取 Hugging Face 模型信息
HF_TOKEN = os.environ.get("HF_TOKEN", None)
BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct" # 替换为基础模型
LORA_MODEL_PATH = "QLWD/test-7b" # 替换为 LoRA 模型仓库路径
# 定义界面标题和描述
TITLE = "<h1><center>漏洞检测 微调模型测试</center></h1>"
DESCRIPTION = f"""
<h3>模型: <a href="https://huggingface.co/{LORA_MODEL_PATH}">漏洞检测 微调模型</a></h3>
<center>
<p>测试基础模型 + LoRA 补丁的生成效果。</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
# 加载基础模型和 LoRA 微调权重
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.float16, device_map="auto", use_auth_token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_auth_token=HF_TOKEN)
# 加载 LoRA 微调权重
model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH, use_auth_token=HF_TOKEN)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
# 定义推理函数
@spaces.GPU(duration=50)
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
conversation = []
# 添加系统提示,定义模型的角色
conversation.append({"role": "system", "content": '''你是一位二进制pwn的分析助手,用户会给你pwn的静态分析结果,给出其对应漏洞已经利用的exp
'''})
# 将历史对话内容添加到会话中
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "漏洞zhushou", "content": answer}])
# 添加当前用户的输入到对话中
conversation.append({"role": "user", "content": message})
# 使用自定义对话模板生成 input_ids
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
# 设置生成参数
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=[151645, 151643],
)
# 启动生成线程
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
# 定义 Gradio 界面
chatbot = gr.Chatbot(height=450)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ 参数设置", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=4096, step=1, value=1024, label="Max new tokens", render=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.8, label="top_p", render=False),
gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k", render=False),
gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Repetition penalty", render=False),
],
cache_examples=False,
)
# 启动 Gradio 应用
if __name__ == "__main__":
demo.launch()
|