param-bharat commited on
Commit
c3dbbe7
·
verified ·
1 Parent(s): 4c950b6

Upload NLIScorer

Browse files
Files changed (1) hide show
  1. pipeline.py +7 -7
pipeline.py CHANGED
@@ -363,13 +363,13 @@ class NLIScorer(Pipeline):
363
  def _sanitize_parameters(self, **kwargs):
364
  preprocess_kwargs = {}
365
  postprocess_kwargs = {}
366
- if "task_name" in kwargs:
367
- preprocess_kwargs["task_name"] = kwargs["task_name"]
368
- preprocess_kwargs["task_name"] = kwargs["task_name"]
369
  return preprocess_kwargs, {}, postprocess_kwargs
370
 
371
- def preprocess(self, inputs, task_name):
372
- TaskClass = TASK_CLASSES[task_name]
373
  task_class = TaskClass(tokenizer=self.tokenizer, **inputs)
374
  return task_class.as_model_inputs
375
 
@@ -377,8 +377,8 @@ class NLIScorer(Pipeline):
377
  outputs = self.model(**model_inputs)
378
  return outputs
379
 
380
- def postprocess(self, model_outputs, task_name):
381
- threshold = TASK_THRESHOLDS[task_name]
382
  pos_scores = model_outputs["logits"].softmax(-1)[0][1]
383
  best_class = int(pos_scores > threshold)
384
  if best_class == 1:
 
363
  def _sanitize_parameters(self, **kwargs):
364
  preprocess_kwargs = {}
365
  postprocess_kwargs = {}
366
+ if "task_type" in kwargs:
367
+ preprocess_kwargs["task_type"] = kwargs["task_type"]
368
+ preprocess_kwargs["task_type"] = kwargs["task_type"]
369
  return preprocess_kwargs, {}, postprocess_kwargs
370
 
371
+ def preprocess(self, inputs, task_type):
372
+ TaskClass = TASK_CLASSES[task_type]
373
  task_class = TaskClass(tokenizer=self.tokenizer, **inputs)
374
  return task_class.as_model_inputs
375
 
 
377
  outputs = self.model(**model_inputs)
378
  return outputs
379
 
380
+ def postprocess(self, model_outputs, task_type):
381
+ threshold = TASK_THRESHOLDS[task_type]
382
  pos_scores = model_outputs["logits"].softmax(-1)[0][1]
383
  best_class = int(pos_scores > threshold)
384
  if best_class == 1: