hynky HF staff commited on
Commit
2de65f4
·
1 Parent(s): da641b4
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -26,21 +26,22 @@ models = [
26
 
27
  pipelines = {model: pipeline(task="text-classification",
28
  model=f"hynky/{model.replace(' ', '_')}", tokenizer="ufal/robeczech-base",
29
- return_all_scores=True
 
30
  ) for model in models}
31
 
32
-
33
  def predict(article):
34
  article = normalize_text(article)
35
- predictions = {model: pipelines[model](article) for model in models}
36
- return predictions
 
37
 
38
  gr.Interface(
39
  predict,
40
- inputs=gr.inputs.Textbox(lines=2, placeholder="Paste a news article here..."),
41
  # multioutput of gradio text
42
- outputs=[gr.outputs.Label(num_top_classses=5)
43
-
44
  for model in models],
45
  title="News Article Classifier",
46
  ).launch()
 
26
 
27
  pipelines = {model: pipeline(task="text-classification",
28
  model=f"hynky/{model.replace(' ', '_')}", tokenizer="ufal/robeczech-base",
29
+ truncation=True, max_length=512,
30
+ top_k=5
31
  ) for model in models}
32
 
33
+
34
  def predict(article):
35
  article = normalize_text(article)
36
+ predictions = [pipelines[model](article)[0] for model in models]
37
+ predictions = [{pred["label"]: round(pred["score"], 3) for pred in task_preds} for task_preds in predictions]
38
+ return tuple(predictions)
39
 
40
  gr.Interface(
41
  predict,
42
+ inputs=gr.inputs.Textbox(lines=4, placeholder="Paste a news article here..."),
43
  # multioutput of gradio text
44
+ outputs=[gr.outputs.Label(num_top_classes=5, label=model)
 
45
  for model in models],
46
  title="News Article Classifier",
47
  ).launch()