Spaces:
Runtime error
Runtime error
# ------------------- LIBRARIES -------------------- # | |
import os, logging, torch, streamlit as st | |
from transformers import ( | |
AutoTokenizer, AutoModelForCausalLM) | |
# --------------------- HELPER --------------------- # | |
def C(text, color="yellow"): | |
color_dict: dict = dict( | |
red="\033[01;31m", | |
green="\033[01;32m", | |
yellow="\033[01;33m", | |
blue="\033[01;34m", | |
magenta="\033[01;35m", | |
cyan="\033[01;36m", | |
) | |
color_dict[None] = "\033[0m" | |
return ( | |
f"{color_dict.get(color, None)}" | |
f"{text}{color_dict[None]}") | |
def stcache(): | |
from packaging import version | |
if version.parse(st.__version__) < version.parse("1.18"): | |
return lambda f: st.cache(suppress_st_warning=True)(f) | |
return lambda f: st.cache_resource()(f) | |
st.title("`ckip-joint/bloom-1b1-zh` demo") | |
# ------------------ ENVIORNMENT ------------------- # | |
os.environ["HF_ENDPOINT"] = "https://huggingface.co" | |
device = ("cuda" | |
if torch.cuda.is_available() else "cpu") | |
logging.info(C("[INFO] "f"device = {device}")) | |
# ------------------ INITITALIZE ------------------- # | |
stdec = stcache() | |
def model_init(): | |
logging.info(C("[INFO] "f"Model init start!")) | |
from transformers import GenerationConfig | |
# generation_config, unused_kwargs = GenerationConfig.from_pretrained( | |
# "ckip-joint/bloom-1b1-zh", | |
# max_new_tokens=200, | |
# return_unused_kwargs=True) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"ckip-joint/bloom-1b1-zh") | |
model = AutoModelForCausalLM.from_pretrained( | |
"ckip-joint/bloom-1b1-zh", | |
# Ref.: Eric, Thanks! | |
# torch_dtype="auto", | |
# device_map="auto", | |
# Ref. for `half`: Chan-Jan, Thanks! | |
).eval().to(device) | |
st.balloons() | |
logging.info(C("[INFO] "f"Model init success!")) | |
return tokenizer, model | |
tokenizer, model = model_init() | |
if 1: | |
try: | |
# ===================== INPUT ====================== # | |
prompt = st.text_input("Prompt: ") | |
# =================== INFERENCE ==================== # | |
if prompt: | |
# placeholder = st.empty() | |
# st.title(prompt) | |
with st.container(): | |
st.markdown(f"" | |
f":violet[{prompt}]⋯⋯" | |
) | |
# st.empty() | |
with torch.no_grad(): | |
[texts_out] = model.generate( | |
**tokenizer( | |
prompt, return_tensors="pt", | |
).to(device), | |
min_new_tokens=0, | |
max_new_tokens=100, | |
) | |
output_text = tokenizer.decode(texts_out, | |
skip_special_tokens=True, | |
) | |
st.empty() | |
if output_text.startswith(prompt): | |
out_gens = output_text[len(prompt):] | |
assert prompt + out_gens == output_text | |
else: | |
out_gens = output_text | |
prompt = "" | |
st.balloons() | |
out_gens = out_gens.split('\n')[0] | |
def multiline(string): | |
lines = string.split('\n') | |
return '\\\n'.join([f"**:red[{l}]**" | |
for l in lines]) | |
# st.empty() | |
st.caption("Result: ") | |
st.markdown(f"" | |
f":blue[{prompt}]**:red[{multiline(out_gens)}]**" | |
) | |
# st.text(repr(out_gens0)) | |
except Exception as err: | |
st.write(str(err)) | |
st.snow() | |
# import streamlit as st | |
# st.markdown('Streamlit is **_really_ cool**.') | |
# st.markdown("This text is :red[colored red], and this is **:blue[colored]** and bold.") | |
# st.markdown(":green[$\sqrt{x^2+y^2}=1$] is a Pythagorean identity. :pencil:") | |
# def multiline(string): | |
# lines = string.split('\n') | |
# return '\\\n'.join([f"**:red[{l}]**" | |
# for l in lines]) | |
# st.markdown(multiline("1234 \n5616")) | |
# st.markdown("1234\\\n5616") | |
# https://docs.streamlit.io/library/api-reference/status/st.spinner | |
# https://stackoverflow.com/questions/32402502/how-to-change-the-time-zone-in-python-logging |