luulinh90s commited on
Commit
0fe5446
·
1 Parent(s): 39d7f47
Files changed (1) hide show
  1. app.py +129 -11
app.py CHANGED
@@ -100,21 +100,44 @@ def save_session_data_to_hf(session_id, data):
100
  logger.info(f"Session data saved for session {session_id} in Hugging Face Data Space")
101
  except Exception as e:
102
  logger.exception(f"Error saving session data for session {session_id}: {e}")
103
-
104
- def load_samples():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  common_samples = []
106
  categories = ["TP", "TN", "FP", "FN"]
107
 
108
  for category in categories:
109
- files = set(os.listdir(f'htmls_NO_XAI_mod/{category}'))
110
- for method in ["Dater", "Chain-of-Table", "Plan-of-SQLs", "Text2SQL"]:
111
- method_dir = VISUALIZATION_DIRS[method]
112
- files &= set(os.listdir(f'{method_dir}/{category}'))
113
 
114
  for file in files:
115
- common_samples.append({'category': category, 'file': file})
 
 
 
 
 
 
 
 
116
 
117
- logger.info(f"Found {len(common_samples)} common samples across all methods")
118
  return common_samples
119
 
120
  def select_balanced_samples(samples):
@@ -149,6 +172,45 @@ def select_balanced_samples(samples):
149
  @app.route('/attribution')
150
  def attribution():
151
  return render_template('attribution.html')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  @app.route('/index', methods=['GET', 'POST'])
154
  def index():
@@ -158,12 +220,28 @@ def index():
158
  method = request.form.get('method')
159
  if not username or not seed or not method:
160
  return render_template('index.html', error="Please fill in all fields and select a method.")
161
- if method not in ['Chain-of-Table', 'Plan-of-SQLs', 'Dater', 'Text2SQL']:
162
  return render_template('index.html', error="Invalid method selected.")
163
  try:
164
  seed = int(seed)
165
  random.seed(seed)
166
- all_samples = load_samples()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  selected_samples = select_balanced_samples(all_samples)
168
  if len(selected_samples) == 0:
169
  return render_template('index.html', error="No common samples were found")
@@ -212,6 +290,39 @@ def explanation(session_id):
212
  else:
213
  logger.error(f"Invalid method '{method}' for session ID: {session_id}")
214
  return redirect(url_for('index'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  @app.route('/experiment/<session_id>', methods=['GET', 'POST'])
217
  def experiment(session_id):
@@ -231,6 +342,12 @@ def experiment(session_id):
231
  visualization_dir = VISUALIZATION_DIRS[method]
232
  visualization_path = f"{visualization_dir}/{sample['category']}/{sample['file']}"
233
 
 
 
 
 
 
 
234
  statement = """
235
  Please note that in select row function, starting index is 0 for Chain-of-Table and 1 for Dater and Index * represents the selection for all rows.
236
  """
@@ -240,7 +357,8 @@ Please note that in select row function, starting index is 0 for Chain-of-Table
240
  statement=statement,
241
  visualization=url_for('send_visualization', filename=visualization_path),
242
  session_id=session_id,
243
- method=method)
 
244
  except Exception as e:
245
  logger.exception(f"An error occurred in the experiment route: {e}")
246
  return "An error occurred", 500
 
100
  logger.info(f"Session data saved for session {session_id} in Hugging Face Data Space")
101
  except Exception as e:
102
  logger.exception(f"Error saving session data for session {session_id}: {e}")
103
+ #
104
+ # def load_samples():
105
+ # common_samples = []
106
+ # categories = ["TP", "TN", "FP", "FN"]
107
+ #
108
+ # for category in categories:
109
+ # files = set(os.listdir(f'htmls_NO_XAI_mod/{category}'))
110
+ # for method in ["Dater", "Chain-of-Table", "Plan-of-SQLs", "Text2SQL"]:
111
+ # method_dir = VISUALIZATION_DIRS[method]
112
+ # files &= set(os.listdir(f'{method_dir}/{category}'))
113
+ #
114
+ # for file in files:
115
+ # common_samples.append({'category': category, 'file': file})
116
+ #
117
+ # logger.info(f"Found {len(common_samples)} common samples across all methods")
118
+ # return common_samples
119
+
120
+ def load_samples(method, metadata):
121
  common_samples = []
122
  categories = ["TP", "TN", "FP", "FN"]
123
 
124
  for category in categories:
125
+ # files = set(os.listdir(f'htmls_NO_XAI_mod/{category}'))
126
+ method_dir = VISUALIZATION_DIRS[method]
127
+ files = set(os.listdir(f'{method_dir}/{category}'))
 
128
 
129
  for file in files:
130
+ index = file.split('-')[1].split('.')[0]
131
+ metadata_key = f"{get_method_dir(method)}_test-{index}.html"
132
+ sample_metadata = metadata.get(metadata_key, {})
133
+
134
+ common_samples.append({
135
+ 'category': category,
136
+ 'file': file,
137
+ 'metadata': sample_metadata
138
+ })
139
 
140
+ logger.info(f"Found {len(common_samples)} samples for method {method}")
141
  return common_samples
142
 
143
  def select_balanced_samples(samples):
 
172
  @app.route('/attribution')
173
  def attribution():
174
  return render_template('attribution.html')
175
+ #
176
+ # @app.route('/index', methods=['GET', 'POST'])
177
+ # def index():
178
+ # if request.method == 'POST':
179
+ # username = request.form.get('username')
180
+ # seed = request.form.get('seed')
181
+ # method = request.form.get('method')
182
+ # if not username or not seed or not method:
183
+ # return render_template('index.html', error="Please fill in all fields and select a method.")
184
+ # if method not in ['Chain-of-Table', 'Plan-of-SQLs', 'Dater', 'Text2SQL']:
185
+ # return render_template('index.html', error="Invalid method selected.")
186
+ # try:
187
+ # seed = int(seed)
188
+ # random.seed(seed)
189
+ # all_samples = load_samples()
190
+ # selected_samples = select_balanced_samples(all_samples)
191
+ # if len(selected_samples) == 0:
192
+ # return render_template('index.html', error="No common samples were found")
193
+ # start_time = datetime.now().isoformat()
194
+ # session_id = generate_session_id()
195
+ # session_data = {
196
+ # 'username': username,
197
+ # 'seed': str(seed),
198
+ # 'method': method,
199
+ # 'selected_samples': selected_samples,
200
+ # 'current_index': 0,
201
+ # 'responses': [],
202
+ # 'start_time': start_time,
203
+ # 'session_id': session_id
204
+ # }
205
+ # save_session_data(session_id, session_data)
206
+ # logger.info(f"Session data stored for user {username}, method {method}, session_id {session_id}")
207
+ #
208
+ # # Redirect to explanation for all methods
209
+ # return redirect(url_for('explanation', session_id=session_id))
210
+ # except Exception as e:
211
+ # logger.exception(f"Error in index route: {e}")
212
+ # return render_template('index.html', error="An error occurred. Please try again.")
213
+ # return render_template('index.html', show_no_xai=False)
214
 
215
  @app.route('/index', methods=['GET', 'POST'])
216
  def index():
 
220
  method = request.form.get('method')
221
  if not username or not seed or not method:
222
  return render_template('index.html', error="Please fill in all fields and select a method.")
223
+ if method not in ['Chain-of-Table', 'Plan-of-SQLs', 'Dater', 'Text2SQL', 'No-XAI']:
224
  return render_template('index.html', error="Invalid method selected.")
225
  try:
226
  seed = int(seed)
227
  random.seed(seed)
228
+
229
+ # Load the appropriate metadata file
230
+ if method == "Chain-of-Table":
231
+ json_file = 'Tabular_LLMs_human_study_vis_6_COT.json'
232
+ elif method == "Plan-of-SQLs":
233
+ json_file = 'Tabular_LLMs_human_study_vis_6_POS.json'
234
+ elif method == "Dater":
235
+ json_file = 'Tabular_LLMs_human_study_vis_6_DATER.json'
236
+ elif method == "No-XAI":
237
+ json_file = 'Tabular_LLMs_human_study_vis_6_NO_XAI.json'
238
+ elif method == "Text2SQL":
239
+ json_file = 'Tabular_LLMs_human_study_vis_6_Text2SQL.json'
240
+
241
+ with open(json_file, 'r') as f:
242
+ metadata = json.load(f)
243
+
244
+ all_samples = load_samples(method, metadata)
245
  selected_samples = select_balanced_samples(all_samples)
246
  if len(selected_samples) == 0:
247
  return render_template('index.html', error="No common samples were found")
 
290
  else:
291
  logger.error(f"Invalid method '{method}' for session ID: {session_id}")
292
  return redirect(url_for('index'))
293
+ #
294
+ # @app.route('/experiment/<session_id>', methods=['GET', 'POST'])
295
+ # def experiment(session_id):
296
+ # try:
297
+ # session_data = load_session_data(session_id)
298
+ # if not session_data:
299
+ # return redirect(url_for('index'))
300
+ #
301
+ # selected_samples = session_data['selected_samples']
302
+ #
303
+ # method = session_data['method']
304
+ # current_index = session_data['current_index']
305
+ #
306
+ # if current_index >= len(selected_samples):
307
+ # return redirect(url_for('completed', session_id=session_id))
308
+ #
309
+ # sample = selected_samples[current_index]
310
+ # visualization_dir = VISUALIZATION_DIRS[method]
311
+ # visualization_path = f"{visualization_dir}/{sample['category']}/{sample['file']}"
312
+ #
313
+ # statement = """
314
+ # Please note that in select row function, starting index is 0 for Chain-of-Table and 1 for Dater and Index * represents the selection for all rows.
315
+ # """
316
+ #
317
+ # return render_template('experiment.html',
318
+ # sample_id=current_index,
319
+ # statement=statement,
320
+ # visualization=url_for('send_visualization', filename=visualization_path),
321
+ # session_id=session_id,
322
+ # method=method)
323
+ # except Exception as e:
324
+ # logger.exception(f"An error occurred in the experiment route: {e}")
325
+ # return "An error occurred", 500
326
 
327
  @app.route('/experiment/<session_id>', methods=['GET', 'POST'])
328
  def experiment(session_id):
 
342
  visualization_dir = VISUALIZATION_DIRS[method]
343
  visualization_path = f"{visualization_dir}/{sample['category']}/{sample['file']}"
344
 
345
+ # Extract metadata
346
+ metadata = sample.get('metadata', {})
347
+
348
+ # Log the metadata
349
+ logger.info(f"Sample metadata for session {session_id}, method {method}, index {current_index}: {metadata}")
350
+
351
  statement = """
352
  Please note that in select row function, starting index is 0 for Chain-of-Table and 1 for Dater and Index * represents the selection for all rows.
353
  """
 
357
  statement=statement,
358
  visualization=url_for('send_visualization', filename=visualization_path),
359
  session_id=session_id,
360
+ method=method,
361
+ metadata=metadata) # Pass metadata to the template
362
  except Exception as e:
363
  logger.exception(f"An error occurred in the experiment route: {e}")
364
  return "An error occurred", 500