|
from typing import List |
|
|
|
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser |
|
from langchain.prompts.chat import ( |
|
ChatPromptTemplate, |
|
) |
|
from langchain.schema.language_model import BaseLanguageModel |
|
from langchain.schema.retriever import BaseRetriever |
|
from langchain.schema.runnable import RunnablePassthrough, RunnableSequence |
|
from pydantic import BaseModel, Field |
|
|
|
|
|
class QuestionAnswerPair(BaseModel): |
|
question: str = Field(..., description="The question that will be answered.") |
|
answer: str = Field(..., description="The answer to the question that was asked.") |
|
|
|
def to_str(self, idx: int) -> str: |
|
question_piece = f"{idx}. **Q:** {self.question}" |
|
whitespace = " " * (len(str(idx)) + 2) |
|
answer_piece = f"{whitespace}**A:** {self.answer}" |
|
return f"{question_piece}\n\n{answer_piece}" |
|
|
|
|
|
class QuestionAnswerPairList(BaseModel): |
|
QuestionAnswerPairs: List[QuestionAnswerPair] |
|
|
|
def to_str(self) -> str: |
|
return "\n\n".join( |
|
[ |
|
qap.to_str(idx) |
|
for idx, qap in enumerate(self.QuestionAnswerPairs, start=1) |
|
], |
|
) |
|
|
|
|
|
PYDANTIC_PARSER: PydanticOutputParser = PydanticOutputParser( |
|
pydantic_object=QuestionAnswerPairList, |
|
) |
|
|
|
|
|
templ1 = """You are a smart assistant designed to help college professors come up with reading comprehension questions. |
|
Given a piece of text, you must come up with question and answer pairs that can be used to test a student's reading comprehension abilities. |
|
Generate as many question/answer pairs as you can. |
|
When coming up with the question/answer pairs, you must respond in the following format: |
|
{format_instructions} |
|
|
|
Do not provide additional commentary and do not wrap your response in Markdown formatting. Return RAW, VALID JSON. |
|
""" |
|
templ2 = """{prompt} |
|
Please create question/answer pairs, in the specified JSON format, for the following text: |
|
---------------- |
|
{context}""" |
|
CHAT_PROMPT = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", templ1), |
|
("human", templ2), |
|
], |
|
).partial(format_instructions=PYDANTIC_PARSER.get_format_instructions) |
|
|
|
|
|
def get_rag_qa_gen_chain( |
|
retriever: BaseRetriever, |
|
llm: BaseLanguageModel, |
|
input_key: str = "prompt", |
|
) -> RunnableSequence: |
|
return ( |
|
{"context": retriever, input_key: RunnablePassthrough()} |
|
| CHAT_PROMPT |
|
| llm |
|
| OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER) |
|
| (lambda parsed_output: parsed_output.to_str()) |
|
) |
|
|