Ceshine Lee commited on
Commit
d19c498
·
1 Parent(s): 3084017

Compare results from the two fine-tuned models

Browse files
Files changed (1) hide show
  1. app.py +65 -47
app.py CHANGED
@@ -5,64 +5,82 @@ import requests
5
  import gradio as gr
6
  from gradio import inputs, outputs
7
 
8
- ENDPOINT = (
9
- "https://api-inference.huggingface.co/models/ceshine/t5-paraphrase-quora-paws"
 
10
  )
11
 
12
 
13
- def paraphrase(source_text: str, temperature: float):
14
- if temperature > 0:
15
- params = {
16
- "do_sample": True,
17
- "temperature": temperature,
18
- "top_k": 5,
19
- "num_return_sequences": 10,
20
- "max_length": 100
21
- }
22
- else:
23
- params = {
24
- "num_beams": 10,
25
- "num_return_sequences": 10,
26
- "max_length": 100
27
- }
28
- res = requests.post(
29
- ENDPOINT,
30
- headers={"Authorization": f"Bearer {os.environ['TOKEN']}"},
31
- data=json.dumps(
32
- {
33
- "inputs": "paraphrase: " + source_text,
34
- "parameters": params,
35
  }
36
- ),
37
- )
38
- if not (res.status_code == 200):
39
- return f"Got a {res.status_code} status code from HuggingFace."
40
- results = res.json()
41
- # print(results)
42
- outputs = [
43
- x["generated_text"]
44
- for x in results
45
- if x["generated_text"].lower() != source_text.lower().strip()
46
- ][:3]
47
- text = ""
48
- for i, output in enumerate(outputs):
49
- text += f"{i+1}: {output}\n\n"
50
- return text
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
- interface = gr.Interface(
54
- fn=paraphrase,
 
55
  inputs=[
56
  inputs.Textbox(label="Source text"),
57
- inputs.Number(default=0.0, label="Temperature (0 -> disable sampling and use beam search)"),
 
 
58
  ],
59
- outputs=outputs.Textbox(label="Generated text:"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  title="T5 Sentence Paraphraser",
61
- description="A paraphrasing model trained on PAWS and Quora datasets.",
62
  examples=[
63
  ["I bought a ticket from London to New York.", 0],
64
  ["Weh Seun spends 14 hours a week doing housework.", 1.2],
65
  ],
66
- )
67
-
68
- interface.launch(enable_queue=True)
 
5
  import gradio as gr
6
  from gradio import inputs, outputs
7
 
8
+ ENDPOINTS = (
9
+ "https://api-inference.huggingface.co/models/ceshine/t5-paraphrase-quora-paws",
10
+ "https://api-inference.huggingface.co/models/ceshine/t5-paraphrase-paws-msrp-opinosis",
11
  )
12
 
13
 
14
+ def get_fn(endpoint):
15
+ def paraphrase(source_text: str, temperature: float):
16
+ if temperature > 0:
17
+ params = {
18
+ "do_sample": True,
19
+ "temperature": temperature,
20
+ "top_k": 5,
21
+ "num_return_sequences": 10,
22
+ "max_length": 100,
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  }
24
+ else:
25
+ params = {"num_beams": 10, "num_return_sequences": 10, "max_length": 100}
26
+ res = requests.post(
27
+ endpoint,
28
+ headers={"Authorization": f"Bearer {os.environ['TOKEN']}"},
29
+ data=json.dumps(
30
+ {
31
+ "inputs": "paraphrase: " + source_text,
32
+ "parameters": params,
33
+ }
34
+ ),
35
+ )
36
+ if not (res.status_code == 200):
37
+ return f"Got a {res.status_code} status code from HuggingFace."
38
+ results = res.json()
39
+ # print(results)
40
+ outputs = [
41
+ x["generated_text"]
42
+ for x in results
43
+ if x["generated_text"].lower() != source_text.lower().strip()
44
+ ][:3]
45
+ text = ""
46
+ for i, output in enumerate(outputs):
47
+ text += f"{i+1}: {output}\n\n"
48
+ return text
49
+
50
+ return paraphrase
51
 
52
 
53
+ interface_1 = gr.Interface(
54
+ fn=get_fn(ENDPOINTS[0]),
55
+ title="quora-paws",
56
  inputs=[
57
  inputs.Textbox(label="Source text"),
58
+ inputs.Number(
59
+ default=0.0, label="Temperature (0 -> disable sampling and use beam search)"
60
+ ),
61
  ],
62
+ outputs=outputs.Textbox(label="quora-paws"),
63
+ )
64
+
65
+ interface_2 = gr.Interface(
66
+ fn=get_fn(ENDPOINTS[1]),
67
+ title="paws-msrp-opinosis",
68
+ inputs=[
69
+ inputs.Textbox(label="Source text"),
70
+ inputs.Number(
71
+ default=0.0, label="Temperature (0 -> disable sampling and use beam search)"
72
+ ),
73
+ ],
74
+ outputs=outputs.Textbox(label="paws-msrp-opinosis"),
75
+ )
76
+
77
+ gr.Parallel(
78
+ interface_1,
79
+ interface_2,
80
  title="T5 Sentence Paraphraser",
81
+ description="Compare generated paraphrases from two models (`ceshine/t5-paraphrase-quora-paws` and `ceshine/t5-paraphrase-paws-msrp-opinosis`).",
82
  examples=[
83
  ["I bought a ticket from London to New York.", 0],
84
  ["Weh Seun spends 14 hours a week doing housework.", 1.2],
85
  ],
86
+ ).launch(enable_queue=True)