Spaces:
Runtime error
Runtime error
File size: 4,339 Bytes
eca00c9 34e9d48 9ff10ae eca00c9 9ff10ae d64ab4a eca00c9 d64ab4a 27ff1f5 c9d3334 856d35d c9d3334 27ff1f5 eca00c9 0f04f5f 27ff1f5 0f04f5f 27ff1f5 eca00c9 d64ab4a 2bf069f d64ab4a 856d35d d64ab4a c9d3334 d64ab4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# ------------------- LIBRARIES -------------------- #
import os, logging, torch, streamlit as st
from transformers import (
AutoTokenizer, AutoModelForCausalLM)
# --------------------- HELPER --------------------- #
def C(text, color="yellow"):
color_dict: dict = dict(
red="\033[01;31m",
green="\033[01;32m",
yellow="\033[01;33m",
blue="\033[01;34m",
magenta="\033[01;35m",
cyan="\033[01;36m",
)
color_dict[None] = "\033[0m"
return (
f"{color_dict.get(color, None)}"
f"{text}{color_dict[None]}")
def stcache():
from packaging import version
if version.parse(st.__version__) < version.parse("1.18"):
return lambda f: st.cache(suppress_st_warning=True)(f)
return lambda f: st.cache_resource()(f)
st.title("`ckip-joint/bloom-1b1-zh` demo")
# ------------------ ENVIORNMENT ------------------- #
os.environ["HF_ENDPOINT"] = "https://huggingface.co"
device = ("cuda"
if torch.cuda.is_available() else "cpu")
logging.info(C("[INFO] "f"device = {device}"))
# ------------------ INITITALIZE ------------------- #
stdec = stcache()
@stdec
def model_init():
logging.info(C("[INFO] "f"Model init start!"))
from transformers import GenerationConfig
# generation_config, unused_kwargs = GenerationConfig.from_pretrained(
# "ckip-joint/bloom-1b1-zh",
# max_new_tokens=200,
# return_unused_kwargs=True)
tokenizer = AutoTokenizer.from_pretrained(
"ckip-joint/bloom-1b1-zh")
model = AutoModelForCausalLM.from_pretrained(
"ckip-joint/bloom-1b1-zh",
# Ref.: Eric, Thanks!
# torch_dtype="auto",
# device_map="auto",
# Ref. for `half`: Chan-Jan, Thanks!
).eval().to(device)
st.balloons()
logging.info(C("[INFO] "f"Model init success!"))
return tokenizer, model
tokenizer, model = model_init()
if 1:
try:
# ===================== INPUT ====================== #
prompt = st.text_input("Prompt: ")
# =================== INFERENCE ==================== #
if prompt:
# placeholder = st.empty()
# st.title(prompt)
with st.container():
st.markdown(f""
f":violet[{prompt}]⋯⋯"
)
# st.empty()
with torch.no_grad():
[texts_out] = model.generate(
**tokenizer(
prompt, return_tensors="pt",
).to(device),
min_new_tokens=0,
max_new_tokens=100,
)
output_text = tokenizer.decode(texts_out,
skip_special_tokens=True,
)
st.empty()
if output_text.startswith(prompt):
out_gens = output_text[len(prompt):]
assert prompt + out_gens == output_text
else:
out_gens = output_text
prompt = ""
st.balloons()
out_gens = out_gens.split('\n')[0]
def multiline(string):
lines = string.split('\n')
return '\\\n'.join([f"**:red[{l}]**"
for l in lines])
# st.empty()
st.caption("Result: ")
st.markdown(f""
f":blue[{prompt}]**:red[{multiline(out_gens)}]**"
)
# st.text(repr(out_gens0))
except Exception as err:
st.write(str(err))
st.snow()
# import streamlit as st
# st.markdown('Streamlit is **_really_ cool**.')
# st.markdown("This text is :red[colored red], and this is **:blue[colored]** and bold.")
# st.markdown(":green[$\sqrt{x^2+y^2}=1$] is a Pythagorean identity. :pencil:")
# def multiline(string):
# lines = string.split('\n')
# return '\\\n'.join([f"**:red[{l}]**"
# for l in lines])
# st.markdown(multiline("1234 \n5616"))
# st.markdown("1234\\\n5616")
# https://docs.streamlit.io/library/api-reference/status/st.spinner
# https://stackoverflow.com/questions/32402502/how-to-change-the-time-zone-in-python-logging |