samidh commited on
Commit
7b371cc
·
verified ·
1 Parent(s): 8f9126b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -6
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  import os
5
 
6
  import torch
 
7
  from peft import PeftConfig, PeftModel
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9
 
@@ -97,13 +98,34 @@ def predict(content, policy):
97
 
98
  with torch.inference_mode():
99
  outputs = model(input_ids)
100
- logits = outputs.logits[:, -1, :] # Get logits for the last token
 
 
 
 
 
 
 
101
  predicted_token_id = torch.argmax(logits, dim=-1).item()
102
- decoded_output = tokenizer.decode([predicted_token_id])
103
- if decoded_output == '0':
104
- return f'NON-Violating ({decoded_output})'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  else:
106
- return f'VIOLATING ({decoded_output})'
107
 
108
  with gr.Blocks() as iface:
109
  gr.Markdown("# CoPE Alpha Preview")
@@ -127,7 +149,7 @@ with gr.Blocks() as iface:
127
  2. Specify your policy in the "Policy" box.
128
  3. Click "Submit" to see the results.
129
 
130
- **Note**: Inference times are **slow** (2-5 seconds) since this is built on dev infra and not yet optimized for live systems. Please be patient!
131
 
132
  ## More Info
133
 
 
4
  import os
5
 
6
  import torch
7
+ import torch.nn.functional as F
8
  from peft import PeftConfig, PeftModel
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
10
 
 
98
 
99
  with torch.inference_mode():
100
  outputs = model(input_ids)
101
+
102
+ # Get logits for the last token
103
+ logits = outputs.logits[:, -1, :]
104
+
105
+ # Apply softmax to get probabilities
106
+ probabilities = F.softmax(logits, dim=-1)
107
+
108
+ # Get the predicted token ID
109
  predicted_token_id = torch.argmax(logits, dim=-1).item()
110
+
111
+ # Decode the predicted token
112
+ decoded_output = tokenizer.decode([predicted_token_id])
113
+
114
+ # Get the probability of the predicted token
115
+ predicted_prob = probabilities[0, predicted_token_id].item()
116
+
117
+ # Function to get probability for a specific token
118
+ def get_token_probability(token):
119
+ token_id = tokenizer.encode(token, add_special_tokens=False)[0]
120
+ return probabilities[0, token_id].item()
121
+
122
+ predicted_prob_0 = get_token_probability('0')
123
+ predicted_prob_1 = get_token_probability('1')
124
+
125
+ if decoded_output == '1':
126
+ return f'VIOLATING\n(P: {predicted_prob_1:.2f})'
127
  else:
128
+ return f'NON-Violating\n(P: {predicted_prob_0:.2f})'
129
 
130
  with gr.Blocks() as iface:
131
  gr.Markdown("# CoPE Alpha Preview")
 
149
  2. Specify your policy in the "Policy" box.
150
  3. Click "Submit" to see the results.
151
 
152
+ **Note**: Inference times are **slow** (1-2 seconds) since this is built on dev infra and not yet optimized for live systems. Please be patient!
153
 
154
  ## More Info
155