samidh commited on
Commit
d09f7dc
·
verified ·
1 Parent(s): 8e64dab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -3,9 +3,12 @@
3
  import gradio as gr
4
  import os
5
 
 
6
  from peft import PeftConfig, PeftModel
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
 
 
9
  base_model_name = "google/gemma-7b"
10
  #adapter_model_name = "samidh/cope-g2b-2c-hs-skr-s1.5.9-sx-sk-s5.d25"
11
  #adapter_model_name = "samidh/cope-g2b-2c-hs-skr-s1.5.9-sx-sk-s1.5.l1e4-e10-d25"
@@ -20,6 +23,8 @@ model = AutoModelForCausalLM.from_pretrained(base_model_name, token=os.environ['
20
  model = PeftModel.from_pretrained(model, adapter_model_name, token=os.environ['HF_TOKEN'])
21
  model.merge_and_unload()
22
 
 
 
23
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
24
 
25
  #inputs = tokenizer.encode("This movie was really", return_tensors="pt")
@@ -90,14 +95,17 @@ DEFAULT_CONTENT = "LLMs steal our jobs."
90
  # Function to make predictions
91
  def predict(content, policy):
92
  input_text = PROMPT.format(policy=policy, content=content)
93
- inputs = tokenizer.encode(input_text, return_tensors="pt")
94
- outputs = model.generate(inputs, max_new_tokens=1)
95
- decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
- if int(decoded_output[-1]) == 0:
97
- return f'NON-Violating ({decoded_output[-1]})'
 
 
 
 
98
  else:
99
- return f'VIOLATING ({decoded_output[-1]})'
100
-
101
 
102
  with gr.Blocks() as iface:
103
  gr.Markdown("# CoPE Alpha Preview")
 
3
  import gradio as gr
4
  import os
5
 
6
+ import torch
7
  from peft import PeftConfig, PeftModel
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+
12
  base_model_name = "google/gemma-7b"
13
  #adapter_model_name = "samidh/cope-g2b-2c-hs-skr-s1.5.9-sx-sk-s5.d25"
14
  #adapter_model_name = "samidh/cope-g2b-2c-hs-skr-s1.5.9-sx-sk-s1.5.l1e4-e10-d25"
 
23
  model = PeftModel.from_pretrained(model, adapter_model_name, token=os.environ['HF_TOKEN'])
24
  model.merge_and_unload()
25
 
26
+ model = model.to(device)
27
+
28
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
29
 
30
  #inputs = tokenizer.encode("This movie was really", return_tensors="pt")
 
95
  # Function to make predictions
96
  def predict(content, policy):
97
  input_text = PROMPT.format(policy=policy, content=content)
98
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
99
+
100
+ with torch.no_grad():
101
+ outputs = model(input_ids)
102
+ logits = outputs.logits[:, -1, :] # Get logits for the last token
103
+ predicted_token_id = torch.argmax(logits, dim=-1).item()
104
+ decoded_output = tokenizer.decode([predicted_token_id])
105
+ if decoded_output == '0':
106
+ return f'NON-Violating ({decoded_output})'
107
  else:
108
+ return f'VIOLATING ({decoded_output})'
 
109
 
110
  with gr.Blocks() as iface:
111
  gr.Markdown("# CoPE Alpha Preview")