luulinh90s commited on
Commit
0dc341f
·
1 Parent(s): 8602eac
Files changed (1) hide show
  1. app.py +281 -281
app.py CHANGED
@@ -1,291 +1,291 @@
1
- # from flask import Flask, render_template, request, redirect, url_for, send_from_directory, session
2
- # import json
3
- # import random
4
- # import os
5
- # import string
6
- # from flask_session import Session
7
- # import logging
8
- #
9
- # # Set up logging
10
- # logging.basicConfig(level=logging.INFO,
11
- # format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
12
- # handlers=[
13
- # logging.FileHandler("app.log"),
14
- # logging.StreamHandler()
15
- # ])
16
- # logger = logging.getLogger(__name__)
17
- #
18
- # app = Flask(__name__)
19
- # app.config['SECRET_KEY'] = 'supersecretkey' # Change this to a random secret key
20
- # app.config['SESSION_TYPE'] = 'filesystem'
21
- # Session(app)
22
- #
23
- # # Directories for visualizations
24
- # VISUALIZATION_DIRS_PLAN_OF_SQLS = {
25
- # "TP": "visualizations/TP",
26
- # "TN": "visualizations/TN",
27
- # "FP": "visualizations/FP",
28
- # "FN": "visualizations/FN"
29
- # }
30
- #
31
- # VISUALIZATION_DIRS_CHAIN_OF_TABLE = {
32
- # "TP": "htmls_COT/TP",
33
- # "TN": "htmls_COT/TN",
34
- # "FP": "htmls_COT/FP",
35
- # "FN": "htmls_COT/FN"
36
- # }
37
- #
38
- #
39
- # # Load all sample files from the directories based on the selected method
40
- # def load_samples(method):
41
- # logger.info(f"Loading samples for method: {method}")
42
- # if method == "Chain-of-Table":
43
- # visualization_dirs = VISUALIZATION_DIRS_CHAIN_OF_TABLE
44
- # else:
45
- # visualization_dirs = VISUALIZATION_DIRS_PLAN_OF_SQLS
46
- #
47
- # samples = {"TP": [], "TN": [], "FP": [], "FN": []}
48
- # for category, dir_path in visualization_dirs.items():
49
- # try:
50
- # for filename in os.listdir(dir_path):
51
- # if filename.endswith(".html"):
52
- # samples[category].append(filename)
53
- # logger.info(f"Loaded {len(samples[category])} samples for category {category}")
54
- # except Exception as e:
55
- # logger.exception(f"Error loading samples from {dir_path}: {e}")
56
- # return samples
57
- #
58
- #
59
- # # Randomly select balanced samples
60
- # def select_balanced_samples(samples):
61
- # try:
62
- # tp_fp_samples = random.sample(samples["TP"] + samples["FP"], 5)
63
- # tn_fn_samples = random.sample(samples["TN"] + samples["FN"], 5)
64
- # logger.info(f"Selected balanced samples: {len(tp_fp_samples + tn_fn_samples)}")
65
- # return tp_fp_samples + tn_fn_samples
66
- # except Exception as e:
67
- # logger.exception("Error selecting balanced samples")
68
- # return []
69
- #
70
- #
71
- # def generate_random_string(length=8):
72
- # return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
73
- #
74
- #
75
- # @app.route('/', methods=['GET', 'POST'])
76
- # def index():
77
- # logger.info("Rendering index page.")
78
- # if request.method == 'POST':
79
- # username = request.form.get('username')
80
- # seed = request.form.get('seed')
81
- # method = request.form.get('method')
82
- #
83
- # if not username or not seed or not method:
84
- # logger.error("Missing username, seed, or method.")
85
- # return "Missing username, seed, or method", 400
86
- #
87
- # try:
88
- # seed = int(seed)
89
- # random.seed(seed)
90
- # all_samples = load_samples(method)
91
- # selected_samples = select_balanced_samples(all_samples)
92
- # random_string = generate_random_string()
93
- # filename = f'{username}_{seed}_{method}_{random_string}.json'
94
- #
95
- # logger.info(f"Generated filename: {filename}")
96
- #
97
- # session['selected_samples'] = selected_samples
98
- # session['responses'] = [] # Initialize responses list
99
- # session['method'] = method # Store the selected method
100
- #
101
- # return redirect(url_for('experiment', username=username, sample_index=0, seed=seed, filename=filename))
102
- # except Exception as e:
103
- # logger.exception(f"Error in index route: {e}")
104
- # return "An error occurred", 500
105
- # return render_template('index.html')
106
- #
107
- #
108
- # @app.route('/experiment/<username>/<sample_index>/<seed>/<filename>', methods=['GET'])
109
- # def experiment(username, sample_index, seed, filename):
110
- # try:
111
- # sample_index = int(sample_index)
112
- # selected_samples = session.get('selected_samples', [])
113
- # method = session.get('method') # Retrieve the selected method
114
- #
115
- # if sample_index >= len(selected_samples):
116
- # return redirect(url_for('completed', filename=filename))
117
- #
118
- # visualization_file = selected_samples[sample_index]
119
- # visualization_path = None
120
- #
121
- # # Determine the correct visualization directory based on the method
122
- # if method == "Chain-of-Table":
123
- # visualization_dirs = VISUALIZATION_DIRS_CHAIN_OF_TABLE
124
- # else:
125
- # visualization_dirs = VISUALIZATION_DIRS_PLAN_OF_SQLS
126
- #
127
- # # Find the correct visualization path
128
- # for category, dir_path in visualization_dirs.items():
129
- # if visualization_file in os.listdir(dir_path):
130
- # visualization_path = f"{category}/{visualization_file}"
131
- # break
132
- #
133
- # if not visualization_path:
134
- # logger.error("Visualization file not found.")
135
- # return "Visualization file not found", 404
136
- #
137
- # statement = "Please make a decision to Accept/Reject the AI prediction based on the explanation."
138
- # return render_template('experiment.html',
139
- # sample_id=sample_index,
140
- # statement=statement,
141
- # visualization=visualization_path,
142
- # username=username,
143
- # seed=seed,
144
- # sample_index=sample_index,
145
- # filename=filename)
146
- # except Exception as e:
147
- # logger.exception(f"An error occurred in the experiment route: {e}")
148
- # return "An error occurred", 500
149
- #
150
- #
151
- # @app.route('/visualizations/<path:path>')
152
- # def send_visualization(path):
153
- # try:
154
- # method = session.get('method')
155
- # if method == "Chain-of-Table":
156
- # visualization_dir = 'htmls_COT'
157
- # else: # Default to Plan-of-SQLs
158
- # visualization_dir = 'visualizations'
159
- #
160
- # return send_from_directory(visualization_dir, path)
161
- # except Exception as e:
162
- # logger.exception(f"Error sending visualization: {e}")
163
- # return "An error occurred", 500
164
- #
165
- #
166
- # @app.route('/feedback', methods=['POST'])
167
- # def feedback():
168
- # try:
169
- # sample_id = request.form['sample_id']
170
- # feedback = request.form['feedback']
171
- # username = request.form['username']
172
- # seed = request.form['seed']
173
- # sample_index = int(request.form['sample_index'])
174
- # filename = request.form['filename']
175
- #
176
- # selected_samples = session.get('selected_samples', [])
177
- # responses = session.get('responses', [])
178
- #
179
- # responses.append({
180
- # 'sample_id': sample_id,
181
- # 'feedback': feedback
182
- # })
183
- # session['responses'] = responses
184
- #
185
- # result_dir = 'human_study'
186
- # os.makedirs(result_dir, exist_ok=True)
187
- #
188
- # filepath = os.path.join(result_dir, filename)
189
- # if os.path.exists(filepath):
190
- # with open(filepath, 'r') as f:
191
- # data = json.load(f)
192
- # else:
193
- # data = {}
194
- #
195
- # data[sample_index] = {
196
- # 'Username': username,
197
- # 'Seed': seed,
198
- # 'Sample ID': sample_id,
199
- # 'Task': f"Please make a decision to Accept/Reject the AI prediction based on the explanation.",
200
- # 'User Feedback': feedback
201
- # }
202
- #
203
- # with open(filepath, 'w') as f:
204
- # json.dump(data, f, indent=4)
205
- #
206
- # logger.info(f"Feedback saved for sample {sample_id}")
207
- #
208
- # next_sample_index = sample_index + 1
209
- # if next_sample_index >= len(selected_samples):
210
- # return redirect(url_for('completed', filename=filename))
211
- #
212
- # return redirect(
213
- # url_for('experiment', username=username, sample_index=next_sample_index, seed=seed, filename=filename))
214
- # except Exception as e:
215
- # logger.exception(f"Error in feedback route: {e}")
216
- # return "An error occurred", 500
217
- #
218
- #
219
- # @app.route('/completed/<filename>')
220
- # def completed(filename):
221
- # try:
222
- # responses = session.get('responses', [])
223
- # method = session.get('method')
224
- # if method == "Chain-of-Table":
225
- # json_file = 'Tabular_LLMs_human_study_vis_6_COT.json'
226
- # else: # Default to Plan-of-SQLs
227
- # json_file = 'Tabular_LLMs_human_study_vis_6.json'
228
- #
229
- # with open(json_file, 'r') as f:
230
- # ground_truth = json.load(f)
231
- #
232
- # correct_responses = 0
233
- # accept_count = 0
234
- # reject_count = 0
235
- #
236
- # for response in responses:
237
- # sample_id = response['sample_id']
238
- # feedback = response['feedback']
239
- # index = sample_id.split('-')[1].split('.')[0] # Extract index from filename
240
- #
241
- # if feedback.upper() == "TRUE":
242
- # accept_count += 1
243
- # elif feedback.upper() == "FALSE":
244
- # reject_count += 1
245
- #
246
- # if method == "Chain-of-Table":
247
- # ground_truth_key = f"COT_test-{index}.html"
248
- # else:
249
- # ground_truth_key = f"POS_test-{index}.html"
250
- #
251
- # if ground_truth_key in ground_truth and ground_truth[ground_truth_key][
252
- # 'answer'].upper() == feedback.upper():
253
- # correct_responses += 1
254
- # else:
255
- # logger.warning(f"Missing or mismatched key: {ground_truth_key}")
256
- #
257
- # accuracy = (correct_responses / len(responses)) * 100 if responses else 0
258
- # accuracy = round(accuracy, 2)
259
- #
260
- # accept_percentage = (accept_count / len(responses)) * 100 if len(responses) else 0
261
- # reject_percentage = (reject_count / len(responses)) * 100 if len(responses) else 0
262
- #
263
- # accept_percentage = round(accept_percentage, 2)
264
- # reject_percentage = round(reject_percentage, 2)
265
- #
266
- # return render_template('completed.html',
267
- # accuracy=accuracy,
268
- # accept_percentage=accept_percentage,
269
- # reject_percentage=reject_percentage)
270
- # except Exception as e:
271
- # logger.exception(f"Error in completed route: {e}")
272
- # return "An error occurred", 500
273
- #
274
- #
275
  # if __name__ == '__main__':
276
  # try:
277
  # app.run(debug=False, port=7860)
278
  # except Exception as e:
279
  # logger.exception(f"Failed to start app: {e}")
280
- #
281
-
282
- from flask import Flask
283
 
284
- app = Flask(__name__)
285
 
286
- @app.route('/')
287
- def index():
288
- return "Hello, world!"
 
 
 
 
289
 
290
  # if __name__ == '__main__':
291
  # app.run(debug=False, port=7860)
 
1
+ from flask import Flask, render_template, request, redirect, url_for, send_from_directory, session
2
+ import json
3
+ import random
4
+ import os
5
+ import string
6
+ from flask_session import Session
7
+ import logging
8
+
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO,
11
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
12
+ handlers=[
13
+ logging.FileHandler("app.log"),
14
+ logging.StreamHandler()
15
+ ])
16
+ logger = logging.getLogger(__name__)
17
+
18
+ app = Flask(__name__)
19
+ app.config['SECRET_KEY'] = 'supersecretkey' # Change this to a random secret key
20
+ app.config['SESSION_TYPE'] = 'filesystem'
21
+ Session(app)
22
+
23
+ # Directories for visualizations
24
+ VISUALIZATION_DIRS_PLAN_OF_SQLS = {
25
+ "TP": "visualizations/TP",
26
+ "TN": "visualizations/TN",
27
+ "FP": "visualizations/FP",
28
+ "FN": "visualizations/FN"
29
+ }
30
+
31
+ VISUALIZATION_DIRS_CHAIN_OF_TABLE = {
32
+ "TP": "htmls_COT/TP",
33
+ "TN": "htmls_COT/TN",
34
+ "FP": "htmls_COT/FP",
35
+ "FN": "htmls_COT/FN"
36
+ }
37
+
38
+
39
+ # Load all sample files from the directories based on the selected method
40
+ def load_samples(method):
41
+ logger.info(f"Loading samples for method: {method}")
42
+ if method == "Chain-of-Table":
43
+ visualization_dirs = VISUALIZATION_DIRS_CHAIN_OF_TABLE
44
+ else:
45
+ visualization_dirs = VISUALIZATION_DIRS_PLAN_OF_SQLS
46
+
47
+ samples = {"TP": [], "TN": [], "FP": [], "FN": []}
48
+ for category, dir_path in visualization_dirs.items():
49
+ try:
50
+ for filename in os.listdir(dir_path):
51
+ if filename.endswith(".html"):
52
+ samples[category].append(filename)
53
+ logger.info(f"Loaded {len(samples[category])} samples for category {category}")
54
+ except Exception as e:
55
+ logger.exception(f"Error loading samples from {dir_path}: {e}")
56
+ return samples
57
+
58
+
59
+ # Randomly select balanced samples
60
+ def select_balanced_samples(samples):
61
+ try:
62
+ tp_fp_samples = random.sample(samples["TP"] + samples["FP"], 5)
63
+ tn_fn_samples = random.sample(samples["TN"] + samples["FN"], 5)
64
+ logger.info(f"Selected balanced samples: {len(tp_fp_samples + tn_fn_samples)}")
65
+ return tp_fp_samples + tn_fn_samples
66
+ except Exception as e:
67
+ logger.exception("Error selecting balanced samples")
68
+ return []
69
+
70
+
71
+ def generate_random_string(length=8):
72
+ return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
73
+
74
+
75
+ @app.route('/', methods=['GET', 'POST'])
76
+ def index():
77
+ logger.info("Rendering index page.")
78
+ if request.method == 'POST':
79
+ username = request.form.get('username')
80
+ seed = request.form.get('seed')
81
+ method = request.form.get('method')
82
+
83
+ if not username or not seed or not method:
84
+ logger.error("Missing username, seed, or method.")
85
+ return "Missing username, seed, or method", 400
86
+
87
+ try:
88
+ seed = int(seed)
89
+ random.seed(seed)
90
+ all_samples = load_samples(method)
91
+ selected_samples = select_balanced_samples(all_samples)
92
+ random_string = generate_random_string()
93
+ filename = f'{username}_{seed}_{method}_{random_string}.json'
94
+
95
+ logger.info(f"Generated filename: {filename}")
96
+
97
+ session['selected_samples'] = selected_samples
98
+ session['responses'] = [] # Initialize responses list
99
+ session['method'] = method # Store the selected method
100
+
101
+ return redirect(url_for('experiment', username=username, sample_index=0, seed=seed, filename=filename))
102
+ except Exception as e:
103
+ logger.exception(f"Error in index route: {e}")
104
+ return "An error occurred", 500
105
+ return render_template('index.html')
106
+
107
+
108
+ @app.route('/experiment/<username>/<sample_index>/<seed>/<filename>', methods=['GET'])
109
+ def experiment(username, sample_index, seed, filename):
110
+ try:
111
+ sample_index = int(sample_index)
112
+ selected_samples = session.get('selected_samples', [])
113
+ method = session.get('method') # Retrieve the selected method
114
+
115
+ if sample_index >= len(selected_samples):
116
+ return redirect(url_for('completed', filename=filename))
117
+
118
+ visualization_file = selected_samples[sample_index]
119
+ visualization_path = None
120
+
121
+ # Determine the correct visualization directory based on the method
122
+ if method == "Chain-of-Table":
123
+ visualization_dirs = VISUALIZATION_DIRS_CHAIN_OF_TABLE
124
+ else:
125
+ visualization_dirs = VISUALIZATION_DIRS_PLAN_OF_SQLS
126
+
127
+ # Find the correct visualization path
128
+ for category, dir_path in visualization_dirs.items():
129
+ if visualization_file in os.listdir(dir_path):
130
+ visualization_path = f"{category}/{visualization_file}"
131
+ break
132
+
133
+ if not visualization_path:
134
+ logger.error("Visualization file not found.")
135
+ return "Visualization file not found", 404
136
+
137
+ statement = "Please make a decision to Accept/Reject the AI prediction based on the explanation."
138
+ return render_template('experiment.html',
139
+ sample_id=sample_index,
140
+ statement=statement,
141
+ visualization=visualization_path,
142
+ username=username,
143
+ seed=seed,
144
+ sample_index=sample_index,
145
+ filename=filename)
146
+ except Exception as e:
147
+ logger.exception(f"An error occurred in the experiment route: {e}")
148
+ return "An error occurred", 500
149
+
150
+
151
+ @app.route('/visualizations/<path:path>')
152
+ def send_visualization(path):
153
+ try:
154
+ method = session.get('method')
155
+ if method == "Chain-of-Table":
156
+ visualization_dir = 'htmls_COT'
157
+ else: # Default to Plan-of-SQLs
158
+ visualization_dir = 'visualizations'
159
+
160
+ return send_from_directory(visualization_dir, path)
161
+ except Exception as e:
162
+ logger.exception(f"Error sending visualization: {e}")
163
+ return "An error occurred", 500
164
+
165
+
166
+ @app.route('/feedback', methods=['POST'])
167
+ def feedback():
168
+ try:
169
+ sample_id = request.form['sample_id']
170
+ feedback = request.form['feedback']
171
+ username = request.form['username']
172
+ seed = request.form['seed']
173
+ sample_index = int(request.form['sample_index'])
174
+ filename = request.form['filename']
175
+
176
+ selected_samples = session.get('selected_samples', [])
177
+ responses = session.get('responses', [])
178
+
179
+ responses.append({
180
+ 'sample_id': sample_id,
181
+ 'feedback': feedback
182
+ })
183
+ session['responses'] = responses
184
+
185
+ result_dir = 'human_study'
186
+ os.makedirs(result_dir, exist_ok=True)
187
+
188
+ filepath = os.path.join(result_dir, filename)
189
+ if os.path.exists(filepath):
190
+ with open(filepath, 'r') as f:
191
+ data = json.load(f)
192
+ else:
193
+ data = {}
194
+
195
+ data[sample_index] = {
196
+ 'Username': username,
197
+ 'Seed': seed,
198
+ 'Sample ID': sample_id,
199
+ 'Task': f"Please make a decision to Accept/Reject the AI prediction based on the explanation.",
200
+ 'User Feedback': feedback
201
+ }
202
+
203
+ with open(filepath, 'w') as f:
204
+ json.dump(data, f, indent=4)
205
+
206
+ logger.info(f"Feedback saved for sample {sample_id}")
207
+
208
+ next_sample_index = sample_index + 1
209
+ if next_sample_index >= len(selected_samples):
210
+ return redirect(url_for('completed', filename=filename))
211
+
212
+ return redirect(
213
+ url_for('experiment', username=username, sample_index=next_sample_index, seed=seed, filename=filename))
214
+ except Exception as e:
215
+ logger.exception(f"Error in feedback route: {e}")
216
+ return "An error occurred", 500
217
+
218
+
219
+ @app.route('/completed/<filename>')
220
+ def completed(filename):
221
+ try:
222
+ responses = session.get('responses', [])
223
+ method = session.get('method')
224
+ if method == "Chain-of-Table":
225
+ json_file = 'Tabular_LLMs_human_study_vis_6_COT.json'
226
+ else: # Default to Plan-of-SQLs
227
+ json_file = 'Tabular_LLMs_human_study_vis_6.json'
228
+
229
+ with open(json_file, 'r') as f:
230
+ ground_truth = json.load(f)
231
+
232
+ correct_responses = 0
233
+ accept_count = 0
234
+ reject_count = 0
235
+
236
+ for response in responses:
237
+ sample_id = response['sample_id']
238
+ feedback = response['feedback']
239
+ index = sample_id.split('-')[1].split('.')[0] # Extract index from filename
240
+
241
+ if feedback.upper() == "TRUE":
242
+ accept_count += 1
243
+ elif feedback.upper() == "FALSE":
244
+ reject_count += 1
245
+
246
+ if method == "Chain-of-Table":
247
+ ground_truth_key = f"COT_test-{index}.html"
248
+ else:
249
+ ground_truth_key = f"POS_test-{index}.html"
250
+
251
+ if ground_truth_key in ground_truth and ground_truth[ground_truth_key][
252
+ 'answer'].upper() == feedback.upper():
253
+ correct_responses += 1
254
+ else:
255
+ logger.warning(f"Missing or mismatched key: {ground_truth_key}")
256
+
257
+ accuracy = (correct_responses / len(responses)) * 100 if responses else 0
258
+ accuracy = round(accuracy, 2)
259
+
260
+ accept_percentage = (accept_count / len(responses)) * 100 if len(responses) else 0
261
+ reject_percentage = (reject_count / len(responses)) * 100 if len(responses) else 0
262
+
263
+ accept_percentage = round(accept_percentage, 2)
264
+ reject_percentage = round(reject_percentage, 2)
265
+
266
+ return render_template('completed.html',
267
+ accuracy=accuracy,
268
+ accept_percentage=accept_percentage,
269
+ reject_percentage=reject_percentage)
270
+ except Exception as e:
271
+ logger.exception(f"Error in completed route: {e}")
272
+ return "An error occurred", 500
273
+
274
+
275
  # if __name__ == '__main__':
276
  # try:
277
  # app.run(debug=False, port=7860)
278
  # except Exception as e:
279
  # logger.exception(f"Failed to start app: {e}")
 
 
 
280
 
 
281
 
282
+ # from flask import Flask
283
+ #
284
+ # app = Flask(__name__)
285
+ #
286
+ # @app.route('/')
287
+ # def index():
288
+ # return "Hello, world!"
289
 
290
  # if __name__ == '__main__':
291
  # app.run(debug=False, port=7860)