import torch from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig from IndicTransTokenizer.IndicTransTokenizer.utils import IndicProcessor from IndicTransTokenizer.IndicTransTokenizer.tokenizer import IndicTransTokenizer from peft import PeftModel from config import lora_repo_id, model_repo_id, batch_size, src_lang, tgt_lang DIRECTION = "en-indic" QUANTIZATION = None IP = IndicProcessor(inference=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" HALF = True if torch.cuda.is_available() else False def initialize_model_and_tokenizer(): if QUANTIZATION == "4-bit": qconfig = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) elif QUANTIZATION == "8-bit": qconfig = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_compute_dtype=torch.bfloat16, ) else: qconfig = None tokenizer = IndicTransTokenizer(direction=DIRECTION) model = AutoModelForSeq2SeqLM.from_pretrained( model_repo_id, trust_remote_code=True, low_cpu_mem_usage=True, quantization_config=qconfig, ) model2 = AutoModelForSeq2SeqLM.from_pretrained( model_repo_id, trust_remote_code=True, low_cpu_mem_usage=True, quantization_config=qconfig, ) if qconfig == None: model = model.to(DEVICE) model2 = model2.to(DEVICE) model.eval() model2.eval() lora_model = PeftModel.from_pretrained(model2, lora_repo_id) return tokenizer, model, lora_model def batch_translate(input_sentences, model, tokenizer): translations = [] for i in range(0, len(input_sentences), batch_size): batch = input_sentences[i : i + batch_size] # Preprocess the batch and extract entity mappings batch = IP.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang) # Tokenize the batch and generate input encodings inputs = tokenizer( batch, src=True, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True, ).to(DEVICE) # Generate translations using the model with torch.inference_mode(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) # Decode the generated tokens into text generated_tokens = tokenizer.batch_decode( generated_tokens.detach().cpu().tolist(), src=False ) # Postprocess the translations, including entity replacement translations += IP.postprocess_batch(generated_tokens, lang=tgt_lang) del inputs return translations