Spaces:
Sleeping
Sleeping
from random import randint | |
from time import sleep | |
import streamlit as st | |
from huggingface_hub import hf_hub_download | |
from llama_cpp import Llama | |
from preprocessing import preprocess_pipeline | |
vistral_path = hf_hub_download( | |
repo_id="nguyen1207/Vistral-7B-MT-GGUF", | |
filename="unsloth.Q4_K_M.gguf", | |
resume_download=True, | |
cache_dir="models", | |
) | |
llm = Llama(model_path=vistral_path) | |
def disable_input(): | |
st.session_state.translating = True | |
def translate(llm, prompt, top_p, top_k, temperature, repetition_penalty, max_length): | |
stream = llm.create_completion( | |
prompt, | |
stream=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
frequency_penalty=repetition_penalty, | |
max_tokens=max_length, | |
) | |
count = 0 | |
for response in stream: | |
if count < 3: | |
count += 1 | |
yield "" | |
else: | |
yield response["choices"][0]["text"] | |
model = None | |
tokenizer = None | |
st.set_page_config(page_title="Vietnamese to English Translation") | |
st.title( | |
"🇻🇳 Vietnamese to 🇺🇸 English Translation but with Teencode and Slang understanding 🤯" | |
) | |
st.sidebar.header("Translation Parameters") | |
top_p = st.sidebar.slider("Top p", min_value=0.0, max_value=1.0, value=0.95) | |
top_k = st.sidebar.slider("Top k", min_value=1, max_value=100, value=50) | |
temperature = st.sidebar.slider("Temperature", min_value=0.0, max_value=2.0, value=0.3) | |
repetition_penalty = st.sidebar.slider( | |
"Repetition Penalty", min_value=1.0, max_value=3.0, value=1.05 | |
) | |
max_length = st.sidebar.slider("Max Length", min_value=10, max_value=512, value=128) | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
st.session_state.translating = False | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if user_input := st.chat_input( | |
"Vietnamese text goes here... 🇻🇳", | |
disabled=st.session_state.translating, | |
on_submit=disable_input, | |
): | |
if user_input.strip() != "": | |
st.session_state.translating = True | |
preprocessed_input = preprocess_pipeline(user_input) | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
with st.chat_message("assistant"): | |
prompt_template = """<s> [INST] Dịch câu sau từ tiếng Việt sang tiếng Anh: | |
Tiếng Việt: {} [/INST] """ | |
prompt = prompt_template.format(preprocessed_input) | |
stream = translate( | |
llm, prompt, top_p, top_k, temperature, repetition_penalty, max_length | |
) | |
translation = st.write_stream(stream) | |
st.markdown(translation) | |
st.session_state.messages.append({"role": "assistant", "content": translation}) | |
# Reset the input field | |
st.session_state.translating = False | |
st.rerun() | |