justinxzhao commited on
Commit
279a804
·
1 Parent(s): 6fae7e2

Factor out judge results code so that it persists when the submit button is inactivated.

Browse files
Files changed (1) hide show
  1. app.py +255 -189
app.py CHANGED
@@ -508,6 +508,14 @@ def plot_per_judge_overall_scores(df):
508
  st.pyplot(plt)
509
 
510
 
 
 
 
 
 
 
 
 
511
  # Main Streamlit App
512
  def main():
513
  st.set_page_config(
@@ -577,6 +585,24 @@ def main():
577
  if "selected_aggregator" not in st.session_state:
578
  st.session_state["selected_aggregator"] = None
579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  with st.form(key="prompt_form"):
581
  st.markdown("#### LLM Council Member Selection")
582
 
@@ -608,9 +634,11 @@ def main():
608
  # Render the chats.
609
  response_columns = st.columns(3)
610
 
611
- selected_models_to_streamlit_column_map = {
612
- model: response_columns[i] for i, model in enumerate(selected_models)
613
- }
 
 
614
 
615
  # Fetching and streaming responses from each selected model
616
  for selected_model in st.session_state.selected_models:
@@ -643,7 +671,7 @@ def main():
643
  selected_aggregator, aggregator_prompt
644
  )
645
  if aggregator_stream:
646
- st.session_state["responses"]["agg__" + selected_aggregator] = (
647
  message_placeholder.write_stream(aggregator_stream)
648
  )
649
 
@@ -654,10 +682,11 @@ def main():
654
  st.markdown("#### Responses")
655
 
656
  response_columns = st.columns(3)
657
- selected_models_to_streamlit_column_map = {
658
- model: response_columns[i]
659
- for i, model in enumerate(st.session_state.selected_models)
660
- }
 
661
  for response_model, response in st.session_state.responses.items():
662
  st_column = selected_models_to_streamlit_column_map.get(
663
  response_model, response_columns[0]
@@ -671,106 +700,56 @@ def main():
671
 
672
  # Judging.
673
  if st.session_state.responses_collected:
674
- st.markdown("#### Judging Configuration")
 
675
 
676
- # Choose the type of assessment
677
- assessment_type = st.radio(
678
- "Select the type of assessment",
679
- options=["Direct Assessment", "Pairwise Comparison"],
680
- )
681
 
682
- _, center_column, _ = st.columns([3, 5, 3])
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
- # Depending on the assessment type, render different forms
685
- if assessment_type == "Direct Assessment":
686
 
687
- # Initialize session state for direct assessment.
688
- if "direct_assessment_overall_score" not in st.session_state:
689
- st.session_state["direct_assessment_overall_score"] = {}
690
- if "direct_assessment_judging_df" not in st.session_state:
691
- st.session_state["direct_assessment_judging_df"] = {}
692
- for response_model in selected_models:
693
- st.session_state["direct_assessment_judging_df"][
694
- response_model
695
- ] = {}
696
- # aggregator model
697
- st.session_state["direct_assessment_judging_df"][
698
- "agg__" + selected_aggregator
699
- ] = {}
700
- if "direct_assessment_judging_responses" not in st.session_state:
701
- st.session_state["direct_assessment_judging_responses"] = {}
702
- for response_model in selected_models:
703
- st.session_state["direct_assessment_judging_responses"][
704
- response_model
705
- ] = {}
706
- # aggregator model
707
- st.session_state["direct_assessment_judging_responses"][
708
- "agg__" + selected_aggregator
709
- ] = {}
710
- if "direct_assessment_overall_scores" not in st.session_state:
711
- st.session_state["direct_assessment_overall_scores"] = {}
712
- for response_model in selected_models:
713
- st.session_state["direct_assessment_overall_scores"][
714
- response_model
715
- ] = {}
716
- st.session_state["direct_assessment_overall_scores"][
717
- "agg__" + selected_aggregator
718
- ] = {}
719
- if "judging_status" not in st.session_state:
720
- st.session_state["judging_status"] = "incomplete"
721
-
722
- # Direct assessment prompt.
723
- with center_column.expander("Direct Assessment Prompt"):
724
- direct_assessment_prompt = st.text_area(
725
- "Prompt for the Direct Assessment",
726
- value=get_default_direct_assessment_prompt(
727
- user_prompt=user_prompt
728
- ),
729
- height=500,
730
- key="direct_assessment_prompt",
731
  )
732
 
733
- # TODO: Add option to edit criteria list with a basic text field.
734
- criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
735
-
736
- # Create DirectAssessment object when form is submitted
737
- if center_column.button(
738
- "Submit Direct Assessment", use_container_width=True
739
- ):
740
-
741
- # Render the chats.
742
- response_columns = st.columns(3)
743
- selected_models_to_streamlit_column_map = {
744
- model: response_columns[i]
745
- for i, model in enumerate(selected_models)
746
- }
747
- for response_model, response in st.session_state[
748
- "responses"
749
- ].items():
750
- st_column = selected_models_to_streamlit_column_map.get(
751
- response_model, response_columns[0]
752
- )
753
- with st_column:
754
- with st.chat_message(
755
- get_ui_friendly_name(response_model),
756
- avatar=get_llm_avatar(response_model),
757
- ):
758
- st.write(get_ui_friendly_name(response_model))
759
- st.write(response)
760
-
761
- # Submit direct asssessment.
762
- responses_for_judging = st.session_state["responses"]
763
 
764
- response_judging_columns = st.columns(3)
765
 
766
- responses_for_judging_to_streamlit_column_map = {
767
- model: response_judging_columns[i % 3]
768
- for i, model in enumerate(responses_for_judging.keys())
769
- }
 
 
 
770
 
771
- # Get judging responses.
772
  for response_model, response in responses_for_judging.items():
773
-
774
  st_column = responses_for_judging_to_streamlit_column_map[
775
  response_model
776
  ]
@@ -792,7 +771,7 @@ def main():
792
 
793
  for judging_model in selected_models:
794
  with st.expander(
795
- get_ui_friendly_name(judging_model), expanded=False
796
  ):
797
  with st.chat_message(
798
  judging_model,
@@ -811,7 +790,6 @@ def main():
811
  )
812
  # When all of the judging is finished for the given response, get the actual
813
  # values, parsed.
814
- # TODO.
815
  judging_responses = st.session_state[
816
  "direct_assessment_judging_responses"
817
  ][response_model]
@@ -872,104 +850,192 @@ def main():
872
  st.write(f"Overall Score: {overall_score:.2f}")
873
  st.write(f"Controversy: {controversy:.2f}")
874
 
875
- st.session_state["judging_status"] = "complete"
 
 
 
 
876
 
877
- # Judging is complete.
878
- # The session state now contains the overall scores for each response from each judge.
879
- if st.session_state["judging_status"] == "complete":
880
- st.write("#### Results")
 
 
 
881
 
882
- overall_scores_df_raw = pd.DataFrame(
883
- st.session_state["direct_assessment_overall_scores"]
884
- ).reset_index()
 
885
 
886
- overall_scores_df = pd.melt(
887
- overall_scores_df_raw,
888
- id_vars=["index"],
889
- var_name="response_model",
890
- value_name="score",
891
- ).rename(columns={"index": "judging_model"})
 
 
 
 
 
892
 
893
- # Print the overall winner.
894
- overall_winner = overall_scores_df.loc[
895
- overall_scores_df["score"].idxmax()
896
- ]
897
 
898
- st.write(
899
- f"**Overall Winner:** {get_ui_friendly_name(overall_winner['response_model'])}"
900
- )
901
- # Find how much the standard deviation overlaps with other models.
902
- # Calculate separability.
903
- # TODO.
904
- st.write(f"**Confidence:** {overall_winner['score']:.2f}")
905
-
906
- left_column, right_column = st.columns([1, 1])
907
- with left_column:
908
- plot_overall_scores(overall_scores_df)
909
-
910
- with right_column:
911
- # All overall scores.
912
- overall_scores_df = overall_scores_df[
913
- ["response_model", "judging_model", "score"]
914
- ]
915
- overall_scores_df["response_model"] = overall_scores_df[
916
- "response_model"
917
- ].apply(get_ui_friendly_name)
918
- overall_scores_df["judging_model"] = overall_scores_df[
919
- "judging_model"
920
- ].apply(get_ui_friendly_name)
921
-
922
- with st.expander("Overall scores from all judges"):
923
- st.dataframe(overall_scores_df)
924
-
925
- # All criteria scores.
926
- with right_column:
927
- all_scores_df = pd.DataFrame()
928
- for response_model, score_df in st.session_state[
929
- "direct_assessment_judging_df"
930
- ].items():
931
- score_df["response_model"] = response_model
932
- all_scores_df = pd.concat([all_scores_df, score_df])
933
- all_scores_df = all_scores_df.reset_index()
934
- all_scores_df = all_scores_df.drop(columns="index")
935
-
936
- # Reorder the columns
937
- all_scores_df = all_scores_df[
938
- [
939
- "response_model",
940
- "judging_model",
941
- "criteria",
942
- "score",
943
- "explanation",
944
- ]
945
- ]
946
- all_scores_df["response_model"] = all_scores_df[
947
- "response_model"
948
- ].apply(get_ui_friendly_name)
949
- all_scores_df["judging_model"] = all_scores_df[
950
- "judging_model"
951
- ].apply(get_ui_friendly_name)
952
 
953
- with st.expander(
954
- "Criteria-specific scores and explanations from all judges"
955
- ):
956
- st.dataframe(all_scores_df)
 
 
 
957
 
958
- elif assessment_type == "Pairwise Comparison":
959
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960
 
961
  # Token usage.
962
- with st.expander("Token Usage"):
963
- st.write("Input tokens used.")
964
- st.write(st.session_state.input_token_usage)
965
- st.write(
966
- f"Input Tokens Total: {sum(st.session_state.input_token_usage.values())}"
967
- )
968
- st.write("Output tokens used.")
969
- st.write(st.session_state.output_token_usage)
970
- st.write(
971
- f"Output Tokens Total: {sum(st.session_state.output_token_usage.values())}"
972
- )
 
 
973
 
974
  else:
975
  with cols[1]:
 
508
  st.pyplot(plt)
509
 
510
 
511
+ def get_selected_models_to_streamlit_column_map(st_columns, selected_models):
512
+ selected_models_to_streamlit_column_map = {
513
+ model: st_columns[i % len(st_columns)]
514
+ for i, model in enumerate(selected_models)
515
+ }
516
+ return selected_models_to_streamlit_column_map
517
+
518
+
519
  # Main Streamlit App
520
  def main():
521
  st.set_page_config(
 
585
  if "selected_aggregator" not in st.session_state:
586
  st.session_state["selected_aggregator"] = None
587
 
588
+ # Initialize session state for direct assessment judging.
589
+ if "direct_assessment_overall_score" not in st.session_state:
590
+ st.session_state.direct_assessment_overall_score = {}
591
+ if "direct_assessment_judging_df" not in st.session_state:
592
+ st.session_state.direct_assessment_judging_df = defaultdict(dict)
593
+ if "direct_assessment_judging_responses" not in st.session_state:
594
+ st.session_state.direct_assessment_judging_responses = defaultdict(dict)
595
+ if "direct_assessment_overall_scores" not in st.session_state:
596
+ st.session_state.direct_assessment_overall_scores = defaultdict(dict)
597
+ if "judging_status" not in st.session_state:
598
+ st.session_state.judging_status = "incomplete"
599
+ if "direct_assessment_config" not in st.session_state:
600
+ st.session_state.direct_assessment_config = {}
601
+ if "pairwise_comparison_config" not in st.session_state:
602
+ st.session_state.pairwise_comparison_config = {}
603
+ if "assessment_type" not in st.session_state:
604
+ st.session_state.assessment_type = None
605
+
606
  with st.form(key="prompt_form"):
607
  st.markdown("#### LLM Council Member Selection")
608
 
 
634
  # Render the chats.
635
  response_columns = st.columns(3)
636
 
637
+ selected_models_to_streamlit_column_map = (
638
+ get_selected_models_to_streamlit_column_map(
639
+ response_columns, selected_models
640
+ )
641
+ )
642
 
643
  # Fetching and streaming responses from each selected model
644
  for selected_model in st.session_state.selected_models:
 
671
  selected_aggregator, aggregator_prompt
672
  )
673
  if aggregator_stream:
674
+ st.session_state.responses["agg__" + selected_aggregator] = (
675
  message_placeholder.write_stream(aggregator_stream)
676
  )
677
 
 
682
  st.markdown("#### Responses")
683
 
684
  response_columns = st.columns(3)
685
+ selected_models_to_streamlit_column_map = (
686
+ get_selected_models_to_streamlit_column_map(
687
+ response_columns, st.session_state.selected_models
688
+ )
689
+ )
690
  for response_model, response in st.session_state.responses.items():
691
  st_column = selected_models_to_streamlit_column_map.get(
692
  response_model, response_columns[0]
 
700
 
701
  # Judging.
702
  if st.session_state.responses_collected:
703
+ with st.form(key="judging_form"):
704
+ st.markdown("#### Judging Configuration")
705
 
706
+ # Choose the type of assessment
707
+ assessment_type = st.radio(
708
+ "Select the type of assessment",
709
+ options=["Direct Assessment", "Pairwise Comparison"],
710
+ )
711
 
712
+ _, center_column, _ = st.columns([3, 5, 3])
713
+
714
+ # Depending on the assessment type, render different forms
715
+ if assessment_type == "Direct Assessment":
716
+ # Direct assessment prompt.
717
+ with center_column.expander("Direct Assessment Prompt"):
718
+ direct_assessment_prompt = st.text_area(
719
+ "Prompt for the Direct Assessment",
720
+ value=get_default_direct_assessment_prompt(
721
+ user_prompt=user_prompt
722
+ ),
723
+ height=500,
724
+ key="direct_assessment_prompt",
725
+ )
726
 
727
+ # TODO: Add option to edit criteria list with a basic text field.
728
+ criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
729
 
730
+ judging_submit_button = st.form_submit_button(
731
+ "Submit Judging", use_container_width=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
  )
733
 
734
+ if judging_submit_button:
735
+ st.session_state.assessment_type = assessment_type
736
+ st.session_state.direct_assessment_config = {
737
+ "prompt": direct_assessment_prompt,
738
+ "criteria_list": criteria_list,
739
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740
 
741
+ responses_for_judging = st.session_state.responses
742
 
743
+ # Get judging responses.
744
+ response_judging_columns = st.columns(3)
745
+ responses_for_judging_to_streamlit_column_map = (
746
+ get_selected_models_to_streamlit_column_map(
747
+ response_judging_columns, responses_for_judging.keys()
748
+ )
749
+ )
750
 
751
+ if st.session_state.assessment_type == "Direct Assessment":
752
  for response_model, response in responses_for_judging.items():
 
753
  st_column = responses_for_judging_to_streamlit_column_map[
754
  response_model
755
  ]
 
771
 
772
  for judging_model in selected_models:
773
  with st.expander(
774
+ get_ui_friendly_name(judging_model), expanded=True
775
  ):
776
  with st.chat_message(
777
  judging_model,
 
790
  )
791
  # When all of the judging is finished for the given response, get the actual
792
  # values, parsed.
 
793
  judging_responses = st.session_state[
794
  "direct_assessment_judging_responses"
795
  ][response_model]
 
850
  st.write(f"Overall Score: {overall_score:.2f}")
851
  st.write(f"Controversy: {controversy:.2f}")
852
 
853
+ st.session_state.judging_status = "complete"
854
+ # If judging is complete, but the submit button is cleared, still render the results.
855
+ elif st.session_state.judging_status == "complete":
856
+ if st.session_state.assessment_type == "Direct Assessment":
857
+ responses_for_judging = st.session_state.responses
858
 
859
+ # Get judging responses.
860
+ response_judging_columns = st.columns(3)
861
+ responses_for_judging_to_streamlit_column_map = (
862
+ get_selected_models_to_streamlit_column_map(
863
+ response_judging_columns, responses_for_judging.keys()
864
+ )
865
+ )
866
 
867
+ for response_model, response in responses_for_judging.items():
868
+ st_column = responses_for_judging_to_streamlit_column_map[
869
+ response_model
870
+ ]
871
 
872
+ with st_column:
873
+ st.write(
874
+ f"Judging for {get_ui_friendly_name(response_model)}"
875
+ )
876
+ judging_prompt = get_direct_assessment_prompt(
877
+ direct_assessment_prompt=direct_assessment_prompt,
878
+ user_prompt=user_prompt,
879
+ response=response,
880
+ criteria_list=criteria_list,
881
+ options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
882
+ )
883
 
884
+ with st.expander("Final Judging Prompt"):
885
+ st.code(judging_prompt)
 
 
886
 
887
+ for judging_model in selected_models:
888
+ with st.expander(
889
+ get_ui_friendly_name(judging_model), expanded=True
890
+ ):
891
+ with st.chat_message(
892
+ judging_model,
893
+ avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
894
+ ):
895
+ st.write(
896
+ st.session_state.direct_assessment_judging_responses[
897
+ response_model
898
+ ][
899
+ judging_model
900
+ ]
901
+ )
902
+ # When all of the judging is finished for the given response, get the actual
903
+ # values, parsed.
904
+ judging_responses = (
905
+ st.session_state.direct_assessment_judging_responses[
906
+ response_model
907
+ ]
908
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
909
 
910
+ parse_judging_response_prompt = (
911
+ get_parse_judging_response_for_direct_assessment_prompt(
912
+ judging_responses,
913
+ criteria_list,
914
+ SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
915
+ )
916
+ )
917
 
918
+ plot_criteria_scores(
919
+ st.session_state.direct_assessment_judging_df[
920
+ response_model
921
+ ]
922
+ )
923
+
924
+ plot_per_judge_overall_scores(
925
+ st.session_state.direct_assessment_judging_df[
926
+ response_model
927
+ ]
928
+ )
929
+
930
+ grouped = (
931
+ st.session_state.direct_assessment_judging_df[
932
+ response_model
933
+ ]
934
+ .groupby(["judging_model"])
935
+ .agg({"score": ["mean"]})
936
+ .reset_index()
937
+ )
938
+ grouped.columns = ["judging_model", "overall_score"]
939
+
940
+ overall_score = grouped["overall_score"].mean()
941
+ controversy = grouped["overall_score"].std()
942
+ st.write(f"Overall Score: {overall_score:.2f}")
943
+ st.write(f"Controversy: {controversy:.2f}")
944
+
945
+ # Judging is complete, stuff that would be rendered that's not stream-specific.
946
+ # The session state now contains the overall scores for each response from each judge.
947
+ if st.session_state.judging_status == "complete":
948
+ st.write("#### Results")
949
+
950
+ overall_scores_df_raw = pd.DataFrame(
951
+ st.session_state["direct_assessment_overall_scores"]
952
+ ).reset_index()
953
+
954
+ overall_scores_df = pd.melt(
955
+ overall_scores_df_raw,
956
+ id_vars=["index"],
957
+ var_name="response_model",
958
+ value_name="score",
959
+ ).rename(columns={"index": "judging_model"})
960
+
961
+ # Print the overall winner.
962
+ overall_winner = overall_scores_df.loc[
963
+ overall_scores_df["score"].idxmax()
964
+ ]
965
+
966
+ st.write(
967
+ f"**Overall Winner:** {get_ui_friendly_name(overall_winner['response_model'])}"
968
+ )
969
+ # Find how much the standard deviation overlaps with other models
970
+ # TODO: Calculate separability.
971
+ st.write(f"**Confidence:** {overall_winner['score']:.2f}")
972
+
973
+ left_column, right_column = st.columns([1, 1])
974
+ with left_column:
975
+ plot_overall_scores(overall_scores_df)
976
+
977
+ with right_column:
978
+ # All overall scores.
979
+ overall_scores_df = overall_scores_df[
980
+ ["response_model", "judging_model", "score"]
981
+ ]
982
+ overall_scores_df["response_model"] = overall_scores_df[
983
+ "response_model"
984
+ ].apply(get_ui_friendly_name)
985
+ overall_scores_df["judging_model"] = overall_scores_df[
986
+ "judging_model"
987
+ ].apply(get_ui_friendly_name)
988
+
989
+ with st.expander("Overall scores from all judges"):
990
+ st.dataframe(overall_scores_df)
991
+
992
+ # All criteria scores.
993
+ with right_column:
994
+ all_scores_df = pd.DataFrame()
995
+ for response_model, score_df in st.session_state[
996
+ "direct_assessment_judging_df"
997
+ ].items():
998
+ score_df["response_model"] = response_model
999
+ all_scores_df = pd.concat([all_scores_df, score_df])
1000
+ all_scores_df = all_scores_df.reset_index()
1001
+ all_scores_df = all_scores_df.drop(columns="index")
1002
+
1003
+ # Reorder the columns
1004
+ all_scores_df = all_scores_df[
1005
+ [
1006
+ "response_model",
1007
+ "judging_model",
1008
+ "criteria",
1009
+ "score",
1010
+ "explanation",
1011
+ ]
1012
+ ]
1013
+ all_scores_df["response_model"] = all_scores_df[
1014
+ "response_model"
1015
+ ].apply(get_ui_friendly_name)
1016
+ all_scores_df["judging_model"] = all_scores_df[
1017
+ "judging_model"
1018
+ ].apply(get_ui_friendly_name)
1019
+
1020
+ with st.expander(
1021
+ "Criteria-specific scores and explanations from all judges"
1022
+ ):
1023
+ st.dataframe(all_scores_df)
1024
 
1025
  # Token usage.
1026
+ if st.session_state.responses:
1027
+ st.divider()
1028
+ with st.expander("Token Usage"):
1029
+ st.write("Input tokens used.")
1030
+ st.write(st.session_state.input_token_usage)
1031
+ st.write(
1032
+ f"Input Tokens Total: {sum(st.session_state.input_token_usage.values())}"
1033
+ )
1034
+ st.write("Output tokens used.")
1035
+ st.write(st.session_state.output_token_usage)
1036
+ st.write(
1037
+ f"Output Tokens Total: {sum(st.session_state.output_token_usage.values())}"
1038
+ )
1039
 
1040
  else:
1041
  with cols[1]: