justinxzhao commited on
Commit
6fae7e2
·
1 Parent(s): 16d72cb

Added general rendering of chats so that they don't disappear during app saving.

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app.py +455 -340
  3. constants.py +50 -18
  4. img/qwen.webp +0 -0
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  env/
2
  client_secret.json
3
- __pycache__
 
 
1
  env/
2
  client_secret.json
3
+ __pycache__
4
+ .env
app.py CHANGED
@@ -7,6 +7,7 @@ import anthropic
7
  from together import Together
8
  import google.generativeai as genai
9
  import time
 
10
  from typing import List, Optional, Literal, Union, Dict
11
  from constants import (
12
  LLM_COUNCIL_MEMBERS,
@@ -51,7 +52,7 @@ anthropic_client = anthropic.Anthropic()
51
  client = OpenAI()
52
 
53
 
54
- def anthropic_streamlit_streamer(stream):
55
  """
56
  Process the Anthropic streaming response and yield content from the deltas.
57
 
@@ -67,6 +68,18 @@ def anthropic_streamlit_streamer(stream):
67
  if text_delta:
68
  yield text_delta
69
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # Handle message completion events (optional if needed)
71
  elif event.type == "message_stop":
72
  break # End of message, stop streaming
@@ -83,22 +96,34 @@ def get_ui_friendly_name(llm):
83
 
84
 
85
  def google_streamlit_streamer(stream):
 
86
  for chunk in stream:
87
  yield chunk.text
88
 
89
 
90
- def together_streamlit_streamer(stream):
 
91
  for chunk in stream:
 
 
 
 
92
  yield chunk.choices[0].delta.content
93
 
94
 
95
  def llm_streamlit_streamer(stream, llm):
96
  if llm.startswith("anthropic"):
97
- return anthropic_streamlit_streamer(stream)
 
98
  elif llm.startswith("vertex"):
 
99
  return google_streamlit_streamer(stream)
100
  elif llm.startswith("together"):
101
- return together_streamlit_streamer(stream)
 
 
 
 
102
 
103
 
104
  # Helper functions for LLM council and aggregator selection
@@ -152,9 +177,13 @@ def get_llm_response_stream(model_identifier, prompt):
152
  if provider == "openai":
153
  return get_openai_response(model_name, prompt)
154
  elif provider == "anthropic":
155
- return anthropic_streamlit_streamer(get_anthropic_response(model_name, prompt))
 
 
156
  elif provider == "together":
157
- return together_streamlit_streamer(get_together_response(model_name, prompt))
 
 
158
  elif provider == "vertex":
159
  return google_streamlit_streamer(get_google_response(model_name, prompt))
160
  else:
@@ -174,7 +203,7 @@ def create_dataframe_for_direct_assessment_judging_response(
174
  for criteria_score in judging_model.criteria_scores:
175
  data.append(
176
  {
177
- "llm_judge_model": model_name,
178
  "criteria": criteria_score.criterion,
179
  "score": criteria_score.score,
180
  "explanation": criteria_score.explanation,
@@ -283,58 +312,62 @@ def get_parse_judging_response_for_direct_assessment_prompt(
283
  )
284
 
285
 
286
- DEBUG_MODE = True
287
-
288
-
289
  def parse_judging_responses(
290
  prompt: str, judging_responses: dict[str, str]
291
  ) -> DirectAssessmentJudgingResponse:
292
- if DEBUG_MODE:
293
- return DirectAssessmentJudgingResponse(
294
- judging_models=[
295
- DirectAssessmentCriteriaScores(
296
- model="together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
297
- criteria_scores=[
298
- DirectAssessmentCriterionScore(
299
- criterion="helpfulness", score=3, explanation="explanation1"
300
- ),
301
- DirectAssessmentCriterionScore(
302
- criterion="conciseness", score=4, explanation="explanation2"
303
- ),
304
- DirectAssessmentCriterionScore(
305
- criterion="relevance", score=5, explanation="explanation3"
306
- ),
307
- ],
308
- ),
309
- DirectAssessmentCriteriaScores(
310
- model="together://meta-llama/Llama-3.2-3B-Instruct-Turbo",
311
- criteria_scores=[
312
- DirectAssessmentCriterionScore(
313
- criterion="helpfulness", score=1, explanation="explanation1"
314
- ),
315
- DirectAssessmentCriterionScore(
316
- criterion="conciseness", score=2, explanation="explanation2"
317
- ),
318
- DirectAssessmentCriterionScore(
319
- criterion="relevance", score=3, explanation="explanation3"
320
- ),
321
- ],
322
- ),
323
- ]
324
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  else:
326
- completion = client.beta.chat.completions.parse(
327
- model="gpt-4o-mini",
328
- messages=[
329
- {
330
- "role": "system",
331
- "content": "Parse the judging responses into structured data.",
332
- },
333
- {"role": "user", "content": prompt},
334
- ],
335
- response_format=DirectAssessmentJudgingResponse,
336
- )
337
- return completion.choices[0].message.parsed
338
 
339
 
340
  def plot_criteria_scores(df):
@@ -401,11 +434,11 @@ def plot_overall_scores(overall_scores_df):
401
  ax = sns.barplot(
402
  x="ui_friendly_name",
403
  y="mean_score",
404
- hue="ui_friendly_name", # Add this line
405
  data=summary,
406
  palette="prism",
407
  capsize=0.1,
408
- legend=False, # Add this line
409
  )
410
 
411
  # Add error bars manually
@@ -420,15 +453,20 @@ def plot_overall_scores(overall_scores_df):
420
  zorder=10, # Ensure error bars are on top
421
  )
422
 
423
- # Add text annotations
424
- for i, row in summary.iterrows():
 
 
 
 
 
425
  ax.text(
426
- i,
427
- row["mean_score"],
428
- f"{row['mean_score']:.2f}",
429
  ha="center",
430
  va="bottom",
431
- fontweight="bold",
432
  color="black",
433
  bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.5),
434
  )
@@ -446,23 +484,24 @@ def plot_overall_scores(overall_scores_df):
446
  def plot_per_judge_overall_scores(df):
447
  # Find the overall score by finding the overall score for each judge, and then averaging
448
  # over all judges.
449
- grouped = df.groupby(["llm_judge_model"]).agg({"score": ["mean"]}).reset_index()
450
- grouped.columns = ["llm_judge_model", "overall_score"]
451
 
452
  # Create the horizontal bar plot
453
  plt.figure(figsize=(10, 6))
454
  ax = sns.barplot(
455
  data=grouped,
456
- y="llm_judge_model",
457
- x="overall_score",
458
- hue="llm_judge_model",
459
- orient="h",
 
460
  )
461
 
462
  # Customize the plot
463
- plt.title("Overall Scores by LLM Judge Model")
464
  plt.xlabel("Overall Score")
465
- plt.ylabel("LLM Judge Model")
466
 
467
  # Adjust layout and display the plot
468
  plt.tight_layout()
@@ -510,41 +549,63 @@ def main():
510
  cols = st.columns([2, 1, 2])
511
  if not st.session_state.authenticated:
512
  with cols[1]:
513
- password = st.text_input("Password", type="password")
514
- if st.button("Login", use_container_width=True):
515
- if password == PASSWORD:
516
- st.session_state.authenticated = True
517
- else:
518
- st.error("Invalid credentials")
 
 
 
 
 
519
 
520
  if st.session_state.authenticated:
521
- # cols[1].success("Logged in successfully!")
522
- st.markdown("#### LLM Council Member Selection")
523
-
524
- # Council and aggregator selection
525
- selected_models = llm_council_selector()
526
-
527
- # st.write("Selected Models:", selected_models)
528
-
529
- selected_aggregator = aggregator_selector()
530
-
531
  # Initialize session state for collecting responses.
532
  if "responses" not in st.session_state:
533
- st.session_state.responses = {}
534
- # if "aggregator_response" not in st.session_state:
535
- # st.session_state.aggregator_response = {}
536
-
537
- # Prompt input
538
- st.markdown("#### Enter your prompt")
539
- _, center_column, _ = st.columns([3, 5, 3])
540
- with center_column:
541
- user_prompt = st.text_area(
542
- "Enter your prompt", value="Say 'Hello World'", key="user_prompt"
543
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
 
545
- if center_column.button("Submit", use_container_width=True):
546
  st.markdown("#### Responses")
547
 
 
 
 
 
 
548
  response_columns = st.columns(3)
549
 
550
  selected_models_to_streamlit_column_map = {
@@ -552,7 +613,7 @@ def main():
552
  }
553
 
554
  # Fetching and streaming responses from each selected model
555
- for selected_model in selected_models:
556
  with selected_models_to_streamlit_column_map[selected_model]:
557
  st.write(get_ui_friendly_name(selected_model))
558
  with st.chat_message(
@@ -571,11 +632,8 @@ def main():
571
  user_prompt=user_prompt, llms=selected_models
572
  )
573
 
574
- with st.expander("Aggregator Prompt"):
575
- st.code(aggregator_prompt)
576
-
577
  # Fetching and streaming response from the aggregator
578
- st.write(f"Mixture-of-Agents ({get_ui_friendly_name(selected_aggregator)})")
579
  with st.chat_message(
580
  selected_aggregator,
581
  avatar="img/council_icon.png",
@@ -589,272 +647,329 @@ def main():
589
  message_placeholder.write_stream(aggregator_stream)
590
  )
591
 
592
- # st.write("Responses (in session state):")
593
- # st.write(st.session_state["responses"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
 
595
  # Judging.
596
- st.markdown("#### Judging Configuration")
 
597
 
598
- # Choose the type of assessment
599
- assessment_type = st.radio(
600
- "Select the type of assessment",
601
- options=["Direct Assessment", "Pairwise Comparison"],
602
- )
603
 
604
- _, center_column, _ = st.columns([3, 5, 3])
605
 
606
- # Depending on the assessment type, render different forms
607
- if assessment_type == "Direct Assessment":
608
 
609
- # Initialize session state for direct assessment.
610
- if "direct_assessment_overall_score" not in st.session_state:
611
- st.session_state["direct_assessment_overall_score"] = {}
612
- if "direct_assessment_judging_df" not in st.session_state:
613
- st.session_state["direct_assessment_judging_df"] = {}
614
- for response_model in selected_models:
 
 
 
 
615
  st.session_state["direct_assessment_judging_df"][
616
- response_model
617
  ] = {}
618
- # aggregator model
619
- st.session_state["direct_assessment_judging_df"][
620
- "agg__" + selected_aggregator
621
- ] = {}
622
- if "direct_assessment_judging_responses" not in st.session_state:
623
- st.session_state["direct_assessment_judging_responses"] = {}
624
- for response_model in selected_models:
625
  st.session_state["direct_assessment_judging_responses"][
626
- response_model
627
  ] = {}
628
- # aggregator model
629
- st.session_state["direct_assessment_judging_responses"][
630
- "agg__" + selected_aggregator
631
- ] = {}
632
- if "direct_assessment_overall_scores" not in st.session_state:
633
- st.session_state["direct_assessment_overall_scores"] = {}
634
- for response_model in selected_models:
635
  st.session_state["direct_assessment_overall_scores"][
636
- response_model
637
  ] = {}
638
- st.session_state["direct_assessment_overall_scores"][
639
- "agg__" + selected_aggregator
640
- ] = {}
641
- if "judging_status" not in st.session_state:
642
- st.session_state["judging_status"] = "incomplete"
643
-
644
- # Direct assessment prompt.
645
- with center_column.expander("Direct Assessment Prompt"):
646
- direct_assessment_prompt = st.text_area(
647
- "Prompt for the Direct Assessment",
648
- value=get_default_direct_assessment_prompt(user_prompt=user_prompt),
649
- height=500,
650
- key="direct_assessment_prompt",
651
- )
652
-
653
- # TODO: Add option to edit criteria list with a basic text field.
654
- criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
655
-
656
- # Create DirectAssessment object when form is submitted
657
- if center_column.button(
658
- "Submit Direct Assessment", use_container_width=True
659
- ):
660
 
661
- # Submit direct asssessment.
662
- responses_for_judging = st.session_state["responses"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
- # st.write("Responses for judging (in session state):")
665
- # st.write(responses_for_judging)
666
 
667
- response_judging_columns = st.columns(3)
668
 
669
- responses_for_judging_to_streamlit_column_map = {
670
- model: response_judging_columns[i % 3]
671
- for i, model in enumerate(responses_for_judging.keys())
672
- }
673
 
674
- # Get judging responses.
675
- for response_model, response in responses_for_judging.items():
676
 
677
- st_column = responses_for_judging_to_streamlit_column_map[
678
- response_model
679
- ]
680
 
681
- with st_column:
682
- if "agg__" in response_model:
683
- judging_model_header = "Mixture-of-Agents Response"
684
- else:
685
- judging_model_header = get_ui_friendly_name(response_model)
686
- st.write(f"Judging for {judging_model_header}")
687
- # st.write("Response being judged: ")
688
- # st.write(response)
689
- judging_prompt = get_direct_assessment_prompt(
690
- direct_assessment_prompt=direct_assessment_prompt,
691
- user_prompt=user_prompt,
692
- response=response,
693
- criteria_list=criteria_list,
694
- options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
695
- )
696
 
697
- with st.expander("Final Judging Prompt"):
698
- st.code(judging_prompt)
699
 
700
- for judging_model in selected_models:
701
- with st.expander(
702
- get_ui_friendly_name(judging_model), expanded=False
703
- ):
704
- with st.chat_message(
705
- judging_model,
706
- avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
707
  ):
708
- message_placeholder = st.empty()
709
- judging_stream = get_llm_response_stream(
710
- judging_model, judging_prompt
711
- )
712
- # if judging_stream:
713
- st.session_state[
714
- "direct_assessment_judging_responses"
715
- ][response_model][
716
- judging_model
717
- ] = message_placeholder.write_stream(
718
- judging_stream
719
- )
720
- # When all of the judging is finished for the given response, get the actual
721
- # values, parsed (use gpt-4o-mini for now) with json mode.
722
- # TODO.
723
- judging_responses = st.session_state[
724
- "direct_assessment_judging_responses"
725
- ][response_model]
726
-
727
- # st.write("Judging responses (in session state):")
728
- # st.write(judging_responses)
729
-
730
- if not judging_responses:
731
- st.error(f"No judging responses for {response_model}")
732
- quit()
733
- parse_judging_response_prompt = (
734
- get_parse_judging_response_for_direct_assessment_prompt(
735
- judging_responses,
736
- criteria_list,
737
- SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
 
 
 
 
 
738
  )
739
- )
740
- with st.expander("Parse Judging Response Prompt"):
741
- st.code(parse_judging_response_prompt)
742
- # Issue the prompt to openai mini with structured outputs
743
- parsed_judging_responses = parse_judging_responses(
744
- parse_judging_response_prompt, judging_responses
745
- )
746
-
747
- st.session_state["direct_assessment_judging_df"][
748
- response_model
749
- ] = create_dataframe_for_direct_assessment_judging_response(
750
- parsed_judging_responses
751
- )
752
- st.write(
753
- st.session_state["direct_assessment_judging_df"][
754
- response_model
755
- ]
756
- )
757
 
758
- plot_criteria_scores(
759
  st.session_state["direct_assessment_judging_df"][
760
  response_model
761
- ]
762
- )
 
763
 
764
- # Find the overall score by finding the overall score for each judge, and then averaging
765
- # over all judges.
766
- plot_per_judge_overall_scores(
767
- st.session_state["direct_assessment_judging_df"][
768
- response_model
769
- ]
770
- )
771
 
772
- grouped = (
773
- st.session_state["direct_assessment_judging_df"][
774
- response_model
775
- ]
776
- .groupby(["llm_judge_model"])
777
- .agg({"score": ["mean"]})
778
- .reset_index()
779
- )
780
- grouped.columns = ["llm_judge_model", "overall_score"]
781
 
782
- # st.write(
783
- # "Extracting overall scores from this grouped dataframe:"
784
- # )
785
- # st.write(grouped)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
 
787
- # Save the overall scores to the session state.
788
- for record in grouped.to_dict(orient="records"):
789
- st.session_state["direct_assessment_overall_scores"][
790
- response_model
791
- ][record["llm_judge_model"]] = record["overall_score"]
792
-
793
- overall_score = grouped["overall_score"].mean()
794
- controversy = grouped["overall_score"].std()
795
- st.write(f"Overall Score: {overall_score:.2f}")
796
- st.write(f"Controversy: {controversy:.2f}")
797
-
798
- st.session_state["judging_status"] = "complete"
799
-
800
- # Judging is complete.
801
- st.write("#### Results")
802
- # The session state now contains the overall scores for each response from each judge.
803
- if st.session_state["judging_status"] == "complete":
804
- overall_scores_df_raw = pd.DataFrame(
805
- st.session_state["direct_assessment_overall_scores"]
806
- ).reset_index()
807
-
808
- overall_scores_df = pd.melt(
809
- overall_scores_df_raw,
810
- id_vars=["index"],
811
- var_name="response_model",
812
- value_name="score",
813
- ).rename(columns={"index": "judging_model"})
814
-
815
- # Print the overall winner.
816
- overall_winner = overall_scores_df.loc[
817
- overall_scores_df["score"].idxmax()
818
- ]
819
-
820
- st.write(
821
- f"**Overall Winner:** {get_ui_friendly_name(overall_winner['response_model'])}"
822
- )
823
- # Find how much the standard deviation overlaps with other models.
824
- # Calculate separability.
825
- # TODO.
826
- st.write(f"**Confidence:** {overall_winner['score']:.2f}")
827
-
828
- left_column, right_column = st.columns([1, 1])
829
- with left_column:
830
- plot_overall_scores(overall_scores_df)
831
-
832
- with right_column:
833
- st.dataframe(overall_scores_df)
834
-
835
- elif assessment_type == "Pairwise Comparison":
836
- pass
837
- # pairwise_comparison_prompt = st.text_area(
838
- # "Prompt for the Pairwise Comparison"
839
- # )
840
- # granularity = st.selectbox("Granularity", ["coarse", "fine", "super fine"])
841
- # ties_allowed = st.checkbox("Are ties allowed?")
842
- # position_swapping = st.checkbox("Enable position swapping?")
843
- # reference_model = st.text_input("Reference Model")
844
-
845
- # # Create PairwiseComparison object when form is submitted
846
- # if st.button("Submit Pairwise Comparison"):
847
- # pairwise_comparison_config = PairwiseComparison(
848
- # type="pairwise_comparison",
849
- # granularity=granularity,
850
- # ties_allowed=ties_allowed,
851
- # position_swapping=position_swapping,
852
- # reference_model=reference_model,
853
- # prompt=prompt,
854
- # )
855
- # st.success(f"Pairwise Comparison Created: {pairwise_comparison_config}")
856
- # # Submit pairwise comparison.
857
- # responses_for_judging = st.session_state["responses"]
 
 
 
 
858
 
859
  else:
860
  with cols[1]:
 
7
  from together import Together
8
  import google.generativeai as genai
9
  import time
10
+ from collections import defaultdict
11
  from typing import List, Optional, Literal, Union, Dict
12
  from constants import (
13
  LLM_COUNCIL_MEMBERS,
 
52
  client = OpenAI()
53
 
54
 
55
+ def anthropic_streamlit_streamer(stream, llm):
56
  """
57
  Process the Anthropic streaming response and yield content from the deltas.
58
 
 
68
  if text_delta:
69
  yield text_delta
70
 
71
+ # Count input token usage.
72
+ if event.type == "message_start":
73
+ input_token_usage = event["usage"]["input_tokens"]
74
+ output_token_usage = event["usage"]["output_tokens"]
75
+ st.session_state["input_token_usage"][llm] += input_token_usage
76
+ st.session_state["output_token_usage"][llm] += output_token_usage
77
+
78
+ # Count output token usage.
79
+ if event.type == "message_delta":
80
+ output_token_usage = event["usage"]["output_tokens"]
81
+ st.session_state["output_token_usage"][llm] += output_token_usage
82
+
83
  # Handle message completion events (optional if needed)
84
  elif event.type == "message_stop":
85
  break # End of message, stop streaming
 
96
 
97
 
98
  def google_streamlit_streamer(stream):
99
+ # TODO: Count token usage.
100
  for chunk in stream:
101
  yield chunk.text
102
 
103
 
104
+ def together_streamlit_streamer(stream, llm):
105
+ # https://docs.together.ai/docs/chat-overview#streaming-responses
106
  for chunk in stream:
107
+ if chunk.usage:
108
+ st.session_state["input_token_usage"][llm] += chunk.usage.prompt_tokens
109
+ if chunk.usage:
110
+ st.session_state["output_token_usage"][llm] += chunk.usage.completion_tokens
111
  yield chunk.choices[0].delta.content
112
 
113
 
114
  def llm_streamlit_streamer(stream, llm):
115
  if llm.startswith("anthropic"):
116
+ print(f"Using Anthropic streamer for llm: {llm}")
117
+ return anthropic_streamlit_streamer(stream, llm)
118
  elif llm.startswith("vertex"):
119
+ print(f"Using Vertex streamer for llm: {llm}")
120
  return google_streamlit_streamer(stream)
121
  elif llm.startswith("together"):
122
+ print(f"Using Together streamer for llm: {llm}")
123
+ return together_streamlit_streamer(stream, llm)
124
+ else:
125
+ print(f"Using OpenAI streamer for llm: {llm}")
126
+ return openai_streamlit_streamer(stream, llm)
127
 
128
 
129
  # Helper functions for LLM council and aggregator selection
 
177
  if provider == "openai":
178
  return get_openai_response(model_name, prompt)
179
  elif provider == "anthropic":
180
+ return anthropic_streamlit_streamer(
181
+ get_anthropic_response(model_name, prompt), model_identifier
182
+ )
183
  elif provider == "together":
184
+ return together_streamlit_streamer(
185
+ get_together_response(model_name, prompt), model_identifier
186
+ )
187
  elif provider == "vertex":
188
  return google_streamlit_streamer(get_google_response(model_name, prompt))
189
  else:
 
203
  for criteria_score in judging_model.criteria_scores:
204
  data.append(
205
  {
206
+ "judging_model": model_name,
207
  "criteria": criteria_score.criterion,
208
  "score": criteria_score.score,
209
  "explanation": criteria_score.explanation,
 
312
  )
313
 
314
 
 
 
 
315
  def parse_judging_responses(
316
  prompt: str, judging_responses: dict[str, str]
317
  ) -> DirectAssessmentJudgingResponse:
318
+ # if os.getenv("DEBUG_MODE") == "True":
319
+ # return DirectAssessmentJudgingResponse(
320
+ # judging_models=[
321
+ # DirectAssessmentCriteriaScores(
322
+ # model="together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
323
+ # criteria_scores=[
324
+ # DirectAssessmentCriterionScore(
325
+ # criterion="helpfulness", score=3, explanation="explanation1"
326
+ # ),
327
+ # DirectAssessmentCriterionScore(
328
+ # criterion="conciseness", score=4, explanation="explanation2"
329
+ # ),
330
+ # DirectAssessmentCriterionScore(
331
+ # criterion="relevance", score=5, explanation="explanation3"
332
+ # ),
333
+ # ],
334
+ # ),
335
+ # DirectAssessmentCriteriaScores(
336
+ # model="together://meta-llama/Llama-3.2-3B-Instruct-Turbo",
337
+ # criteria_scores=[
338
+ # DirectAssessmentCriterionScore(
339
+ # criterion="helpfulness", score=1, explanation="explanation1"
340
+ # ),
341
+ # DirectAssessmentCriterionScore(
342
+ # criterion="conciseness", score=2, explanation="explanation2"
343
+ # ),
344
+ # DirectAssessmentCriterionScore(
345
+ # criterion="relevance", score=3, explanation="explanation3"
346
+ # ),
347
+ # ],
348
+ # ),
349
+ # ]
350
+ # )
351
+ # else:
352
+ completion = client.beta.chat.completions.parse(
353
+ model="gpt-4o-mini",
354
+ messages=[
355
+ {
356
+ "role": "system",
357
+ "content": "Parse the judging responses into structured data.",
358
+ },
359
+ {"role": "user", "content": prompt},
360
+ ],
361
+ response_format=DirectAssessmentJudgingResponse,
362
+ )
363
+ return completion.choices[0].message.parsed
364
+
365
+
366
+ def get_llm_avatar(model_identifier):
367
+ if "agg__" in model_identifier:
368
+ return "img/council_icon.png"
369
  else:
370
+ return PROVIDER_TO_AVATAR_MAP[model_identifier]
 
 
 
 
 
 
 
 
 
 
 
371
 
372
 
373
  def plot_criteria_scores(df):
 
434
  ax = sns.barplot(
435
  x="ui_friendly_name",
436
  y="mean_score",
437
+ hue="ui_friendly_name",
438
  data=summary,
439
  palette="prism",
440
  capsize=0.1,
441
+ legend=False,
442
  )
443
 
444
  # Add error bars manually
 
453
  zorder=10, # Ensure error bars are on top
454
  )
455
 
456
+ # Add text annotations using the actual positions of the bars
457
+ for patch, row in zip(ax.patches, summary.itertuples()):
458
+ # Get the center of each bar (x position)
459
+ x = patch.get_x() + patch.get_width() / 2
460
+ y = patch.get_height()
461
+
462
+ # Add the text annotation
463
  ax.text(
464
+ x,
465
+ y,
466
+ f"{row.mean_score:.2f}",
467
  ha="center",
468
  va="bottom",
469
+ # fontweight="bold",
470
  color="black",
471
  bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.5),
472
  )
 
484
  def plot_per_judge_overall_scores(df):
485
  # Find the overall score by finding the overall score for each judge, and then averaging
486
  # over all judges.
487
+ grouped = df.groupby(["judging_model"]).agg({"score": ["mean"]}).reset_index()
488
+ grouped.columns = ["judging_model", "overall_score"]
489
 
490
  # Create the horizontal bar plot
491
  plt.figure(figsize=(10, 6))
492
  ax = sns.barplot(
493
  data=grouped,
494
+ x="judging_model",
495
+ y="overall_score",
496
+ hue="judging_model",
497
+ orient="v",
498
+ palette="rainbow",
499
  )
500
 
501
  # Customize the plot
502
+ plt.title("Overall Score from each LLM Judge")
503
  plt.xlabel("Overall Score")
504
+ plt.ylabel("LLM Judge")
505
 
506
  # Adjust layout and display the plot
507
  plt.tight_layout()
 
549
  cols = st.columns([2, 1, 2])
550
  if not st.session_state.authenticated:
551
  with cols[1]:
552
+ with st.form("login_form"):
553
+ password = st.text_input("Password", type="password")
554
+ submit_button = st.form_submit_button("Login", use_container_width=True)
555
+
556
+ if submit_button:
557
+ if password == PASSWORD:
558
+ st.session_state.authenticated = True
559
+ st.success("Logged in successfully!")
560
+ st.rerun()
561
+ else:
562
+ st.error("Invalid credentials")
563
 
564
  if st.session_state.authenticated:
565
+ if "responses_collected" not in st.session_state:
566
+ st.session_state["responses_collected"] = False
 
 
 
 
 
 
 
 
567
  # Initialize session state for collecting responses.
568
  if "responses" not in st.session_state:
569
+ st.session_state.responses = defaultdict(str)
570
+ # Initialize session state for token usage.
571
+ if "input_token_usage" not in st.session_state:
572
+ st.session_state["input_token_usage"] = defaultdict(int)
573
+ if "output_token_usage" not in st.session_state:
574
+ st.session_state["output_token_usage"] = defaultdict(int)
575
+ if "selected_models" not in st.session_state:
576
+ st.session_state["selected_models"] = []
577
+ if "selected_aggregator" not in st.session_state:
578
+ st.session_state["selected_aggregator"] = None
579
+
580
+ with st.form(key="prompt_form"):
581
+ st.markdown("#### LLM Council Member Selection")
582
+
583
+ # Council and aggregator selection
584
+ selected_models = llm_council_selector()
585
+ selected_aggregator = aggregator_selector()
586
+
587
+ # Prompt input and submission form
588
+ st.markdown("#### Enter your prompt")
589
+ _, center_column, _ = st.columns([3, 5, 3])
590
+ with center_column:
591
+ user_prompt = st.text_area(
592
+ "Enter your prompt",
593
+ value="Say 'Hello World'",
594
+ key="user_prompt",
595
+ label_visibility="hidden",
596
+ )
597
+ submit_button = st.form_submit_button(
598
+ "Submit", use_container_width=True
599
+ )
600
 
601
+ if submit_button:
602
  st.markdown("#### Responses")
603
 
604
+ # Udpate state.
605
+ st.session_state.selected_models = selected_models
606
+ st.session_state.selected_aggregator = selected_aggregator
607
+
608
+ # Render the chats.
609
  response_columns = st.columns(3)
610
 
611
  selected_models_to_streamlit_column_map = {
 
613
  }
614
 
615
  # Fetching and streaming responses from each selected model
616
+ for selected_model in st.session_state.selected_models:
617
  with selected_models_to_streamlit_column_map[selected_model]:
618
  st.write(get_ui_friendly_name(selected_model))
619
  with st.chat_message(
 
632
  user_prompt=user_prompt, llms=selected_models
633
  )
634
 
 
 
 
635
  # Fetching and streaming response from the aggregator
636
+ st.write(f"{get_ui_friendly_name(selected_aggregator)}")
637
  with st.chat_message(
638
  selected_aggregator,
639
  avatar="img/council_icon.png",
 
647
  message_placeholder.write_stream(aggregator_stream)
648
  )
649
 
650
+ st.session_state.responses_collected = True
651
+
652
+ # Render chats generally?
653
+ if st.session_state.responses and not submit_button:
654
+ st.markdown("#### Responses")
655
+
656
+ response_columns = st.columns(3)
657
+ selected_models_to_streamlit_column_map = {
658
+ model: response_columns[i]
659
+ for i, model in enumerate(st.session_state.selected_models)
660
+ }
661
+ for response_model, response in st.session_state.responses.items():
662
+ st_column = selected_models_to_streamlit_column_map.get(
663
+ response_model, response_columns[0]
664
+ )
665
+ with st_column.chat_message(
666
+ response_model,
667
+ avatar=get_llm_avatar(response_model),
668
+ ):
669
+ st.write(get_ui_friendly_name(response_model))
670
+ st.write(response)
671
 
672
  # Judging.
673
+ if st.session_state.responses_collected:
674
+ st.markdown("#### Judging Configuration")
675
 
676
+ # Choose the type of assessment
677
+ assessment_type = st.radio(
678
+ "Select the type of assessment",
679
+ options=["Direct Assessment", "Pairwise Comparison"],
680
+ )
681
 
682
+ _, center_column, _ = st.columns([3, 5, 3])
683
 
684
+ # Depending on the assessment type, render different forms
685
+ if assessment_type == "Direct Assessment":
686
 
687
+ # Initialize session state for direct assessment.
688
+ if "direct_assessment_overall_score" not in st.session_state:
689
+ st.session_state["direct_assessment_overall_score"] = {}
690
+ if "direct_assessment_judging_df" not in st.session_state:
691
+ st.session_state["direct_assessment_judging_df"] = {}
692
+ for response_model in selected_models:
693
+ st.session_state["direct_assessment_judging_df"][
694
+ response_model
695
+ ] = {}
696
+ # aggregator model
697
  st.session_state["direct_assessment_judging_df"][
698
+ "agg__" + selected_aggregator
699
  ] = {}
700
+ if "direct_assessment_judging_responses" not in st.session_state:
701
+ st.session_state["direct_assessment_judging_responses"] = {}
702
+ for response_model in selected_models:
703
+ st.session_state["direct_assessment_judging_responses"][
704
+ response_model
705
+ ] = {}
706
+ # aggregator model
707
  st.session_state["direct_assessment_judging_responses"][
708
+ "agg__" + selected_aggregator
709
  ] = {}
710
+ if "direct_assessment_overall_scores" not in st.session_state:
711
+ st.session_state["direct_assessment_overall_scores"] = {}
712
+ for response_model in selected_models:
713
+ st.session_state["direct_assessment_overall_scores"][
714
+ response_model
715
+ ] = {}
 
716
  st.session_state["direct_assessment_overall_scores"][
717
+ "agg__" + selected_aggregator
718
  ] = {}
719
+ if "judging_status" not in st.session_state:
720
+ st.session_state["judging_status"] = "incomplete"
721
+
722
+ # Direct assessment prompt.
723
+ with center_column.expander("Direct Assessment Prompt"):
724
+ direct_assessment_prompt = st.text_area(
725
+ "Prompt for the Direct Assessment",
726
+ value=get_default_direct_assessment_prompt(
727
+ user_prompt=user_prompt
728
+ ),
729
+ height=500,
730
+ key="direct_assessment_prompt",
731
+ )
 
 
 
 
 
 
 
 
 
732
 
733
+ # TODO: Add option to edit criteria list with a basic text field.
734
+ criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
735
+
736
+ # Create DirectAssessment object when form is submitted
737
+ if center_column.button(
738
+ "Submit Direct Assessment", use_container_width=True
739
+ ):
740
+
741
+ # Render the chats.
742
+ response_columns = st.columns(3)
743
+ selected_models_to_streamlit_column_map = {
744
+ model: response_columns[i]
745
+ for i, model in enumerate(selected_models)
746
+ }
747
+ for response_model, response in st.session_state[
748
+ "responses"
749
+ ].items():
750
+ st_column = selected_models_to_streamlit_column_map.get(
751
+ response_model, response_columns[0]
752
+ )
753
+ with st_column:
754
+ with st.chat_message(
755
+ get_ui_friendly_name(response_model),
756
+ avatar=get_llm_avatar(response_model),
757
+ ):
758
+ st.write(get_ui_friendly_name(response_model))
759
+ st.write(response)
760
 
761
+ # Submit direct asssessment.
762
+ responses_for_judging = st.session_state["responses"]
763
 
764
+ response_judging_columns = st.columns(3)
765
 
766
+ responses_for_judging_to_streamlit_column_map = {
767
+ model: response_judging_columns[i % 3]
768
+ for i, model in enumerate(responses_for_judging.keys())
769
+ }
770
 
771
+ # Get judging responses.
772
+ for response_model, response in responses_for_judging.items():
773
 
774
+ st_column = responses_for_judging_to_streamlit_column_map[
775
+ response_model
776
+ ]
777
 
778
+ with st_column:
779
+ st.write(
780
+ f"Judging for {get_ui_friendly_name(response_model)}"
781
+ )
782
+ judging_prompt = get_direct_assessment_prompt(
783
+ direct_assessment_prompt=direct_assessment_prompt,
784
+ user_prompt=user_prompt,
785
+ response=response,
786
+ criteria_list=criteria_list,
787
+ options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
788
+ )
 
 
 
 
789
 
790
+ with st.expander("Final Judging Prompt"):
791
+ st.code(judging_prompt)
792
 
793
+ for judging_model in selected_models:
794
+ with st.expander(
795
+ get_ui_friendly_name(judging_model), expanded=False
 
 
 
 
796
  ):
797
+ with st.chat_message(
798
+ judging_model,
799
+ avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
800
+ ):
801
+ message_placeholder = st.empty()
802
+ judging_stream = get_llm_response_stream(
803
+ judging_model, judging_prompt
804
+ )
805
+ st.session_state[
806
+ "direct_assessment_judging_responses"
807
+ ][response_model][
808
+ judging_model
809
+ ] = message_placeholder.write_stream(
810
+ judging_stream
811
+ )
812
+ # When all of the judging is finished for the given response, get the actual
813
+ # values, parsed.
814
+ # TODO.
815
+ judging_responses = st.session_state[
816
+ "direct_assessment_judging_responses"
817
+ ][response_model]
818
+
819
+ if not judging_responses:
820
+ st.error(f"No judging responses for {response_model}")
821
+ quit()
822
+ parse_judging_response_prompt = (
823
+ get_parse_judging_response_for_direct_assessment_prompt(
824
+ judging_responses,
825
+ criteria_list,
826
+ SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
827
+ )
828
+ )
829
+ # Issue the prompt to openai mini with structured outputs
830
+ parsed_judging_responses = parse_judging_responses(
831
+ parse_judging_response_prompt, judging_responses
832
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
833
 
 
834
  st.session_state["direct_assessment_judging_df"][
835
  response_model
836
+ ] = create_dataframe_for_direct_assessment_judging_response(
837
+ parsed_judging_responses
838
+ )
839
 
840
+ plot_criteria_scores(
841
+ st.session_state["direct_assessment_judging_df"][
842
+ response_model
843
+ ]
844
+ )
 
 
845
 
846
+ # Find the overall score by finding the overall score for each judge, and then averaging
847
+ # over all judges.
848
+ plot_per_judge_overall_scores(
849
+ st.session_state["direct_assessment_judging_df"][
850
+ response_model
851
+ ]
852
+ )
 
 
853
 
854
+ grouped = (
855
+ st.session_state["direct_assessment_judging_df"][
856
+ response_model
857
+ ]
858
+ .groupby(["judging_model"])
859
+ .agg({"score": ["mean"]})
860
+ .reset_index()
861
+ )
862
+ grouped.columns = ["judging_model", "overall_score"]
863
+
864
+ # Save the overall scores to the session state.
865
+ for record in grouped.to_dict(orient="records"):
866
+ st.session_state["direct_assessment_overall_scores"][
867
+ response_model
868
+ ][record["judging_model"]] = record["overall_score"]
869
+
870
+ overall_score = grouped["overall_score"].mean()
871
+ controversy = grouped["overall_score"].std()
872
+ st.write(f"Overall Score: {overall_score:.2f}")
873
+ st.write(f"Controversy: {controversy:.2f}")
874
+
875
+ st.session_state["judging_status"] = "complete"
876
+
877
+ # Judging is complete.
878
+ # The session state now contains the overall scores for each response from each judge.
879
+ if st.session_state["judging_status"] == "complete":
880
+ st.write("#### Results")
881
+
882
+ overall_scores_df_raw = pd.DataFrame(
883
+ st.session_state["direct_assessment_overall_scores"]
884
+ ).reset_index()
885
+
886
+ overall_scores_df = pd.melt(
887
+ overall_scores_df_raw,
888
+ id_vars=["index"],
889
+ var_name="response_model",
890
+ value_name="score",
891
+ ).rename(columns={"index": "judging_model"})
892
+
893
+ # Print the overall winner.
894
+ overall_winner = overall_scores_df.loc[
895
+ overall_scores_df["score"].idxmax()
896
+ ]
897
 
898
+ st.write(
899
+ f"**Overall Winner:** {get_ui_friendly_name(overall_winner['response_model'])}"
900
+ )
901
+ # Find how much the standard deviation overlaps with other models.
902
+ # Calculate separability.
903
+ # TODO.
904
+ st.write(f"**Confidence:** {overall_winner['score']:.2f}")
905
+
906
+ left_column, right_column = st.columns([1, 1])
907
+ with left_column:
908
+ plot_overall_scores(overall_scores_df)
909
+
910
+ with right_column:
911
+ # All overall scores.
912
+ overall_scores_df = overall_scores_df[
913
+ ["response_model", "judging_model", "score"]
914
+ ]
915
+ overall_scores_df["response_model"] = overall_scores_df[
916
+ "response_model"
917
+ ].apply(get_ui_friendly_name)
918
+ overall_scores_df["judging_model"] = overall_scores_df[
919
+ "judging_model"
920
+ ].apply(get_ui_friendly_name)
921
+
922
+ with st.expander("Overall scores from all judges"):
923
+ st.dataframe(overall_scores_df)
924
+
925
+ # All criteria scores.
926
+ with right_column:
927
+ all_scores_df = pd.DataFrame()
928
+ for response_model, score_df in st.session_state[
929
+ "direct_assessment_judging_df"
930
+ ].items():
931
+ score_df["response_model"] = response_model
932
+ all_scores_df = pd.concat([all_scores_df, score_df])
933
+ all_scores_df = all_scores_df.reset_index()
934
+ all_scores_df = all_scores_df.drop(columns="index")
935
+
936
+ # Reorder the columns
937
+ all_scores_df = all_scores_df[
938
+ [
939
+ "response_model",
940
+ "judging_model",
941
+ "criteria",
942
+ "score",
943
+ "explanation",
944
+ ]
945
+ ]
946
+ all_scores_df["response_model"] = all_scores_df[
947
+ "response_model"
948
+ ].apply(get_ui_friendly_name)
949
+ all_scores_df["judging_model"] = all_scores_df[
950
+ "judging_model"
951
+ ].apply(get_ui_friendly_name)
952
+
953
+ with st.expander(
954
+ "Criteria-specific scores and explanations from all judges"
955
+ ):
956
+ st.dataframe(all_scores_df)
957
+
958
+ elif assessment_type == "Pairwise Comparison":
959
+ pass
960
+
961
+ # Token usage.
962
+ with st.expander("Token Usage"):
963
+ st.write("Input tokens used.")
964
+ st.write(st.session_state.input_token_usage)
965
+ st.write(
966
+ f"Input Tokens Total: {sum(st.session_state.input_token_usage.values())}"
967
+ )
968
+ st.write("Output tokens used.")
969
+ st.write(st.session_state.output_token_usage)
970
+ st.write(
971
+ f"Output Tokens Total: {sum(st.session_state.output_token_usage.values())}"
972
+ )
973
 
974
  else:
975
  with cols[1]:
constants.py CHANGED
@@ -1,18 +1,42 @@
1
- LLM_COUNCIL_MEMBERS = {
2
- "Smalls": [
3
- # "openai://gpt-4o-mini",
4
- "together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
5
- "together://meta-llama/Llama-3.2-3B-Instruct-Turbo",
6
- # "vertex://gemini-1.5-flash-001",
7
- # "anthropic://claude-3-haiku-20240307",
8
- ],
9
- "Flagships": [
10
- "openai://gpt-4o",
11
- "together://meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
12
- "vertex://gemini-1.5-pro-001",
13
- "anthropic://claude-3-5-sonnet",
14
- ],
15
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  PROVIDER_TO_AVATAR_MAP = {
18
  "openai://gpt-4o-mini": "",
@@ -34,9 +58,17 @@ LLM_TO_UI_NAME_MAP = {
34
  "anthropic://claude-3-haiku-20240307": "Claude 3 Haiku",
35
  }
36
 
37
- # AGGREGATORS = ["openai://gpt-4o-mini", "openai://gpt-4o"]
38
- AGGREGATORS = ["together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"]
39
-
 
 
 
 
 
 
 
 
40
 
41
  # Fix the aggregator step.
42
  # Add a judging step.
 
1
+ import os
2
+ import dotenv
3
+
4
+ dotenv.load_dotenv()
5
+
6
+
7
+ if os.getenv("DEBUG_MODE") == "True":
8
+ LLM_COUNCIL_MEMBERS = {
9
+ "Smalls": [
10
+ "together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
11
+ "together://meta-llama/Llama-3.2-3B-Instruct-Turbo",
12
+ # "anthropic://claude-3-haiku-20240307",
13
+ ],
14
+ "Flagships": [
15
+ "together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
16
+ "together://meta-llama/Llama-3.2-3B-Instruct-Turbo",
17
+ "anthropic://claude-3-haiku-20240307",
18
+ ],
19
+ }
20
+ else:
21
+ LLM_COUNCIL_MEMBERS = {
22
+ "Smalls": [
23
+ "openai://gpt-4o-mini",
24
+ "together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
25
+ "together://meta-llama/Llama-3.2-3B-Instruct-Turbo",
26
+ "vertex://gemini-1.5-flash-001",
27
+ "anthropic://claude-3-haiku-20240307",
28
+ ],
29
+ "Flagships": [
30
+ "openai://gpt-4o",
31
+ "together://meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
32
+ "vertex://gemini-1.5-pro-002",
33
+ "anthropic://claude-3-5-sonnet",
34
+ ],
35
+ "OpenAI": [
36
+ "openai://gpt-4o",
37
+ "openai://gpt-4o-mini",
38
+ ],
39
+ }
40
 
41
  PROVIDER_TO_AVATAR_MAP = {
42
  "openai://gpt-4o-mini": "",
 
58
  "anthropic://claude-3-haiku-20240307": "Claude 3 Haiku",
59
  }
60
 
61
+ if os.getenv("DEBUG_MODE") == "True":
62
+ AGGREGATORS = ["together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"]
63
+ else:
64
+ AGGREGATORS = [
65
+ "anthropic://claude-3-haiku-20240307",
66
+ "together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
67
+ "together://meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
68
+ "together://meta-llama/Llama-3.2-3B-Instruct-Turbo",
69
+ "openai://gpt-4o",
70
+ "openai://gpt-4o-mini",
71
+ ]
72
 
73
  # Fix the aggregator step.
74
  # Add a judging step.
img/qwen.webp ADDED