File size: 2,434 Bytes
8a1d0f1
2fda096
 
 
 
 
 
 
 
 
 
8a1d0f1
 
 
 
2fda096
8a1d0f1
 
 
 
 
 
 
 
 
 
 
 
2fda096
70bb6ea
44df589
 
88468c1
44df589
2fda096
 
 
88468c1
 
 
 
 
 
 
 
06a6958
2fda096
44df589
88468c1
2fda096
 
8a1d0f1
 
2fda096
 
 
 
 
06a6958
2fda096
 
 
 
 
5ac5327
2fda096
5ac5327
 
 
 
 
 
 
 
 
2fda096
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import streamlit as st
from datasets import load_dataset

st.set_page_config(
    page_icon="🧊",
    layout="wide",
)

st.write(
    "This is an application for viewing different generations for the same prompt. The generations vary depending on the checkpoint used and also the parameters used for the generation."
)

HF_API_TOKEN = st.secrets["HF_API_TOKEN"]
PROMPT_COLOR = "#CA437E"


def safe_text(text):
    text = text.replace("\n", "<br>")
    return f"<pre>{text}</pre>"


def prompt_markup_format(text):
    return f'<*font color="black">{text}</*font>'


def generation_markup_format(text):
    return f"<font color={PROMPT_COLOR}>{text}</pre></font>"


ds = load_dataset("SaulLu/bloom-generations", use_auth_token=HF_API_TOKEN)
ds = ds["train"]

possible_langs = ds.unique("lang")

col_1, col_2 = st.columns(2)
with col_1:
    st.markdown("<h1 style='text-align: center'>Prompt</h1>", unsafe_allow_html=True)
    chosen_lang = st.selectbox("Choose a lang", possible_langs + ["all"])
    if chosen_lang == "all":
        ds_lang = ds
    else:
        ds_lang = ds.filter(
    lambda exs: [lang == chosen_lang for lang in exs["lang"]], batched=True
)
    possible_prompts = ds_lang.unique("prompt")
    chosen_prompt = st.selectbox("Choose a prompt", possible_prompts)
    st.markdown(safe_text(chosen_prompt), unsafe_allow_html=True)

sub_ds = ds_lang.filter(
    lambda exs: [prompt == chosen_prompt for prompt in exs["prompt"]], batched=True
)


with col_2:
    st.markdown(
        "<h1 style='text-align: center'>Generation</h1>", unsafe_allow_html=True
    )
    index_sample = st.number_input(
        "Index of the chosen generation",
        min_value=0,
        max_value=len(sub_ds) - 1,
        value=0,
        step=1,
    )
    
    sample = sub_ds[index_sample]
    generation = sample["generation"]
    stop_index_sample = st.number_input(
        "Stop generation at character number",
        min_value=0,
        max_value=len(generation),
        value=len(generation),
        step=1,
    )
    markdown_text = generation_markup_format(safe_text(generation[:stop_index_sample]))
    st.markdown(markdown_text, unsafe_allow_html=True)
    st.markdown(
        "<h2 style='text-align: center'>Generation configuration</h2>",
        unsafe_allow_html=True,
    )
    config = {
        key: value
        for key, value in sample.items()
        if key not in ["prompt", "generation"]
    }
    config