jfeng1115's picture
init commit
58d33f0
raw
history blame
3.18 kB
"""Python file to serve as the frontend"""
from datetime import datetime
import wandb
from langchain.agents.agent_toolkits.sql.simple_sql import create_simple_sql_agent_excutor
from langchain.callbacks import WandbCallbackHandler, CallbackManager, StdOutCallbackHandler
from langchain.document_loaders import WebBaseLoader
from langchain.embeddings import OpenAIEmbeddings
# import faiss
from langchain import OpenAI, FAISS, LLMChain
from langchain.chains import VectorDBQAWithSourcesChain
import pickle
# root_dir = "/Users/jiefeng/Dropbox/Apps/admixer/neon_scrapy/data/"
# index_path = "".join([root_dir, "docs.index"])
# fass_store_path = "".join([root_dir, "faiss_store.pkl"])
# Load the LangChain.
from langchain.prompts import PromptTemplate
import os
from langchain import OpenAI, VectorDBQA
from flask import Flask, request, jsonify
from flask_cors import CORS, cross_origin
from langchain.agents.agent_toolkits.sql.toolkit import SimpleSQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.llms.openai import OpenAI
# create your SocketIO instance
# handle chat messages
url = "https://langchain.readthedocs.io/en/latest/"
os.environ["OPENAI_API_KEY"] = "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"
os.environ["WANDB_API_KEY"] = "7e3c65043f06598e45810ffdd5588f048ec870db"
qa = None
db = SQLDatabase.from_uri(
"postgresql+psycopg2://macbttqtwpbkxg:8e00539601577e6d3e73f4781d0d71913dc5a165a9b75229cf930abe79ddaae3@ec2-54-173-77-184.compute-1.amazonaws.com:5432/d8cb6alpt8ft06")
toolkit = SimpleSQLDatabaseToolkit(db=db)
session_group = datetime.now().strftime("%m.%d.%Y_%H.%M.%S")
# wandb_callback = WandbCallbackHandler(
# job_type="inference",
# project="marketing_questions",
# group=f"minimal_{session_group}",
# name="llm",
# tags=["test"],
# )
manager = CallbackManager([StdOutCallbackHandler()])
llm = OpenAI(temperature=0,
model_name="gpt-4",
callback_manager=manager,
verbose=True,
)
agent_executor = create_simple_sql_agent_excutor(
llm=llm,
toolkit=toolkit,
callback_manager=manager,
verbose=True
)
# agent_executor.run("What are the most popular pages visited by our visitors?")
# agent_executor.run("how many visitors profiles are from the Unite States?")
# From here down is all the StreamLit UI.
app = Flask(__name__)
cors = CORS(app)
@app.route('/')
@cross_origin()
def hello_world():
return 'Hello, World!'
@app.route('/api/ask', methods=['POST'])
@cross_origin()
def submit():
print("request received")
data = request.get_json()
question = data['question']
sql_data_result = None
if question:
print(question)
sql_data_result = agent_executor.run(question)
#wandb_callback.flush_tracker(agent_executor, reset=False, finish=True)
# chartPrompt = PromptTemplate(
# template="What chart is best for the data {data}?", input_variables=["data"])
#
# chartChain = LLMChain(llm=llm, prompt=chartPrompt)
# chartChain.run(sql_data_result)
result = jsonify(sql_data_result)
return result
if __name__ == '__main__':
app.run(port=7860)