update handler
Browse files- handler.py +1 -5
handler.py
CHANGED
@@ -28,9 +28,7 @@ class EndpointHandler():
|
|
28 |
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
29 |
self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
30 |
self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
|
31 |
-
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model
|
32 |
-
# self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
33 |
-
# self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids
|
34 |
|
35 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
36 |
"""
|
@@ -45,6 +43,4 @@ class EndpointHandler():
|
|
45 |
# with torch.cuda.amp.autocast():
|
46 |
with torch.no_grad():
|
47 |
prediction = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)
|
48 |
-
# prediction = self.pipeline(inputs, return_timestamps=False)
|
49 |
-
prediction['text'] = prediction['text'] + '????'
|
50 |
return prediction
|
|
|
28 |
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
29 |
self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
|
30 |
self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
|
31 |
+
# self.pipeline = pipeline(task= "automatic-speech-recognition", model=self.model
|
|
|
|
|
32 |
|
33 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
34 |
"""
|
|
|
43 |
# with torch.cuda.amp.autocast():
|
44 |
with torch.no_grad():
|
45 |
prediction = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)
|
|
|
|
|
46 |
return prediction
|