Spaces:
Running
Running
File size: 2,990 Bytes
3a89850 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
from IndicTransTokenizer.IndicTransTokenizer.utils import IndicProcessor
from IndicTransTokenizer.IndicTransTokenizer.tokenizer import IndicTransTokenizer
from peft import PeftModel
from config import lora_repo_id, model_repo_id, batch_size, src_lang, tgt_lang
DIRECTION = "en-indic"
QUANTIZATION = None
IP = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HALF = True if torch.cuda.is_available() else False
def initialize_model_and_tokenizer():
if QUANTIZATION == "4-bit":
qconfig = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
elif QUANTIZATION == "8-bit":
qconfig = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_compute_dtype=torch.bfloat16,
)
else:
qconfig = None
tokenizer = IndicTransTokenizer(direction=DIRECTION)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_repo_id,
trust_remote_code=True,
low_cpu_mem_usage=True,
quantization_config=qconfig,
)
model2 = AutoModelForSeq2SeqLM.from_pretrained(
model_repo_id,
trust_remote_code=True,
low_cpu_mem_usage=True,
quantization_config=qconfig,
)
if qconfig == None:
model = model.to(DEVICE)
model2 = model2.to(DEVICE)
model.eval()
model2.eval()
lora_model = PeftModel.from_pretrained(model2, lora_repo_id)
return tokenizer, model, lora_model
def batch_translate(input_sentences, model, tokenizer):
translations = []
for i in range(0, len(input_sentences), batch_size):
batch = input_sentences[i : i + batch_size]
# Preprocess the batch and extract entity mappings
batch = IP.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
# Tokenize the batch and generate input encodings
inputs = tokenizer(
batch,
src=True,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
# Generate translations using the model
with torch.inference_mode():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens into text
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(), src=False
)
# Postprocess the translations, including entity replacement
translations += IP.postprocess_batch(generated_tokens, lang=tgt_lang)
del inputs
return translations
|