import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftModel
import gradio as gr
from threading import Thread
# 从环境变量中获取 Hugging Face 模型信息
HF_TOKEN = os.environ.get("HF_TOKEN", None)
BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct" # 基础模型
LORA_MODEL_PATH = "QLWD/test-7b" # LoRA 模型路径
# 定义界面标题和描述
TITLE = "
漏洞检测 微调模型测试
"
DESCRIPTION = f"""
测试基础模型 + LoRA 补丁的生成效果。
"""
PLACEHOLDER = """
请输入您要分析的代码...
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
device = "cuda"
# 加载tokenizer和基础模型
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_ID,
use_fast=False,
trust_remote_code=True,
force_download=True
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
ignore_mismatched_sizes=True,
force_download=True
)
# 加载 LoRA 微调权重
model = PeftModel.from_pretrained(
base_model,
LORA_MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto"
)
def format_chat(system_prompt, history, message):
formatted_chat = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
for prompt, answer in history:
formatted_chat += f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{answer}<|im_end|>\n"
formatted_chat += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
return formatted_chat
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.3,
max_new_tokens: int = 256,
top_p: float = 1.0,
top_k: int = 20,
repetition_penalty: float = 1.2,
):
print(f'message: {message}')
print(f'history: {history}')
formatted_prompt = format_chat(system_prompt, history, message)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=5000.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=inputs.input_ids,
max_new_tokens=max_new_tokens,
do_sample=False if temperature == 0 else True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
streamer=streamer,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
if "<|endoftext|>" in buffer:
yield buffer.split("<|endoftext|>")[0]
break
yield buffer
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
SYSTEM_PROMPT = '''你是一位代码审计和漏洞修复专家,请仔细分析下面提供的代码,扫描并输出所有存在的漏洞和潜在的风险。每个漏洞或风险之间用分隔符 "--------" 隔开,报告内容左对齐。
从高危到低危的顺序来列出漏洞和风险,每个漏洞或风险的格式如下:
- **类型**:明确描述漏洞的类型或名称(如果已经有对应名称),或潜在的风险类型(如资源泄露、边界条件问题等)。
- **风险等级**:根据漏洞或风险的严重性进行评级(如高危、中危、低危)。
- **漏洞/风险描述**:以专业的角度详细解释漏洞的技术细节和成因,或描述潜在的风险。
- **影响**:说明该漏洞或风险可能对系统、数据或用户造成的具体影响。
- **修复建议**:提供修复该漏洞或风险的具体步骤或建议(不是给出修复代码,而是修复的实现方法)。
- **漏洞所在的代码段**:给出代码中存在漏洞的具体位置和代码段(如适用)。
- **修复的代码段**:给出修复漏洞的替换代码段,以便开发者使用(如适用)。
请确保扫描并**输出所有**漏洞和风险,请确保扫描并**输出所有你能够笃定和大概率存在的**漏洞和风险。
分隔符 "--------" 用于每个漏洞或风险之间。'''
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="复制此 Space 进行私有部署", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ 参数设置", open=False, render=False),
additional_inputs=[
gr.Textbox(
value=SYSTEM_PROMPT,
label="系统提示词",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.1,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=8192,
step=1,
value=8192,
label="最大生成长度",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="Top-p",
render=False,
),
gr.Slider(
minimum=1,
maximum=50,
step=1,
value=20,
label="Top-k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="重复惩罚",
render=False,
),
],
examples=None,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch(share=True)