Spaces:
Running
Running
File size: 4,668 Bytes
4b722ec 8842640 4b722ec 8842640 adbebe0 4b722ec 8842640 4b722ec 8842640 4b722ec 8842640 4b722ec 8842640 4b722ec ac20456 4b722ec adbebe0 4b722ec 8842640 4b722ec 8842640 4b722ec adbebe0 4b722ec adbebe0 4b722ec adbebe0 4b722ec 8842640 4b722ec 8842640 4b722ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""
Main file to execute the TRS Pipeline.
"""
import sys
from augmentation import prompt_generation as pg
from information_retrieval import info_retrieval as ir
from text_generation.models import (
Llama3,
Mistral,
Gemma2,
Llama3Point1,
Llama3Instruct,
MistralInstruct,
Llama3Point1Instruct,
Phi3SmallInstruct,
GPT4,
Gemini,
Claude3Point5Sonnet,
)
from text_generation import text_generation as tg
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
from src.text_generation.mapper import MODEL_MAPPER
from src.post_processing.post_process import post_process_output
TEST_DIR = "../tests/"
MODELS = {
'GPT-4': GPT4,
'Llama3': Llama3,
'Mistral': Mistral,
'Gemma2': Gemma2,
'Llama3.1': Llama3Point1,
'Llama3-Instruct': Llama3Instruct,
'Mistral-Instruct': MistralInstruct,
'Llama3.1-Instruct': Llama3Point1Instruct,
'Phi3-Instruct': Phi3SmallInstruct,
"Gemini-1.0-pro": Gemini,
"Claude3.5-sonnet": Claude3Point5Sonnet,
}
def pipeline(starting_point: str,
query: str,
model_name: str,
test: int = 0, **params):
"""
Executes the entire RAG pipeline, provided the query and model class name.
Args:
- query: str
- model_name: string, one of the following: Llama3, Mistral, Gemma2, Llama3Point1
- test: whether the pipeline is running a test
- params:
- limit (number of results to be retained)
- reranking (binary, whether to rerank results using ColBERT or not)
- sustainability
"""
try:
model_id = MODEL_MAPPER[model_name]
except KeyError:
logger.error(f"Model {model_name} not found in the model mapper.")
model_id = MODEL_MAPPER['Gemini-1.0-pro']
context_params = {
'limit': 5,
'reranking': 0,
'sustainability': 0,
}
if 'limit' in params:
context_params['limit'] = params['limit']
if 'reranking' in params:
context_params['reranking'] = params['reranking']
if 'sustainability' in params:
context_params['sustainability'] = params['sustainability']
logger.info("Retrieving context..")
try:
context = ir.get_context(starting_point=starting_point, query=query, **context_params)
if test:
retrieved_cities = ir.get_cities(context)
else:
retrieved_cities = None
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
logger.error(f"Error at line {exc_tb.tb_lineno} while trying to get context: {e}")
return None
logger.info("Retrieved context, augmenting prompt..")
try:
prompt = pg.augment_prompt(
query=query,
starting_point=starting_point,
context=context,
params=context_params
)
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
logger.error(f"Error at line {exc_tb.tb_lineno} while trying to augment prompt: {e}")
return None
# return prompt
logger.info(f"Augmented prompt, initializing {model_name} and generating response..")
try:
response = tg.generate_response(model_id, prompt, **params)
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
return None
try:
model_params = {"max_tokens": params["max_tokens"], "temperature": params["temperature"]}
post_processed_response = post_process_output(
model_id=model_id, user_query=query,
starting_point=starting_point,
context=context, response=response, **model_params)
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
return None
if test:
return retrieved_cities, prompt[1]['content'], post_processed_response
else:
return post_processed_response
if __name__ == "__main__":
# sample_query = "I'm planning a trip in the summer and I love art, history, and visiting museums. Can you
# suggest " \ "some " \ "European cities? "
sample_query = "I'm planning a trip in July and enjoy beaches, nightlife, and vibrant cities. Recommend some " \
"cities. "
model_name = "GPT-4"
pipeline_response = pipeline(
query=sample_query,
model_name=model_name,
sustainability=1
)
print(pipeline_response)
|