File size: 4,952 Bytes
c323312
6ae5e8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c323312
 
6ae5e8b
 
 
 
c323312
 
 
 
 
 
 
6ae5e8b
 
 
 
 
 
 
bad08ae
 
 
 
 
 
 
 
 
 
 
 
6ae5e8b
 
bad08ae
 
 
 
 
 
 
6ae5e8b
bad08ae
 
 
 
 
 
6ae5e8b
bad08ae
6ae5e8b
bad08ae
 
 
6ae5e8b
 
c323312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
import xml.etree.ElementTree as ET
import re

qa_system_prompt = """As Zeta, your mission is to assist users in navigating the vast sea of machine learning research with ease and insight. When responding to inquiries, adhere to the following guidelines to ensure the utmost accuracy and utility:

Contextual Understanding: When presented with a question, apply your understanding of machine learning concepts to interpret the context provided accurately. Utilize this context to guide your search for answers within the specified research papers.

Answer Provision: Always provide an answer that is directly supported by the research papers' content. If the information needed to answer the question is not available, clearly state, "I don't know."

Citation Requirement: For every answer given, include multiple citations from the research papers. A citation must include a direct quote from the paper that supports your answer, along with the identification (ID) of the paper. This ensures that all provided information can be traced back to its source, maintaining a high level of credibility and transparency.

Formatting Guidelines: Present your citations in the following structured format at the end of your answer to maintain clarity and consistency:


<citations>
    <citation><source_id>[Source ID]</source_id><quote>[Direct quote from the source]</quote></citation>
    ...
</citations>


Conflict Resolution: In cases where multiple sources offer conflicting information, evaluate the context, relevance, and credibility of each source to determine the most accurate answer. Explain your reasoning within the citation section to provide insight into your decision-making process.

User Engagement: Encourage user engagement by asking clarifying questions if the initial inquiry is ambiguous or lacks specific context. This helps in providing more targeted and relevant responses.

Continual Learning: Although you are not expected to generate new text or insights beyond the provided papers, be open to learning from new information as it becomes available to you through user interactions and queries.

By following these guidelines, you ensure that users receive valuable, accurate, and source-backed insights into their inquiries, making their exploration of machine learning research more productive and enlightening.

{context}"""
qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", qa_system_prompt),
        ("human", "{question}"),
    ]
)


def format_docs(docs):
    return "\n\n".join(
        f"{doc.metadata['chunk_id']}: {doc.page_content}" if type(doc) != str else doc
        for doc in docs
    )


rag_chain = lambda retriever, llm: (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | qa_prompt
    | llm
)

qa_chain = lambda llm: (
    {"context": RunnablePassthrough(), "question": RunnablePassthrough()}
    | qa_prompt
    | llm
)


def parse_model_response(input_string):
    parsed_data = {"answer": "", "citations": []}
    try:
        xml_matches = re.findall(r"<citations>.*?</citations>", input_string, re.DOTALL)
        if not xml_matches:
            parsed_data["answer"] = input_string
            return parsed_data

        outside_text_parts = []
        last_end_pos = 0

        for xml_string in xml_matches:
            match = re.search(
                re.escape(xml_string), input_string[last_end_pos:], re.DOTALL
            )

            if match:
                outside_text_parts.append(
                    input_string[last_end_pos : match.start() + last_end_pos]
                )
                last_end_pos += match.end()

            root = ET.fromstring(xml_string)

            for citation in root.findall("citation"):
                source_id = citation.find("source_id").text
                quote = citation.find("quote").text
                parsed_data["citations"].append(
                    {"source_id": source_id, "quote": quote}
                )

        outside_text_parts.append(input_string[last_end_pos:])

        parsed_data["answer"] = "".join(outside_text_parts)
    except Exception as e:
        parsed_data["answer"] = input_string

    return parsed_data


def parse_context_and_question(inputs):
    pattern = r"\[(.*?)\]"
    match = re.search(pattern, inputs)
    if match:
        context = match.group(1)
        context = [c.strip() for c in context.split()]
        question = inputs[: match.start()] + inputs[match.end() :]
        return context, question
    else:
        return "", inputs


format_citations = lambda citations: "\n\n".join(
    [f"{citation['quote']} ... [{citation['source_id']}]" for citation in citations]
)
ai_response_format = lambda message, references: (
    f"{message}\n\n---\n\n{format_citations(references)}"
    if references != ""
    else message
)