llama / app.py
hereoncollab's picture
Update app.py
10ceff5 verified
raw
history blame
3 kB
import os
import torch
import gradio as gr
from transformers import MarianMTModel, MarianTokenizer, pipeline, AutoTokenizer
from huggingface_hub import login
# Read the token from the environment variable
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
# Authenticate with Hugging Face
if HUGGINGFACE_TOKEN:
login(token=HUGGINGFACE_TOKEN)
else:
raise ValueError("Hugging Face token not found in environment variables.")
# Define model and tokenizer for translation between Romanian, French, and English
rfr_md = "Helsinki-NLP/opus-mt-ro-fr"
frr_md = "Helsinki-NLP/opus-mt-fr-en"
enr_md = "Helsinki-NLP/opus-mt-en-ro"
rfr_token = MarianTokenizer.from_pretrained(rfr_md)
rfr_model = MarianMTModel.from_pretrained(rfr_md)
fren_token = MarianTokenizer.from_pretrained(frr_md)
fren_model = MarianMTModel.from_pretrained(frr_md)
enr_token = MarianTokenizer.from_pretrained(enr_md)
enr_model = MarianMTModel.from_pretrained(enr_md)
# Load the Gemma model for text generation, ensuring it runs on CPU
gemma_model = "stabilityai/stablelm-2-1_6b-chat"
gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model)
pipe = pipeline(
"text-generation",
model=gemma_model,
tokenizer=gemma_tokenizer,
device="cpu" # Use CPU
)
# Function to split text into smaller blocks for translation
def char_split(text, tokenizer, max_length=498):
tokens = tokenizer(text, return_tensors="pt", truncation=False, padding=False)["input_ids"][0]
blocks_ = []
start = 0
while start < len(tokens):
end = min(start + max_length, len(tokens))
blocks_.append(tokens[start:end])
start = end
return blocks_
# Function to translate the text block by block
def translate(text, model, tokenizer, max_length=500):
token_blocks = char_split(text, tokenizer, max_length)
text_en = ""
for blk_ in token_blocks:
blk_char = tokenizer.decode(blk_, skip_special_tokens=True)
translated = model.generate(**tokenizer(blk_char, return_tensors="pt", padding=True, truncation=True))
text_en += tokenizer.decode(translated[0], skip_special_tokens=True) + " "
return text_en.strip()
# Function to remove formatting symbols
def rm_rf(text):
import re
return re.sub(r'\*+', '', text)
# Generate text based on Romanian input
def generate(text):
fr_txt = translate(text, rfr_model, rfr_token)
en_txt = translate(fr_txt, fren_model, fren_token)
sequences = pipe(
en_txt,
max_new_tokens=2048,
do_sample=True,
return_full_text=False,
)
generated_text = sequences[0]['generated_text']
cl_txt = rm_rf(generated_text)
ro_txt = translate(cl_txt, enr_model, enr_token)
return ro_txt
# Create the Gradio interface
interface = gr.Interface(
fn=generate,
inputs=gr.Textbox(label="prompt:", lines=2, placeholder="prompt..."),
outputs="text",
title="Gemma Romanian",
description="romanian gemma using nlps."
)
# Launch the Gradio app
interface.launch()