# ------------------- LIBRARIES -------------------- # import os, logging, torch, streamlit as st from transformers import ( AutoTokenizer, AutoModelForCausalLM) # --------------------- HELPER --------------------- # def C(text, color="yellow"): color_dict: dict = dict( red="\033[01;31m", green="\033[01;32m", yellow="\033[01;33m", blue="\033[01;34m", magenta="\033[01;35m", cyan="\033[01;36m", ) color_dict[None] = "\033[0m" return ( f"{color_dict.get(color, None)}" f"{text}{color_dict[None]}") def stcache(): from packaging import version if version.parse(st.__version__) < version.parse("1.18"): return lambda f: st.cache(suppress_st_warning=True)(f) return lambda f: st.cache_resource()(f) st.title("`ckip-joint/bloom-1b1-zh` demo") # ------------------ ENVIORNMENT ------------------- # os.environ["HF_ENDPOINT"] = "https://huggingface.co" device = ("cuda" if torch.cuda.is_available() else "cpu") logging.info(C("[INFO] "f"device = {device}")) # ------------------ INITITALIZE ------------------- # stdec = stcache() @stdec def model_init(): logging.info(C("[INFO] "f"Model init start!")) from transformers import GenerationConfig # generation_config, unused_kwargs = GenerationConfig.from_pretrained( # "ckip-joint/bloom-1b1-zh", # max_new_tokens=200, # return_unused_kwargs=True) tokenizer = AutoTokenizer.from_pretrained( "ckip-joint/bloom-1b1-zh") model = AutoModelForCausalLM.from_pretrained( "ckip-joint/bloom-1b1-zh", # Ref.: Eric, Thanks! # torch_dtype="auto", # device_map="auto", # Ref. for `half`: Chan-Jan, Thanks! ).eval().to(device) st.balloons() logging.info(C("[INFO] "f"Model init success!")) return tokenizer, model tokenizer, model = model_init() if 1: try: # ===================== INPUT ====================== # prompt = st.text_input("Prompt: ") # =================== INFERENCE ==================== # if prompt: # placeholder = st.empty() # st.title(prompt) with st.container(): st.markdown(f"" f":violet[{prompt}]⋯⋯" ) # st.empty() with torch.no_grad(): [texts_out] = model.generate( **tokenizer( prompt, return_tensors="pt", ).to(device), min_new_tokens=0, max_new_tokens=100, ) output_text = tokenizer.decode(texts_out, skip_special_tokens=True, ) st.empty() if output_text.startswith(prompt): out_gens = output_text[len(prompt):] assert prompt + out_gens == output_text else: out_gens = output_text prompt = "" st.balloons() out_gens = out_gens.split('\n')[0] def multiline(string): lines = string.split('\n') return '\\\n'.join([f"**:red[{l}]**" for l in lines]) # st.empty() st.caption("Result: ") st.markdown(f"" f":blue[{prompt}]**:red[{multiline(out_gens)}]**" ) # st.text(repr(out_gens0)) except Exception as err: st.write(str(err)) st.snow() # import streamlit as st # st.markdown('Streamlit is **_really_ cool**.') # st.markdown("This text is :red[colored red], and this is **:blue[colored]** and bold.") # st.markdown(":green[$\sqrt{x^2+y^2}=1$] is a Pythagorean identity. :pencil:") # def multiline(string): # lines = string.split('\n') # return '\\\n'.join([f"**:red[{l}]**" # for l in lines]) # st.markdown(multiline("1234 \n5616")) # st.markdown("1234\\\n5616") # https://docs.streamlit.io/library/api-reference/status/st.spinner # https://stackoverflow.com/questions/32402502/how-to-change-the-time-zone-in-python-logging