Hynek Kydlíček commited on
Commit
861d5e9
·
1 Parent(s): 5259ca7

freeze version

Browse files
Files changed (2) hide show
  1. app.py +16 -14
  2. requirements.txt +3 -3
app.py CHANGED
@@ -1,13 +1,19 @@
1
  from html import unescape
2
  from unicodedata import normalize
3
  import gradio as gr
4
- from transformers import pipeline
5
  import re
6
 
7
  re_multispace = re.compile(r"\s+")
 
 
 
 
 
 
8
 
9
  def normalize_text(text):
10
- if text == None:
11
  return None
12
 
13
  text = text.strip()
@@ -20,28 +26,24 @@ def normalize_text(text):
20
  return text
21
 
22
 
23
- models = [
24
- "Server", "Category", "Gender", "Day Of Week"
25
- ]
26
-
27
- pipelines = {model: pipeline(task="text-classification",
28
- model=f"hynky/{model.replace(' ', '_')}", tokenizer="ufal/robeczech-base",
29
  truncation=True, max_length=512,
30
  top_k=5
31
- ) for model in models}
32
 
33
 
34
  def predict(article):
35
  article = normalize_text(article)
36
- predictions = [pipelines[model](article)[0] for model in models]
37
  predictions = [{pred["label"]: round(pred["score"], 3) for pred in task_preds} for task_preds in predictions]
38
- return tuple(predictions)
39
 
40
  gr.Interface(
41
  predict,
42
- inputs=gr.inputs.Textbox(lines=4, placeholder="Paste a news article here..."),
43
  # multioutput of gradio text
44
- outputs=[gr.outputs.Label(num_top_classes=5, label=model)
45
- for model in models],
46
  title="News Article Classifier",
47
  ).launch()
 
1
  from html import unescape
2
  from unicodedata import normalize
3
  import gradio as gr
4
+ from transformers import pipeline, AutoModel
5
  import re
6
 
7
  re_multispace = re.compile(r"\s+")
8
+ model_task_mapping = {
9
+ "Server": "Server",
10
+ "Category": "Category",
11
+ "Gender": "Gender",
12
+ "Day Of Week": "Day_of_week"
13
+ }
14
 
15
  def normalize_text(text):
16
+ if text is None:
17
  return None
18
 
19
  text = text.strip()
 
26
  return text
27
 
28
 
29
+ pipelines = {task: pipeline(task="text-classification",
30
+ model=f"hynky/{model}", tokenizer="ufal/robeczech-base",
 
 
 
 
31
  truncation=True, max_length=512,
32
  top_k=5
33
+ ) for task, model in model_task_mapping.items()}
34
 
35
 
36
  def predict(article):
37
  article = normalize_text(article)
38
+ predictions = [pipelines[model](article)[0] for model in model_task_mapping.keys()]
39
  predictions = [{pred["label"]: round(pred["score"], 3) for pred in task_preds} for task_preds in predictions]
40
+ return predictions
41
 
42
  gr.Interface(
43
  predict,
44
+ inputs=gr.Textbox(lines=4, placeholder="Paste a news article here..."),
45
  # multioutput of gradio text
46
+ outputs=[gr.Label(num_top_classes=5, label=task)
47
+ for task in model_task_mapping.keys()],
48
  title="News Article Classifier",
49
  ).launch()
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- transformers
2
- torch
3
- gradio==3.26.0
 
1
+ transformers==0.1.1
2
+ torch==2.1.0
3
+ gradio==0.3.26