|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.schema import HumanMessage |
|
from langchain.prompts import PromptTemplate, ChatPromptTemplate, \ |
|
HumanMessagePromptTemplate |
|
from models import load_chat_agent, load_chained_agent, load_sales_agent, \ |
|
load_sqlite_agent, load_book_agent, load_earnings_agent |
|
|
|
import openai, numpy as np |
|
|
|
from collections import namedtuple |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
_struct = namedtuple('Struct', ['deterministic', 'semantic', 'creative']) |
|
|
|
|
|
GEN_TYPE = _struct('deterministic', 'semantic', 'creative') |
|
|
|
|
|
from langchain.output_parsers import StructuredOutputParser, ResponseSchema |
|
|
|
instruct_template = """ |
|
Please answer this question clearly with easy to follow reasoning: |
|
{query} |
|
|
|
If you don't know the answer, just reply: not available. |
|
""" |
|
|
|
instruct_prompt = PromptTemplate( |
|
input_variables=["query"], |
|
template=instruct_template |
|
) |
|
|
|
response_schemas = [ |
|
ResponseSchema(name="artist", |
|
description="The name of the musical artist"), |
|
ResponseSchema(name="song", |
|
description="The name of the song that the artist plays") |
|
] |
|
|
|
output_parser = StructuredOutputParser.from_response_schemas(response_schemas) |
|
format_instructions = output_parser.get_format_instructions() |
|
|
|
LOCAL_MAGIC_TOKENS = ["my company", "for our team", "our sales", "my team"] |
|
DIGITAL_MAGIC_TOKENS = ["digital media", "our database", "our digital"] |
|
|
|
def check_deterministic(sentence, magic_tokens): |
|
return any([t in sentence.lower() for t in magic_tokens]) |
|
|
|
|
|
chat_prompt = ChatPromptTemplate( |
|
messages=[ |
|
HumanMessagePromptTemplate.from_template( |
|
"Given a command from the user, extract the artist and \ |
|
song names \n{format_instructions}\n{user_prompt}") |
|
], |
|
input_variables=["user_prompt"], |
|
partial_variables={"format_instructions": format_instructions} |
|
) |
|
|
|
|
|
def chatAgent(chat_message): |
|
try: |
|
agent = load_chat_agent(verbose=True) |
|
output = agent([HumanMessage(content=chat_message)]) |
|
except: |
|
output = "Please rephrase and try chat again." |
|
return output |
|
|
|
def salesAgent(instruction): |
|
output = "" |
|
try: |
|
agent = load_sales_agent(verbose=True) |
|
output = agent.run(instruction) |
|
print("panda> " + output) |
|
except Exception as e: |
|
logger.error(e) |
|
output = f"Rephrasing your prompt could get better sales results {e}" |
|
return output |
|
|
|
def chinookAgent(instruction, model_name): |
|
output = "" |
|
try: |
|
agent = load_sqlite_agent(model_name) |
|
output = agent.run(instruction) |
|
print("chinook> " + output) |
|
except Exception as e: |
|
logger.error(e) |
|
output = "Rephrasing your prompt could get better db results {e}" |
|
return output |
|
|
|
def check_semantic(string1, string2): |
|
|
|
|
|
|
|
|
|
response = openai.Embedding.create( |
|
input=[string1, string2], |
|
engine="text-similarity-davinci-001" |
|
) |
|
embedding_a = response['data'][0]['embedding'] |
|
embedding_b = response['data'][1]['embedding'] |
|
similarity_score = np.dot(embedding_a, embedding_b) |
|
logger.info(f"similarity: {similarity_score}") |
|
|
|
return similarity_score > 0.8 |
|
|
|
|
|
def earningsAgent(query): |
|
output = "" |
|
try: |
|
agent = load_earnings_agent(True) |
|
result = agent({ |
|
"query": query |
|
}) |
|
logger.info(f"book response: {result['result']}") |
|
output = result['result'] |
|
except Exception as e: |
|
logger.error(e) |
|
output = f"Rephrasing your prompt for the book agent{e}" |
|
return output |
|
|
|
def bookAgent(query): |
|
output = "" |
|
try: |
|
agent = load_book_agent(True) |
|
result = agent({ |
|
"query": query |
|
}) |
|
logger.info(f"book response: {result['result']}") |
|
output = result['result'] |
|
except Exception as e: |
|
logger.error(e) |
|
output = f"Rephrasing your prompt for the book agent{e}" |
|
return output |
|
|
|
|
|
def agentController(question_text, model_name): |
|
output = "" |
|
outputType = None |
|
|
|
if check_deterministic(question_text, LOCAL_MAGIC_TOKENS): |
|
outputType = GEN_TYPE.deterministic |
|
output += salesAgent(question_text) |
|
print(f"๐น salesAgent: {output}") |
|
elif check_deterministic(question_text, DIGITAL_MAGIC_TOKENS): |
|
outputType = GEN_TYPE.deterministic |
|
output += chinookAgent(question_text, model_name) |
|
print(f"๐น chinookAgent: {output}") |
|
|
|
elif check_semantic(question_text, "Salesforce earnings call for Q4 2023"): |
|
outputType = GEN_TYPE.semantic |
|
output += earningsAgent(question_text) |
|
print(f"๐น earningsAgent: {output}") |
|
elif check_semantic(question_text, "how to govern") or \ |
|
check_semantic(question_text, "fight a war"): |
|
outputType = GEN_TYPE.semantic |
|
output += bookAgent(question_text) |
|
print(f"๐น bookAgent: {output}") |
|
|
|
else: |
|
outputType = GEN_TYPE.creative |
|
try: |
|
instruction = instruct_prompt.format(query=question_text) |
|
logger.info(f"instruction: {instruction}") |
|
agent = load_chained_agent(verbose=True, model_name=model_name) |
|
response = agent([instruction]) |
|
if response is None or "not available" in response["output"]: |
|
response = "" |
|
else: |
|
output += response['output'] |
|
logger.info(f"๐น Steps: {response['intermediate_steps']}") |
|
except Exception as e: |
|
output = "Most likely ran out of tokens ..." |
|
logger.error(e) |
|
|
|
return output, outputType |
|
|