fragger246 commited on
Commit
e04f4cc
·
verified ·
1 Parent(s): f5f22f2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ from transformers import BertTokenizer, BertModel, GPT2LMHeadModel, GPT2Tokenizer
5
+ import torch
6
+ import gradio as gr
7
+
8
+ # Initialize the BERT model and tokenizer
9
+ bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
+ bert_model = BertModel.from_pretrained('bert-base-uncased')
11
+
12
+ def get_bert_embeddings(texts):
13
+ inputs = bert_tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
14
+ with torch.no_grad():
15
+ outputs = bert_model(**inputs)
16
+ return outputs.last_hidden_state[:, 0, :].numpy()
17
+
18
+ def get_closest_question(user_query, questions, threshold=0.95):
19
+ all_texts = questions + [user_query]
20
+ embeddings = get_bert_embeddings(all_texts)
21
+ cosine_similarities = np.dot(embeddings[-1], embeddings[:-1].T) / (
22
+ np.linalg.norm(embeddings[-1]) * np.linalg.norm(embeddings[:-1], axis=1)
23
+ )
24
+ max_similarity = np.max(cosine_similarities)
25
+ if max_similarity >= threshold:
26
+ most_similar_index = np.argmax(cosine_similarities)
27
+ return questions[most_similar_index], max_similarity
28
+ else:
29
+ return None, max_similarity
30
+
31
+ def generate_gpt2_response(prompt, model, tokenizer, max_length=100):
32
+ inputs = tokenizer.encode(prompt, return_tensors='pt')
33
+ outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1)
34
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+
36
+ # Initialize data
37
+ data_dict = {
38
+ "questions": [
39
+ "What is Rookus?",
40
+ "How does Rookus use AI in its designs?",
41
+ "What products does Rookus offer?",
42
+ "Can I see samples of Rookus' designs?",
43
+ "How can I join the waitlist for Rookus?",
44
+ "How does Rookus ensure the quality of its AI-generated designs?",
45
+ "Is there a custom design option available at Rookus?",
46
+ "How long does it take to receive a product from Rookus?"
47
+ ],
48
+ "answers": [
49
+ "Rookus is a startup that leverages AI to create unique designs for various products such as clothes, posters, and different arts and crafts.",
50
+ "Rookus uses advanced AI algorithms to generate innovative and aesthetically pleasing designs. These AI models are trained on vast datasets of art and design to produce high-quality mockups.",
51
+ "Rookus offers a variety of products, including clothing, posters, and a range of arts and crafts items, all featuring AI-generated designs.",
52
+ "Yes, Rookus provides samples of its designs on its website. You can view a gallery of products showcasing the AI-generated artwork.",
53
+ "To join the waitlist for Rookus, visit our website and sign up with your email. You'll receive updates on our launch and exclusive early access opportunities.",
54
+ "Rookus ensures the quality of its AI-generated designs through rigorous testing and refinement. Each design goes through multiple review stages to ensure it meets our high standards.",
55
+ "Yes, Rookus offers custom design options. You can submit your preferences, and our AI will generate a design tailored to your specifications.",
56
+ "The delivery time for products from Rookus varies based on the product type and location. Typically, it takes 2-4 weeks for production and delivery."
57
+ ],
58
+ "default_answer": "I'm sorry, I cannot answer this right now. Your question has been saved, and we will get back to you with a response soon."
59
+ }
60
+
61
+ # Initialize GPT-2 model and tokenizer
62
+ gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
63
+ gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2')
64
+
65
+ # Ensure the Excel file is created with necessary structure
66
+ excel_file = 'data.xlsx'
67
+ if not os.path.isfile(excel_file):
68
+ df = pd.DataFrame(columns=['question'])
69
+ df.to_excel(excel_file, index=False)
70
+
71
+ def chatbot(user_query):
72
+ closest_question, similarity = get_closest_question(user_query, data_dict['questions'], threshold=0.95)
73
+ if closest_question and similarity >= 0.95:
74
+ answer_index = data_dict['questions'].index(closest_question)
75
+ answer = data_dict['answers'][answer_index]
76
+ else:
77
+ new_data = pd.DataFrame({'question': [user_query]})
78
+ df = pd.read_excel(excel_file)
79
+ df = pd.concat([df, new_data], ignore_index=True)
80
+ with pd.ExcelWriter(excel_file, engine='openpyxl', mode='w') as writer:
81
+ df.to_excel(writer, index=False)
82
+ answer = data_dict['default_answer']
83
+
84
+ return answer
85
+
86
+ iface = gr.Interface(fn=chatbot, inputs="text", outputs="text")
87
+ iface.launch()