Spaces:
Running
Running
import torch | |
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig | |
from IndicTransTokenizer.IndicTransTokenizer.utils import IndicProcessor | |
from IndicTransTokenizer.IndicTransTokenizer.tokenizer import IndicTransTokenizer | |
from peft import PeftModel | |
from config import lora_repo_id, model_repo_id, batch_size, src_lang, tgt_lang | |
DIRECTION = "en-indic" | |
QUANTIZATION = None | |
IP = IndicProcessor(inference=True) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
HALF = True if torch.cuda.is_available() else False | |
def initialize_model_and_tokenizer(): | |
if QUANTIZATION == "4-bit": | |
qconfig = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
elif QUANTIZATION == "8-bit": | |
qconfig = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_use_double_quant=True, | |
bnb_8bit_compute_dtype=torch.bfloat16, | |
) | |
else: | |
qconfig = None | |
tokenizer = IndicTransTokenizer(direction=DIRECTION) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_repo_id, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
quantization_config=qconfig, | |
) | |
model2 = AutoModelForSeq2SeqLM.from_pretrained( | |
model_repo_id, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
quantization_config=qconfig, | |
) | |
if qconfig == None: | |
model = model.to(DEVICE) | |
model2 = model2.to(DEVICE) | |
model.eval() | |
model2.eval() | |
lora_model = PeftModel.from_pretrained(model2, lora_repo_id) | |
return tokenizer, model, lora_model | |
def batch_translate(input_sentences, model, tokenizer): | |
translations = [] | |
for i in range(0, len(input_sentences), batch_size): | |
batch = input_sentences[i : i + batch_size] | |
# Preprocess the batch and extract entity mappings | |
batch = IP.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang) | |
# Tokenize the batch and generate input encodings | |
inputs = tokenizer( | |
batch, | |
src=True, | |
truncation=True, | |
padding="longest", | |
return_tensors="pt", | |
return_attention_mask=True, | |
).to(DEVICE) | |
# Generate translations using the model | |
with torch.inference_mode(): | |
generated_tokens = model.generate( | |
**inputs, | |
use_cache=True, | |
min_length=0, | |
max_length=256, | |
num_beams=5, | |
num_return_sequences=1, | |
) | |
# Decode the generated tokens into text | |
generated_tokens = tokenizer.batch_decode( | |
generated_tokens.detach().cpu().tolist(), src=False | |
) | |
# Postprocess the translations, including entity replacement | |
translations += IP.postprocess_batch(generated_tokens, lang=tgt_lang) | |
del inputs | |
return translations | |