msy127 commited on
Commit
4b47b90
ยท
1 Parent(s): 2b86939

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -1
app.py CHANGED
@@ -1,11 +1,30 @@
 
 
1
  import gradio as gr
 
2
  from pydantic import BaseModel, Field
3
  from typing import Any, Optional, Dict, List
4
  from huggingface_hub import InferenceClient
5
  from langchain.llms.base import LLM
 
 
 
 
 
6
 
 
7
  hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
8
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class KwArgsModel(BaseModel):
11
  kwargs: Dict[str, Any] = Field(default_factory=dict)
@@ -42,4 +61,124 @@ class CustomInferenceClient(LLM, KwArgsModel):
42
  def _identifying_params(self) -> dict:
43
  return {"model_name": self.model_name}
44
 
45
- kwargs = {"max_new_tokens":256, "temperature":0.9, "top_p":0.6, "repetition_penalty":1.3, "do_sample":True}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
  import gradio as gr
4
+ import time
5
  from pydantic import BaseModel, Field
6
  from typing import Any, Optional, Dict, List
7
  from huggingface_hub import InferenceClient
8
  from langchain.llms.base import LLM
9
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
10
+ from langchain.vectorstores import Chroma
11
+ import os
12
+ from dotenv import load_dotenv
13
+ load_dotenv()
14
 
15
+ path_work = "."
16
  hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
17
 
18
+ embeddings = HuggingFaceInstructEmbeddings(
19
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
20
+ model_kwargs={"device": "cpu"}
21
+ )
22
+
23
+ vectordb = Chroma(
24
+ persist_directory = path_work + '/cromadb_llama2-papers',
25
+ embedding_function=embeddings)
26
+
27
+ retriever = vectordb.as_retriever(search_kwargs={"k": 5})
28
 
29
  class KwArgsModel(BaseModel):
30
  kwargs: Dict[str, Any] = Field(default_factory=dict)
 
61
  def _identifying_params(self) -> dict:
62
  return {"model_name": self.model_name}
63
 
64
+ kwargs = {"max_new_tokens":256, "temperature":0.9, "top_p":0.6, "repetition_penalty":1.3, "do_sample":True}
65
+
66
+ model_list=[
67
+ "meta-llama/Llama-2-13b-chat-hf",
68
+ "HuggingFaceH4/zephyr-7b-alpha",
69
+ "meta-llama/Llama-2-70b-chat-hf",
70
+ "tiiuae/falcon-180B-chat"
71
+ ]
72
+
73
+ qa_chain = None
74
+
75
+ def load_model(model_selected):
76
+ global qa_chain
77
+ model_name = model_selected
78
+ llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
79
+
80
+ from langchain.chains import RetrievalQA
81
+ qa_chain = RetrievalQA.from_chain_type(
82
+ llm=llm,
83
+ chain_type="stuff",
84
+ retriever=retriever,
85
+ return_source_documents=True,
86
+ verbose=True,
87
+ )
88
+ qa_chain
89
+
90
+ load_model("meta-llama/Llama-2-70b-chat-hf")
91
+
92
+ def model_select(model_selected):
93
+ load_model(model_selected)
94
+ return f"๋ชจ๋ธ {model_selected} ๋กœ๋”ฉ ์™„๋ฃŒ."
95
+
96
+ def predict(message, chatbot, temperature=0.9, max_new_tokens=512, top_p=0.6, repetition_penalty=1.3,):
97
+
98
+ temperature = float(temperature)
99
+ if temperature < 1e-2: temperature = 1e-2
100
+ top_p = float(top_p)
101
+
102
+ llm_response = qa_chain(message)
103
+ res_result = llm_response['result']
104
+
105
+ res_relevant_doc = [source.metadata['source'] for source in llm_response["source_documents"]]
106
+ response = f"{res_result}" + "\n\n" + "[๋‹ต๋ณ€ ๊ทผ๊ฑฐ ์†Œ์Šค ๋…ผ๋ฌธ (ctrl + click ํ•˜์„ธ์š”!)] :" + "\n" + f" \n {res_relevant_doc}"
107
+ print("response: =====> \n", response, "\n\n")
108
+
109
+ tokens = response.split('\n')
110
+ token_list = []
111
+ for idx, token in enumerate(tokens):
112
+ token_dict = {"id": idx + 1, "text": token}
113
+ token_list.append(token_dict)
114
+ response = {"data": {"token": token_list}}
115
+ response = json.dumps(response, indent=4)
116
+
117
+ response = json.loads(response)
118
+ data_dict = response.get('data', {})
119
+ token_list = data_dict.get('token', [])
120
+
121
+ partial_message = ""
122
+ for token_entry in token_list:
123
+ if token_entry:
124
+ try:
125
+ token_id = token_entry.get('id', None)
126
+ token_text = token_entry.get('text', None)
127
+
128
+ if token_text:
129
+ for char in token_text:
130
+ partial_message += char
131
+ yield partial_message
132
+ time.sleep(0.01)
133
+ else:
134
+ print(f"[[์›Œ๋‹]] ==> The key 'text' does not exist or is None in this token entry: {token_entry}")
135
+ pass
136
+
137
+ except KeyError as e:
138
+ gr.Warning(f"KeyError: {e} occurred for token entry: {token_entry}")
139
+ continue
140
+
141
+ title = "Llama-2 ๋ชจ๋ธ ๊ด€๋ จ ๋…ผ๋ฌธ Generative QA (with RAG) ์„œ๋น„์Šค (Llama-2-70b ๋ชจ๋ธ ๋“ฑ ํ™œ์šฉ)"
142
+ description = """Chat history ์œ ์ง€ ๋ณด๋‹ค๋Š” QA์— ์ถฉ์‹คํ•˜๋„๋ก ์ œ์ž‘๋˜์—ˆ์œผ๋ฏ€๋กœ Single turn์œผ๋กœ ํ™œ์šฉ ํ•˜์—ฌ ์ฃผ์„ธ์š”. Default๋กœ Llama-2 70b ๋ชจ๋ธ๋กœ ์„ค์ •๋˜์–ด ์žˆ์œผ๋‚˜ GPU ์„œ๋น„์Šค ํ•œ๋„ ์ดˆ๊ณผ๋กœ Error๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์œผ๋‹ˆ ์–‘ํ•ด๋ถ€ํƒ๋“œ๋ฆฌ๋ฉฐ, ํ™”๋ฉด ํ•˜๋‹จ์˜ ๋ชจ๋ธ ๋ณ€๊ฒฝ/๋กœ๋”ฉํ•˜์‹œ์–ด ๋‹ค๋ฅธ ๋ชจ๋ธ๋กœ ๋ณ€๊ฒฝํ•˜์—ฌ ์‚ฌ์šฉ์„ ๋ถ€ํƒ๋“œ๋ฆฝ๋‹ˆ๋‹ค. (๋‹ค๋งŒ, Llama-2 70b๊ฐ€ ๊ฐ€์žฅ ์ •ํ™•ํ•˜์˜ค๋‹ˆ ์ฐธ๊ณ ํ•˜์—ฌ ์ฃผ์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.) """
143
+ css = """.toast-wrap { display: none !important } """
144
+ examples=[['Can you tell me about the llama-2 model?'],['What is percent accuracy, using the SPP layer as features on the SPP (ZF-5) model?'], ["How much less accurate is using the SPP layer as features on the SPP (ZF-5) model compared to using the same model on the undistorted full image?"], ["tell me about method for human pose estimation based on DNNs"]]
145
+
146
+ def vote(data: gr.LikeData):
147
+ if data.liked: print("You upvoted this response: " + data.value)
148
+ else: print("You downvoted this response: " + data.value)
149
+
150
+ additional_inputs = [
151
+ gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
152
+ gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=4096, step=64, interactive=True, info="The maximum numbers of new tokens"),
153
+ gr.Slider(label="Top-p (nucleus sampling)", value=0.6, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
154
+ gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
155
+ ]
156
+
157
+ chatbot_stream = gr.Chatbot(avatar_images=(
158
+ "https://drive.google.com/uc?id=18xKoNOHN15H_qmGhK__VKnGjKjirrquW",
159
+ "https://drive.google.com/uc?id=1tfELAQW_VbPCy6QTRbexRlwAEYo8rSSv"
160
+ ), bubble_full_width = False)
161
+
162
+ chat_interface_stream = gr.ChatInterface(
163
+ predict,
164
+ title=title,
165
+ description=description,
166
+ chatbot=chatbot_stream,
167
+ css=css,
168
+ examples=examples,
169
+ )
170
+
171
+ with gr.Blocks() as demo:
172
+ with gr.Tab("์ŠคํŠธ๋ฆฌ๋ฐ"):
173
+ chatbot_stream.like(vote, None, None)
174
+ chat_interface_stream.render()
175
+ with gr.Row():
176
+ with gr.Column(scale=6):
177
+ with gr.Row():
178
+ model_selector = gr.Dropdown(model_list, label="๋ชจ๋ธ ์„ ํƒ", value= "meta-llama/Llama-2-70b-chat-hf", scale=5)
179
+ submit_btn1 = gr.Button(value="๋ชจ๋ธ ๋กœ๋“œ", scale=1)
180
+ with gr.Column(scale=4):
181
+ model_status = gr.Textbox(value="", label="๋ชจ๋ธ ์ƒํƒœ")
182
+ submit_btn1.click(model_select, inputs=[model_selector], outputs=[model_status])
183
+
184
+ demo.queue(concurrency_count=75, max_size=100).launch(debug=True)