File size: 1,233 Bytes
c4cced7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import streamlit as st
from datasets import load_dataset
import os
HF_TOKEN = os.environ.get("HF_TOKEN", None)
st.set_page_config(
page_title="Logprobs inspection", layout="wide"
)
st.markdown("# Logprobs inspection")
@st.cache_data
def load_data():
ds = load_dataset(
"HuggingFaceTB/sample_log_probs",
split="train",
token=HF_TOKEN,
)
return ds
ds = load_data()
min_log = min(ds["logprobs"])
max_log = max(ds["logprobs"])
col_1, col_2 = st.columns(2)
with col_1:
min_score = st.slider("Select minimum logprob", min_value=min_log, max_value=max_log, value=min_log, step=0.2, key="min_score")
with col_2:
max_score = st.slider("Select maximum logprob", min_value=min_log, max_value=max_log, value=max_log, step=0.2, key="max_score")
filtered_ds = ds.filter(lambda x: min_score <= x["logprobs"] <= max_score)
index = st.slider("Select a sample", 0, len(filtered_ds), 0)
with st.expander("The prompt"):
st.markdown(filtered_ds[index]['prompt'])
st.markdown(f"**Metadata:** log_prob is {filtered_ds[index]['logprobs']:.2f}, seed: {filtered_ds[index]['seed_data']}, {filtered_ds[index]['format']} for {filtered_ds[index]['audience']}.")
st.markdown(filtered_ds[index]["text"])
|