cathyi commited on
Commit
107e350
·
1 Parent(s): 33b7632

modify handler

Browse files
Files changed (1) hide show
  1. handler.py +4 -4
handler.py CHANGED
@@ -24,8 +24,8 @@ class EndpointHandler():
24
  processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
25
  feature_extractor = processor.feature_extractor
26
  self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
27
- # self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
28
- self.pipeline = pipeline(task= "automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
29
  self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
30
  self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
31
  # self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
@@ -43,7 +43,7 @@ class EndpointHandler():
43
 
44
  inputs = data.pop("inputs", data)
45
  with torch.cuda.amp.autocast():
46
- # prediction = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)["text"]
47
- prediction = self.pipeline(inputs, return_timestamps=False)
48
  prediction['text'] = prediction['text'] + '????'
49
  return prediction
 
24
  processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
25
  feature_extractor = processor.feature_extractor
26
  self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
27
+ self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
28
+ # self.pipeline = pipeline(task= "automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
29
  self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
30
  self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
31
  # self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
 
43
 
44
  inputs = data.pop("inputs", data)
45
  with torch.cuda.amp.autocast():
46
+ prediction = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)
47
+ # prediction = self.pipeline(inputs, return_timestamps=False)
48
  prediction['text'] = prediction['text'] + '????'
49
  return prediction