luulinh90s commited on
Commit
eb2147a
·
1 Parent(s): 69d37ef
Files changed (2) hide show
  1. app.py +87 -112
  2. templates/index.html +8 -8
app.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import string
6
  import logging
7
  from datetime import datetime
 
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO,
@@ -15,144 +16,129 @@ logging.basicConfig(level=logging.INFO,
15
  ])
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
 
 
 
18
  app = Flask(__name__)
19
  app.config['SECRET_KEY'] = 'supersecretkey' # Change this to a random secret key
20
 
21
  # Directories for visualizations
22
- VISUALIZATION_DIRS_PLAN_OF_SQLS = {
23
- "TP": "htmls_POS/TP",
24
- "TN": "htmls_POS/TN",
25
- "FP": "htmls_POS/FP",
26
- "FN": "htmls_POS/FN"
27
- }
28
-
29
- VISUALIZATION_DIRS_CHAIN_OF_TABLE = {
30
- "TP": "htmls_COT/TP",
31
- "TN": "htmls_COT/TN",
32
- "FP": "htmls_COT/FP",
33
- "FN": "htmls_COT/FN"
34
  }
35
 
 
36
 
37
  def save_session_data(username, data):
38
  try:
39
- base_dir = os.path.dirname(os.path.abspath(__file__))
40
- session_dir = os.path.join(base_dir, 'session_data')
41
- os.makedirs(session_dir, exist_ok=True)
42
-
43
- file_path = os.path.join(session_dir, f'{username}_session.json')
44
-
45
- with open(file_path, 'w') as f:
46
- json.dump(data, f, indent=4)
47
-
48
- logger.info(f"Session data saved for user {username} at {file_path}")
 
 
 
 
 
49
  except Exception as e:
50
  logger.exception(f"Error saving session data for user {username}: {e}")
51
 
52
-
53
- # Similarly, update the load_session_data function
54
  def load_session_data(username):
55
  try:
56
- base_dir = os.path.dirname(os.path.abspath(__file__))
57
- file_path = os.path.join(base_dir, 'session_data', f'{username}_session.json')
58
-
 
 
 
 
 
59
  with open(file_path, 'r') as f:
60
  data = json.load(f)
61
-
62
- logger.info(f"Session data loaded for user {username} from {file_path}")
63
  return data
64
- except FileNotFoundError:
65
- logger.warning(f"No session data found for user {username}")
66
- return None
67
  except Exception as e:
68
  logger.exception(f"Error loading session data for user {username}: {e}")
69
  return None
70
 
71
- # Load all sample files from the directories based on the selected method
72
- def load_samples(method):
73
- logger.info(f"Loading samples for method: {method}")
74
- if method == "Chain-of-Table":
75
- visualization_dirs = VISUALIZATION_DIRS_CHAIN_OF_TABLE
76
- else:
77
- visualization_dirs = VISUALIZATION_DIRS_PLAN_OF_SQLS
78
 
79
- samples = {"TP": [], "TN": [], "FP": [], "FN": []}
80
- for category, dir_path in visualization_dirs.items():
81
- try:
82
- for filename in os.listdir(dir_path):
83
- if filename.endswith(".html"):
84
- samples[category].append(filename)
85
- logger.info(f"Loaded {len(samples[category])} samples for category {category}")
86
- except Exception as e:
87
- logger.exception(f"Error loading samples from {dir_path}: {e}")
88
- return samples
 
89
 
90
- # Randomly select balanced samples
91
  def select_balanced_samples(samples):
92
  try:
93
- tp_fp_samples = random.sample(samples["TP"] + samples["FP"], 5)
94
- tn_fn_samples = random.sample(samples["TN"] + samples["FN"], 5)
95
- logger.info(f"Selected balanced samples: {len(tp_fp_samples + tn_fn_samples)}")
96
- return tp_fp_samples + tn_fn_samples
 
 
 
97
  except Exception as e:
98
  logger.exception("Error selecting balanced samples")
99
  return []
100
 
101
- def generate_random_string(length=8):
102
- return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
103
-
104
-
105
  @app.route('/', methods=['GET', 'POST'])
106
  def index():
107
- logger.info("Rendering index page.")
108
  if request.method == 'POST':
109
  username = request.form.get('username')
110
  seed = request.form.get('seed')
111
-
112
- if not username or not seed:
113
- logger.error("Missing username or seed.")
114
- return "Missing username or seed", 400
115
-
116
  try:
117
  seed = int(seed)
118
  random.seed(seed)
119
-
120
- # Use only one method (e.g., "Chain-of-Table")
121
- method = "Chain-of-Table"
122
- all_samples = load_samples(method)
123
  selected_samples = select_balanced_samples(all_samples)
124
- logger.info(f"Number of selected samples: {len(selected_samples)}")
125
-
126
  if len(selected_samples) == 0:
127
- logger.error("No samples were selected.")
128
- return "No samples were selected", 500
129
-
130
- filename = f'{username}_{seed}_{method}_{generate_random_string()}.json'
131
- logger.info(f"Generated filename: {filename}")
132
-
133
- # Save session data
134
  session_data = {
135
- 'responses': [],
136
  'username': username,
137
- 'selected_samples': selected_samples,
138
  'method': method,
139
- 'filename': filename,
140
- 'current_index': 0
 
 
141
  }
142
  save_session_data(username, session_data)
143
- logger.info(f"Session data saved for user: {username}")
144
-
145
  return redirect(url_for('experiment', username=username))
146
  except Exception as e:
147
  logger.exception(f"Error in index route: {e}")
148
  return "An error occurred", 500
149
  return render_template('index.html')
 
150
  @app.route('/experiment/<username>', methods=['GET', 'POST'])
151
  def experiment(username):
152
  try:
153
  session_data = load_session_data(username)
154
  if not session_data:
155
- logger.error(f"No session data found for user: {username}")
156
  return redirect(url_for('index'))
157
 
158
  selected_samples = session_data['selected_samples']
@@ -162,20 +148,9 @@ def experiment(username):
162
  if current_index >= len(selected_samples):
163
  return redirect(url_for('completed', username=username))
164
 
165
- visualization_file = selected_samples[current_index]
166
-
167
- vis_dir = 'htmls_COT' if method == "Chain-of-Table" else 'htmls_POS'
168
-
169
- # Determine the correct visualization directory based on the category
170
- for category, dir_path in VISUALIZATION_DIRS_CHAIN_OF_TABLE.items():
171
- if visualization_file in os.listdir(dir_path):
172
- visualization_path = os.path.join(vis_dir, category, visualization_file)
173
- break
174
- else:
175
- logger.error(f"Visualization file {visualization_file} not found.")
176
- return "Visualization file not found", 404
177
-
178
- logger.info(f"Rendering experiment page with visualization: {visualization_path}")
179
 
180
  statement = """
181
  Based on the explanation provided, what do you think the AI model will predict?
@@ -186,11 +161,13 @@ Will it predict the statement as TRUE or FALSE?
186
  sample_id=current_index,
187
  statement=statement,
188
  visualization=url_for('send_visualization', filename=visualization_path),
189
- username=username)
 
190
  except Exception as e:
191
  logger.exception(f"An error occurred in the experiment route: {e}")
192
  return "An error occurred", 500
193
 
 
194
  @app.route('/feedback', methods=['POST'])
195
  def feedback():
196
  try:
@@ -202,16 +179,12 @@ def feedback():
202
  logger.error(f"No session data found for user: {username}")
203
  return redirect(url_for('index'))
204
 
205
- # Store the user's prediction
206
  session_data['responses'].append({
207
  'sample_id': session_data['current_index'],
208
  'user_prediction': prediction
209
  })
210
 
211
- # Move to the next sample
212
  session_data['current_index'] += 1
213
-
214
- # Save updated session data
215
  save_session_data(username, session_data)
216
  logger.info(f"Prediction saved for user {username}, sample {session_data['current_index'] - 1}")
217
 
@@ -223,7 +196,6 @@ def feedback():
223
  logger.exception(f"Error in feedback route: {e}")
224
  return "An error occurred", 500
225
 
226
-
227
  @app.route('/completed/<username>')
228
  def completed(username):
229
  try:
@@ -232,10 +204,11 @@ def completed(username):
232
  logger.error(f"No session data found for user: {username}")
233
  return redirect(url_for('index'))
234
 
 
235
  responses = session_data['responses']
236
  method = session_data['method']
237
 
238
- json_file = 'Tabular_LLMs_human_study_vis_6_COT.json' if method == "Chain-of-Table" else 'Tabular_LLMs_human_study_vis_6_POS.json'
239
 
240
  with open(json_file, 'r') as f:
241
  ground_truth = json.load(f)
@@ -247,10 +220,10 @@ def completed(username):
247
  for response in responses:
248
  sample_id = response['sample_id']
249
  user_prediction = response['user_prediction']
250
- visualization_file = session_data['selected_samples'][sample_id]
251
- index = visualization_file.split('-')[1].split('.')[0] # Extract index from filename
252
 
253
- ground_truth_key = f"COT_test-{index}.html" if method == "Chain-of-Table" else f"POS_test-{index}.html"
254
 
255
  if ground_truth_key in ground_truth:
256
  model_prediction = ground_truth[ground_truth_key]['answer'].upper()
@@ -273,6 +246,12 @@ def completed(username):
273
  true_percentage = round(true_percentage, 2)
274
  false_percentage = round(false_percentage, 2)
275
 
 
 
 
 
 
 
276
  return render_template('completed.html',
277
  accuracy=accuracy,
278
  true_percentage=true_percentage,
@@ -281,11 +260,9 @@ def completed(username):
281
  logger.exception(f"An error occurred in the completed route: {e}")
282
  return "An error occurred", 500
283
 
284
-
285
  @app.route('/visualizations/<path:filename>')
286
  def send_visualization(filename):
287
  logger.info(f"Attempting to serve file: {filename}")
288
- # Ensure the path is safe and doesn't allow access to files outside the intended directory
289
  base_dir = os.getcwd()
290
  file_path = os.path.normpath(os.path.join(base_dir, filename))
291
  if not file_path.startswith(base_dir):
@@ -299,7 +276,5 @@ def send_visualization(filename):
299
  logger.info(f"Serving file from directory: {directory}, filename: {file_name}")
300
  return send_from_directory(directory, file_name)
301
 
302
-
303
  if __name__ == "__main__":
304
- os.makedirs('session_data', exist_ok=True) # Ensure the directory for session files exists
305
  app.run(host="0.0.0.0", port=7860, debug=True)
 
5
  import string
6
  import logging
7
  from datetime import datetime
8
+ from huggingface_hub import login, HfApi, hf_hub_download
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO,
 
16
  ])
17
  logger = logging.getLogger(__name__)
18
 
19
+ # Use the Hugging Face token from environment variables
20
+ hf_token = os.environ.get("HF_TOKEN")
21
+ if hf_token:
22
+ login(token=hf_token)
23
+ else:
24
+ logger.error("HF_TOKEN not found in environment variables")
25
+
26
  app = Flask(__name__)
27
  app.config['SECRET_KEY'] = 'supersecretkey' # Change this to a random secret key
28
 
29
  # Directories for visualizations
30
+ VISUALIZATION_DIRS = {
31
+ "No-XAI": "htmls_NO_XAI",
32
+ "Dater": "htmls_DATER",
33
+ "Chain-of-Table": "htmls_COT",
34
+ "Plan-of-SQLs": "htmls_POS"
 
 
 
 
 
 
 
35
  }
36
 
37
+ METHODS = ["No-XAI", "Dater", "Chain-of-Table", "Plan-of-SQLs"]
38
 
39
  def save_session_data(username, data):
40
  try:
41
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
42
+ file_name = f'{username}_{timestamp}_session.json'
43
+ json_data = json.dumps(data, indent=4)
44
+ temp_file_path = f"/tmp/{file_name}"
45
+ with open(temp_file_path, 'w') as f:
46
+ f.write(json_data)
47
+ api = HfApi()
48
+ api.upload_file(
49
+ path_or_fileobj=temp_file_path,
50
+ path_in_repo=f"session_data_foward_simulation/{file_name}",
51
+ repo_id="luulinh90s/Tabular-LLM-Study-Data",
52
+ repo_type="space",
53
+ )
54
+ os.remove(temp_file_path)
55
+ logger.info(f"Session data saved for user {username} in Hugging Face Data Space")
56
  except Exception as e:
57
  logger.exception(f"Error saving session data for user {username}: {e}")
58
 
 
 
59
  def load_session_data(username):
60
  try:
61
+ api = HfApi()
62
+ files = api.list_repo_files(repo_id="luulinh90s/Tabular-LLM-Study-Data", repo_type="space")
63
+ user_files = [f for f in files if f.startswith(f'session_data_graded/{username}_') and f.endswith('_session.json')]
64
+ if not user_files:
65
+ logger.warning(f"No session data found for user {username}")
66
+ return None
67
+ latest_file = sorted(user_files, reverse=True)[0]
68
+ file_path = hf_hub_download(repo_id="luulinh90s/Tabular-LLM-Study-Data", repo_type="space", filename=latest_file)
69
  with open(file_path, 'r') as f:
70
  data = json.load(f)
71
+ logger.info(f"Session data loaded for user {username} from Hugging Face Data Space")
 
72
  return data
 
 
 
73
  except Exception as e:
74
  logger.exception(f"Error loading session data for user {username}: {e}")
75
  return None
76
 
77
+ def load_samples():
78
+ common_samples = []
79
+ categories = ["TP", "TN", "FP", "FN"]
 
 
 
 
80
 
81
+ for category in categories:
82
+ files = set(os.listdir(f'htmls_NO_XAI/{category}'))
83
+ for method in ["Dater", "Chain-of-Table", "Plan-of-SQLs"]:
84
+ method_dir = VISUALIZATION_DIRS[method]
85
+ files &= set(os.listdir(f'{method_dir}/{category}'))
86
+
87
+ for file in files:
88
+ common_samples.append({'category': category, 'file': file})
89
+
90
+ logger.info(f"Found {len(common_samples)} common samples across all methods")
91
+ return common_samples
92
 
 
93
  def select_balanced_samples(samples):
94
  try:
95
+ if len(samples) < 10:
96
+ logger.warning(f"Not enough common samples. Only {len(samples)} available.")
97
+ return samples
98
+
99
+ selected_samples = random.sample(samples, 10)
100
+ logger.info(f"Selected 10 unique samples")
101
+ return selected_samples
102
  except Exception as e:
103
  logger.exception("Error selecting balanced samples")
104
  return []
105
 
 
 
 
 
106
  @app.route('/', methods=['GET', 'POST'])
107
  def index():
 
108
  if request.method == 'POST':
109
  username = request.form.get('username')
110
  seed = request.form.get('seed')
111
+ method = request.form.get('method')
112
+ if not username or not seed or not method:
113
+ return "Please fill in all fields and select a method.", 400
 
 
114
  try:
115
  seed = int(seed)
116
  random.seed(seed)
117
+ all_samples = load_samples()
 
 
 
118
  selected_samples = select_balanced_samples(all_samples)
 
 
119
  if len(selected_samples) == 0:
120
+ return "No common samples were found", 500
 
 
 
 
 
 
121
  session_data = {
 
122
  'username': username,
123
+ 'seed': seed,
124
  'method': method,
125
+ 'selected_samples': selected_samples,
126
+ 'current_index': 0,
127
+ 'responses': [],
128
+ 'start_time': datetime.now().isoformat()
129
  }
130
  save_session_data(username, session_data)
 
 
131
  return redirect(url_for('experiment', username=username))
132
  except Exception as e:
133
  logger.exception(f"Error in index route: {e}")
134
  return "An error occurred", 500
135
  return render_template('index.html')
136
+
137
  @app.route('/experiment/<username>', methods=['GET', 'POST'])
138
  def experiment(username):
139
  try:
140
  session_data = load_session_data(username)
141
  if not session_data:
 
142
  return redirect(url_for('index'))
143
 
144
  selected_samples = session_data['selected_samples']
 
148
  if current_index >= len(selected_samples):
149
  return redirect(url_for('completed', username=username))
150
 
151
+ sample = selected_samples[current_index]
152
+ visualization_dir = VISUALIZATION_DIRS[method]
153
+ visualization_path = f"{visualization_dir}/{sample['category']}/{sample['file']}"
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  statement = """
156
  Based on the explanation provided, what do you think the AI model will predict?
 
161
  sample_id=current_index,
162
  statement=statement,
163
  visualization=url_for('send_visualization', filename=visualization_path),
164
+ username=username,
165
+ method=method)
166
  except Exception as e:
167
  logger.exception(f"An error occurred in the experiment route: {e}")
168
  return "An error occurred", 500
169
 
170
+
171
  @app.route('/feedback', methods=['POST'])
172
  def feedback():
173
  try:
 
179
  logger.error(f"No session data found for user: {username}")
180
  return redirect(url_for('index'))
181
 
 
182
  session_data['responses'].append({
183
  'sample_id': session_data['current_index'],
184
  'user_prediction': prediction
185
  })
186
 
 
187
  session_data['current_index'] += 1
 
 
188
  save_session_data(username, session_data)
189
  logger.info(f"Prediction saved for user {username}, sample {session_data['current_index'] - 1}")
190
 
 
196
  logger.exception(f"Error in feedback route: {e}")
197
  return "An error occurred", 500
198
 
 
199
  @app.route('/completed/<username>')
200
  def completed(username):
201
  try:
 
204
  logger.error(f"No session data found for user: {username}")
205
  return redirect(url_for('index'))
206
 
207
+ session_data['end_time'] = datetime.now().isoformat()
208
  responses = session_data['responses']
209
  method = session_data['method']
210
 
211
+ json_file = f'Tabular_LLMs_human_study_vis_6_{method.upper().replace("-", "_")}.json'
212
 
213
  with open(json_file, 'r') as f:
214
  ground_truth = json.load(f)
 
220
  for response in responses:
221
  sample_id = response['sample_id']
222
  user_prediction = response['user_prediction']
223
+ visualization_file = session_data['selected_samples'][sample_id]['file']
224
+ index = visualization_file.split('-')[1].split('.')[0]
225
 
226
+ ground_truth_key = f"{method.upper().replace('-', '_')}_test-{index}.html"
227
 
228
  if ground_truth_key in ground_truth:
229
  model_prediction = ground_truth[ground_truth_key]['answer'].upper()
 
246
  true_percentage = round(true_percentage, 2)
247
  false_percentage = round(false_percentage, 2)
248
 
249
+ session_data['accuracy'] = accuracy
250
+ session_data['true_percentage'] = true_percentage
251
+ session_data['false_percentage'] = false_percentage
252
+
253
+ save_session_data(username, session_data)
254
+
255
  return render_template('completed.html',
256
  accuracy=accuracy,
257
  true_percentage=true_percentage,
 
260
  logger.exception(f"An error occurred in the completed route: {e}")
261
  return "An error occurred", 500
262
 
 
263
  @app.route('/visualizations/<path:filename>')
264
  def send_visualization(filename):
265
  logger.info(f"Attempting to serve file: {filename}")
 
266
  base_dir = os.getcwd()
267
  file_path = os.path.normpath(os.path.join(base_dir, filename))
268
  if not file_path.startswith(base_dir):
 
276
  logger.info(f"Serving file from directory: {directory}, filename: {file_name}")
277
  return send_from_directory(directory, file_name)
278
 
 
279
  if __name__ == "__main__":
 
280
  app.run(host="0.0.0.0", port=7860, debug=True)
templates/index.html CHANGED
@@ -43,12 +43,12 @@
43
  .method-buttons {
44
  display: flex;
45
  flex-wrap: wrap;
46
- justify-content: space-between;
47
  margin-bottom: 20px;
48
  gap: 20px;
49
  }
50
  .method-button {
51
- width: calc(50% - 10px);
52
  padding: 15px;
53
  font-size: 20px;
54
  border-radius: 10px;
@@ -66,11 +66,11 @@
66
  background-color: #ffcc80;
67
  color: #e65100;
68
  }
69
- .method-button.cot-ext {
70
  background-color: #e8f5e9;
71
  color: #4caf50;
72
  }
73
- .method-button.pos-ext {
74
  background-color: #fff3e0;
75
  color: #ff9800;
76
  }
@@ -139,11 +139,11 @@
139
  <div class="method-button pos" onclick="selectMethod('Plan-of-SQLs')">
140
  Plan-of-SQLs
141
  </div>
142
- <div class="method-button cot-ext" onclick="selectMethod('Chain-of-Table-Ext')">
143
- Chain-of-Table-Ext
144
  </div>
145
- <div class="method-button pos-ext" onclick="selectMethod('Plan-of-SQLs-Ext')">
146
- Plan-of-SQLs-Ext
147
  </div>
148
  </div>
149
 
 
43
  .method-buttons {
44
  display: flex;
45
  flex-wrap: wrap;
46
+ justify-content: center;
47
  margin-bottom: 20px;
48
  gap: 20px;
49
  }
50
  .method-button {
51
+ width: calc(45% - 10px);
52
  padding: 15px;
53
  font-size: 20px;
54
  border-radius: 10px;
 
66
  background-color: #ffcc80;
67
  color: #e65100;
68
  }
69
+ .method-button.dater {
70
  background-color: #e8f5e9;
71
  color: #4caf50;
72
  }
73
+ .method-button.no-xai {
74
  background-color: #fff3e0;
75
  color: #ff9800;
76
  }
 
139
  <div class="method-button pos" onclick="selectMethod('Plan-of-SQLs')">
140
  Plan-of-SQLs
141
  </div>
142
+ <div class="method-button dater" onclick="selectMethod('Dater')">
143
+ Dater
144
  </div>
145
+ <div class="method-button no-xai" onclick="selectMethod('No-XAI')">
146
+ No-XAI
147
  </div>
148
  </div>
149