nguyen1207's picture
update model, remove comment
3f1aa29
from random import randint
from time import sleep
import streamlit as st
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from preprocessing import preprocess_pipeline
vistral_path = hf_hub_download(
repo_id="nguyen1207/Vistral-7B-MT-GGUF",
filename="unsloth.Q4_K_M.gguf",
resume_download=True,
cache_dir="models",
)
llm = Llama(model_path=vistral_path)
def disable_input():
st.session_state.translating = True
def translate(llm, prompt, top_p, top_k, temperature, repetition_penalty, max_length):
stream = llm.create_completion(
prompt,
stream=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
frequency_penalty=repetition_penalty,
max_tokens=max_length,
)
count = 0
for response in stream:
if count < 3:
count += 1
yield ""
else:
yield response["choices"][0]["text"]
model = None
tokenizer = None
st.set_page_config(page_title="Vietnamese to English Translation")
st.title(
"🇻🇳 Vietnamese to 🇺🇸 English Translation but with Teencode and Slang understanding 🤯"
)
st.sidebar.header("Translation Parameters")
top_p = st.sidebar.slider("Top p", min_value=0.0, max_value=1.0, value=0.95)
top_k = st.sidebar.slider("Top k", min_value=1, max_value=100, value=50)
temperature = st.sidebar.slider("Temperature", min_value=0.0, max_value=2.0, value=0.3)
repetition_penalty = st.sidebar.slider(
"Repetition Penalty", min_value=1.0, max_value=3.0, value=1.05
)
max_length = st.sidebar.slider("Max Length", min_value=10, max_value=512, value=128)
if "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.translating = False
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if user_input := st.chat_input(
"Vietnamese text goes here... 🇻🇳",
disabled=st.session_state.translating,
on_submit=disable_input,
):
if user_input.strip() != "":
st.session_state.translating = True
preprocessed_input = preprocess_pipeline(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
with st.chat_message("assistant"):
prompt_template = """<s> [INST] Dịch câu sau từ tiếng Việt sang tiếng Anh:
Tiếng Việt: {} [/INST] """
prompt = prompt_template.format(preprocessed_input)
stream = translate(
llm, prompt, top_p, top_k, temperature, repetition_penalty, max_length
)
translation = st.write_stream(stream)
st.markdown(translation)
st.session_state.messages.append({"role": "assistant", "content": translation})
# Reset the input field
st.session_state.translating = False
st.rerun()