Spaces:
Running
on
L4
Running
on
L4
Update app.py
Browse files
app.py
CHANGED
@@ -3,9 +3,12 @@
|
|
3 |
import gradio as gr
|
4 |
import os
|
5 |
|
|
|
6 |
from peft import PeftConfig, PeftModel
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
|
|
|
|
|
9 |
base_model_name = "google/gemma-7b"
|
10 |
#adapter_model_name = "samidh/cope-g2b-2c-hs-skr-s1.5.9-sx-sk-s5.d25"
|
11 |
#adapter_model_name = "samidh/cope-g2b-2c-hs-skr-s1.5.9-sx-sk-s1.5.l1e4-e10-d25"
|
@@ -20,6 +23,8 @@ model = AutoModelForCausalLM.from_pretrained(base_model_name, token=os.environ['
|
|
20 |
model = PeftModel.from_pretrained(model, adapter_model_name, token=os.environ['HF_TOKEN'])
|
21 |
model.merge_and_unload()
|
22 |
|
|
|
|
|
23 |
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
24 |
|
25 |
#inputs = tokenizer.encode("This movie was really", return_tensors="pt")
|
@@ -90,14 +95,17 @@ DEFAULT_CONTENT = "LLMs steal our jobs."
|
|
90 |
# Function to make predictions
|
91 |
def predict(content, policy):
|
92 |
input_text = PROMPT.format(policy=policy, content=content)
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
98 |
else:
|
99 |
-
return f'VIOLATING ({decoded_output
|
100 |
-
|
101 |
|
102 |
with gr.Blocks() as iface:
|
103 |
gr.Markdown("# CoPE Alpha Preview")
|
|
|
3 |
import gradio as gr
|
4 |
import os
|
5 |
|
6 |
+
import torch
|
7 |
from peft import PeftConfig, PeftModel
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
9 |
|
10 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
11 |
+
|
12 |
base_model_name = "google/gemma-7b"
|
13 |
#adapter_model_name = "samidh/cope-g2b-2c-hs-skr-s1.5.9-sx-sk-s5.d25"
|
14 |
#adapter_model_name = "samidh/cope-g2b-2c-hs-skr-s1.5.9-sx-sk-s1.5.l1e4-e10-d25"
|
|
|
23 |
model = PeftModel.from_pretrained(model, adapter_model_name, token=os.environ['HF_TOKEN'])
|
24 |
model.merge_and_unload()
|
25 |
|
26 |
+
model = model.to(device)
|
27 |
+
|
28 |
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
29 |
|
30 |
#inputs = tokenizer.encode("This movie was really", return_tensors="pt")
|
|
|
95 |
# Function to make predictions
|
96 |
def predict(content, policy):
|
97 |
input_text = PROMPT.format(policy=policy, content=content)
|
98 |
+
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
|
99 |
+
|
100 |
+
with torch.no_grad():
|
101 |
+
outputs = model(input_ids)
|
102 |
+
logits = outputs.logits[:, -1, :] # Get logits for the last token
|
103 |
+
predicted_token_id = torch.argmax(logits, dim=-1).item()
|
104 |
+
decoded_output = tokenizer.decode([predicted_token_id])
|
105 |
+
if decoded_output == '0':
|
106 |
+
return f'NON-Violating ({decoded_output})'
|
107 |
else:
|
108 |
+
return f'VIOLATING ({decoded_output})'
|
|
|
109 |
|
110 |
with gr.Blocks() as iface:
|
111 |
gr.Markdown("# CoPE Alpha Preview")
|