Spaces:
Running
Running
import os | |
import json | |
import logging | |
import gradio as gr | |
from backend.inference import section_infer, cwe_infer, PREDEF_MODEL_MAP, LOCAL_MODEL_PEFT_MAP, PREDEF_CWE_MODEL | |
APP_TITLE = "PATCHOULI" | |
STYLE_APP_TITLE = '<div style="text-align: center; font-weight: bold; font-family: Arial, sans-serif; font-size: 44px;">' + \ | |
'<span style="color: #14e166">PATCH</span> ' + \ | |
'<span style="color: #14e166">O</span>bserving ' + \ | |
'and ' + \ | |
'<span style="color: #14e166">U</span>ntang<span style="color: #14e166">l</span>ing ' + \ | |
'Eng<span style="color: #14e166">i</span>ne' + \ | |
'</div>' | |
# from 0.00 to 1.00, 41 colors | |
NONVUL_GRADIENT_COLORS = ["#d3f8d6", | |
"#d3f8d6", "#d0f8d3", "#ccf7d0", "#c9f7cd", "#c6f6cb", "#c2f6c8", "#bff5c5", "#bcf5c2", "#b8f4bf", "#b5f4bc", | |
"#b1f3ba", "#aef2b7", "#aaf2b4", "#a7f1b1", "#a3f1ae", "#9ff0ab", "#9cf0a9", "#98efa6", "#94efa3", "#90eea0", | |
"#8ced9d", "#88ed9a", "#84ec98", "#80ec95", "#7ceb92", "#78ea8f", "#73ea8c", "#6fe989", "#6ae886", "#65e883", | |
"#60e781", "#5ae67e", "#55e67b", "#4fe578", "#48e475", "#41e472", "#39e36f", "#30e26c", "#25e269", "#14e166" | |
] | |
# from 0.00 to 1.00, 41 colors | |
VUL_GRADIENT_COLORS = ["#d3f8d6", | |
"#fdcfc9", "#fdccc5", "#fcc9c2", "#fcc5bf", "#fcc2bb", "#fbbfb8", "#fbbcb4", "#fab9b1", "#fab5ad", "#f9b2aa", | |
"#f8afa7", "#f8aca3", "#f7a8a0", "#f7a59c", "#f6a299", "#f59f96", "#f59c92", "#f4988f", "#f3958c", "#f29288", | |
"#f18e85", "#f18b82", "#f0887f", "#ef847c", "#ee8178", "#ed7e75", "#ec7a72", "#eb776f", "#ea736c", "#e97068", | |
"#e86c65", "#e76962", "#e6655f", "#e5615c", "#e45e59", "#e35a56", "#e25653", "#e05250", "#df4e4d", "#de4a4a" | |
] | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logging.getLogger("httpx").setLevel(logging.WARNING) | |
def generate_color_map(): | |
color_map = {} | |
for i in range(0, 101): | |
color_map[f"non-vul-fixing: {i/100:0.2f}"] = NONVUL_GRADIENT_COLORS[int(i * 0.4)] | |
color_map[f"vul-fixing: {i/100:0.2f}"] = VUL_GRADIENT_COLORS[int(i * 0.4)] | |
return color_map | |
def on_submit(diff_code, patch_message, cwe_model, section_model_type, progress = gr.Progress(track_tqdm=True), *model_config): | |
if diff_code == "": | |
return gr.skip(), gr.skip(), gr.skip() | |
try: | |
section_results = section_infer(diff_code, patch_message, section_model_type, *model_config) | |
except Exception as e: | |
raise gr.Error(f"Error: {str(e)}") | |
vul_cnt = 0 | |
for file_results in section_results.values(): | |
for item in file_results: | |
if item["predict"] == 1: | |
vul_cnt += 1 | |
label_text = f"Vul-fixing patch" if vul_cnt > 0 \ | |
else f"Non-vul-fixing patch" | |
color = "#de4a4a" if vul_cnt > 0 else "#14e166" | |
patch_category_label = gr.Label(value = label_text, color = color) | |
if cwe_model == "": | |
cwe_cls_result = "No model selected" | |
elif vul_cnt == 0: | |
cwe_cls_result = "No vulnerability found" | |
else: | |
cwe_cls_result = cwe_infer(diff_code, patch_message, cwe_model) | |
return patch_category_label, section_results, cwe_cls_result | |
with gr.Blocks(title = APP_TITLE, fill_width=True) as demo: | |
section_results_state = gr.State({}) | |
cls_results_state = gr.State({}) | |
title = gr.HTML(STYLE_APP_TITLE) | |
with gr.Row() as main_block: | |
with gr.Column(scale=1) as input_block: | |
diff_codebox = gr.Code(label="Input git diff here", max_lines=10) | |
with gr.Accordion("Patch message (optional)", open=False): | |
message_textbox = gr.Textbox(label="Patch message", placeholder="Enter patch message here", container=False, lines=2, max_lines=5) | |
cwe_model_selector = gr.Dropdown(PREDEF_CWE_MODEL, label="Select vulnerability type classifier", allow_custom_value=True) | |
with gr.Tabs(selected=0) as model_type_tabs: | |
MODEL_TYPE_NAMES = list(PREDEF_MODEL_MAP.keys()) | |
with gr.Tab(MODEL_TYPE_NAMES[0]) as local_llm_tab: | |
local_model_selector = gr.Dropdown(PREDEF_MODEL_MAP[MODEL_TYPE_NAMES[0]], label="Select model", allow_custom_value=True) | |
local_peft_selector = gr.Dropdown(LOCAL_MODEL_PEFT_MAP[local_model_selector.value], label="Select PEFT model (optional)", allow_custom_value=True) | |
local_submit_btn = gr.Button("Run", variant="primary") | |
with gr.Tab(MODEL_TYPE_NAMES[1]) as online_llm_tab: | |
online_model_selector = gr.Dropdown(PREDEF_MODEL_MAP[MODEL_TYPE_NAMES[1]], label="Select model", allow_custom_value=True) | |
online_api_url_textbox = gr.Textbox(label="API URL") | |
online_api_key_textbox = gr.Textbox(label="API Key", placeholder="We won't store your API key", value=os.getenv("ONLINE_API_KEY"), type="password") | |
online_submit_btn = gr.Button("Run", variant="primary") | |
section_model_type = gr.State(model_type_tabs.children[0].label) | |
with gr.Accordion("Load examples", open=False): | |
with open("./backend/examples.json", "r") as f: | |
examples = json.load(f) | |
gr.Button("Load example 1", size='sm').click(lambda : examples[0], outputs=[diff_codebox, message_textbox]) | |
gr.Button("Load example 2", size='sm').click(lambda : examples[1], outputs=[diff_codebox, message_textbox]) | |
gr.Button("Load example 3", size='sm').click(lambda : examples[2], outputs=[diff_codebox, message_textbox]) | |
with gr.Column(scale=2) as section_result_block: | |
def display_result(section_results): | |
if not section_results or len(section_results) == 0: | |
with gr.Tab("File tabs"): | |
gr.Markdown("No results") | |
else: | |
for file_name, file_results in section_results.items(): | |
with gr.Tab(file_name) as file_tab: | |
highlited_results = [] | |
full_color_map = generate_color_map() | |
this_color_map = {} | |
for item in file_results: | |
predict_result = {-1: 'error', 0: 'non-vul-fixing', 1: 'vul-fixing'} | |
text_label = f"{predict_result[item['predict']]}: {item['conf']:0.2f}" | |
this_color_map[text_label] = full_color_map[text_label] | |
highlited_results.append(( | |
item["section"], | |
text_label | |
)) | |
gr.HighlightedText( | |
highlited_results, | |
label="Results", | |
color_map=this_color_map | |
) | |
with gr.Column(scale=1) as result_block: | |
patch_category_label = gr.Label(value = "No results", label = "Result of the whole patch") | |
def update_vul_type_label(cls_results): | |
return gr.Label(cls_results) | |
vul_type_label = gr.Label(update_vul_type_label, label = "Possible fixed vulnerability type", inputs = [cls_results_state]) | |
def update_model_type_state(evt: gr.SelectData): | |
return evt.value | |
model_type_tabs.select(update_model_type_state, outputs = [section_model_type]) | |
def update_support_peft(base_model): | |
return gr.Dropdown(LOCAL_MODEL_PEFT_MAP[base_model], value = LOCAL_MODEL_PEFT_MAP[base_model][0]) | |
local_model_selector.change(update_support_peft, inputs=[local_model_selector], outputs = [local_peft_selector]) | |
local_submit_btn.click(fn = on_submit, | |
inputs = [diff_codebox, message_textbox, cwe_model_selector, section_model_type, local_model_selector, local_peft_selector], | |
outputs = [patch_category_label, section_results_state, cls_results_state]) | |
online_submit_btn.click(fn = on_submit, | |
inputs = [diff_codebox, message_textbox, cwe_model_selector, section_model_type, online_model_selector, online_api_url_textbox, online_api_key_textbox], | |
outputs = [patch_category_label, section_results_state, cls_results_state]) | |
demo.launch() | |