File size: 8,974 Bytes
507a14d
 
9ceb843
e5d5995
8e499f4
df04a09
ab74236
bbe05a0
507a14d
 
 
 
 
df04a09
ab74236
7f5f365
df04a09
9ceb843
e5d5995
7f5f365
507a14d
 
9ceb843
df04a09
 
9ceb843
e4cd4cd
9ceb843
 
507a14d
 
 
df04a09
 
6ce351e
 
df04a09
 
63c5ebf
df04a09
507a14d
8e499f4
df04a09
 
 
e5d5995
 
df04a09
 
 
 
 
e5d5995
 
 
ab74236
8e499f4
 
df04a09
e5d5995
4a1518a
df04a09
8799e00
 
 
df04a09
 
8799e00
 
 
 
31bff5a
8799e00
9f4ce43
 
4a1518a
 
 
df04a09
63c5ebf
df04a09
 
9f4ce43
8799e00
874c0c9
df04a09
8799e00
bbe05a0
31bff5a
507a14d
521165c
874c0c9
521165c
31bff5a
 
 
 
908984c
31bff5a
9ceb843
df04a09
31bff5a
f89f357
df04a09
 
 
9ceb843
06fd8bd
31bff5a
df04a09
 
 
06fd8bd
9ceb843
31bff5a
df04a09
 
 
 
 
59b52cf
8799e00
df04a09
56fcfaf
df04a09
 
 
908984c
 
ab74236
df04a09
 
 
 
 
8799e00
 
df04a09
 
 
 
 
 
59b52cf
06fd8bd
9ceb843
 
 
8e499f4
 
 
 
92c7f09
df04a09
8e499f4
 
 
 
 
e5d5995
8799e00
df04a09
 
 
 
31bff5a
bbe05a0
 
 
149a173
17f167a
bd17252
 
 
 
149a173
bbe05a0
 
 
 
e5d5995
9ceb843
e5d5995
 
 
c8a4819
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gradio as gr
import os
from huggingface_hub import HfApi, snapshot_download
from apscheduler.schedulers.background import BackgroundScheduler
from datasets import load_dataset
from src.utils import load_all_data, prep_df, sort_by_category
from src.md import ABOUT_TEXT, TOP_TEXT
from src.css import custom_css
import numpy as np

api = HfApi()

COLLAB_TOKEN = os.environ.get("COLLAB_TOKEN")
evals_repo = "alrope/href_results"

eval_set_repo = "allenai/href_validation"
local_result_dir = "./results/"

def restart_space():
    api.restart_space(repo_id="allenai/href", token=COLLAB_TOKEN)

print("Pulling evaluation results")
repo = snapshot_download(
    local_dir=local_result_dir,
    ignore_patterns=[],
    repo_id=evals_repo,
    use_auth_token=COLLAB_TOKEN,
    tqdm_class=None, 
    etag_timeout=30,
    repo_type="dataset",
)
    
href_data_greedy = prep_df(load_all_data(local_result_dir, subdir="temperature=0.0"))
href_data_nongreedy = prep_df(load_all_data(local_result_dir, subdir="temperature=1.0"))


col_types_href = ["number"] + ["markdown"] + ["number"] * int((len(href_data_greedy.columns) - 1) / 2)
col_types_href_hidden = ["number"] + ["markdown"] + ["number"] * (len(href_data_greedy.columns) - 1)
categories = ['Average', 'Brainstorm', 'Open QA', 'Closed QA', 'Extract', 'Generation', 'Rewrite', 'Summarize', 'Classify', "Reasoning Over Numerical Data", "Multi-Document Synthesis", "Fact Checking or Attributed QA"]
# categories = ['Average', 'Brainstorm', 'Open QA', 'Closed QA', 'Extract', 'Generation', 'Rewrite', 'Summarize', 'Classify']

# for showing random samples
eval_set = load_dataset(eval_set_repo, use_auth_token=COLLAB_TOKEN, split="dev")
def random_sample(r: gr.Request, category):
    if category is None or category == []:
        sample_index = np.random.randint(0, len(eval_set) - 1)
        sample = eval_set[sample_index]
    else: # filter by category (can be list)
        if isinstance(category, str):
            category = [category]
        # filter down dataset to only include the category(s)
        eval_set_filtered = eval_set.filter(lambda x: x["category"] in category)
        sample_index = np.random.randint(0, len(eval_set_filtered) - 1)
        sample = eval_set_filtered[sample_index]

    markdown_text = '\n\n'.join([f"**{key}**:\n\n{value}" for key, value in sample.items()])
    return markdown_text

subsets = eval_set.unique("category")


def regex_table(dataframe, regex, selected_category, style=True):
    """
    Takes a model name as a regex, then returns only the rows that has that in it.
    """
    dataframe = sort_by_category(dataframe, selected_category)

    # Split regex statement by comma and trim whitespace around regexes
    regex_list = [x.strip() for x in regex.split(",")]
    # Join the list into a single regex pattern with '|' acting as OR
    combined_regex = '|'.join(regex_list)

    # Filter the dataframe such that 'model' contains any of the regex patterns
    data = dataframe[dataframe["Model"].str.contains(combined_regex, case=False, na=False)]

    data.reset_index(drop=True, inplace=True)

    if style:
         # Format for different columns
        format_dict = {col: "{:.1f}" for col in data.columns if col not in ['Average', 'Model', 'Rank', '95% CI']}
        format_dict['Average'] = "{:.2f}"
        data = data.style.format(format_dict, na_rep='').set_properties(**{'text-align': 'right'})
    return data


total_models = len(regex_table(href_data_greedy.copy(), "", "Average", style=False).values)

with gr.Blocks(css=custom_css) as app:
    # create tabs for the app, moving the current table to one titled "rewardbench" and the benchmark_text to a tab called "About"
    with gr.Row():
        with gr.Column(scale=6):
            gr.Markdown(TOP_TEXT.format(str(total_models)))
        with gr.Column(scale=4):
            # search = gr.Textbox(label="Model Search (delimit with , )", placeholder="Regex search for a model")
            # filter_button = gr.Checkbox(label="Include AI2 training runs (or type ai2 above).", interactive=True)
            # img = gr.Image(value="https://private-user-images.githubusercontent.com/10695622/310698241-24ed272a-0844-451f-b414-fde57478703e.png", width=500)
            gr.Markdown("""
                        <img src="file/src/logo.png" height="130">
                        """)
    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        with gr.TabItem("🏆 HREF Leaderboard"):
            with gr.Row():
                search_1 = gr.Textbox(label="Model Search (delimit with , )", 
                                    #   placeholder="Model Search (delimit with , )",
                                      show_label=True)
                category_selector_1 = gr.Dropdown(categories, label="Sorted By", value="Average", multiselect=False, show_label=True)
            with gr.Row():
                # reference data
                rewardbench_table_hidden = gr.Dataframe(
                    href_data_greedy.values,
                    datatype=col_types_href_hidden,
                    headers=href_data_greedy.columns.tolist(),
                    visible=False,
                )
                rewardbench_table = gr.Dataframe(
                    regex_table(href_data_greedy.copy(), "", "Average"),
                    datatype=col_types_href,
                    headers=href_data_greedy.columns.tolist(),
                    elem_id="href_data_greedy",
                    interactive=False,
                    height=1000,
                )
        with gr.TabItem("Non-Greedy"):
            with gr.Row():
                search_2 = gr.Textbox(label="Model Search (delimit with , )", 
                                    #   placeholder="Model Search (delimit with , )",
                                      show_label=True)
                category_selector_2 = gr.Dropdown(categories, label="Sorted By", value="Average",
                                                    multiselect=False, show_label=True, elem_id="category_selector")
            with gr.Row():
                # reference data
                rewardbench_table_hidden_nongreedy = gr.Dataframe(
                    href_data_nongreedy.values,
                    datatype=col_types_href_hidden,
                    headers=href_data_nongreedy.columns.tolist(),
                    visible=False,
                )
                rewardbench_table_nongreedy = gr.Dataframe(
                    regex_table(href_data_nongreedy.copy(), "", "Average"),
                    datatype=col_types_href,
                    headers=href_data_nongreedy.columns.tolist(),
                    elem_id="href_data_nongreedy",
                    interactive=False,
                    height=1000,
                )
        with gr.TabItem("About"):
            with gr.Row():
                gr.Markdown(ABOUT_TEXT)

        with gr.TabItem("Dataset Viewer"):
            with gr.Row():
                # loads one sample
                gr.Markdown("""## Random Dataset Sample Viewer""")
                subset_selector = gr.Dropdown(subsets, label="Category", value=None, multiselect=True)
                button = gr.Button("Show Random Sample")

            with gr.Row():
                sample_display = gr.Markdown("{sampled data loads here}")

            button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])

    search_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, category_selector_1], outputs=rewardbench_table)  
    category_selector_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, category_selector_1], outputs=rewardbench_table)
    search_2.change(regex_table, inputs=[rewardbench_table_hidden_nongreedy, search_2, category_selector_2], outputs=rewardbench_table_nongreedy)  
    category_selector_2.change(regex_table, inputs=[rewardbench_table_hidden_nongreedy, search_2, category_selector_2], outputs=rewardbench_table_nongreedy)

    with gr.Row():
        with gr.Accordion("📚 Citation", open=False):
            citation_button = gr.Textbox(
                value=r"""@misc{RewardBench,
    title={RewardBench: Evaluating Reward Models for Language Modeling},
    author={Lambert, Nathan and Pyatkin, Valentina and Morrison, Jacob and Miranda, LJ and Lin, Bill Yuchen and Chandu, Khyathi and Dziri, Nouha and Kumar, Sachin and Zick, Tom and Choi, Yejin and Smith, Noah A. and Hajishirzi, Hannaneh},
    year={2024},
    howpublished={\url{https://huggingface.co/spaces/allenai/reward-bench}
}""",
                lines=7,
                label="Copy the following to cite these results.",
                elem_id="citation-button",
                show_copy_button=True,
            )


scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=10800) # restarted every 3h
scheduler.start()
app.launch(allowed_paths=['src/']) # had .queue() before launch before... not sure if that's necessary