Spaces:
Sleeping
Sleeping
Megatron17
commited on
Commit
·
4ff740e
1
Parent(s):
ac18860
Upload 3 files
Browse files- Dockerfile +11 -0
- legalminds.py +164 -0
- requirements.txt +17 -0
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11
|
2 |
+
RUN useradd -m -u 1000 user
|
3 |
+
USER user
|
4 |
+
ENV HOME=/home/user \
|
5 |
+
PATH=/home/user/.local/bin:$PATH
|
6 |
+
WORKDIR $HOME/app
|
7 |
+
COPY --chown=user . $HOME/app
|
8 |
+
COPY ./requirements.txt ~/app/requirements.txt
|
9 |
+
RUN pip install -r requirements.txt
|
10 |
+
COPY . .
|
11 |
+
CMD ["chainlit", "run", "legalminds.py", "--port", "7860"]
|
legalminds.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
load_dotenv()
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import time
|
8 |
+
from tqdm import tqdm
|
9 |
+
import warnings
|
10 |
+
from langchain.chains import RetrievalQA
|
11 |
+
from langchain.callbacks import StdOutCallbackHandler
|
12 |
+
import chainlit as cl # importing chainlit for our app
|
13 |
+
from chainlit.prompt import Prompt, PromptMessage
|
14 |
+
from chainlit.playground.providers.openai import ChatOpenAI # importing ChatOpenAI tools
|
15 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16 |
+
from langchain.vectorstores import Chroma, DeepLake
|
17 |
+
from langchain.embeddings import OpenAIEmbeddings
|
18 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
19 |
+
from langchain.document_loaders.dataframe import DataFrameLoader
|
20 |
+
from langchain.chat_models import ChatOpenAI
|
21 |
+
from langchain.chains import RetrievalQA
|
22 |
+
from langchain.memory import ConversationBufferMemory
|
23 |
+
from langchain.chains import ConversationalRetrievalChain
|
24 |
+
from langchain.llms.openai import OpenAIChat
|
25 |
+
from langchain.agents.agent_toolkits import create_retriever_tool
|
26 |
+
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
|
27 |
+
from langchain.utilities import SerpAPIWrapper
|
28 |
+
from langchain.agents import load_tools
|
29 |
+
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
30 |
+
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
31 |
+
from langchain.schema.messages import SystemMessage
|
32 |
+
from langchain.prompts import MessagesPlaceholder
|
33 |
+
from langchain.agents import AgentExecutor
|
34 |
+
|
35 |
+
warnings.filterwarnings("ignore")
|
36 |
+
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
|
37 |
+
|
38 |
+
# review_df = pd.read_csv("./data/justice.csv")
|
39 |
+
|
40 |
+
# data = review_df
|
41 |
+
|
42 |
+
# text_splitter = RecursiveCharacterTextSplitter(
|
43 |
+
# chunk_size = 7000, # the character length of the chunk
|
44 |
+
# chunk_overlap = 700, # the character length of the overlap between chunks
|
45 |
+
# length_function = len, # the length function - in this case, character length (aka the python len() fn.)
|
46 |
+
# )
|
47 |
+
|
48 |
+
# loader = DataFrameLoader(review_df, page_content_column="facts")
|
49 |
+
# base_docs = loader.load()
|
50 |
+
# docs = text_splitter.split_documents(base_docs)
|
51 |
+
|
52 |
+
embedder = OpenAIEmbeddings()
|
53 |
+
|
54 |
+
# This is needed for both the memory and the prompt
|
55 |
+
memory_key = "history"
|
56 |
+
# Embed and persist db
|
57 |
+
persist_directory = "./data/chroma"
|
58 |
+
|
59 |
+
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedder)
|
60 |
+
# vectorstore = DeepLake(dataset_path="./legalminds/", embedding=embedder, overwrite=True)
|
61 |
+
# vectorstore.add_documents(docs)
|
62 |
+
|
63 |
+
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
64 |
+
primary_qa_llm = ChatOpenAI(
|
65 |
+
model="gpt-3.5-turbo-16k",
|
66 |
+
temperature=0,
|
67 |
+
)
|
68 |
+
|
69 |
+
retriever = vectorstore.as_retriever()
|
70 |
+
CUSTOM_TOOL_N_DOCS = 3 # number of retrieved docs from deep lake to consider
|
71 |
+
CUSTOM_TOOL_DOCS_SEPARATOR ="\n\n" # how to join together the retrieved docs to form a single string
|
72 |
+
|
73 |
+
def retrieve_n_docs_tool(query: str) -> str:
|
74 |
+
""" Searches for relevant documents that may contain the answer to the query."""
|
75 |
+
docs = retriever.get_relevant_documents(query)[:CUSTOM_TOOL_N_DOCS]
|
76 |
+
texts = [doc.page_content for doc in docs]
|
77 |
+
texts_merged = CUSTOM_TOOL_DOCS_SEPARATOR.join(texts)
|
78 |
+
return texts_merged
|
79 |
+
|
80 |
+
|
81 |
+
serp_tool = load_tools(["serpapi"])
|
82 |
+
# print("Serp Tool:",serp_tool[0])
|
83 |
+
|
84 |
+
data_tool = create_retriever_tool(
|
85 |
+
retriever,
|
86 |
+
"retrieve_n_docs_tool",
|
87 |
+
"Searches and returns documents regarding the query asked."
|
88 |
+
)
|
89 |
+
tools = [data_tool, serp_tool[0]]
|
90 |
+
|
91 |
+
# llm = OpenAIChat(model="gpt-3.5-turbo", temperature=0)
|
92 |
+
llm = ChatOpenAI(temperature = 0)
|
93 |
+
|
94 |
+
|
95 |
+
memory = AgentTokenBufferMemory(memory_key=memory_key, llm=llm)
|
96 |
+
system_message = SystemMessage(
|
97 |
+
content=(
|
98 |
+
"Do your best to answer the questions. "
|
99 |
+
"Feel free to use any tools available to look up "
|
100 |
+
"relevant information, only if necessary"
|
101 |
+
)
|
102 |
+
)
|
103 |
+
|
104 |
+
prompt = OpenAIFunctionsAgent.create_prompt(
|
105 |
+
system_message=system_message,
|
106 |
+
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)]
|
107 |
+
)
|
108 |
+
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
|
109 |
+
|
110 |
+
handler = StdOutCallbackHandler()
|
111 |
+
|
112 |
+
|
113 |
+
@cl.on_chat_start # marks a function that will be executed at the start of a user session
|
114 |
+
async def start_chat():
|
115 |
+
|
116 |
+
agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=True,
|
117 |
+
return_intermediate_steps=True)
|
118 |
+
# agent_executor = create_conversational_retrieval_agent(llm, tools, verbose=True)
|
119 |
+
# qa_with_sources_chain = RetrievalQA.from_chain_type(
|
120 |
+
# llm=llm,
|
121 |
+
# retriever=retriever,
|
122 |
+
# callbacks=[handler],
|
123 |
+
# return_source_documents=True
|
124 |
+
# )
|
125 |
+
|
126 |
+
cl.user_session.set("agent", agent_executor)
|
127 |
+
|
128 |
+
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
|
129 |
+
async def main(message: str):
|
130 |
+
agent_executor = cl.user_session.get("agent")
|
131 |
+
|
132 |
+
# prompt = Prompt(
|
133 |
+
# provider=ChatOpenAI.id,
|
134 |
+
# messages=[
|
135 |
+
# PromptMessage(
|
136 |
+
# role="system",
|
137 |
+
# # template=RAQA_PROMPT_TEMPLATE,
|
138 |
+
# # formatted=RAQA_PROMPT_TEMPLATE,
|
139 |
+
# ),
|
140 |
+
# PromptMessage(
|
141 |
+
# role="user",
|
142 |
+
# # template=user_template,
|
143 |
+
# # formatted=user_template.format(input=message),
|
144 |
+
# ),
|
145 |
+
# ],
|
146 |
+
# inputs={"input": message},
|
147 |
+
# # settings=settings,
|
148 |
+
# )
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
# result = await qa_with_sources_chain.acall({"query" : message}) #, callbacks=[cl.AsyncLangchainCallbackHandler()])
|
153 |
+
result = agent_executor({"input": message})
|
154 |
+
# print("result Dict:",result)
|
155 |
+
|
156 |
+
msg = cl.Message(content=result["output"])
|
157 |
+
print("message:",msg)
|
158 |
+
print("output message:",msg.content)
|
159 |
+
# Update the prompt object with the completion
|
160 |
+
# msg.content = result["output"]
|
161 |
+
# prompt.completion = msg.content
|
162 |
+
# msg.prompt = prompt
|
163 |
+
# print("message_content: ",msg.content)
|
164 |
+
await msg.send()
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
chainlit==0.7.0
|
2 |
+
numpy==1.25.2
|
3 |
+
openai==0.27.8
|
4 |
+
python-dotenv==1.0.0
|
5 |
+
wandb==0.15.11
|
6 |
+
chromadb
|
7 |
+
langchain
|
8 |
+
tiktoken
|
9 |
+
pandas
|
10 |
+
scipy
|
11 |
+
scikit-learn
|
12 |
+
ipykernel
|
13 |
+
matplotlib
|
14 |
+
plotly
|
15 |
+
deeplake
|
16 |
+
google-search-results
|
17 |
+
|