Spaces:
Sleeping
Sleeping
justinxzhao
commited on
Commit
·
a0dca54
1
Parent(s):
279a804
Factor out LLM chat rendering so that it persists even when the submit button isn't active.
Browse files
app.py
CHANGED
@@ -516,6 +516,206 @@ def get_selected_models_to_streamlit_column_map(st_columns, selected_models):
|
|
516 |
return selected_models_to_streamlit_column_map
|
517 |
|
518 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
# Main Streamlit App
|
520 |
def main():
|
521 |
st.set_page_config(
|
@@ -632,71 +832,11 @@ def main():
|
|
632 |
st.session_state.selected_aggregator = selected_aggregator
|
633 |
|
634 |
# Render the chats.
|
635 |
-
|
636 |
-
|
637 |
-
selected_models_to_streamlit_column_map = (
|
638 |
-
get_selected_models_to_streamlit_column_map(
|
639 |
-
response_columns, selected_models
|
640 |
-
)
|
641 |
-
)
|
642 |
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
st.write(get_ui_friendly_name(selected_model))
|
647 |
-
with st.chat_message(
|
648 |
-
selected_model,
|
649 |
-
avatar=PROVIDER_TO_AVATAR_MAP[selected_model],
|
650 |
-
):
|
651 |
-
message_placeholder = st.empty()
|
652 |
-
stream = get_llm_response_stream(selected_model, user_prompt)
|
653 |
-
if stream:
|
654 |
-
st.session_state["responses"][selected_model] = (
|
655 |
-
message_placeholder.write_stream(stream)
|
656 |
-
)
|
657 |
-
|
658 |
-
# Get the aggregator prompt.
|
659 |
-
aggregator_prompt = get_default_aggregator_prompt(
|
660 |
-
user_prompt=user_prompt, llms=selected_models
|
661 |
-
)
|
662 |
-
|
663 |
-
# Fetching and streaming response from the aggregator
|
664 |
-
st.write(f"{get_ui_friendly_name(selected_aggregator)}")
|
665 |
-
with st.chat_message(
|
666 |
-
selected_aggregator,
|
667 |
-
avatar="img/council_icon.png",
|
668 |
-
):
|
669 |
-
message_placeholder = st.empty()
|
670 |
-
aggregator_stream = get_llm_response_stream(
|
671 |
-
selected_aggregator, aggregator_prompt
|
672 |
-
)
|
673 |
-
if aggregator_stream:
|
674 |
-
st.session_state.responses["agg__" + selected_aggregator] = (
|
675 |
-
message_placeholder.write_stream(aggregator_stream)
|
676 |
-
)
|
677 |
-
|
678 |
-
st.session_state.responses_collected = True
|
679 |
-
|
680 |
-
# Render chats generally?
|
681 |
-
if st.session_state.responses and not submit_button:
|
682 |
-
st.markdown("#### Responses")
|
683 |
-
|
684 |
-
response_columns = st.columns(3)
|
685 |
-
selected_models_to_streamlit_column_map = (
|
686 |
-
get_selected_models_to_streamlit_column_map(
|
687 |
-
response_columns, st.session_state.selected_models
|
688 |
-
)
|
689 |
-
)
|
690 |
-
for response_model, response in st.session_state.responses.items():
|
691 |
-
st_column = selected_models_to_streamlit_column_map.get(
|
692 |
-
response_model, response_columns[0]
|
693 |
-
)
|
694 |
-
with st_column.chat_message(
|
695 |
-
response_model,
|
696 |
-
avatar=get_llm_avatar(response_model),
|
697 |
-
):
|
698 |
-
st.write(get_ui_friendly_name(response_model))
|
699 |
-
st.write(response)
|
700 |
|
701 |
# Judging.
|
702 |
if st.session_state.responses_collected:
|
@@ -727,228 +867,41 @@ def main():
|
|
727 |
# TODO: Add option to edit criteria list with a basic text field.
|
728 |
criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
|
729 |
|
|
|
730 |
judging_submit_button = st.form_submit_button(
|
731 |
"Submit Judging", use_container_width=True
|
732 |
)
|
733 |
|
734 |
if judging_submit_button:
|
|
|
735 |
st.session_state.assessment_type = assessment_type
|
736 |
-
st.session_state.direct_assessment_config = {
|
737 |
-
"prompt": direct_assessment_prompt,
|
738 |
-
"criteria_list": criteria_list,
|
739 |
-
}
|
740 |
-
|
741 |
-
responses_for_judging = st.session_state.responses
|
742 |
-
|
743 |
-
# Get judging responses.
|
744 |
-
response_judging_columns = st.columns(3)
|
745 |
-
responses_for_judging_to_streamlit_column_map = (
|
746 |
-
get_selected_models_to_streamlit_column_map(
|
747 |
-
response_judging_columns, responses_for_judging.keys()
|
748 |
-
)
|
749 |
-
)
|
750 |
-
|
751 |
if st.session_state.assessment_type == "Direct Assessment":
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
judging_prompt = get_direct_assessment_prompt(
|
762 |
-
direct_assessment_prompt=direct_assessment_prompt,
|
763 |
-
user_prompt=user_prompt,
|
764 |
-
response=response,
|
765 |
-
criteria_list=criteria_list,
|
766 |
-
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
767 |
-
)
|
768 |
-
|
769 |
-
with st.expander("Final Judging Prompt"):
|
770 |
-
st.code(judging_prompt)
|
771 |
-
|
772 |
-
for judging_model in selected_models:
|
773 |
-
with st.expander(
|
774 |
-
get_ui_friendly_name(judging_model), expanded=True
|
775 |
-
):
|
776 |
-
with st.chat_message(
|
777 |
-
judging_model,
|
778 |
-
avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
|
779 |
-
):
|
780 |
-
message_placeholder = st.empty()
|
781 |
-
judging_stream = get_llm_response_stream(
|
782 |
-
judging_model, judging_prompt
|
783 |
-
)
|
784 |
-
st.session_state[
|
785 |
-
"direct_assessment_judging_responses"
|
786 |
-
][response_model][
|
787 |
-
judging_model
|
788 |
-
] = message_placeholder.write_stream(
|
789 |
-
judging_stream
|
790 |
-
)
|
791 |
-
# When all of the judging is finished for the given response, get the actual
|
792 |
-
# values, parsed.
|
793 |
-
judging_responses = st.session_state[
|
794 |
-
"direct_assessment_judging_responses"
|
795 |
-
][response_model]
|
796 |
-
|
797 |
-
if not judging_responses:
|
798 |
-
st.error(f"No judging responses for {response_model}")
|
799 |
-
quit()
|
800 |
-
parse_judging_response_prompt = (
|
801 |
-
get_parse_judging_response_for_direct_assessment_prompt(
|
802 |
-
judging_responses,
|
803 |
-
criteria_list,
|
804 |
-
SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
805 |
-
)
|
806 |
-
)
|
807 |
-
# Issue the prompt to openai mini with structured outputs
|
808 |
-
parsed_judging_responses = parse_judging_responses(
|
809 |
-
parse_judging_response_prompt, judging_responses
|
810 |
-
)
|
811 |
-
|
812 |
-
st.session_state["direct_assessment_judging_df"][
|
813 |
-
response_model
|
814 |
-
] = create_dataframe_for_direct_assessment_judging_response(
|
815 |
-
parsed_judging_responses
|
816 |
-
)
|
817 |
-
|
818 |
-
plot_criteria_scores(
|
819 |
-
st.session_state["direct_assessment_judging_df"][
|
820 |
-
response_model
|
821 |
-
]
|
822 |
-
)
|
823 |
-
|
824 |
-
# Find the overall score by finding the overall score for each judge, and then averaging
|
825 |
-
# over all judges.
|
826 |
-
plot_per_judge_overall_scores(
|
827 |
-
st.session_state["direct_assessment_judging_df"][
|
828 |
-
response_model
|
829 |
-
]
|
830 |
-
)
|
831 |
-
|
832 |
-
grouped = (
|
833 |
-
st.session_state["direct_assessment_judging_df"][
|
834 |
-
response_model
|
835 |
-
]
|
836 |
-
.groupby(["judging_model"])
|
837 |
-
.agg({"score": ["mean"]})
|
838 |
-
.reset_index()
|
839 |
-
)
|
840 |
-
grouped.columns = ["judging_model", "overall_score"]
|
841 |
-
|
842 |
-
# Save the overall scores to the session state.
|
843 |
-
for record in grouped.to_dict(orient="records"):
|
844 |
-
st.session_state["direct_assessment_overall_scores"][
|
845 |
-
response_model
|
846 |
-
][record["judging_model"]] = record["overall_score"]
|
847 |
-
|
848 |
-
overall_score = grouped["overall_score"].mean()
|
849 |
-
controversy = grouped["overall_score"].std()
|
850 |
-
st.write(f"Overall Score: {overall_score:.2f}")
|
851 |
-
st.write(f"Controversy: {controversy:.2f}")
|
852 |
-
|
853 |
-
st.session_state.judging_status = "complete"
|
854 |
# If judging is complete, but the submit button is cleared, still render the results.
|
855 |
elif st.session_state.judging_status == "complete":
|
856 |
if st.session_state.assessment_type == "Direct Assessment":
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
responses_for_judging_to_streamlit_column_map = (
|
862 |
-
get_selected_models_to_streamlit_column_map(
|
863 |
-
response_judging_columns, responses_for_judging.keys()
|
864 |
-
)
|
865 |
)
|
866 |
|
867 |
-
|
868 |
-
|
869 |
-
response_model
|
870 |
-
]
|
871 |
-
|
872 |
-
with st_column:
|
873 |
-
st.write(
|
874 |
-
f"Judging for {get_ui_friendly_name(response_model)}"
|
875 |
-
)
|
876 |
-
judging_prompt = get_direct_assessment_prompt(
|
877 |
-
direct_assessment_prompt=direct_assessment_prompt,
|
878 |
-
user_prompt=user_prompt,
|
879 |
-
response=response,
|
880 |
-
criteria_list=criteria_list,
|
881 |
-
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
882 |
-
)
|
883 |
-
|
884 |
-
with st.expander("Final Judging Prompt"):
|
885 |
-
st.code(judging_prompt)
|
886 |
-
|
887 |
-
for judging_model in selected_models:
|
888 |
-
with st.expander(
|
889 |
-
get_ui_friendly_name(judging_model), expanded=True
|
890 |
-
):
|
891 |
-
with st.chat_message(
|
892 |
-
judging_model,
|
893 |
-
avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
|
894 |
-
):
|
895 |
-
st.write(
|
896 |
-
st.session_state.direct_assessment_judging_responses[
|
897 |
-
response_model
|
898 |
-
][
|
899 |
-
judging_model
|
900 |
-
]
|
901 |
-
)
|
902 |
-
# When all of the judging is finished for the given response, get the actual
|
903 |
-
# values, parsed.
|
904 |
-
judging_responses = (
|
905 |
-
st.session_state.direct_assessment_judging_responses[
|
906 |
-
response_model
|
907 |
-
]
|
908 |
-
)
|
909 |
-
|
910 |
-
parse_judging_response_prompt = (
|
911 |
-
get_parse_judging_response_for_direct_assessment_prompt(
|
912 |
-
judging_responses,
|
913 |
-
criteria_list,
|
914 |
-
SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
915 |
-
)
|
916 |
-
)
|
917 |
-
|
918 |
-
plot_criteria_scores(
|
919 |
-
st.session_state.direct_assessment_judging_df[
|
920 |
-
response_model
|
921 |
-
]
|
922 |
-
)
|
923 |
-
|
924 |
-
plot_per_judge_overall_scores(
|
925 |
-
st.session_state.direct_assessment_judging_df[
|
926 |
-
response_model
|
927 |
-
]
|
928 |
-
)
|
929 |
-
|
930 |
-
grouped = (
|
931 |
-
st.session_state.direct_assessment_judging_df[
|
932 |
-
response_model
|
933 |
-
]
|
934 |
-
.groupby(["judging_model"])
|
935 |
-
.agg({"score": ["mean"]})
|
936 |
-
.reset_index()
|
937 |
-
)
|
938 |
-
grouped.columns = ["judging_model", "overall_score"]
|
939 |
-
|
940 |
-
overall_score = grouped["overall_score"].mean()
|
941 |
-
controversy = grouped["overall_score"].std()
|
942 |
-
st.write(f"Overall Score: {overall_score:.2f}")
|
943 |
-
st.write(f"Controversy: {controversy:.2f}")
|
944 |
-
|
945 |
-
# Judging is complete, stuff that would be rendered that's not stream-specific.
|
946 |
# The session state now contains the overall scores for each response from each judge.
|
947 |
if st.session_state.judging_status == "complete":
|
948 |
st.write("#### Results")
|
949 |
|
950 |
overall_scores_df_raw = pd.DataFrame(
|
951 |
-
st.session_state
|
952 |
).reset_index()
|
953 |
|
954 |
overall_scores_df = pd.melt(
|
|
|
516 |
return selected_models_to_streamlit_column_map
|
517 |
|
518 |
|
519 |
+
def get_aggregator_key(llm_aggregator):
|
520 |
+
return "agg__" + llm_aggregator
|
521 |
+
|
522 |
+
|
523 |
+
def st_render_responses(user_prompt):
|
524 |
+
"""Renders the responses from the LLMs.
|
525 |
+
|
526 |
+
Uses cached responses from the session state, if available.
|
527 |
+
Otherwise, streams the responses anew.
|
528 |
+
|
529 |
+
Assumes that the session state has already been set up with selected models and selected aggregator.
|
530 |
+
"""
|
531 |
+
st.markdown("#### Responses")
|
532 |
+
|
533 |
+
response_columns = st.columns(3)
|
534 |
+
selected_models_to_streamlit_column_map = (
|
535 |
+
get_selected_models_to_streamlit_column_map(
|
536 |
+
response_columns, st.session_state.selected_models
|
537 |
+
)
|
538 |
+
)
|
539 |
+
for response_model in st.session_state.selected_models:
|
540 |
+
st_column = selected_models_to_streamlit_column_map.get(
|
541 |
+
response_model, response_columns[0]
|
542 |
+
)
|
543 |
+
|
544 |
+
with st_column.chat_message(
|
545 |
+
response_model,
|
546 |
+
avatar=get_llm_avatar(response_model),
|
547 |
+
):
|
548 |
+
st.write(get_ui_friendly_name(response_model))
|
549 |
+
if response_model in st.session_state.responses:
|
550 |
+
# Use the cached response from session state.
|
551 |
+
st.write(st.session_state.responses[response_model])
|
552 |
+
else:
|
553 |
+
# Stream the response from the LLM.
|
554 |
+
message_placeholder = st.empty()
|
555 |
+
stream = get_llm_response_stream(response_model, user_prompt)
|
556 |
+
st.session_state.responses[response_model] = (
|
557 |
+
message_placeholder.write_stream(stream)
|
558 |
+
)
|
559 |
+
|
560 |
+
# Render the aggregator response.
|
561 |
+
aggregator_prompt = get_default_aggregator_prompt(
|
562 |
+
user_prompt=user_prompt, llms=st.session_state.selected_models
|
563 |
+
)
|
564 |
+
|
565 |
+
# Streaming response from the aggregator.
|
566 |
+
with st.chat_message(
|
567 |
+
get_aggregator_key(st.session_state.selected_aggregator),
|
568 |
+
avatar="img/council_icon.png",
|
569 |
+
):
|
570 |
+
st.write(
|
571 |
+
f"{get_ui_friendly_name(get_aggregator_key(st.session_state.selected_aggregator))}"
|
572 |
+
)
|
573 |
+
if (
|
574 |
+
get_aggregator_key(st.session_state.selected_aggregator)
|
575 |
+
in st.session_state.responses
|
576 |
+
):
|
577 |
+
st.write(
|
578 |
+
st.session_state.responses[
|
579 |
+
get_aggregator_key(st.session_state.selected_aggregator)
|
580 |
+
]
|
581 |
+
)
|
582 |
+
else:
|
583 |
+
message_placeholder = st.empty()
|
584 |
+
aggregator_stream = get_llm_response_stream(
|
585 |
+
selected_aggregator, aggregator_prompt
|
586 |
+
)
|
587 |
+
if aggregator_stream:
|
588 |
+
st.session_state.responses[get_aggregator_key(selected_aggregator)] = (
|
589 |
+
message_placeholder.write_stream(aggregator_stream)
|
590 |
+
)
|
591 |
+
|
592 |
+
st.session_state.responses_collected = True
|
593 |
+
|
594 |
+
|
595 |
+
def st_direct_assessment_results(user_prompt, direct_assessment_prompt, criteria_list):
|
596 |
+
"""Renders the direct assessment results block.
|
597 |
+
|
598 |
+
Uses session state to render results from LLMs. If the session state isn't set, then fetches the
|
599 |
+
responses from the LLMs services from scratch (and sets the session state).
|
600 |
+
|
601 |
+
Assumes that the session state has already been set up with responses.
|
602 |
+
"""
|
603 |
+
responses_for_judging = st.session_state.responses
|
604 |
+
|
605 |
+
# Get judging responses.
|
606 |
+
response_judging_columns = st.columns(3)
|
607 |
+
responses_for_judging_to_streamlit_column_map = (
|
608 |
+
get_selected_models_to_streamlit_column_map(
|
609 |
+
response_judging_columns, responses_for_judging.keys()
|
610 |
+
)
|
611 |
+
)
|
612 |
+
|
613 |
+
for response_model, response in responses_for_judging.items():
|
614 |
+
st_column = responses_for_judging_to_streamlit_column_map[response_model]
|
615 |
+
|
616 |
+
with st_column:
|
617 |
+
st.write(f"Judging for {get_ui_friendly_name(response_model)}")
|
618 |
+
judging_prompt = get_direct_assessment_prompt(
|
619 |
+
direct_assessment_prompt=direct_assessment_prompt,
|
620 |
+
user_prompt=user_prompt,
|
621 |
+
response=response,
|
622 |
+
criteria_list=criteria_list,
|
623 |
+
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
624 |
+
)
|
625 |
+
|
626 |
+
with st.expander("Final Judging Prompt"):
|
627 |
+
st.code(judging_prompt)
|
628 |
+
|
629 |
+
for judging_model in st.session_state.selected_models:
|
630 |
+
with st.expander(get_ui_friendly_name(judging_model), expanded=True):
|
631 |
+
with st.chat_message(
|
632 |
+
judging_model,
|
633 |
+
avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
|
634 |
+
):
|
635 |
+
if (
|
636 |
+
judging_model
|
637 |
+
in st.session_state.direct_assessment_judging_responses[
|
638 |
+
response_model
|
639 |
+
]
|
640 |
+
):
|
641 |
+
# Use the session state cached response.
|
642 |
+
st.write(
|
643 |
+
st.session_state.direct_assessment_judging_responses[
|
644 |
+
response_model
|
645 |
+
][judging_model]
|
646 |
+
)
|
647 |
+
else:
|
648 |
+
message_placeholder = st.empty()
|
649 |
+
# Get the judging response from the LLM.
|
650 |
+
judging_stream = get_llm_response_stream(
|
651 |
+
judging_model, judging_prompt
|
652 |
+
)
|
653 |
+
st.session_state.direct_assessment_judging_responses[
|
654 |
+
response_model
|
655 |
+
][judging_model] = message_placeholder.write_stream(
|
656 |
+
judging_stream
|
657 |
+
)
|
658 |
+
|
659 |
+
# Extract actual scores from open-ended responses using structured outputs.
|
660 |
+
# Since we're extracting structured data for the first time, we can save the dataframe
|
661 |
+
# to the session state so that it's cached.
|
662 |
+
if response_model not in st.session_state.direct_assessment_judging_df:
|
663 |
+
judging_responses = (
|
664 |
+
st.session_state.direct_assessment_judging_responses[response_model]
|
665 |
+
)
|
666 |
+
parse_judging_response_prompt = (
|
667 |
+
get_parse_judging_response_for_direct_assessment_prompt(
|
668 |
+
judging_responses,
|
669 |
+
criteria_list,
|
670 |
+
SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
671 |
+
)
|
672 |
+
)
|
673 |
+
parsed_judging_responses = parse_judging_responses(
|
674 |
+
parse_judging_response_prompt, judging_responses
|
675 |
+
)
|
676 |
+
st.session_state.direct_assessment_judging_df[response_model] = (
|
677 |
+
create_dataframe_for_direct_assessment_judging_response(
|
678 |
+
parsed_judging_responses
|
679 |
+
)
|
680 |
+
)
|
681 |
+
|
682 |
+
# Uses the session state to plot the criteria scores and graphs for a given response
|
683 |
+
# model.
|
684 |
+
plot_criteria_scores(
|
685 |
+
st.session_state.direct_assessment_judging_df[response_model]
|
686 |
+
)
|
687 |
+
|
688 |
+
plot_per_judge_overall_scores(
|
689 |
+
st.session_state.direct_assessment_judging_df[response_model]
|
690 |
+
)
|
691 |
+
|
692 |
+
grouped = (
|
693 |
+
st.session_state.direct_assessment_judging_df[response_model]
|
694 |
+
.groupby(["judging_model"])
|
695 |
+
.agg({"score": ["mean"]})
|
696 |
+
.reset_index()
|
697 |
+
)
|
698 |
+
grouped.columns = ["judging_model", "overall_score"]
|
699 |
+
|
700 |
+
# Save the overall scores to the session state if it's not already there.
|
701 |
+
for record in grouped.to_dict(orient="records"):
|
702 |
+
if (
|
703 |
+
response_model
|
704 |
+
not in st.session_state.direct_assessment_overall_scores
|
705 |
+
):
|
706 |
+
st.session_state.direct_assessment_overall_scores[response_model][
|
707 |
+
record["judging_model"]
|
708 |
+
] = record["overall_score"]
|
709 |
+
|
710 |
+
overall_score = grouped["overall_score"].mean()
|
711 |
+
controversy = grouped["overall_score"].std()
|
712 |
+
st.write(f"Overall Score: {overall_score:.2f}")
|
713 |
+
st.write(f"Controversy: {controversy:.2f}")
|
714 |
+
|
715 |
+
# Mark judging as complete.
|
716 |
+
st.session_state.judging_status = "complete"
|
717 |
+
|
718 |
+
|
719 |
# Main Streamlit App
|
720 |
def main():
|
721 |
st.set_page_config(
|
|
|
832 |
st.session_state.selected_aggregator = selected_aggregator
|
833 |
|
834 |
# Render the chats.
|
835 |
+
st_render_responses(user_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
836 |
|
837 |
+
# Render chats generally even they are available, if the submit button isn't clicked.
|
838 |
+
elif st.session_state.responses:
|
839 |
+
st_render_responses(user_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
840 |
|
841 |
# Judging.
|
842 |
if st.session_state.responses_collected:
|
|
|
867 |
# TODO: Add option to edit criteria list with a basic text field.
|
868 |
criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
|
869 |
|
870 |
+
with center_column:
|
871 |
judging_submit_button = st.form_submit_button(
|
872 |
"Submit Judging", use_container_width=True
|
873 |
)
|
874 |
|
875 |
if judging_submit_button:
|
876 |
+
# Update session state.
|
877 |
st.session_state.assessment_type = assessment_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
878 |
if st.session_state.assessment_type == "Direct Assessment":
|
879 |
+
st.session_state.direct_assessment_config = {
|
880 |
+
"prompt": direct_assessment_prompt,
|
881 |
+
"criteria_list": criteria_list,
|
882 |
+
}
|
883 |
+
st_direct_assessment_results(
|
884 |
+
user_prompt=st.session_state.user_prompt,
|
885 |
+
direct_assessment_prompt=direct_assessment_prompt,
|
886 |
+
criteria_list=criteria_list,
|
887 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
888 |
# If judging is complete, but the submit button is cleared, still render the results.
|
889 |
elif st.session_state.judging_status == "complete":
|
890 |
if st.session_state.assessment_type == "Direct Assessment":
|
891 |
+
st_direct_assessment_results(
|
892 |
+
user_prompt=st.session_state.user_prompt,
|
893 |
+
direct_assessment_prompt=direct_assessment_prompt,
|
894 |
+
criteria_list=criteria_list,
|
|
|
|
|
|
|
|
|
895 |
)
|
896 |
|
897 |
+
# Judging is complete.
|
898 |
+
# Render stuff that would be rendered that's not stream-specific.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
899 |
# The session state now contains the overall scores for each response from each judge.
|
900 |
if st.session_state.judging_status == "complete":
|
901 |
st.write("#### Results")
|
902 |
|
903 |
overall_scores_df_raw = pd.DataFrame(
|
904 |
+
st.session_state.direct_assessment_overall_scores
|
905 |
).reset_index()
|
906 |
|
907 |
overall_scores_df = pd.melt(
|