justinxzhao commited on
Commit
a0dca54
·
1 Parent(s): 279a804

Factor out LLM chat rendering so that it persists even when the submit button isn't active.

Browse files
Files changed (1) hide show
  1. app.py +222 -269
app.py CHANGED
@@ -516,6 +516,206 @@ def get_selected_models_to_streamlit_column_map(st_columns, selected_models):
516
  return selected_models_to_streamlit_column_map
517
 
518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  # Main Streamlit App
520
  def main():
521
  st.set_page_config(
@@ -632,71 +832,11 @@ def main():
632
  st.session_state.selected_aggregator = selected_aggregator
633
 
634
  # Render the chats.
635
- response_columns = st.columns(3)
636
-
637
- selected_models_to_streamlit_column_map = (
638
- get_selected_models_to_streamlit_column_map(
639
- response_columns, selected_models
640
- )
641
- )
642
 
643
- # Fetching and streaming responses from each selected model
644
- for selected_model in st.session_state.selected_models:
645
- with selected_models_to_streamlit_column_map[selected_model]:
646
- st.write(get_ui_friendly_name(selected_model))
647
- with st.chat_message(
648
- selected_model,
649
- avatar=PROVIDER_TO_AVATAR_MAP[selected_model],
650
- ):
651
- message_placeholder = st.empty()
652
- stream = get_llm_response_stream(selected_model, user_prompt)
653
- if stream:
654
- st.session_state["responses"][selected_model] = (
655
- message_placeholder.write_stream(stream)
656
- )
657
-
658
- # Get the aggregator prompt.
659
- aggregator_prompt = get_default_aggregator_prompt(
660
- user_prompt=user_prompt, llms=selected_models
661
- )
662
-
663
- # Fetching and streaming response from the aggregator
664
- st.write(f"{get_ui_friendly_name(selected_aggregator)}")
665
- with st.chat_message(
666
- selected_aggregator,
667
- avatar="img/council_icon.png",
668
- ):
669
- message_placeholder = st.empty()
670
- aggregator_stream = get_llm_response_stream(
671
- selected_aggregator, aggregator_prompt
672
- )
673
- if aggregator_stream:
674
- st.session_state.responses["agg__" + selected_aggregator] = (
675
- message_placeholder.write_stream(aggregator_stream)
676
- )
677
-
678
- st.session_state.responses_collected = True
679
-
680
- # Render chats generally?
681
- if st.session_state.responses and not submit_button:
682
- st.markdown("#### Responses")
683
-
684
- response_columns = st.columns(3)
685
- selected_models_to_streamlit_column_map = (
686
- get_selected_models_to_streamlit_column_map(
687
- response_columns, st.session_state.selected_models
688
- )
689
- )
690
- for response_model, response in st.session_state.responses.items():
691
- st_column = selected_models_to_streamlit_column_map.get(
692
- response_model, response_columns[0]
693
- )
694
- with st_column.chat_message(
695
- response_model,
696
- avatar=get_llm_avatar(response_model),
697
- ):
698
- st.write(get_ui_friendly_name(response_model))
699
- st.write(response)
700
 
701
  # Judging.
702
  if st.session_state.responses_collected:
@@ -727,228 +867,41 @@ def main():
727
  # TODO: Add option to edit criteria list with a basic text field.
728
  criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
729
 
 
730
  judging_submit_button = st.form_submit_button(
731
  "Submit Judging", use_container_width=True
732
  )
733
 
734
  if judging_submit_button:
 
735
  st.session_state.assessment_type = assessment_type
736
- st.session_state.direct_assessment_config = {
737
- "prompt": direct_assessment_prompt,
738
- "criteria_list": criteria_list,
739
- }
740
-
741
- responses_for_judging = st.session_state.responses
742
-
743
- # Get judging responses.
744
- response_judging_columns = st.columns(3)
745
- responses_for_judging_to_streamlit_column_map = (
746
- get_selected_models_to_streamlit_column_map(
747
- response_judging_columns, responses_for_judging.keys()
748
- )
749
- )
750
-
751
  if st.session_state.assessment_type == "Direct Assessment":
752
- for response_model, response in responses_for_judging.items():
753
- st_column = responses_for_judging_to_streamlit_column_map[
754
- response_model
755
- ]
756
-
757
- with st_column:
758
- st.write(
759
- f"Judging for {get_ui_friendly_name(response_model)}"
760
- )
761
- judging_prompt = get_direct_assessment_prompt(
762
- direct_assessment_prompt=direct_assessment_prompt,
763
- user_prompt=user_prompt,
764
- response=response,
765
- criteria_list=criteria_list,
766
- options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
767
- )
768
-
769
- with st.expander("Final Judging Prompt"):
770
- st.code(judging_prompt)
771
-
772
- for judging_model in selected_models:
773
- with st.expander(
774
- get_ui_friendly_name(judging_model), expanded=True
775
- ):
776
- with st.chat_message(
777
- judging_model,
778
- avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
779
- ):
780
- message_placeholder = st.empty()
781
- judging_stream = get_llm_response_stream(
782
- judging_model, judging_prompt
783
- )
784
- st.session_state[
785
- "direct_assessment_judging_responses"
786
- ][response_model][
787
- judging_model
788
- ] = message_placeholder.write_stream(
789
- judging_stream
790
- )
791
- # When all of the judging is finished for the given response, get the actual
792
- # values, parsed.
793
- judging_responses = st.session_state[
794
- "direct_assessment_judging_responses"
795
- ][response_model]
796
-
797
- if not judging_responses:
798
- st.error(f"No judging responses for {response_model}")
799
- quit()
800
- parse_judging_response_prompt = (
801
- get_parse_judging_response_for_direct_assessment_prompt(
802
- judging_responses,
803
- criteria_list,
804
- SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
805
- )
806
- )
807
- # Issue the prompt to openai mini with structured outputs
808
- parsed_judging_responses = parse_judging_responses(
809
- parse_judging_response_prompt, judging_responses
810
- )
811
-
812
- st.session_state["direct_assessment_judging_df"][
813
- response_model
814
- ] = create_dataframe_for_direct_assessment_judging_response(
815
- parsed_judging_responses
816
- )
817
-
818
- plot_criteria_scores(
819
- st.session_state["direct_assessment_judging_df"][
820
- response_model
821
- ]
822
- )
823
-
824
- # Find the overall score by finding the overall score for each judge, and then averaging
825
- # over all judges.
826
- plot_per_judge_overall_scores(
827
- st.session_state["direct_assessment_judging_df"][
828
- response_model
829
- ]
830
- )
831
-
832
- grouped = (
833
- st.session_state["direct_assessment_judging_df"][
834
- response_model
835
- ]
836
- .groupby(["judging_model"])
837
- .agg({"score": ["mean"]})
838
- .reset_index()
839
- )
840
- grouped.columns = ["judging_model", "overall_score"]
841
-
842
- # Save the overall scores to the session state.
843
- for record in grouped.to_dict(orient="records"):
844
- st.session_state["direct_assessment_overall_scores"][
845
- response_model
846
- ][record["judging_model"]] = record["overall_score"]
847
-
848
- overall_score = grouped["overall_score"].mean()
849
- controversy = grouped["overall_score"].std()
850
- st.write(f"Overall Score: {overall_score:.2f}")
851
- st.write(f"Controversy: {controversy:.2f}")
852
-
853
- st.session_state.judging_status = "complete"
854
  # If judging is complete, but the submit button is cleared, still render the results.
855
  elif st.session_state.judging_status == "complete":
856
  if st.session_state.assessment_type == "Direct Assessment":
857
- responses_for_judging = st.session_state.responses
858
-
859
- # Get judging responses.
860
- response_judging_columns = st.columns(3)
861
- responses_for_judging_to_streamlit_column_map = (
862
- get_selected_models_to_streamlit_column_map(
863
- response_judging_columns, responses_for_judging.keys()
864
- )
865
  )
866
 
867
- for response_model, response in responses_for_judging.items():
868
- st_column = responses_for_judging_to_streamlit_column_map[
869
- response_model
870
- ]
871
-
872
- with st_column:
873
- st.write(
874
- f"Judging for {get_ui_friendly_name(response_model)}"
875
- )
876
- judging_prompt = get_direct_assessment_prompt(
877
- direct_assessment_prompt=direct_assessment_prompt,
878
- user_prompt=user_prompt,
879
- response=response,
880
- criteria_list=criteria_list,
881
- options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
882
- )
883
-
884
- with st.expander("Final Judging Prompt"):
885
- st.code(judging_prompt)
886
-
887
- for judging_model in selected_models:
888
- with st.expander(
889
- get_ui_friendly_name(judging_model), expanded=True
890
- ):
891
- with st.chat_message(
892
- judging_model,
893
- avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
894
- ):
895
- st.write(
896
- st.session_state.direct_assessment_judging_responses[
897
- response_model
898
- ][
899
- judging_model
900
- ]
901
- )
902
- # When all of the judging is finished for the given response, get the actual
903
- # values, parsed.
904
- judging_responses = (
905
- st.session_state.direct_assessment_judging_responses[
906
- response_model
907
- ]
908
- )
909
-
910
- parse_judging_response_prompt = (
911
- get_parse_judging_response_for_direct_assessment_prompt(
912
- judging_responses,
913
- criteria_list,
914
- SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
915
- )
916
- )
917
-
918
- plot_criteria_scores(
919
- st.session_state.direct_assessment_judging_df[
920
- response_model
921
- ]
922
- )
923
-
924
- plot_per_judge_overall_scores(
925
- st.session_state.direct_assessment_judging_df[
926
- response_model
927
- ]
928
- )
929
-
930
- grouped = (
931
- st.session_state.direct_assessment_judging_df[
932
- response_model
933
- ]
934
- .groupby(["judging_model"])
935
- .agg({"score": ["mean"]})
936
- .reset_index()
937
- )
938
- grouped.columns = ["judging_model", "overall_score"]
939
-
940
- overall_score = grouped["overall_score"].mean()
941
- controversy = grouped["overall_score"].std()
942
- st.write(f"Overall Score: {overall_score:.2f}")
943
- st.write(f"Controversy: {controversy:.2f}")
944
-
945
- # Judging is complete, stuff that would be rendered that's not stream-specific.
946
  # The session state now contains the overall scores for each response from each judge.
947
  if st.session_state.judging_status == "complete":
948
  st.write("#### Results")
949
 
950
  overall_scores_df_raw = pd.DataFrame(
951
- st.session_state["direct_assessment_overall_scores"]
952
  ).reset_index()
953
 
954
  overall_scores_df = pd.melt(
 
516
  return selected_models_to_streamlit_column_map
517
 
518
 
519
+ def get_aggregator_key(llm_aggregator):
520
+ return "agg__" + llm_aggregator
521
+
522
+
523
+ def st_render_responses(user_prompt):
524
+ """Renders the responses from the LLMs.
525
+
526
+ Uses cached responses from the session state, if available.
527
+ Otherwise, streams the responses anew.
528
+
529
+ Assumes that the session state has already been set up with selected models and selected aggregator.
530
+ """
531
+ st.markdown("#### Responses")
532
+
533
+ response_columns = st.columns(3)
534
+ selected_models_to_streamlit_column_map = (
535
+ get_selected_models_to_streamlit_column_map(
536
+ response_columns, st.session_state.selected_models
537
+ )
538
+ )
539
+ for response_model in st.session_state.selected_models:
540
+ st_column = selected_models_to_streamlit_column_map.get(
541
+ response_model, response_columns[0]
542
+ )
543
+
544
+ with st_column.chat_message(
545
+ response_model,
546
+ avatar=get_llm_avatar(response_model),
547
+ ):
548
+ st.write(get_ui_friendly_name(response_model))
549
+ if response_model in st.session_state.responses:
550
+ # Use the cached response from session state.
551
+ st.write(st.session_state.responses[response_model])
552
+ else:
553
+ # Stream the response from the LLM.
554
+ message_placeholder = st.empty()
555
+ stream = get_llm_response_stream(response_model, user_prompt)
556
+ st.session_state.responses[response_model] = (
557
+ message_placeholder.write_stream(stream)
558
+ )
559
+
560
+ # Render the aggregator response.
561
+ aggregator_prompt = get_default_aggregator_prompt(
562
+ user_prompt=user_prompt, llms=st.session_state.selected_models
563
+ )
564
+
565
+ # Streaming response from the aggregator.
566
+ with st.chat_message(
567
+ get_aggregator_key(st.session_state.selected_aggregator),
568
+ avatar="img/council_icon.png",
569
+ ):
570
+ st.write(
571
+ f"{get_ui_friendly_name(get_aggregator_key(st.session_state.selected_aggregator))}"
572
+ )
573
+ if (
574
+ get_aggregator_key(st.session_state.selected_aggregator)
575
+ in st.session_state.responses
576
+ ):
577
+ st.write(
578
+ st.session_state.responses[
579
+ get_aggregator_key(st.session_state.selected_aggregator)
580
+ ]
581
+ )
582
+ else:
583
+ message_placeholder = st.empty()
584
+ aggregator_stream = get_llm_response_stream(
585
+ selected_aggregator, aggregator_prompt
586
+ )
587
+ if aggregator_stream:
588
+ st.session_state.responses[get_aggregator_key(selected_aggregator)] = (
589
+ message_placeholder.write_stream(aggregator_stream)
590
+ )
591
+
592
+ st.session_state.responses_collected = True
593
+
594
+
595
+ def st_direct_assessment_results(user_prompt, direct_assessment_prompt, criteria_list):
596
+ """Renders the direct assessment results block.
597
+
598
+ Uses session state to render results from LLMs. If the session state isn't set, then fetches the
599
+ responses from the LLMs services from scratch (and sets the session state).
600
+
601
+ Assumes that the session state has already been set up with responses.
602
+ """
603
+ responses_for_judging = st.session_state.responses
604
+
605
+ # Get judging responses.
606
+ response_judging_columns = st.columns(3)
607
+ responses_for_judging_to_streamlit_column_map = (
608
+ get_selected_models_to_streamlit_column_map(
609
+ response_judging_columns, responses_for_judging.keys()
610
+ )
611
+ )
612
+
613
+ for response_model, response in responses_for_judging.items():
614
+ st_column = responses_for_judging_to_streamlit_column_map[response_model]
615
+
616
+ with st_column:
617
+ st.write(f"Judging for {get_ui_friendly_name(response_model)}")
618
+ judging_prompt = get_direct_assessment_prompt(
619
+ direct_assessment_prompt=direct_assessment_prompt,
620
+ user_prompt=user_prompt,
621
+ response=response,
622
+ criteria_list=criteria_list,
623
+ options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
624
+ )
625
+
626
+ with st.expander("Final Judging Prompt"):
627
+ st.code(judging_prompt)
628
+
629
+ for judging_model in st.session_state.selected_models:
630
+ with st.expander(get_ui_friendly_name(judging_model), expanded=True):
631
+ with st.chat_message(
632
+ judging_model,
633
+ avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
634
+ ):
635
+ if (
636
+ judging_model
637
+ in st.session_state.direct_assessment_judging_responses[
638
+ response_model
639
+ ]
640
+ ):
641
+ # Use the session state cached response.
642
+ st.write(
643
+ st.session_state.direct_assessment_judging_responses[
644
+ response_model
645
+ ][judging_model]
646
+ )
647
+ else:
648
+ message_placeholder = st.empty()
649
+ # Get the judging response from the LLM.
650
+ judging_stream = get_llm_response_stream(
651
+ judging_model, judging_prompt
652
+ )
653
+ st.session_state.direct_assessment_judging_responses[
654
+ response_model
655
+ ][judging_model] = message_placeholder.write_stream(
656
+ judging_stream
657
+ )
658
+
659
+ # Extract actual scores from open-ended responses using structured outputs.
660
+ # Since we're extracting structured data for the first time, we can save the dataframe
661
+ # to the session state so that it's cached.
662
+ if response_model not in st.session_state.direct_assessment_judging_df:
663
+ judging_responses = (
664
+ st.session_state.direct_assessment_judging_responses[response_model]
665
+ )
666
+ parse_judging_response_prompt = (
667
+ get_parse_judging_response_for_direct_assessment_prompt(
668
+ judging_responses,
669
+ criteria_list,
670
+ SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
671
+ )
672
+ )
673
+ parsed_judging_responses = parse_judging_responses(
674
+ parse_judging_response_prompt, judging_responses
675
+ )
676
+ st.session_state.direct_assessment_judging_df[response_model] = (
677
+ create_dataframe_for_direct_assessment_judging_response(
678
+ parsed_judging_responses
679
+ )
680
+ )
681
+
682
+ # Uses the session state to plot the criteria scores and graphs for a given response
683
+ # model.
684
+ plot_criteria_scores(
685
+ st.session_state.direct_assessment_judging_df[response_model]
686
+ )
687
+
688
+ plot_per_judge_overall_scores(
689
+ st.session_state.direct_assessment_judging_df[response_model]
690
+ )
691
+
692
+ grouped = (
693
+ st.session_state.direct_assessment_judging_df[response_model]
694
+ .groupby(["judging_model"])
695
+ .agg({"score": ["mean"]})
696
+ .reset_index()
697
+ )
698
+ grouped.columns = ["judging_model", "overall_score"]
699
+
700
+ # Save the overall scores to the session state if it's not already there.
701
+ for record in grouped.to_dict(orient="records"):
702
+ if (
703
+ response_model
704
+ not in st.session_state.direct_assessment_overall_scores
705
+ ):
706
+ st.session_state.direct_assessment_overall_scores[response_model][
707
+ record["judging_model"]
708
+ ] = record["overall_score"]
709
+
710
+ overall_score = grouped["overall_score"].mean()
711
+ controversy = grouped["overall_score"].std()
712
+ st.write(f"Overall Score: {overall_score:.2f}")
713
+ st.write(f"Controversy: {controversy:.2f}")
714
+
715
+ # Mark judging as complete.
716
+ st.session_state.judging_status = "complete"
717
+
718
+
719
  # Main Streamlit App
720
  def main():
721
  st.set_page_config(
 
832
  st.session_state.selected_aggregator = selected_aggregator
833
 
834
  # Render the chats.
835
+ st_render_responses(user_prompt)
 
 
 
 
 
 
836
 
837
+ # Render chats generally even they are available, if the submit button isn't clicked.
838
+ elif st.session_state.responses:
839
+ st_render_responses(user_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840
 
841
  # Judging.
842
  if st.session_state.responses_collected:
 
867
  # TODO: Add option to edit criteria list with a basic text field.
868
  criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
869
 
870
+ with center_column:
871
  judging_submit_button = st.form_submit_button(
872
  "Submit Judging", use_container_width=True
873
  )
874
 
875
  if judging_submit_button:
876
+ # Update session state.
877
  st.session_state.assessment_type = assessment_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878
  if st.session_state.assessment_type == "Direct Assessment":
879
+ st.session_state.direct_assessment_config = {
880
+ "prompt": direct_assessment_prompt,
881
+ "criteria_list": criteria_list,
882
+ }
883
+ st_direct_assessment_results(
884
+ user_prompt=st.session_state.user_prompt,
885
+ direct_assessment_prompt=direct_assessment_prompt,
886
+ criteria_list=criteria_list,
887
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
888
  # If judging is complete, but the submit button is cleared, still render the results.
889
  elif st.session_state.judging_status == "complete":
890
  if st.session_state.assessment_type == "Direct Assessment":
891
+ st_direct_assessment_results(
892
+ user_prompt=st.session_state.user_prompt,
893
+ direct_assessment_prompt=direct_assessment_prompt,
894
+ criteria_list=criteria_list,
 
 
 
 
895
  )
896
 
897
+ # Judging is complete.
898
+ # Render stuff that would be rendered that's not stream-specific.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
899
  # The session state now contains the overall scores for each response from each judge.
900
  if st.session_state.judging_status == "complete":
901
  st.write("#### Results")
902
 
903
  overall_scores_df_raw = pd.DataFrame(
904
+ st.session_state.direct_assessment_overall_scores
905
  ).reset_index()
906
 
907
  overall_scores_df = pd.melt(