Spaces:
Running
Running
import re | |
import spaces | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from unidecode import unidecode | |
from gradio_i18n import gettext, Translate | |
from datasets import load_dataset | |
from style import custom_css, solution_style, letter_style, definition_style | |
template = """<s><|user|> | |
Risolvi gli indizi tra parentesi per ottenere una prima lettura, e usa la chiave di lettura per ottenere la soluzione del rebus. | |
Rebus: {rebus} | |
Chiave di lettura: {key}<|end|> | |
<|assistant|>""" | |
eureka5_test_data = load_dataset( | |
'gsarti/eureka-rebus', 'llm_sft', | |
data_files=["id_test.jsonl", "ood_test.jsonl"], | |
split = "train", | |
revision="1.0" | |
) | |
OUTPUTS_BASE_URL = "https://raw.githubusercontent.com/gsarti/verbalized-rebus/main/outputs/" | |
model_outputs = load_dataset( | |
"csv", | |
data_files={ | |
"gpt4": OUTPUTS_BASE_URL + "prompted_models/gpt4o_results.csv", | |
"claude3_5_sonnet": OUTPUTS_BASE_URL + "prompted_models/claude3_5_sonnet_results.csv", | |
"llama3_70b": OUTPUTS_BASE_URL + "prompted_models/llama3_70b_results.csv", | |
"qwen_72b": OUTPUTS_BASE_URL + "prompted_models/qwen_72b_results.csv", | |
"phi3_mini": OUTPUTS_BASE_URL + "phi3_mini/phi3_mini_results_step_5070.csv", | |
"gemma2": OUTPUTS_BASE_URL + "gemma2_2b/gemma2_2b_results_step_5070.csv", | |
"llama3_1_8b": OUTPUTS_BASE_URL + "llama3.1_8b/llama3.1_8b_results_step_5070.csv" | |
} | |
) | |
def extract(span_text: str, tag: str = "span") -> str: | |
pattern = rf'<{tag}[^>]*>(.*?)<\/{tag}>' | |
matches = re.findall(pattern, span_text) | |
return "".join(matches) if matches else "" | |
def parse_rebus(ex_idx: int): | |
i = eureka5_test_data[ex_idx - 1]["conversations"][0]["value"] | |
o = eureka5_test_data[ex_idx - 1]["conversations"][1]["value"] | |
rebus = i.split("Rebus: ")[1].split("\n")[0] | |
rebus_letters = re.sub(r"\[.*?\]", "<<<>>>", rebus) | |
rebus_letters = re.sub(r"([a-zA-Z]+)", rf"""{letter_style}\1</span>""", rebus_letters) | |
fp_empty = rebus_letters.replace("<<<>>>", f"{definition_style}___</span>") | |
key = i.split("Chiave di lettura: ")[1].split("\n")[0] | |
key_split = key | |
key_highlighted = re.sub(r"(\d+)", rf"""{solution_style}\1</span>""", key) | |
fp_elements = re.findall(r"- (.*) = (.*)", o) | |
definitions = [x[0] for x in fp_elements if x[0].startswith("[")] | |
for i, el in enumerate(fp_elements): | |
if el[0].startswith("["): | |
fp_elements[i] = (re.sub(r"\[(.*?)\]", rf"""{definition_style}[\1]</span>""", fp_elements[i][0]), fp_elements[i][1]) | |
else: | |
fp_elements[i] = ( | |
f"{letter_style}{fp_elements[i][0]}</span>", | |
f"{letter_style}{fp_elements[i][1]}</span>", | |
) | |
fp = re.findall(r"Prima lettura: (.*)", o)[0] | |
s_elements = re.findall(r"(\d+) = (.*)", o) | |
s = re.findall(r"Soluzione: (.*)", o)[0] | |
for d in definitions: | |
rebus_letters = rebus_letters.replace("<<<>>>", d, 1) | |
rebus_highlighted = re.sub(r"\[(.*?)\]", rf"""{definition_style}[\1]</span>""", rebus_letters) | |
return { | |
"rebus": rebus_highlighted, | |
"key": key_highlighted, | |
"key_split": key_split, | |
"fp_elements": fp_elements, | |
"fp": fp, | |
"fp_empty": fp_empty, | |
"s_elements": s_elements, | |
"s": s | |
} | |
#tokenizer = AutoTokenizer.from_pretrained("gsarti/phi3-mini-rebus-solver-fp16") | |
#model = AutoModelForCausalLM.from_pretrained("gsarti/phi3-mini-rebus-solver-fp16") | |
def solve_verbalized_rebus(example, history): | |
input = template.format(input=example) | |
#inputs = tokenizer(input, return_tensors="pt")["input_ids"] | |
#outputs = model.generate(input_ids = inputs, max_new_tokens = 500, use_cache = True) | |
#model_generations = tokenizer.batch_decode(outputs) | |
#return model_generations[0] | |
return input | |
#demo = gr.ChatInterface(fn=solve_verbalized_rebus, examples=["Rebus: [Materiale espulso dai vulcani] R O [Strumento del calzolaio] [Si trovano ai lati del bacino] C I [Si ingrassano con la polenta] E I N [Contiene scorte di cibi] B [Isola in francese]\nChiave risolutiva: 1 ' 5 6 5 3 3 1 14"], title="Verbalized Rebus Solver") | |
#demo.launch() | |
with gr.Blocks(css=custom_css) as demo: | |
lang = gr.Dropdown([("English", "en"), ("Italian", "it")], value="it", label="Select language:", interactive=True) | |
with Translate("translations.yaml", lang, placeholder_langs=["en", "it"]): | |
gr.Markdown(gettext("Title")) | |
gr.Markdown(gettext("Intro")) | |
with gr.Tab(gettext("GuessingGame")): | |
with gr.Row(): | |
with gr.Column(): | |
example_id = gr.Number(1, label=gettext("CurrentExample"), minimum=1, maximum=2000, step=1, interactive=True) | |
with gr.Column(): | |
show_length_hints = gr.Checkbox(False, label=gettext("ShowLengthHints"), interactive=True) | |
def show_example(example_number, show_length_hints): | |
parsed_rebus = parse_rebus(example_number) | |
gr.Markdown(gettext("Instructions")) | |
gr.Markdown(gettext("Rebus") + f"{parsed_rebus['rebus']}</h4>"), | |
gr.Markdown(gettext("Key") + f"{parsed_rebus['key']}</h4>") | |
gr.Markdown("<br><br>") | |
with gr.Row(): | |
answers: list[gr.Textbox] = [] | |
with gr.Column(scale=2): | |
gr.Markdown(gettext("ProceedToResolution")) | |
for el_key, el_value in parsed_rebus['fp_elements']: | |
with gr.Row(): | |
with gr.Column(scale=0.2, min_width=250): | |
gr.Markdown(f"<p>{el_key} = </p>") | |
if el_key.startswith('<span class="definition"') and show_length_hints: | |
gr.Markdown(f"<p>({len(el_value)} lettere)</p>") | |
with gr.Column(scale=0.2, min_width=150): | |
if el_key.startswith('<span class="definition"'): | |
definition_answer = gr.Textbox(show_label=False, placeholder="Guess...", interactive=True, max_lines=3) | |
answers.append(definition_answer) | |
else: | |
gr.Markdown(el_value) | |
gr.Markdown("<hr>") | |
with gr.Column(scale=3): | |
key_value = gr.Markdown(parsed_rebus['key_split'], visible=False) | |
fp_empty = gr.Markdown(parsed_rebus['fp_empty'], visible=False) | |
fp = gr.Markdown(gettext("FirstPass") + f"{parsed_rebus['fp_empty']}</h4><br>") | |
solution_words: list[gr.Markdown] = [] | |
clean_solution_words: list[str] = [] | |
clean_fp = extract(fp.value) | |
curr_idx = 0 | |
for n_char in parsed_rebus['key_split'].split(): | |
word = clean_fp[curr_idx:curr_idx + int(n_char)].upper() | |
clean_solution_words.append(word) | |
solution_word = gr.Markdown(gettext("SolutionWord") + f"{n_char}: {solution_style}{word}</span></h4>") | |
curr_idx += int(n_char) | |
solution_words.append(solution_word) | |
gr.Markdown("<br>") | |
solution = gr.Markdown(gettext("Solution") + f"{solution_style}{' '.join(clean_solution_words)}</span></h4>") | |
correct_solution = gr.Markdown(gettext("CorrectSolution") + f"{solution_style}{parsed_rebus['s'].upper()}</span></h4>", visible=False) | |
correct_solution_shown = gr.Checkbox(False, visible=False) | |
gr.Markdown("<hr>") | |
prompted_models = gr.Markdown(gettext("PromptedModels"), visible=False) | |
gpt4_solution = gr.Markdown(gettext("GPT4Solution") + f"{solution_style}{model_outputs['gpt4'][example_number - 1]['solution']}</span></h4>", visible=False) | |
claude_solution = gr.Markdown(gettext("ClaudeSolution") + f"{solution_style}{model_outputs['claude3_5_sonnet'][example_number - 1]['solution']}</span></h4>", visible=False) | |
llama3_70b_solution = gr.Markdown(gettext("LLaMA370BSolution") + f"{solution_style}{model_outputs['llama3_70b'][example_number - 1]['solution']}</span></h4>", visible=False) | |
qwen_72b_solution = gr.Markdown(gettext("Qwen72BSolution") + f"{solution_style}{model_outputs['qwen_72b'][example_number - 1]['solution']}</span></h4>", visible=False) | |
models_separator = gr.Markdown("<hr>", visible=False) | |
trained_models = gr.Markdown(gettext("TrainedModels"), visible=False) | |
llama3_1_8b_solution = gr.Markdown(gettext("LLaMA318BSolution") + f"{solution_style}{model_outputs['llama3_1_8b'][example_number - 1]['solution']}</span></h4>", visible=False) | |
phi3_mini_solution = gr.Markdown(gettext("Phi3MiniSolution") + f"{solution_style}{model_outputs['phi3_mini'][example_number - 1]['solution']}</span></h4>", visible=False) | |
gemma2_solution = gr.Markdown(gettext("Gemma22BSolution") + f"{solution_style}{model_outputs['gemma2'][example_number - 1]['solution']}</span></h4>", visible=False) | |
models_solutions_shown = gr.Checkbox(False, visible=False) | |
with gr.Row(): | |
btn_check = gr.Button(gettext("CheckSolution"), variant="primary") | |
btn_show = gr.Button(gettext("ShowSolution")) | |
btn_show_models_solutions = gr.Button(gettext("ShowModelsSolutions")) | |
def update_fp(fp_empty=fp_empty, key_value=key_value, *answers): | |
len_solutions = key_value.split() | |
for answer in answers: | |
if answer is not None and answer != "": | |
fp_empty = fp_empty.replace("___", answer, 1) | |
curr_idx = 0 | |
new_solutions = [] | |
new_solutions_clean = [] | |
clean_fp_empty = extract(fp_empty) | |
for n_char in len_solutions: | |
word = clean_fp_empty[curr_idx:curr_idx + int(n_char)].upper() | |
new_solutions_clean.append(word) | |
new_solutions.append(gr.Markdown(gettext("SolutionWord") + f"{n_char}: {solution_style}{word}</span></h4>")) | |
curr_idx += int(n_char) | |
return [ | |
gr.Markdown(gettext("FirstPass") + f"{fp_empty}</h4><br>"), | |
gr.Markdown(gettext("Solution") + f"{solution_style}{' '.join(new_solutions_clean)}</span></h4>") | |
] + new_solutions | |
def check_solution(solution, correct_solution): | |
solution = unidecode(extract(solution)) | |
correct_solution = unidecode(extract(correct_solution)) | |
if solution == correct_solution: | |
gr.Info(gettext("CorrectSolutionMsg")) | |
else: | |
gr.Info(gettext("IncorrectSolutionMsg")) | |
def show_solution(correct_solution, btn_show, shown): | |
if shown: | |
return gr.Markdown(correct_solution, visible=False), gr.Button(gettext("ShowSolution")), gr.Checkbox(False, visible=False) | |
else: | |
return gr.Markdown(correct_solution, visible=True), gr.Button(gettext("HideSolution")), gr.Checkbox(True, visible=False) | |
def show_models_solutions(models_solutions_shown, btn_show_models_solutions, gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator): | |
if models_solutions_shown: | |
return gr.Markdown(gpt4_solution, visible=False), gr.Markdown(claude_solution, visible=False), gr.Markdown(llama3_70b_solution, visible=False), gr.Markdown(qwen_72b_solution, visible=False), gr.Markdown(llama3_1_8b_solution, visible=False), gr.Markdown(phi3_mini_solution, visible=False), gr.Markdown(gemma2_solution, visible=False), gr.Markdown(prompted_models, visible=False), gr.Markdown(trained_models, visible=False), gr.Markdown(models_separator, visible=False), gr.Button(gettext("ShowModelsSolutions")), gr.Checkbox(False, visible=False) | |
else: | |
return gr.Markdown(gpt4_solution, visible=True), gr.Markdown(claude_solution, visible=True), gr.Markdown(llama3_70b_solution, visible=True), gr.Markdown(qwen_72b_solution, visible=True), gr.Markdown(llama3_1_8b_solution, visible=True), gr.Markdown(phi3_mini_solution, visible=True), gr.Markdown(gemma2_solution, visible=True), gr.Markdown(prompted_models, visible=True), gr.Markdown(trained_models, visible=True), gr.Markdown(models_separator, visible=True), gr.Button(gettext("HideModelsSolutions")), gr.Checkbox(True, visible=False) | |
for answer in answers: | |
answer.change(update_fp, [fp_empty, key_value, *answers], [fp, solution, *solution_words]) | |
btn_check.click(check_solution, [solution, correct_solution], None) | |
btn_show.click(show_solution, [correct_solution, btn_show, correct_solution_shown], [correct_solution, btn_show, correct_solution_shown]) | |
btn_show_models_solutions.click(show_models_solutions, [models_solutions_shown, btn_show_models_solutions, gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator], [gpt4_solution, claude_solution, llama3_70b_solution, qwen_72b_solution, llama3_1_8b_solution, phi3_mini_solution, gemma2_solution, prompted_models, trained_models, models_separator, btn_show_models_solutions, models_solutions_shown]) | |
with gr.Tab(gettext("ModelEvaluation")): | |
gr.Markdown("<i>This section is under construction! Check again later 🙏</i>") | |
demo.launch(show_api=False) | |