File size: 4,339 Bytes
eca00c9
34e9d48
 
 
9ff10ae
eca00c9
 
 
 
 
 
 
 
 
 
 
 
 
 
9ff10ae
d64ab4a
 
 
 
 
 
 
 
eca00c9
 
 
 
 
 
 
d64ab4a
 
27ff1f5
c9d3334
856d35d
c9d3334
 
 
 
 
 
 
 
 
 
 
 
 
 
27ff1f5
 
 
 
 
 
 
 
 
eca00c9
 
0f04f5f
27ff1f5
0f04f5f
27ff1f5
eca00c9
d64ab4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bf069f
 
d64ab4a
 
 
 
 
 
 
 
 
 
 
 
856d35d
d64ab4a
 
 
 
 
 
 
 
c9d3334
d64ab4a
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# ------------------- LIBRARIES -------------------- #
import os, logging, torch, streamlit as st
from transformers import (
    AutoTokenizer, AutoModelForCausalLM)

# --------------------- HELPER --------------------- #
def C(text, color="yellow"):
    color_dict: dict = dict(
            red="\033[01;31m",
          green="\033[01;32m",
         yellow="\033[01;33m",
           blue="\033[01;34m",
        magenta="\033[01;35m",
           cyan="\033[01;36m",
    )
    color_dict[None] = "\033[0m"
    return (
        f"{color_dict.get(color, None)}"
        f"{text}{color_dict[None]}")

def stcache():
    from packaging import version
    if version.parse(st.__version__) < version.parse("1.18"):
        return lambda f: st.cache(suppress_st_warning=True)(f)
    return lambda f: st.cache_resource()(f)
    
st.title("`ckip-joint/bloom-1b1-zh` demo")

# ------------------ ENVIORNMENT ------------------- #
os.environ["HF_ENDPOINT"] = "https://huggingface.co"
device = ("cuda" 
    if torch.cuda.is_available() else "cpu")
logging.info(C("[INFO] "f"device = {device}"))

# ------------------ INITITALIZE ------------------- #
stdec = stcache()
@stdec
def model_init():
    
    logging.info(C("[INFO] "f"Model init start!"))
    
    
    from transformers import GenerationConfig

    # generation_config, unused_kwargs = GenerationConfig.from_pretrained(
    #     "ckip-joint/bloom-1b1-zh", 
    #     max_new_tokens=200,
    #     return_unused_kwargs=True)
    
    
    
    
    
    
    tokenizer = AutoTokenizer.from_pretrained(
        "ckip-joint/bloom-1b1-zh")
    model = AutoModelForCausalLM.from_pretrained(
        "ckip-joint/bloom-1b1-zh",
        # Ref.: Eric, Thanks!
        # torch_dtype="auto", 
        # device_map="auto",
    # Ref. for `half`: Chan-Jan, Thanks!
    ).eval().to(device)
    st.balloons()
    logging.info(C("[INFO] "f"Model init success!"))
    return tokenizer, model

tokenizer, model = model_init()


if 1:
    try:
        # ===================== INPUT ====================== #
        prompt = st.text_input("Prompt: ")

        # =================== INFERENCE ==================== #
        if prompt:
            # placeholder = st.empty()
            # st.title(prompt)
            with st.container():
                st.markdown(f""
                    f":violet[{prompt}]⋯⋯"
                )
                # st.empty()

                with torch.no_grad():
                    [texts_out] = model.generate(
                        **tokenizer(
                            prompt, return_tensors="pt",
                            
                        ).to(device),
                        min_new_tokens=0,
                        max_new_tokens=100,
                    )
                output_text = tokenizer.decode(texts_out, 
                    skip_special_tokens=True,
                )
                st.empty()
            if output_text.startswith(prompt):
                out_gens = output_text[len(prompt):]
                assert prompt + out_gens == output_text
            else:
                out_gens = output_text
                prompt = ""
            st.balloons()

            out_gens = out_gens.split('\n')[0]
            
            def multiline(string):
                lines = string.split('\n')
                return '\\\n'.join([f"**:red[{l}]**"
                for l in lines])
                
                
            
            # st.empty()
            st.caption("Result: ")
            st.markdown(f""
                f":blue[{prompt}]**:red[{multiline(out_gens)}]**"
            )
            # st.text(repr(out_gens0))

    except Exception as err:
        st.write(str(err))
        st.snow()


    # import streamlit as st

    # st.markdown('Streamlit is **_really_ cool**.')
    # st.markdown("This text is :red[colored red], and this is **:blue[colored]** and bold.")
    # st.markdown(":green[$\sqrt{x^2+y^2}=1$] is a Pythagorean identity. :pencil:")
# def multiline(string):
#     lines = string.split('\n')
#     return '\\\n'.join([f"**:red[{l}]**"
#     for l in lines])
# st.markdown(multiline("1234 \n5616"))
# st.markdown("1234\\\n5616")
# https://docs.streamlit.io/library/api-reference/status/st.spinner
# https://stackoverflow.com/questions/32402502/how-to-change-the-time-zone-in-python-logging