indictrans2-conversation / indictrans2.py
sam749's picture
Upload folder using huggingface_hub
3a89850 verified
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
from IndicTransTokenizer.IndicTransTokenizer.utils import IndicProcessor
from IndicTransTokenizer.IndicTransTokenizer.tokenizer import IndicTransTokenizer
from peft import PeftModel
from config import lora_repo_id, model_repo_id, batch_size, src_lang, tgt_lang
DIRECTION = "en-indic"
QUANTIZATION = None
IP = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HALF = True if torch.cuda.is_available() else False
def initialize_model_and_tokenizer():
if QUANTIZATION == "4-bit":
qconfig = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
elif QUANTIZATION == "8-bit":
qconfig = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_compute_dtype=torch.bfloat16,
)
else:
qconfig = None
tokenizer = IndicTransTokenizer(direction=DIRECTION)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_repo_id,
trust_remote_code=True,
low_cpu_mem_usage=True,
quantization_config=qconfig,
)
model2 = AutoModelForSeq2SeqLM.from_pretrained(
model_repo_id,
trust_remote_code=True,
low_cpu_mem_usage=True,
quantization_config=qconfig,
)
if qconfig == None:
model = model.to(DEVICE)
model2 = model2.to(DEVICE)
model.eval()
model2.eval()
lora_model = PeftModel.from_pretrained(model2, lora_repo_id)
return tokenizer, model, lora_model
def batch_translate(input_sentences, model, tokenizer):
translations = []
for i in range(0, len(input_sentences), batch_size):
batch = input_sentences[i : i + batch_size]
# Preprocess the batch and extract entity mappings
batch = IP.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
# Tokenize the batch and generate input encodings
inputs = tokenizer(
batch,
src=True,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
# Generate translations using the model
with torch.inference_mode():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens into text
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(), src=False
)
# Postprocess the translations, including entity replacement
translations += IP.postprocess_batch(generated_tokens, lang=tgt_lang)
del inputs
return translations