VARCO_Arena / pages /see_results.py
sonsus's picture
Update pages/see_results.py
929aefc verified
raw
history blame
14.6 kB
import pandas as pd
import streamlit as st
import analysis_utils as au
from analysis_utils import number_breakdown_from_df
from app import load_and_cache_data
# from app import VA_ROOT
from query_comp import QueryWrapper, get_base_url
from varco_arena.varco_arena_core.prompts import load_prompt
from view_utils import (
default_page_setting,
escape_markdown,
set_nav_bar,
show_linebreak_in_md,
)
DEFAULT_LAYOUT_DICT = {
"title": {"font": {"size": 20, "family": "Gothic A1"}},
"font": {"size": 16, "family": "Gothic A1"},
"xaxis": {"tickfont": {"size": 12, "family": "Gothic A1"}},
"yaxis": {"tickfont": {"size": 12, "family": "Gothic A1"}},
"legend": {"font": {"size": 12, "family": "Gothic A1"}},
}
def navigate(t, source, key, val):
# print(key, val)
if source is None:
return
target_index = t.index(source) + val
if 0 <= target_index < len(t):
st.session_state[key] = t[target_index]
st.rerun()
def main():
sidebar_placeholder = default_page_setting(layout="wide")
set_nav_bar(
False,
sidebar_placeholder=sidebar_placeholder,
toggle_hashstr="see_results_init",
)
# load the data
# print(f"{st.session_state.get('result_file_path', None)=}")
most_recent_run = st.session_state.get("result_file_path", None)
most_recent_run = str(most_recent_run) if most_recent_run is not None else None
(
st.session_state["all_result_dict"],
st.session_state["df_dict"],
) = load_and_cache_data(result_file_path=most_recent_run)
# side bar
st.sidebar.title("Select Result:")
result_select = QueryWrapper("expname")(
st.sidebar.selectbox,
list(st.session_state["all_result_dict"].keys()),
)
if result_select is None:
if st.session_state.korean:
st.markdown("๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•˜๋ ค๋ฉด ๋จผ์ € **๐Ÿ”ฅVARCO Arena๋ฅผ ๊ตฌ๋™**ํ•˜์…”์•ผ ํ•ฉ๋‹ˆ๋‹ค")
else:
st.markdown("You should **๐Ÿ”ฅRun VARCO Arena** first to see results")
st.stop()
eval_prompt_name = result_select.split("/")[-1].strip()
if st.sidebar.button("Clear Cache"):
st.cache_data.clear()
st.cache_resource.clear()
st.rerun()
if result_select:
if "alpha2names" in st.session_state:
del st.session_state["alpha2names"]
fig_dict_per_task = st.session_state["all_result_dict"][result_select]
task_list = list(fig_dict_per_task.keys())
elo_rating_by_task = fig_dict_per_task["Overall"]["elo_rating_by_task"]
# tabs = st.tabs(task_list)
df_dict_per_task = st.session_state["df_dict"][result_select]
default_layout_dict = DEFAULT_LAYOUT_DICT
task = QueryWrapper("task", "Select Task")(st.selectbox, task_list)
if task is None:
st.stop()
figure_dict = fig_dict_per_task[task]
judgename = figure_dict["judgename"]
df = df_dict_per_task[task]
interpretation, n_models, size_testset = number_breakdown_from_df(df)
if st.session_state.korean:
st.markdown(f"## ๊ฒฐ๊ณผ ({task})")
st.markdown(f"##### Judge ๋ชจ๋ธ: {judgename} / ํ‰๊ฐ€ํ”„๋กฌ: {eval_prompt_name}")
st.markdown(f"##### ํ…Œ์ŠคํŠธ์…‹ ์‚ฌ์ด์ฆˆ: {int(size_testset)} ํ–‰")
else:
st.markdown(f"## Results ({task})")
st.markdown(f"##### Judge Model: {judgename} / prompt: {eval_prompt_name}")
st.markdown(f"##### Size of Testset: {int(size_testset)} rows")
col1, col2 = st.columns(2)
with col1:
with st.container(border=True):
st.markdown(f"#### Ratings ({task})")
st.table(figure_dict["elo_rating"])
st.write(show_linebreak_in_md(escape_markdown(interpretation)))
with col2:
with st.container(border=True):
st.plotly_chart(
elo_rating_by_task.update_layout(**default_layout_dict),
use_container_width=True,
key=f"{task}_elo_rating_by_task",
)
st.divider()
if st.session_state.korean:
st.markdown("### ํ† ๋„ˆ๋จผํŠธ (ํ…Œ์ŠคํŠธ ์‹œ๋‚˜๋ฆฌ์˜ค) ๋ณ„๋กœ ๋ณด๊ธฐ")
else:
st.markdown("### Tournament Results by Test Scenario")
# with st.expander("๋ณผ ํ† ๋„ˆ๋จผํŠธ ๊ณ ๋ฅด๊ธฐ"):
d = list(df.idx_inst_src.unique())
default_idx = st.session_state.get("selected_tournament", None)
cols = st.columns((1, 18, 1))
with cols[0]:
if st.button("โ—€", key="prev_tournament"):
navigate(d, default_idx, "selected_tournament", -1)
with cols[1]:
tournament_prm_select = QueryWrapper("tournament", "Select Tournament")(
st.selectbox,
d,
default_idx,
key=f"{task}_tournament_select",
on_change=lambda: st.session_state.update(
selected_tournament=st.session_state.get(f"{task}_tournament_select"),
selected_match=None,
),
label_visibility="collapsed",
)
with cols[2]:
if st.button("โ–ถ", key="next_tournament"):
navigate(d, default_idx, "selected_tournament", 1)
# tournament_prm_select = st.selectbox(
# "Select Tournament",
# df.idx_inst_src.unique(),
# index=d.index(st.session_state.get("selected_tournament")),
# key=f"{task}_tournament_{result_select}",
# )
# print(tournament_prm_select, type(tournament_prm_select))
st.session_state["selected_tournament"] = tournament_prm_select
# tournament_prm_select = st.selectbox(
# "Select Tournament",
# df.idx_inst_src.unique(),
# key=f"{task}_tournament_{result_select}",
# )
df_now_processed = None
if tournament_prm_select:
df_now = df[df.idx_inst_src == tournament_prm_select]
df_now_processed, _alpha2names = au.init_tournament_dataframe(
df_now,
alpha2names=st.session_state["alpha2names"]
if "alpha2names" in st.session_state.keys()
else None,
)
if "alpha2names" not in st.session_state:
st.session_state["alpha2names"] = _alpha2names
try:
bracket_drawing = au.draw(
df_now_processed,
alpha2names=st.session_state["alpha2names"],
)
legend = au.make_legend_str(
df_now_processed, st.session_state["alpha2names"]
)
st.code(bracket_drawing + legend)
m = list(df_now_processed.human_readable_idx)
default_idx = st.session_state.get("selected_match", None)
cols = st.columns((1, 18, 1))
with cols[0]:
if st.button("โ—€", key="prev_match"):
navigate(m, default_idx, "selected_match", -1)
with cols[1]:
match_idx_human = QueryWrapper("match", "Select Match")(
st.selectbox,
m,
default_idx,
key=f"{task}_match_select",
label_visibility="collapsed",
)
with cols[2]:
if st.button("โ–ถ", key="next_match"):
navigate(m, default_idx, "selected_match", 1)
# match_idx_human = st.selectbox(
# "Select Match",
# df_now_processed.human_readable_idx,
# key=f"{task}_match_{result_select}",
# )
# print(match_idx_human)
st.session_state["selected_match"] = match_idx_human
# match_idx_human = st.selectbox(
# "Select Match",
# df_now_processed.human_readable_idx,
# key=f"{task}_match_{result_select}",
# )
if match_idx_human:
match_idx = int(match_idx_human.split(": ")[0])
row = df_now_processed.loc[match_idx]
st.markdown("#### Current Test Scenario:")
with st.expander(
f"### Evaluation Prompt (evalprompt: {eval_prompt_name}--{task})"
):
prompt = load_prompt(eval_prompt_name, task=task)
kwargs = dict(
inst="{inst}",
src="{src}",
out_a="{out_a}",
out_b="{out_b}",
task=task,
)
if eval_prompt_name == "translation_pair":
kwargs["source_lang"] = "{source_lang}"
kwargs["target_lang"] = "{target_lang}"
prompt_cmpl = prompt.complete_prompt(**kwargs)
for msg in prompt_cmpl:
st.markdown(f"**{msg['role']}**")
st.info(show_linebreak_in_md(escape_markdown(msg["content"])))
st.info(show_linebreak_in_md(tournament_prm_select))
winner = row.winner
col1, col2 = st.columns(2)
winnerbox = st.success
loserbox = st.error
with col1:
iswinner = winner == "model_a"
writemsg = winnerbox if iswinner else loserbox
st.markdown(f"#### ({row.model_a}) {row.human_readable_model_a}")
writemsg(
show_linebreak_in_md(row.generated_a),
icon="โœ…" if iswinner else "โŒ",
)
with col2:
iswinner = winner == "model_b"
writemsg = winnerbox if iswinner else loserbox
st.markdown(f"#### ({row.model_b}) {row.human_readable_model_b}")
writemsg(
show_linebreak_in_md(row.generated_b),
icon="โœ…" if iswinner else "โŒ",
)
except Exception as e:
import traceback
traceback.print_exc()
st.markdown(
"**Bug: ์•„๋ž˜ ํ‘œ๋ฅผ ๋ณต์‚ฌํ•ด์„œ ์ด์Šˆ๋กœ ๋‚จ๊ฒจ์ฃผ์‹œ๋ฉด ๊ฐœ์„ ์— ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค. ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค๐Ÿ™**"
if st.session_state.korean
else "Bug: Please open issue and attach the table output below to help me out. Thanks in advance.๐Ÿ™"
)
st.error(e)
st.info(tournament_prm_select)
st.table(
df_now_processed[
[
"depth",
"round",
"winner_nodes",
"winner_resolved",
"winner",
"model_a",
"model_b",
]
]
)
st.write("Sharable link")
st.code(f"{get_base_url()}/see_results?{QueryWrapper.get_sharable_link()}")
st.divider()
if st.session_state.korean:
st.markdown("### ๋งค์น˜ ํ†ต๊ณ„")
else:
st.markdown("### Match Stats.")
col1, col2 = st.columns(2)
col1, col2 = st.columns(2)
with col1:
with st.container(border=True):
st.plotly_chart(
figure_dict[
"fraction_of_model_a_wins_for_all_a_vs_b_matches"
].update_layout(autosize=True, **default_layout_dict),
use_container_width=True,
key=f"{task}_fraction_of_model_a_wins_for_all_a_vs_b_matches",
)
with col2:
with st.container(border=True):
st.plotly_chart(
figure_dict["match_count_of_each_combination_of_models"].update_layout(
autosize=True, **default_layout_dict
),
use_container_width=True,
key=f"{task}_match_count_of_each_combination_of_models",
)
with col1:
with st.container(border=True):
st.plotly_chart(
figure_dict["match_count_for_each_model"].update_layout(
**default_layout_dict
),
use_container_width=True,
key=f"{task}_match_count_for_each_model",
)
with col2:
pass
if st.session_state.korean:
st.markdown("### ์ฐธ๊ณ ์šฉ LLM Judge ํŽธํ–ฅ ์ •๋ณด")
else:
st.markdown("### FYI: How biased is your LLM Judge?")
with st.expander("ํŽผ์ณ์„œ ๋ณด๊ธฐ" if st.session_state.korean else "Expand to show"):
st.info(
"""
Varco Arena์—์„œ๋Š” position bias์˜ ์˜ํ–ฅ์„ ์ตœ์†Œํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋“  ๋ชจ๋ธ์ด A๋‚˜ B์œ„์น˜์— ๋ฒˆ๊ฐˆ์•„ ์œ„์น˜ํ•˜๋„๋ก ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ LLM Judge ํ˜น์€ Prompt์˜ ์„ฑ๋Šฅ์ด ๋ถ€์กฑํ•˜๋‹ค๊ณ  ๋Š๊ปด์ง„๋‹ค๋ฉด, ์•„๋ž˜ ์•Œ๋ ค์ง„ LLM Judge bias๊ฐ€ ์ฐธ๊ณ ๊ฐ€ ๋ ๊ฒ๋‹ˆ๋‹ค.
* position bias (์™ผ์ชฝ)
* length bias (์˜ค๋ฅธ์ชฝ)
๊ฒฐ๊ณผ์˜ ์™œ๊ณก์ด LLM Judge์˜ ๋ถ€์กฑํ•จ ๋–„๋ฌธ์ด์—ˆ๋‹ค๋Š” ์ ์„ ๊ทœ๋ช…ํ•˜๋ ค๋ฉด ์‚ฌ์šฉํ•˜์‹  LLM Judge์™€ Prompt์˜ binary classification ์ •ํ™•๋„๋ฅผ ์ธก์ •ํ•ด๋ณด์‹œ๊ธธ ๋ฐ”๋ž๋‹ˆ๋‹ค (Varco Arena๋ฅผ ํ™œ์šฉํ•˜์—ฌ ์ด๋ฅผ ์ˆ˜ํ–‰ํ•ด๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!).""".strip()
if st.session_state.korean
else """
In Varco Arena, to minimize the effect of position bias, all models are alternately positioned in either position A or B. However, if you feel the LLM Judge or Prompt performance is insufficient, the following known LLM Judge biases may be helpful to reference:
* position bias (left)
* length bias (right)
To determine if result distortion was due to LLM Judge limitations, please measure the binary classification accuracy of your LLM Judge and Prompt (You could use Varco Arena for this purpose!).
""".strip()
)
st.markdown(f"#### {judgename} + prompt = {eval_prompt_name}")
col1, col2 = st.columns(2)
with col1:
with st.container(border=True):
st.plotly_chart(
figure_dict["counts_of_match_winners"].update_layout(
**default_layout_dict
),
use_container_width=True,
key=f"{task}_counts_of_match_winners",
)
with col2:
with st.container(border=True):
st.plotly_chart(
figure_dict["length_bias"].update_layout(**default_layout_dict),
use_container_width=True,
key=f"{task}_length_bias",
)
st.table(figure_dict["length_bias_df"].groupby("category").describe().T)
if __name__ == "__main__":
main()