File size: 1,957 Bytes
195bb33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf

class TTSModel:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_name = "ai4bharat/indic-parler-tts"
        
        # Print cache directory and model files
        print(f"Loading model on device: {self.device}")
        
        # Initialize model and tokenizers exactly as in the documentation
        self.model = ParlerTTSForConditionalGeneration.from_pretrained(self.model_name).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.description_tokenizer = AutoTokenizer.from_pretrained(self.model.config.text_encoder._name_or_path)
        
        print("Model loaded successfully")

    def generate_audio(self, text, description):
        try:
            # Tokenize exactly as shown in the documentation
            description_inputs = self.description_tokenizer(
                description, 
                return_tensors="pt"
            ).to(self.device)
            
            prompt_inputs = self.tokenizer(
                text, 
                return_tensors="pt"
            ).to(self.device)
            
            # Generate audio
            with torch.no_grad():
                generation = self.model.generate(
                    input_ids=description_inputs.input_ids,
                    attention_mask=description_inputs.attention_mask,
                    prompt_input_ids=prompt_inputs.input_ids,
                    prompt_attention_mask=prompt_inputs.attention_mask
                )
                
                # Convert to numpy array
                audio_array = generation.cpu().numpy().squeeze()
            
            return audio_array
            
        except Exception as e:
            print(f"Error in speech generation: {str(e)}")
            raise