modify handler
Browse files- handler.py +4 -4
handler.py
CHANGED
@@ -24,8 +24,8 @@ class EndpointHandler():
|
|
24 |
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
25 |
feature_extractor = processor.feature_extractor
|
26 |
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
|
27 |
-
|
28 |
-
self.pipeline = pipeline(task= "automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
29 |
self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
30 |
self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
|
31 |
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
|
@@ -43,7 +43,7 @@ class EndpointHandler():
|
|
43 |
|
44 |
inputs = data.pop("inputs", data)
|
45 |
with torch.cuda.amp.autocast():
|
46 |
-
|
47 |
-
prediction = self.pipeline(inputs, return_timestamps=False)
|
48 |
prediction['text'] = prediction['text'] + '????'
|
49 |
return prediction
|
|
|
24 |
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
25 |
feature_extractor = processor.feature_extractor
|
26 |
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
|
27 |
+
self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
28 |
+
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
29 |
self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
30 |
self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
|
31 |
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model)
|
|
|
43 |
|
44 |
inputs = data.pop("inputs", data)
|
45 |
with torch.cuda.amp.autocast():
|
46 |
+
prediction = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)
|
47 |
+
# prediction = self.pipeline(inputs, return_timestamps=False)
|
48 |
prediction['text'] = prediction['text'] + '????'
|
49 |
return prediction
|