etweedy commited on
Commit
a7b21ef
·
1 Parent(s): 151dae7

Upload 23 files

Browse files
Files changed (4) hide show
  1. README.md +16 -11
  2. app.py +50 -28
  3. app_data.pickle +2 -2
  4. lib/.DS_Store +0 -0
README.md CHANGED
@@ -1,12 +1,17 @@
1
- ---
2
- title: BikeSaferPA
3
- emoji: 🏃
4
- colorFrom: blue
5
- colorTo: pink
6
- sdk: streamlit
7
- sdk_version: 1.25.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## BikeSaferPA: understanding cyclist outcomes
 
 
 
 
 
 
 
 
 
2
 
3
+ This web app provides a suite of tools to accompany Eamonn Tweedy's [BikeSaferPA project](https://github.com/e-tweedy/BikeSaferPA). These tools allow the user to:
4
+ - Visualize data related to crashes involving bicycles in Pennsylvania during the years 2002-2021, which was collected from a publically available [PENNDOT crash dataset](https://pennshare.maps.arcgis.com/apps/webappviewer/index.html?id=8fdbf046e36e41649bbfd9d7dd7c7e7e).
5
+ - Experiment with the BikeSaferPA model, which was trained on this cyclist crash data and designed to predict severity outcomes for cyclists based on crash data.
6
+
7
+ ### [Visit the web app](https://bike-safer-pa.streamlit.app/)
8
+
9
+ ### Repository components:
10
+ - 'cyclists.csv' and 'crashes.csv' : datasets used for analysis
11
+ - 'app.py' : main streamlit app page
12
+ - 'study.pkl' : trained BikeSaferPA machine learning model
13
+ - 'app_data.pkl' : prepared data used for user input widget labels
14
+ - 'lib' : directory of custom modules
15
+ - 'vis_data.py' : data visualization functions
16
+ - 'transform_data'py' : data transformation functions
17
+ - 'study_classif.py' : class for studying machine learning classifiers
app.py CHANGED
@@ -114,6 +114,9 @@ You also have the option to restrict to Philadelpha county only, or the PA count
114
  Expand the toolbox below to choose plot options.
115
  """)
116
 
 
 
 
117
  ### User input - settings for plot ###
118
 
119
  with time_settings_container:
@@ -163,41 +166,42 @@ with time_settings_container:
163
  time_bin_data[k][feat][2]=st.checkbox(time_bin_data[k][feat][0],key=f'time_{feat}')
164
  # if checked, filter samples and add feature to plot title addendum
165
  if time_bin_data[k][feat][2]:
166
- crashes = crashes[crashes[time_bin_data[k][feat][1]]==1]
167
  title_add+= ', '+time_bin_data[k][feat][0].split('one ')[-1]
168
 
169
  ### Post-process user-selected setting data ###
170
 
 
171
  # Geographic restriction
172
  if geo != 'statewide':
173
- crashes = crashes[crashes.COUNTY.isin(geo_data[geo][1])]
174
  # Relegate rare categories to 'other' for plot readability
175
  if stratify=='int_type':
176
- crashes['INTERSECT_TYPE']=crashes['INTERSECT_TYPE']\
177
- .replace({cat:'other' for cat in crashes.INTERSECT_TYPE.value_counts().index[3:]})
178
  if stratify=='coll_type':
179
- crashes['COLLISION_TYPE']=crashes['COLLISION_TYPE']\
180
- .replace({cat:'other' for cat in crashes.COLLISION_TYPE.value_counts().index[6:]})
181
  if stratify=='weather':
182
- crashes['WEATHER']=crashes['WEATHER']\
183
- .replace({cat:'other' for cat in crashes.WEATHER.value_counts().index[5:]})
184
  if stratify=='tcd':
185
- crashes['TCD_TYPE']=crashes['TCD_TYPE']\
186
- .replace({cat:'other' for cat in crashes.TCD_TYPE.value_counts().index[3:]})
187
- crashes=crashes.dropna(subset=period_data[period][1])
188
 
189
  # Order categories in descending order by frequency
190
- category_orders = {time_cat_data[cat][1]:list(crashes[time_cat_data[cat][1]].value_counts().index) for cat in time_cat_data}
191
 
192
  # Define cohort
193
  if cohort == 'inj':
194
- crashes = crashes[crashes.BICYCLE_SUSP_SERIOUS_INJ_COUNT > 0]
195
  elif cohort == 'fat':
196
- crashes = crashes[crashes.BICYCLE_DEATH_COUNT > 0]
197
 
198
  # Replace day,month numbers with string labels
199
  if period in ['day','month']:
200
- crashes[period_data[period][1]] = crashes[period_data[period][1]].apply(lambda x:period_data[period][2][x-1])
201
 
202
  # Plot title addendum
203
  if len(title_add)>0:
@@ -214,8 +218,8 @@ else:
214
 
215
  with time_plot_container:
216
  # Plot samples if any, else report no samples remain
217
- if crashes.shape[0]>0:
218
- fig = px.histogram(crashes,
219
  x=period_data[period][1],
220
  color=color,
221
  nbins=len(period_data[period][2]),
@@ -249,6 +253,9 @@ This tool provides interactive maps of crash events, either statewide or in one
249
  Expand the menu below to adjust map options.
250
  """)
251
 
 
 
 
252
  ### User input - settings for map plot ###
253
 
254
  with map_settings_container:
@@ -290,14 +297,14 @@ else:
290
 
291
  if county is not None:
292
  if animate_by == 'year':
293
- color_dots = len(crashes.query('COUNTY==@county[0] and CRASH_YEAR==2002')\
294
  .BICYCLE_DEATH_COUNT.unique())+\
295
- len(crashes.query('COUNTY==@county[0] and CRASH_YEAR==2002')\
296
  .BICYCLE_SUSP_SERIOUS_INJ_COUNT.unique()) > 3
297
  else:
298
- color_dots = len(crashes.query('COUNTY==@county[0] and CRASH_YEAR==2002 and CRASH_MONTH==1')\
299
  .BICYCLE_DEATH_COUNT.unique())+\
300
- len(crashes.query('COUNTY==@county[0] and CRASH_YEAR==2002 and CRASH_MONTH==1')\
301
  .BICYCLE_SUSP_SERIOUS_INJ_COUNT.unique()) > 3
302
  if color_dots==False:
303
  st.markdown("""
@@ -312,7 +319,7 @@ from lib.vis_data import plot_map
312
 
313
  with map_plot_container:
314
  fig = plot_map(
315
- df=crashes,county=county,animate=animate,
316
  color_dots=color_dots,animate_by=animate_by,
317
  show_fig=False,return_fig=True,
318
  )
@@ -337,6 +344,9 @@ Expand the following menu to choose a feature, and the graph will show its distr
337
  Pay particular attention to feature values which become more or less prevalent among cyclists suffering serious injury or death - for instance, 6.2% of all cyclists statewide were involved in a head-on collision, whereas 11.8% of those with serious injury or fatality were in a head-on collision.
338
  """)
339
 
 
 
 
340
  ### User input - settings for plot ###
341
 
342
  with feature_settings_container:
@@ -360,13 +370,13 @@ with feature_settings_container:
360
  from lib.vis_data import feat_perc,feat_perc_bar
361
  # Geographic restriction
362
  if geo != 'statewide':
363
- cyclists = cyclists[cyclists.COUNTY.isin(geo_data[geo][1])]
364
 
365
  # Recast binary and day of week data
366
  if feature not in ord_features:
367
- cyclists[feature]=cyclists[feature].replace({1:'yes',0:'no'})
368
  if feature == 'DAY_OF_WEEK':
369
- cyclists[feature]=cyclists[feature].astype(str)
370
 
371
  ### Build and display plot ###
372
 
@@ -375,7 +385,7 @@ with feature_plot_container:
375
  # Generate plot
376
  sort = False if feature in ord_features else True
377
  fig = feat_perc_bar(
378
- feature,cyclists, feat_name=feature_names[feature],
379
  return_fig=True,show_fig=False,sort=sort
380
  )
381
 
@@ -550,9 +560,21 @@ The force plot will update as you adjust input features in the menu above.
550
  # shap_values = explainer(sample_trans)
551
  # shap_values_list.append(shap_values.values)
552
  # shap_values = np.array(shap_values_list).sum(axis=0) / len(shap_values_list)
553
- explainer = shap.TreeExplainer(pipe[-1], feature_names = pipe['col'].get_feature_names_out())
 
 
 
 
554
  shap_values = explainer(sample_trans)
555
- sample_trans = pd.DataFrame(sample_trans,columns=pipe['col'].get_feature_names_out())
 
 
 
 
 
 
 
 
556
  # def st_shap(plot, height=None):
557
  # shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
558
  # components.html(shap_html, height=height)
 
114
  Expand the toolbox below to choose plot options.
115
  """)
116
 
117
+ # Copy dataframe for this tab
118
+ crashes_time = crashes.copy()
119
+
120
  ### User input - settings for plot ###
121
 
122
  with time_settings_container:
 
166
  time_bin_data[k][feat][2]=st.checkbox(time_bin_data[k][feat][0],key=f'time_{feat}')
167
  # if checked, filter samples and add feature to plot title addendum
168
  if time_bin_data[k][feat][2]:
169
+ crashes_time = crashes_time[crashes_time[time_bin_data[k][feat][1]]==1]
170
  title_add+= ', '+time_bin_data[k][feat][0].split('one ')[-1]
171
 
172
  ### Post-process user-selected setting data ###
173
 
174
+
175
  # Geographic restriction
176
  if geo != 'statewide':
177
+ crashes_time[crashes_time.COUNTY.isin(geo_data[geo][1])]
178
  # Relegate rare categories to 'other' for plot readability
179
  if stratify=='int_type':
180
+ crashes_time['INTERSECT_TYPE']=crashes_time['INTERSECT_TYPE']\
181
+ .replace({cat:'other' for cat in crashes_time.INTERSECT_TYPE.value_counts().index[3:]})
182
  if stratify=='coll_type':
183
+ crashes_time['COLLISION_TYPE']=crashes_time['COLLISION_TYPE']\
184
+ .replace({cat:'other' for cat in crashes_time.COLLISION_TYPE.value_counts().index[6:]})
185
  if stratify=='weather':
186
+ crashes_time['WEATHER']=crashes_time['WEATHER']\
187
+ .replace({cat:'other' for cat in crashes_time.WEATHER.value_counts().index[5:]})
188
  if stratify=='tcd':
189
+ crashes_time['TCD_TYPE']=crashes_time['TCD_TYPE']\
190
+ .replace({cat:'other' for cat in crashes_time.TCD_TYPE.value_counts().index[3:]})
191
+ crashes_time=crashes_time.dropna(subset=period_data[period][1])
192
 
193
  # Order categories in descending order by frequency
194
+ category_orders = {time_cat_data[cat][1]:list(crashes_time[time_cat_data[cat][1]].value_counts().index) for cat in time_cat_data}
195
 
196
  # Define cohort
197
  if cohort == 'inj':
198
+ crashes_time = crashes_time[crashes_time.BICYCLE_SUSP_SERIOUS_INJ_COUNT > 0]
199
  elif cohort == 'fat':
200
+ crashes_time = crashes_time[crashes_time.BICYCLE_DEATH_COUNT > 0]
201
 
202
  # Replace day,month numbers with string labels
203
  if period in ['day','month']:
204
+ crashes_time[period_data[period][1]] = crashes_time[period_data[period][1]].apply(lambda x:period_data[period][2][x-1])
205
 
206
  # Plot title addendum
207
  if len(title_add)>0:
 
218
 
219
  with time_plot_container:
220
  # Plot samples if any, else report no samples remain
221
+ if crashes_time.shape[0]>0:
222
+ fig = px.histogram(crashes_time,
223
  x=period_data[period][1],
224
  color=color,
225
  nbins=len(period_data[period][2]),
 
253
  Expand the menu below to adjust map options.
254
  """)
255
 
256
+ # Copy dataframe for this tab
257
+ crashes_map = crashes.copy()
258
+
259
  ### User input - settings for map plot ###
260
 
261
  with map_settings_container:
 
297
 
298
  if county is not None:
299
  if animate_by == 'year':
300
+ color_dots = len(crashes_map.query('COUNTY==@county[0] and CRASH_YEAR==2002')\
301
  .BICYCLE_DEATH_COUNT.unique())+\
302
+ len(crashes_map.query('COUNTY==@county[0] and CRASH_YEAR==2002')\
303
  .BICYCLE_SUSP_SERIOUS_INJ_COUNT.unique()) > 3
304
  else:
305
+ color_dots = len(crashes_map.query('COUNTY==@county[0] and CRASH_YEAR==2002 and CRASH_MONTH==1')\
306
  .BICYCLE_DEATH_COUNT.unique())+\
307
+ len(crashes_map.query('COUNTY==@county[0] and CRASH_YEAR==2002 and CRASH_MONTH==1')\
308
  .BICYCLE_SUSP_SERIOUS_INJ_COUNT.unique()) > 3
309
  if color_dots==False:
310
  st.markdown("""
 
319
 
320
  with map_plot_container:
321
  fig = plot_map(
322
+ df=crashes_map,county=county,animate=animate,
323
  color_dots=color_dots,animate_by=animate_by,
324
  show_fig=False,return_fig=True,
325
  )
 
344
  Pay particular attention to feature values which become more or less prevalent among cyclists suffering serious injury or death - for instance, 6.2% of all cyclists statewide were involved in a head-on collision, whereas 11.8% of those with serious injury or fatality were in a head-on collision.
345
  """)
346
 
347
+ # Copy dataframe for this tab
348
+ cyclists_feat = cyclists.copy()
349
+
350
  ### User input - settings for plot ###
351
 
352
  with feature_settings_container:
 
370
  from lib.vis_data import feat_perc,feat_perc_bar
371
  # Geographic restriction
372
  if geo != 'statewide':
373
+ cyclists_feat = cyclists_feat[cyclists_feat.COUNTY.isin(geo_data[geo][1])]
374
 
375
  # Recast binary and day of week data
376
  if feature not in ord_features:
377
+ cyclists_feat[feature]=cyclists_feat[feature].replace({1:'yes',0:'no'})
378
  if feature == 'DAY_OF_WEEK':
379
+ cyclists_feat[feature]=cyclists_feat[feature].astype(str)
380
 
381
  ### Build and display plot ###
382
 
 
385
  # Generate plot
386
  sort = False if feature in ord_features else True
387
  fig = feat_perc_bar(
388
+ feature,cyclists_feat, feat_name=feature_names[feature],
389
  return_fig=True,show_fig=False,sort=sort
390
  )
391
 
 
560
  # shap_values = explainer(sample_trans)
561
  # shap_values_list.append(shap_values.values)
562
  # shap_values = np.array(shap_values_list).sum(axis=0) / len(shap_values_list)
563
+
564
+ #Retrieve feature names
565
+ feature_names = pipe['col'].get_feature_names_out()
566
+
567
+ explainer = shap.TreeExplainer(pipe[-1], feature_names = feature_names)
568
  shap_values = explainer(sample_trans)
569
+ sample_trans = pd.DataFrame(sample_trans,columns=feature_names)
570
+
571
+ # Get arrays of category names from OrdinalEncoder
572
+ cat_names = study.pipe_fitted[-2].transformers_[0][1][-1].categories_
573
+ for ind,feature in enumerate(feature_names):
574
+ if ind < 8:
575
+ cat_dict = {k:v for k,v in enumerate(cat_names[ind])}
576
+ sample_trans[feature] = sample_trans[feature].replace(cat_dict)
577
+
578
  # def st_shap(plot, height=None):
579
  # shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
580
  # components.html(shap_html, height=height)
app_data.pickle CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c8882e6c8ec43e8a4e96724b21f6f1c11347cc9e18317d1a0dbbd5621bd93812
3
- size 4990
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aafa91e4ef4ad7f43b14b54d7c313f4de761f5f645f420541f02f87318dd975c
3
+ size 5000
lib/.DS_Store CHANGED
Binary files a/lib/.DS_Store and b/lib/.DS_Store differ