Spaces:
Sleeping
Sleeping
import gradio as gr | |
import joblib | |
import torch | |
import sys | |
import os | |
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
from l3prune import LLMEncoder | |
#load the model | |
best_clf = joblib.load("./saved/classifier_llama32.joblib") | |
encoder = LLMEncoder.from_pretrained( | |
"./saved/pruned_encoder_llama32", | |
device_map="cpu", | |
torch_dtype=torch.bfloat16, | |
#torch_dtype=torch, | |
#cache_dir=cache_dir | |
) | |
def classify_prompt(prompt): | |
#response = client.text_classification(prompt) | |
#label = response[0]['label'] | |
#score = response[0]['score'] | |
#if label == 'hate': | |
# result = f"Harmful (Confidence: {score:.2%})" | |
#else: | |
# result = f"Benign (Confidence: {score:.2%})" | |
X = encoder.encode([prompt]) | |
result = best_clf.predict(X)[0] | |
return "Harmful" if result else "Benign" | |
demo = gr.Interface( | |
fn=classify_prompt, | |
inputs=gr.Textbox(lines=3, placeholder="Enter a prompt to classify..."), | |
outputs=gr.Textbox(label="Classification Result"), | |
title="Harmful Prompt Classifier", | |
description="This app classifies whether a given prompt is potentially harmful or benign.", | |
show_api=False, | |
show_response_timing=True | |
) | |
if __name__ == "__main__": | |
demo.launch() |