Spaces:
Running
Running
""" | |
Main file to execute the TRS Pipeline. | |
""" | |
import sys | |
from augmentation import prompt_generation as pg | |
from information_retrieval import info_retrieval as ir | |
from text_generation.models import ( | |
Llama3, | |
Mistral, | |
Gemma2, | |
Llama3Point1, | |
Llama3Instruct, | |
MistralInstruct, | |
Llama3Point1Instruct, | |
Phi3SmallInstruct, | |
GPT4, | |
Gemini, | |
Claude3Point5Sonnet, | |
) | |
from text_generation import text_generation as tg | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG) | |
from src.text_generation.mapper import MODEL_MAPPER | |
from src.post_processing.post_process import post_process_output | |
TEST_DIR = "../tests/" | |
MODELS = { | |
'GPT-4': GPT4, | |
'Llama3': Llama3, | |
'Mistral': Mistral, | |
'Gemma2': Gemma2, | |
'Llama3.1': Llama3Point1, | |
'Llama3-Instruct': Llama3Instruct, | |
'Mistral-Instruct': MistralInstruct, | |
'Llama3.1-Instruct': Llama3Point1Instruct, | |
'Phi3-Instruct': Phi3SmallInstruct, | |
"Gemini-1.0-pro": Gemini, | |
"Claude3.5-sonnet": Claude3Point5Sonnet, | |
} | |
def pipeline(starting_point: str, | |
query: str, | |
model_name: str, | |
test: int = 0, **params): | |
""" | |
Executes the entire RAG pipeline, provided the query and model class name. | |
Args: | |
- query: str | |
- model_name: string, one of the following: Llama3, Mistral, Gemma2, Llama3Point1 | |
- test: whether the pipeline is running a test | |
- params: | |
- limit (number of results to be retained) | |
- reranking (binary, whether to rerank results using ColBERT or not) | |
- sustainability | |
""" | |
try: | |
model_id = MODEL_MAPPER[model_name] | |
except KeyError: | |
logger.error(f"Model {model_name} not found in the model mapper.") | |
model_id = MODEL_MAPPER['Gemini-1.0-pro'] | |
context_params = { | |
'limit': 5, | |
'reranking': 0, | |
'sustainability': 0, | |
} | |
if 'limit' in params: | |
context_params['limit'] = params['limit'] | |
if 'reranking' in params: | |
context_params['reranking'] = params['reranking'] | |
if 'sustainability' in params: | |
context_params['sustainability'] = params['sustainability'] | |
logger.info("Retrieving context..") | |
try: | |
context = ir.get_context(starting_point=starting_point, query=query, **context_params) | |
if test: | |
retrieved_cities = ir.get_cities(context) | |
else: | |
retrieved_cities = None | |
except Exception as e: | |
exc_type, exc_obj, exc_tb = sys.exc_info() | |
logger.error(f"Error at line {exc_tb.tb_lineno} while trying to get context: {e}") | |
return None | |
logger.info("Retrieved context, augmenting prompt..") | |
try: | |
prompt = pg.augment_prompt( | |
query=query, | |
starting_point=starting_point, | |
context=context, | |
params=context_params | |
) | |
except Exception as e: | |
exc_type, exc_obj, exc_tb = sys.exc_info() | |
logger.error(f"Error at line {exc_tb.tb_lineno} while trying to augment prompt: {e}") | |
return None | |
# return prompt | |
logger.info(f"Augmented prompt, initializing {model_name} and generating response..") | |
try: | |
response = tg.generate_response(model_id, prompt, **params) | |
except Exception as e: | |
exc_type, exc_obj, exc_tb = sys.exc_info() | |
logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}") | |
return None | |
try: | |
model_params = {"max_tokens": params["max_tokens"], "temperature": params["temperature"]} | |
post_processed_response = post_process_output( | |
model_id=model_id, user_query=query, | |
starting_point=starting_point, | |
context=context, response=response, **model_params) | |
except Exception as e: | |
exc_type, exc_obj, exc_tb = sys.exc_info() | |
logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}") | |
return None | |
if test: | |
return retrieved_cities, prompt[1]['content'], post_processed_response | |
else: | |
return post_processed_response | |
if __name__ == "__main__": | |
# sample_query = "I'm planning a trip in the summer and I love art, history, and visiting museums. Can you | |
# suggest " \ "some " \ "European cities? " | |
sample_query = "I'm planning a trip in July and enjoy beaches, nightlife, and vibrant cities. Recommend some " \ | |
"cities. " | |
model_name = "GPT-4" | |
pipeline_response = pipeline( | |
query=sample_query, | |
model_name=model_name, | |
sustainability=1 | |
) | |
print(pipeline_response) | |