Spaces:
Paused
Paused
from langchain.chains import RetrievalQA | |
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain.llms import HuggingFacePipeline | |
from constants import CHROMA_SETTINGS, PERSIST_DIRECTORY | |
from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline | |
import click | |
from constants import CHROMA_SETTINGS | |
def load_model(): | |
''' | |
Select a model on huggingface. | |
If you are running this for the first time, it will download a model for you. | |
subsequent runs will use the model from the disk. | |
''' | |
model_id = "TheBloke/vicuna-7B-1.1-HF" | |
tokenizer = LlamaTokenizer.from_pretrained(model_id) | |
model = LlamaForCausalLM.from_pretrained(model_id, | |
# load_in_8bit=True, # set these options if your GPU supports them! | |
# device_map=1#'auto', | |
# torch_dtype=torch.float16, | |
# low_cpu_mem_usage=True | |
) | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_length=2048, | |
temperature=0, | |
top_p=0.95, | |
repetition_penalty=1.15 | |
) | |
local_llm = HuggingFacePipeline(pipeline=pipe) | |
return local_llm | |
def main(device_type, ): | |
# load the instructorEmbeddings | |
if device_type in ['cpu', 'CPU']: | |
device='cpu' | |
else: | |
device='cuda' | |
print(f"Running on: {device}") | |
embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl", | |
model_kwargs={"device": device}) | |
# load the vectorstore | |
db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) | |
retriever = db.as_retriever() | |
# Prepare the LLM | |
# callbacks = [StreamingStdOutCallbackHandler()] | |
# load the LLM for generating Natural Language responses. | |
llm = load_model() | |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True) | |
# Interactive questions and answers | |
while True: | |
query = input("\nEnter a query: ") | |
if query == "exit": | |
break | |
# Get the answer from the chain | |
res = qa(query) | |
answer, docs = res['result'], res['source_documents'] | |
# Print the result | |
print("\n\n> Question:") | |
print(query) | |
print("\n> Answer:") | |
print(answer) | |
# # Print the relevant sources used for the answer | |
print("----------------------------------SOURCE DOCUMENTS---------------------------") | |
for document in docs: | |
print("\n> " + document.metadata["source"] + ":") | |
print(document.page_content) | |
print("----------------------------------SOURCE DOCUMENTS---------------------------") | |
if __name__ == "__main__": | |
main() | |