kaushikbar commited on
Commit
f5b762d
·
1 Parent(s): 31d359c

added explain

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -15,12 +15,12 @@ models = {'en': 'joeddav/xlm-roberta-large-xnli', #'Narsil/deberta-large-mnli-ze
15
  #'tr': 'vicgalle/xlm-roberta-large-xnli-anli', # Turkish
16
  'no': 'NbAiLab/nb-bert-base-mnli'} # Norsk
17
 
18
- hypothesis_templates = {'en': 'This example is {}.', # English
19
- 'de': 'Dieses beispiel ist {}.', # German
20
- 'es': 'Este ejemplo es {}.', # Spanish
21
- 'it': 'Questo esempio è {}.', # Italian
22
- 'ru': 'Этот пример {}.', # Russian
23
- 'tr': 'Bu örnek {}.', # Turkish
24
  'no': 'Dette eksempelet er {}.'} # Norsk
25
 
26
  classifiers = {'en': pipeline("zero-shot-classification", hypothesis_template=hypothesis_templates['en'],
@@ -170,7 +170,7 @@ def sequence_to_classify(sequence, labels):
170
  puncts = list(string.punctuation)
171
 
172
  model_expl = ZeroShotClassificationExplainer(classifier.model, classifier.tokenizer)
173
- response_expl = model_expl(sequence, label_clean, hypothesis_template="This example is {}.")
174
 
175
  if len(predicted_labels) == 1:
176
  response_expl = response_expl[model_expl.predicted_label]
 
15
  #'tr': 'vicgalle/xlm-roberta-large-xnli-anli', # Turkish
16
  'no': 'NbAiLab/nb-bert-base-mnli'} # Norsk
17
 
18
+ hypothesis_templates = {'en': 'This passage talks about {}.', # English
19
+ #'de': 'Dieses beispiel ist {}.', # German
20
+ #'es': 'Este ejemplo es {}.', # Spanish
21
+ #'it': 'Questo esempio è {}.', # Italian
22
+ #'ru': 'Этот пример {}.', # Russian
23
+ #'tr': 'Bu örnek {}.', # Turkish
24
  'no': 'Dette eksempelet er {}.'} # Norsk
25
 
26
  classifiers = {'en': pipeline("zero-shot-classification", hypothesis_template=hypothesis_templates['en'],
 
170
  puncts = list(string.punctuation)
171
 
172
  model_expl = ZeroShotClassificationExplainer(classifier.model, classifier.tokenizer)
173
+ response_expl = model_expl(sequence, label_clean, hypothesis_template="This passage talks about {}.")
174
 
175
  if len(predicted_labels) == 1:
176
  response_expl = response_expl[model_expl.predicted_label]