clementsan commited on
Commit
ecf1633
·
1 Parent(s): 5db4902

Use HuggingFaceEndpoint

Browse files
Files changed (1) hide show
  1. app.py +29 -11
app.py CHANGED
@@ -9,7 +9,7 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain_community.llms import HuggingFacePipeline
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
- from langchain.llms import HuggingFaceHub
13
 
14
  from pathlib import Path
15
  import chromadb
@@ -101,32 +101,50 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
101
  # Warning: langchain issue
102
  # URL: https://github.com/langchain-ai/langchain/issues/6080
103
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
104
- llm = HuggingFaceHub(
105
  repo_id=llm_model,
106
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
 
 
 
 
107
  )
108
  elif llm_model == "microsoft/phi-2":
109
  raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
110
- llm = HuggingFaceHub(
111
  repo_id=llm_model,
112
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
 
 
 
 
 
113
  )
114
  elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
115
- llm = HuggingFaceHub(
116
  repo_id=llm_model,
117
- model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
 
 
 
118
  )
119
  elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
120
  raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
121
- llm = HuggingFaceHub(
122
  repo_id=llm_model,
123
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
 
 
 
124
  )
125
  else:
126
- llm = HuggingFaceHub(
127
  repo_id=llm_model,
128
  # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
129
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
 
 
 
130
  )
131
 
132
  progress(0.75, desc="Defining buffer memory...")
 
9
  from langchain_community.llms import HuggingFacePipeline
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
+ from langchain_community.llms import HuggingFaceEndpoint
13
 
14
  from pathlib import Path
15
  import chromadb
 
101
  # Warning: langchain issue
102
  # URL: https://github.com/langchain-ai/langchain/issues/6080
103
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
104
+ llm = HuggingFaceEndpoint(
105
  repo_id=llm_model,
106
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
107
+ temperature = temperature,
108
+ max_new_tokens = max_tokens,
109
+ top_k = top_k,
110
+ load_in_8bit = True,
111
  )
112
  elif llm_model == "microsoft/phi-2":
113
  raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
114
+ llm = HuggingFaceEndpoint(
115
  repo_id=llm_model,
116
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
117
+ temperature = temperature,
118
+ max_new_tokens = max_tokens,
119
+ top_k = top_k,
120
+ trust_remote_code = True,
121
+ torch_dtype = "auto",
122
  )
123
  elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
124
+ llm = HuggingFaceEndpoint(
125
  repo_id=llm_model,
126
+ # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
127
+ temperature = temperature,
128
+ max_new_tokens = 250,
129
+ top_k = top_k,
130
  )
131
  elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
132
  raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
133
+ llm = HuggingFaceEndpoint(
134
  repo_id=llm_model,
135
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
136
+ temperature = temperature,
137
+ max_new_tokens = max_tokens,
138
+ top_k = top_k,
139
  )
140
  else:
141
+ llm = HuggingFaceEndpoint(
142
  repo_id=llm_model,
143
  # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
144
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
145
+ temperature = temperature,
146
+ max_new_tokens = max_tokens,
147
+ top_k = top_k,
148
  )
149
 
150
  progress(0.75, desc="Defining buffer memory...")