File size: 4,668 Bytes
4b722ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8842640
4b722ec
 
 
 
 
 
8842640
adbebe0
4b722ec
 
8842640
4b722ec
 
 
 
 
 
 
 
 
 
8842640
 
4b722ec
 
 
8842640
 
 
 
4b722ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8842640
 
 
 
 
4b722ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac20456
4b722ec
 
 
 
 
 
 
 
 
 
 
 
 
adbebe0
4b722ec
 
 
 
 
 
 
 
 
 
8842640
4b722ec
8842640
4b722ec
 
 
 
 
adbebe0
 
 
 
 
 
 
 
 
 
4b722ec
adbebe0
4b722ec
 
adbebe0
4b722ec
 
 
 
 
 
 
8842640
4b722ec
 
 
8842640
4b722ec
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
Main file to execute the TRS Pipeline.
"""
import sys
from augmentation import prompt_generation as pg
from information_retrieval import info_retrieval as ir
from text_generation.models import (
    Llama3,
    Mistral,
    Gemma2,
    Llama3Point1,
    Llama3Instruct,
    MistralInstruct,
    Llama3Point1Instruct,
    Phi3SmallInstruct,
    GPT4,
    Gemini,
    Claude3Point5Sonnet,
)
from text_generation import text_generation as tg
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
from src.text_generation.mapper import MODEL_MAPPER
from src.post_processing.post_process import post_process_output

TEST_DIR = "../tests/"

MODELS = {
    'GPT-4': GPT4,
    'Llama3': Llama3,
    'Mistral': Mistral,
    'Gemma2': Gemma2,
    'Llama3.1': Llama3Point1,
    'Llama3-Instruct': Llama3Instruct,
    'Mistral-Instruct': MistralInstruct,
    'Llama3.1-Instruct': Llama3Point1Instruct,
    'Phi3-Instruct': Phi3SmallInstruct,
    "Gemini-1.0-pro": Gemini,
    "Claude3.5-sonnet": Claude3Point5Sonnet,
}


def pipeline(starting_point: str,
             query: str,
             model_name: str,
             test: int = 0, **params):
    """
    
    Executes the entire RAG pipeline, provided the query and model class name.

    Args: 
        - query: str
        - model_name: string, one of the following: Llama3, Mistral, Gemma2, Llama3Point1
        - test: whether the pipeline is running a test
        - params: 
            - limit (number of results to be retained) 
            - reranking (binary, whether to rerank results using ColBERT or not)
            - sustainability

    
    """
    try:
        model_id = MODEL_MAPPER[model_name]
    except KeyError:
        logger.error(f"Model {model_name} not found in the model mapper.")
        model_id = MODEL_MAPPER['Gemini-1.0-pro']
    context_params = {
        'limit': 5,
        'reranking': 0,
        'sustainability': 0,
    }

    if 'limit' in params:
        context_params['limit'] = params['limit']

    if 'reranking' in params:
        context_params['reranking'] = params['reranking']

    if 'sustainability' in params:
        context_params['sustainability'] = params['sustainability']

    logger.info("Retrieving context..")
    try:
        context = ir.get_context(starting_point=starting_point, query=query, **context_params)
        if test:
            retrieved_cities = ir.get_cities(context)
        else:
            retrieved_cities = None
    except Exception as e:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        logger.error(f"Error at line {exc_tb.tb_lineno} while trying to get context: {e}")
        return None

    logger.info("Retrieved context, augmenting prompt..")
    try:
        prompt = pg.augment_prompt(
            query=query,
            starting_point=starting_point,
            context=context,
            params=context_params
        )
    except Exception as e:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        logger.error(f"Error at line {exc_tb.tb_lineno} while trying to augment prompt: {e}")
        return None

    # return prompt

    logger.info(f"Augmented prompt, initializing {model_name} and generating response..")
    try:
        response = tg.generate_response(model_id, prompt, **params)
    except Exception as e:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
        return None

    try:
        model_params = {"max_tokens": params["max_tokens"], "temperature": params["temperature"]}
        post_processed_response = post_process_output(
            model_id=model_id, user_query=query,
            starting_point=starting_point,
            context=context, response=response, **model_params)
    except Exception as e:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
        return None
    if test:
        return retrieved_cities, prompt[1]['content'], post_processed_response

    else:
        return post_processed_response


if __name__ == "__main__":
    # sample_query = "I'm planning a trip in the summer and I love art, history, and visiting museums. Can you
    # suggest " \ "some " \ "European cities? "
    sample_query = "I'm planning a trip in July and enjoy beaches, nightlife, and vibrant cities. Recommend some " \
                   "cities. "
    model_name = "GPT-4"

    pipeline_response = pipeline(
        query=sample_query,
        model_name=model_name,
        sustainability=1
    )

    print(pipeline_response)