Joshua Sundance Bailey
refactor
923e6fa
raw
history blame
2.53 kB
from typing import List
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
from langchain.prompts.chat import (
ChatPromptTemplate,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import RunnablePassthrough, RunnableSequence
from pydantic import BaseModel, Field
class QuestionAnswerPair(BaseModel):
question: str = Field(..., description="The question that will be answered.")
answer: str = Field(..., description="The answer to the question that was asked.")
def to_str(self, idx: int) -> str:
question_piece = f"{idx}. **Q:** {self.question}"
whitespace = " " * (len(str(idx)) + 2)
answer_piece = f"{whitespace}**A:** {self.answer}"
return f"{question_piece}\n\n{answer_piece}"
class QuestionAnswerPairList(BaseModel):
QuestionAnswerPairs: List[QuestionAnswerPair]
def to_str(self) -> str:
return "\n\n".join(
[
qap.to_str(idx)
for idx, qap in enumerate(self.QuestionAnswerPairs, start=1)
],
)
PYDANTIC_PARSER: PydanticOutputParser = PydanticOutputParser(
pydantic_object=QuestionAnswerPairList,
)
templ1 = """You are a smart assistant designed to help college professors come up with reading comprehension questions.
Given a piece of text, you must come up with question and answer pairs that can be used to test a student's reading comprehension abilities.
Generate as many question/answer pairs as you can.
When coming up with the question/answer pairs, you must respond in the following format:
{format_instructions}
Do not provide additional commentary and do not wrap your response in Markdown formatting. Return RAW, VALID JSON.
"""
templ2 = """{prompt}
Please create question/answer pairs, in the specified JSON format, for the following text:
----------------
{context}"""
CHAT_PROMPT = ChatPromptTemplate.from_messages(
[
("system", templ1),
("human", templ2),
],
).partial(format_instructions=PYDANTIC_PARSER.get_format_instructions)
def get_rag_qa_gen_chain(
retriever: BaseRetriever,
llm: BaseLanguageModel,
input_key: str = "prompt",
) -> RunnableSequence:
return (
{"context": retriever, input_key: RunnablePassthrough()}
| CHAT_PROMPT
| llm
| OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
| (lambda parsed_output: parsed_output.to_str())
)