Joshua Sundance Bailey
commited on
Commit
·
47c2ffc
1
Parent(s):
622ac66
qagen & summarize
Browse files
langchain-streamlit-demo/app.py
CHANGED
@@ -7,12 +7,12 @@ import anthropic
|
|
7 |
import langsmith.utils
|
8 |
import openai
|
9 |
import streamlit as st
|
10 |
-
from langchain import LLMChain
|
11 |
from langchain.callbacks import StreamlitCallbackHandler
|
12 |
from langchain.callbacks.base import BaseCallbackHandler
|
13 |
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
14 |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
|
15 |
from langchain.chains import RetrievalQA
|
|
|
16 |
from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
|
17 |
from langchain.document_loaders import PyPDFLoader
|
18 |
from langchain.embeddings import OpenAIEmbeddings
|
@@ -26,6 +26,7 @@ from langsmith.client import Client
|
|
26 |
from streamlit_feedback import streamlit_feedback
|
27 |
|
28 |
from qagen import get_qa_gen_chain, combine_qa_pair_lists
|
|
|
29 |
|
30 |
__version__ = "0.0.6"
|
31 |
|
@@ -216,7 +217,14 @@ with sidebar:
|
|
216 |
)
|
217 |
document_chat_chain_type = st.selectbox(
|
218 |
label="Document Chat Chain Type",
|
219 |
-
options=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
index=0,
|
221 |
help=chain_type_help,
|
222 |
disabled=not document_chat,
|
@@ -331,13 +339,7 @@ if st.session_state.llm:
|
|
331 |
# --- Document Chat ---
|
332 |
if st.session_state.retriever:
|
333 |
if document_chat_chain_type == "Summarization":
|
334 |
-
|
335 |
-
# st.session_state.doc_chain = RetrievalQA.from_chain_type(
|
336 |
-
# llm=st.session_state.llm,
|
337 |
-
# chain_type=chain_type,
|
338 |
-
# retriever=st.session_state.retriever,
|
339 |
-
# memory=MEMORY,
|
340 |
-
# )
|
341 |
elif document_chat_chain_type == "Q&A Generation":
|
342 |
st.session_state.doc_chain = get_qa_gen_chain(st.session_state.llm)
|
343 |
|
@@ -393,7 +395,17 @@ if st.session_state.llm:
|
|
393 |
full_response: Union[str, None]
|
394 |
if use_document_chat:
|
395 |
if document_chat_chain_type == "Summarization":
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
elif document_chat_chain_type == "Q&A Generation":
|
398 |
config: Dict[str, Any] = dict(
|
399 |
callbacks=callbacks,
|
@@ -409,14 +421,21 @@ if st.session_state.llm:
|
|
409 |
config,
|
410 |
)
|
411 |
results = combine_qa_pair_lists(raw_results).QuestionAnswerPairs
|
412 |
-
|
413 |
-
|
414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
)
|
416 |
-
|
417 |
-
|
418 |
-
st.markdown(f"{idx}. **A:** {result.answer}")
|
419 |
-
st.markdown("\n")
|
420 |
|
421 |
else:
|
422 |
st_handler = StreamlitCallbackHandler(st.container())
|
|
|
7 |
import langsmith.utils
|
8 |
import openai
|
9 |
import streamlit as st
|
|
|
10 |
from langchain.callbacks import StreamlitCallbackHandler
|
11 |
from langchain.callbacks.base import BaseCallbackHandler
|
12 |
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
13 |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
|
14 |
from langchain.chains import RetrievalQA
|
15 |
+
from langchain.chains.llm import LLMChain
|
16 |
from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
|
17 |
from langchain.document_loaders import PyPDFLoader
|
18 |
from langchain.embeddings import OpenAIEmbeddings
|
|
|
26 |
from streamlit_feedback import streamlit_feedback
|
27 |
|
28 |
from qagen import get_qa_gen_chain, combine_qa_pair_lists
|
29 |
+
from summarize import get_summarization_chain
|
30 |
|
31 |
__version__ = "0.0.6"
|
32 |
|
|
|
217 |
)
|
218 |
document_chat_chain_type = st.selectbox(
|
219 |
label="Document Chat Chain Type",
|
220 |
+
options=[
|
221 |
+
"stuff",
|
222 |
+
"refine",
|
223 |
+
"map_reduce",
|
224 |
+
"map_rerank",
|
225 |
+
"Q&A Generation",
|
226 |
+
"Summarization",
|
227 |
+
],
|
228 |
index=0,
|
229 |
help=chain_type_help,
|
230 |
disabled=not document_chat,
|
|
|
339 |
# --- Document Chat ---
|
340 |
if st.session_state.retriever:
|
341 |
if document_chat_chain_type == "Summarization":
|
342 |
+
st.session_state.doc_chain = "summarization"
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
elif document_chat_chain_type == "Q&A Generation":
|
344 |
st.session_state.doc_chain = get_qa_gen_chain(st.session_state.llm)
|
345 |
|
|
|
395 |
full_response: Union[str, None]
|
396 |
if use_document_chat:
|
397 |
if document_chat_chain_type == "Summarization":
|
398 |
+
st.session_state.doc_chain = get_summarization_chain(
|
399 |
+
st.session_state.llm,
|
400 |
+
prompt,
|
401 |
+
)
|
402 |
+
full_response = st.session_state.doc_chain.run(
|
403 |
+
st.session_state.texts,
|
404 |
+
callbacks=callbacks,
|
405 |
+
tags=["Streamlit Chat"],
|
406 |
+
)
|
407 |
+
|
408 |
+
st.markdown(full_response)
|
409 |
elif document_chat_chain_type == "Q&A Generation":
|
410 |
config: Dict[str, Any] = dict(
|
411 |
callbacks=callbacks,
|
|
|
421 |
config,
|
422 |
)
|
423 |
results = combine_qa_pair_lists(raw_results).QuestionAnswerPairs
|
424 |
+
|
425 |
+
def _to_str(idx, qap):
|
426 |
+
question_piece = f"{idx}. **Q:** {qap.question}"
|
427 |
+
whitespace = " " * (len(str(idx)) + 2)
|
428 |
+
answer_piece = f"{whitespace}**A:** {qap.answer}"
|
429 |
+
return f"{question_piece}\n{answer_piece}"
|
430 |
+
|
431 |
+
output_text = "\n\n".join(
|
432 |
+
[
|
433 |
+
_to_str(idx, qap)
|
434 |
+
for idx, qap in enumerate(results, start=1)
|
435 |
+
],
|
436 |
)
|
437 |
+
|
438 |
+
st.markdown(output_text)
|
|
|
|
|
439 |
|
440 |
else:
|
441 |
st_handler = StreamlitCallbackHandler(st.container())
|
langchain-streamlit-demo/summarize.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chains.base import Chain
|
2 |
+
from langchain.chains.summarize import load_summarize_chain
|
3 |
+
from langchain.prompts import PromptTemplate
|
4 |
+
from langchain.schema.language_model import BaseLanguageModel
|
5 |
+
|
6 |
+
prompt_template = """Write a concise summary of the following text, based on the user input.
|
7 |
+
User input: {query}
|
8 |
+
Text:
|
9 |
+
```
|
10 |
+
{text}
|
11 |
+
```
|
12 |
+
CONCISE SUMMARY:"""
|
13 |
+
|
14 |
+
refine_template = (
|
15 |
+
"You are iteratively crafting a summary of the text below based on the user input\n"
|
16 |
+
"User input: {query}"
|
17 |
+
"We have provided an existing summary up to a certain point: {existing_answer}\n"
|
18 |
+
"We have the opportunity to refine the existing summary"
|
19 |
+
"(only if needed) with some more context below.\n"
|
20 |
+
"------------\n"
|
21 |
+
"{text}\n"
|
22 |
+
"------------\n"
|
23 |
+
"Given the new context, refine the original summary.\n"
|
24 |
+
"If the context isn't useful, return the original summary.\n"
|
25 |
+
"If the context is useful, refine the summary to include the new context.\n"
|
26 |
+
"Your contribution is helping to build a comprehensive summary of a large body of knowledge.\n"
|
27 |
+
"You do not have the complete context, so do not discard pieces of the original summary."
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def get_summarization_chain(
|
32 |
+
llm: BaseLanguageModel,
|
33 |
+
prompt: str,
|
34 |
+
) -> Chain:
|
35 |
+
_prompt = PromptTemplate.from_template(
|
36 |
+
prompt_template,
|
37 |
+
partial_variables={"query": prompt},
|
38 |
+
)
|
39 |
+
refine_prompt = PromptTemplate.from_template(
|
40 |
+
refine_template,
|
41 |
+
partial_variables={"query": prompt},
|
42 |
+
)
|
43 |
+
return load_summarize_chain(
|
44 |
+
llm=llm,
|
45 |
+
chain_type="refine",
|
46 |
+
question_prompt=_prompt,
|
47 |
+
refine_prompt=refine_prompt,
|
48 |
+
return_intermediate_steps=False,
|
49 |
+
input_key="input_documents",
|
50 |
+
output_key="output_text",
|
51 |
+
)
|