Joshua Sundance Bailey
commited on
Commit
·
dd9bfbd
1
Parent(s):
0609980
rag summarization
Browse files
langchain-streamlit-demo/app.py
CHANGED
@@ -27,7 +27,7 @@ from langsmith.client import Client
|
|
27 |
from streamlit_feedback import streamlit_feedback
|
28 |
|
29 |
from qagen import get_rag_qa_gen_chain
|
30 |
-
from summarize import
|
31 |
|
32 |
__version__ = "0.0.10"
|
33 |
|
@@ -421,14 +421,17 @@ if st.session_state.llm:
|
|
421 |
full_response: Union[str, None]
|
422 |
if use_document_chat:
|
423 |
if document_chat_chain_type == "Summarization":
|
424 |
-
st.session_state.doc_chain =
|
425 |
-
st.session_state.llm,
|
426 |
prompt,
|
|
|
|
|
427 |
)
|
428 |
-
full_response = st.session_state.doc_chain.
|
429 |
-
|
430 |
-
|
431 |
-
|
|
|
|
|
432 |
)
|
433 |
|
434 |
st.markdown(full_response)
|
|
|
27 |
from streamlit_feedback import streamlit_feedback
|
28 |
|
29 |
from qagen import get_rag_qa_gen_chain
|
30 |
+
from summarize import get_rag_summarization_chain
|
31 |
|
32 |
__version__ = "0.0.10"
|
33 |
|
|
|
421 |
full_response: Union[str, None]
|
422 |
if use_document_chat:
|
423 |
if document_chat_chain_type == "Summarization":
|
424 |
+
st.session_state.doc_chain = get_rag_summarization_chain(
|
|
|
425 |
prompt,
|
426 |
+
st.session_state.retriever,
|
427 |
+
st.session_state.llm,
|
428 |
)
|
429 |
+
full_response = st.session_state.doc_chain.invoke(
|
430 |
+
prompt,
|
431 |
+
dict(
|
432 |
+
callbacks=callbacks,
|
433 |
+
tags=["Streamlit Chat"],
|
434 |
+
),
|
435 |
)
|
436 |
|
437 |
st.markdown(full_response)
|
langchain-streamlit-demo/summarize.py
CHANGED
@@ -2,6 +2,8 @@ from langchain.chains.base import Chain
|
|
2 |
from langchain.chains.summarize import load_summarize_chain
|
3 |
from langchain.prompts import PromptTemplate
|
4 |
from langchain.schema.language_model import BaseLanguageModel
|
|
|
|
|
5 |
|
6 |
prompt_template = """Write a concise summary of the following text, based on the user input.
|
7 |
User input: {query}
|
@@ -49,3 +51,16 @@ def get_summarization_chain(
|
|
49 |
input_key="input_documents",
|
50 |
output_key="output_text",
|
51 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from langchain.chains.summarize import load_summarize_chain
|
3 |
from langchain.prompts import PromptTemplate
|
4 |
from langchain.schema.language_model import BaseLanguageModel
|
5 |
+
from langchain.schema.retriever import BaseRetriever
|
6 |
+
from langchain.schema.runnable import RunnableSequence, RunnablePassthrough
|
7 |
|
8 |
prompt_template = """Write a concise summary of the following text, based on the user input.
|
9 |
User input: {query}
|
|
|
51 |
input_key="input_documents",
|
52 |
output_key="output_text",
|
53 |
)
|
54 |
+
|
55 |
+
|
56 |
+
def get_rag_summarization_chain(
|
57 |
+
prompt: str,
|
58 |
+
retriever: BaseRetriever,
|
59 |
+
llm: BaseLanguageModel,
|
60 |
+
input_key: str = "prompt",
|
61 |
+
) -> RunnableSequence:
|
62 |
+
return (
|
63 |
+
{"input_documents": retriever, input_key: RunnablePassthrough()}
|
64 |
+
| get_summarization_chain(llm, prompt)
|
65 |
+
| (lambda output: output["output_text"])
|
66 |
+
)
|