manandey commited on
Commit
08caf3a
·
1 Parent(s): 83afa77

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import shutil
4
+ import requests
5
+
6
+ import gradio as gr
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+
10
+ def generate(html, entity, website_desc, datasource, year, month, title):
11
+ html_text = "html | " if html == "on" else ""
12
+ entity_text = ""
13
+ if entity != "":
14
+ ent_list = [x.strip() for x in entity.split(',')]
15
+ for ent in ent_list:
16
+ entity_text = entity_text + " |" + ent + "|"
17
+ entity_text = "entity ||| <ENTITY_CHAIN>" + entity_text + " </ENTITY_CHAIN> "
18
+ else:
19
+ entity_text = ""
20
+ website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else ""
21
+ datasource_text = "Datasource: " + datasource + " | " if datasource != "" else ""
22
+ year_text = "Year: " + year + " | " if year != "" else ""
23
+ month_text = "Month: " + month + " | " if month != "" else ""
24
+ title_text = "Title: " + title + " | " if title != "" else ""
25
+
26
+ prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text
27
+
28
+ model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step")
29
+ tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer")
30
+
31
+ inputs = tokenizer(prompt, return_tensors="pt")
32
+
33
+ outputs = model.generate(**inputs, max_new_tokens=128)
34
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)
35
+
36
+
37
+ html = gr.Radio(["on", "off"], label="html", info="turn html as on or off")
38
+ entity = gr.Textbox(placeholder="enter a list of comma separated entities or keywords", label="list of entities")
39
+ website_desc = gr.Textbox(placeholder="enter a website description", label="website description")
40
+ datasource = gr.Textbox(placeholder="enter a datasource", label="datasource")
41
+ year = gr.Textbox(placeholder="enter a year", label="year")
42
+ month = gr.Textbox(placeholder="enter a month", label="month")
43
+ title = gr.Textbox(placeholder="enter a website title", label="website title")
44
+
45
+ demo = gr.Interface(
46
+ fn=generate,
47
+ inputs=[html, entity, website_desc, datasource, year, month, title],
48
+ outputs="text",
49
+ )
50
+ demo.launch()