Spaces:
Running
on
L4
Running
on
L4
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
|
|
4 |
import os
|
5 |
|
6 |
import torch
|
|
|
7 |
from peft import PeftConfig, PeftModel
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
9 |
|
@@ -97,13 +98,34 @@ def predict(content, policy):
|
|
97 |
|
98 |
with torch.inference_mode():
|
99 |
outputs = model(input_ids)
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
predicted_token_id = torch.argmax(logits, dim=-1).item()
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
else:
|
106 |
-
return f'
|
107 |
|
108 |
with gr.Blocks() as iface:
|
109 |
gr.Markdown("# CoPE Alpha Preview")
|
@@ -127,7 +149,7 @@ with gr.Blocks() as iface:
|
|
127 |
2. Specify your policy in the "Policy" box.
|
128 |
3. Click "Submit" to see the results.
|
129 |
|
130 |
-
**Note**: Inference times are **slow** (2
|
131 |
|
132 |
## More Info
|
133 |
|
|
|
4 |
import os
|
5 |
|
6 |
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
from peft import PeftConfig, PeftModel
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
10 |
|
|
|
98 |
|
99 |
with torch.inference_mode():
|
100 |
outputs = model(input_ids)
|
101 |
+
|
102 |
+
# Get logits for the last token
|
103 |
+
logits = outputs.logits[:, -1, :]
|
104 |
+
|
105 |
+
# Apply softmax to get probabilities
|
106 |
+
probabilities = F.softmax(logits, dim=-1)
|
107 |
+
|
108 |
+
# Get the predicted token ID
|
109 |
predicted_token_id = torch.argmax(logits, dim=-1).item()
|
110 |
+
|
111 |
+
# Decode the predicted token
|
112 |
+
decoded_output = tokenizer.decode([predicted_token_id])
|
113 |
+
|
114 |
+
# Get the probability of the predicted token
|
115 |
+
predicted_prob = probabilities[0, predicted_token_id].item()
|
116 |
+
|
117 |
+
# Function to get probability for a specific token
|
118 |
+
def get_token_probability(token):
|
119 |
+
token_id = tokenizer.encode(token, add_special_tokens=False)[0]
|
120 |
+
return probabilities[0, token_id].item()
|
121 |
+
|
122 |
+
predicted_prob_0 = get_token_probability('0')
|
123 |
+
predicted_prob_1 = get_token_probability('1')
|
124 |
+
|
125 |
+
if decoded_output == '1':
|
126 |
+
return f'VIOLATING\n(P: {predicted_prob_1:.2f})'
|
127 |
else:
|
128 |
+
return f'NON-Violating\n(P: {predicted_prob_0:.2f})'
|
129 |
|
130 |
with gr.Blocks() as iface:
|
131 |
gr.Markdown("# CoPE Alpha Preview")
|
|
|
149 |
2. Specify your policy in the "Policy" box.
|
150 |
3. Click "Submit" to see the results.
|
151 |
|
152 |
+
**Note**: Inference times are **slow** (1-2 seconds) since this is built on dev infra and not yet optimized for live systems. Please be patient!
|
153 |
|
154 |
## More Info
|
155 |
|