File size: 4,617 Bytes
3b3bb6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80e788d
 
3b3bb6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
from typing import Any, Dict, List, Optional

import numpy as np
import transformers

# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .ultravox_model import UltravoxModel
from .ultravox_processing import UltravoxProcessor


class UltravoxPipeline(transformers.Pipeline):
    def __init__(
        self,
        model: UltravoxModel,
        tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
        audio_processor: Optional[transformers.ProcessorMixin] = None,
        **kwargs
    ):
        if tokenizer is None:
            try:
                tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model.config._name_or_path
                )
            except:
                tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model.config.text_model_id or model.config.text_config._name_or_path
                )

        if audio_processor is None:
            audio_processor = transformers.AutoProcessor.from_pretrained(
                model.config.audio_model_id or model.config.audio_config._name_or_path
            )

        super().__init__(model=model, tokenizer=tokenizer, **kwargs)
        
        self.processor = UltravoxProcessor(
            audio_processor=audio_processor,
            tokenizer=tokenizer,
            stack_factor=model.config.stack_factor,
        )

    def _sanitize_parameters(self, **kwargs):
        generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"]
        generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys}
        return {}, generation_kwargs, {}

    def preprocess(self, inputs: Dict[str, Any]):
        turns: list = inputs.get("turns", [])

        audio = inputs.get("audio", None)
        # Convert to float32 if needed.
        if isinstance(audio, np.ndarray):
            if audio.dtype == np.float64:
                audio = audio.astype(np.float32)
            elif audio.dtype == np.int16:
                audio = audio.astype(np.float32) / np.float32(32768.0)
            elif audio.dtype == np.int32:
                audio = audio.astype(np.float32) / np.float32(2147483648.0)

        if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
            prompt = inputs.get("prompt", "<|audio|>")
            if "<|audio|>" not in prompt:
                logging.warning(
                    "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
                )

                prompt += " <|audio|>"
            turns.append({"role": "user", "content": prompt})

        text = self.processor.tokenizer.apply_chat_template(
            turns, add_generation_prompt=True, tokenize=False
        )

        if "sampling_rate" not in inputs and audio is not None:
            logging.warning(
                "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
            )

        output = self.processor(
            text=text,
            audio=audio,
            sampling_rate=inputs.get("sampling_rate", 16000),
        )
        if "audio_values" in output:
            output["audio_values"] = output["audio_values"].to(self.model.dtype)

        return output

    def _forward(
        self,
        model_inputs: Dict[str, Any],
        temperature: Optional[float] = None,
        max_new_tokens: Optional[int] = None,
        repetition_penalty: float = 1.1,
    ) -> List[int]:
        temperature = temperature or None
        do_sample = temperature is not None

        terminators = [self.tokenizer.eos_token_id]
        if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
            terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))

        input_len = model_inputs["input_ids"].shape[1]

        outputs = self.model.generate(
            **model_inputs,
            do_sample=do_sample,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            eos_token_id=terminators
        )
        return outputs[0][input_len:]

    def postprocess(self, model_outputs) -> str:
        output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
        return output_text


transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
    "ultravox-pipeline",
    pipeline_class=UltravoxPipeline,
    pt_model=transformers.AutoModel,
    type="multimodal",
)