File size: 3,181 Bytes
58d33f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Python file to serve as the frontend"""
from datetime import datetime

import wandb

from langchain.agents.agent_toolkits.sql.simple_sql import create_simple_sql_agent_excutor
from langchain.callbacks import WandbCallbackHandler, CallbackManager, StdOutCallbackHandler
from langchain.document_loaders import WebBaseLoader
from langchain.embeddings import OpenAIEmbeddings
# import faiss
from langchain import OpenAI, FAISS, LLMChain
from langchain.chains import VectorDBQAWithSourcesChain
import pickle

# root_dir = "/Users/jiefeng/Dropbox/Apps/admixer/neon_scrapy/data/"
# index_path = "".join([root_dir, "docs.index"])
# fass_store_path = "".join([root_dir, "faiss_store.pkl"])
# Load the LangChain.

from langchain.prompts import PromptTemplate
import os
from langchain import OpenAI, VectorDBQA
from flask import Flask, request, jsonify
from flask_cors import CORS, cross_origin
from langchain.agents.agent_toolkits.sql.toolkit import SimpleSQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.llms.openai import OpenAI

# create your SocketIO instance
# handle chat messages


url = "https://langchain.readthedocs.io/en/latest/"
os.environ["OPENAI_API_KEY"] = "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"
os.environ["WANDB_API_KEY"] = "7e3c65043f06598e45810ffdd5588f048ec870db"
qa = None

db = SQLDatabase.from_uri(
    "postgresql+psycopg2://macbttqtwpbkxg:8e00539601577e6d3e73f4781d0d71913dc5a165a9b75229cf930abe79ddaae3@ec2-54-173-77-184.compute-1.amazonaws.com:5432/d8cb6alpt8ft06")
toolkit = SimpleSQLDatabaseToolkit(db=db)

session_group = datetime.now().strftime("%m.%d.%Y_%H.%M.%S")
# wandb_callback = WandbCallbackHandler(
#     job_type="inference",
#     project="marketing_questions",
#     group=f"minimal_{session_group}",
#     name="llm",
#     tags=["test"],
# )
manager = CallbackManager([StdOutCallbackHandler()])


llm = OpenAI(temperature=0,
             model_name="gpt-4",
             callback_manager=manager,
             verbose=True,
             )

agent_executor = create_simple_sql_agent_excutor(
    llm=llm,
    toolkit=toolkit,
    callback_manager=manager,
    verbose=True
)
# agent_executor.run("What are the most popular pages visited by our visitors?")

# agent_executor.run("how many visitors profiles are from the Unite States?")
# From here down is all the StreamLit UI.


app = Flask(__name__)
cors = CORS(app)

@app.route('/')
@cross_origin()
def hello_world():
    return 'Hello, World!'


@app.route('/api/ask', methods=['POST'])
@cross_origin()
def submit():
    print("request received")
    data = request.get_json()
    question = data['question']
    sql_data_result = None
    if question:
        print(question)
        sql_data_result = agent_executor.run(question)
        #wandb_callback.flush_tracker(agent_executor, reset=False, finish=True)

    # chartPrompt = PromptTemplate(
    #     template="What chart is best for the data {data}?", input_variables=["data"])
    #
    # chartChain = LLMChain(llm=llm, prompt=chartPrompt)
    # chartChain.run(sql_data_result)
    result = jsonify(sql_data_result)
    return result


if __name__ == '__main__':
    app.run(port=7860)