|
import gradio as gr |
|
import json |
|
from tqdm import tqdm |
|
import numpy as np |
|
import random |
|
import torch |
|
import ast |
|
from difflib import HtmlDiff |
|
|
|
from src.kg.main import script2kg |
|
from src.summary.summarizer import Summarizer |
|
from src.summary.utils import preprocess_script, chunk_script_gpt |
|
from src.summary.prompt import build_summarizer_prompt |
|
from src.fact.narrativefactscore import NarrativeFactScore |
|
|
|
def _set_seed(seed): |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
def parse_scenes(scene_text): |
|
try: |
|
return json.loads(scene_text) |
|
except json.JSONDecodeError: |
|
return ast.literal_eval(scene_text) |
|
|
|
def set_name_list(dataset, data_type): |
|
if dataset == "MovieSum": |
|
if data_type == "train": |
|
return ['8MM_1999', 'The Iron Lady_2011', 'Adventureland_2009', 'Napoleon_2023', |
|
'Kubo and the Two Strings_2016', 'The Woman King_2022', 'What They Had_2018', |
|
'Synecdoche, New York_2008', 'Black Christmas_2006', 'Superbad_2007'] |
|
elif data_type == "validation": |
|
return ['The Boondock Saints_1999', 'The House with a Clock in Its Walls_2018', |
|
'The Unbelievable Truth_1989', 'Insidious_2010', 'If Beale Street Could Talk_2018', |
|
'The Battle of Shaker Heights_2003', '20th Century Women_2016', |
|
'Captain Phillips_2013', 'Conspiracy Theory_1997', 'Domino_2005'] |
|
elif data_type == "test": |
|
|
|
return ['A Nightmare on Elm Street 3: Dream Warriors_1987', 'Van Helsing_2004', |
|
'Oppenheimer_2023', 'Armored_2009', 'The Martian_2015'] |
|
elif dataset == "MENSA": |
|
if data_type == "train": |
|
return ['The_Ides_of_March_(film)', 'An_American_Werewolf_in_Paris', |
|
'Batman_&_Robin_(film)', 'Airplane_II:_The_Sequel', 'Krull_(film)'] |
|
elif data_type == "validation": |
|
return ['Pleasantville_(film)', 'V_for_Vendetta_(film)', |
|
'Mary_Shelleys_Frankenstein_(film)', 'Rapture_(1965_film)', 'Get_Out'] |
|
elif data_type == "test": |
|
return ['Knives_Out', 'Black_Panther', 'Pet_Sematary_(film)', |
|
'Panic_Room', 'The_Village_(2004_film)'] |
|
return [] |
|
|
|
def update_name_list_interface(dataset, data_type): |
|
if dataset in ["MovieSum", "MENSA"]: |
|
return ( |
|
gr.update(choices=set_name_list(dataset, data_type), value=None, visible=True), |
|
gr.update(visible=False), |
|
gr.update(value="") |
|
) |
|
else: |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(value="Click next 'Knowledge Graph' to continue") |
|
) |
|
|
|
def read_data(dataset, data_type): |
|
file_path = f"dataset/{dataset}/{data_type}.jsonl" |
|
try: |
|
with open(file_path, 'r', encoding='utf8') as f: |
|
data = [json.loads(line) for line in f] |
|
return data |
|
except FileNotFoundError: |
|
return [] |
|
|
|
def find_work_index(data, work_name): |
|
for idx, entry in enumerate(data): |
|
if entry.get("name") == work_name: |
|
return idx, entry |
|
return None, "Work not found in the selected dataset." |
|
|
|
def get_narrative_content(dataset, data_type, work): |
|
data = read_data(dataset, data_type) |
|
for entry in data: |
|
if entry.get("name") == work: |
|
return entry['scenes'] |
|
return "Work not found in the selected dataset." |
|
|
|
def get_narrative_content_with_index(dataset, data_type, work): |
|
data = read_data(dataset, data_type) |
|
for idx, entry in enumerate(data): |
|
if entry.get("name") == work: |
|
|
|
if dataset in ["MovieSum", "MENSA"]: |
|
return "\n".join(entry['scenes']), idx, data |
|
|
|
return entry, idx, data |
|
return "Work not found in the selected dataset.", None, None |
|
|
|
def show_diff(original, revised): |
|
d = HtmlDiff() |
|
original_lines = original.splitlines(keepends=True) |
|
revised_lines = revised.splitlines(keepends=True) |
|
diff_table = d.make_table(original_lines, revised_lines, fromdesc='Original Summary', todesc='Refined Summary', context=True, numlines=2) |
|
return diff_table |
|
|
|
def extract_initial_summary(summary_result): |
|
return summary_result['summary_agg']['summaries'] |
|
|
|
def extract_factuality_score_and_details(fact_score_result): |
|
factuality_score = fact_score_result['fact_score'] |
|
feedback_list = [] |
|
for i, feedback_data in enumerate(fact_score_result['summary_feedback_pairs']): |
|
feedbacks = [fb for fb in feedback_data['feedbacks'] if fb.strip()] |
|
if feedbacks: |
|
feedback_list.append(f"In chunk {i + 1}: {'; '.join(feedbacks)}") |
|
incorrect_details = "\n".join(feedback_list) |
|
return factuality_score, incorrect_details |
|
|
|
def build_kg(script, idx, api_key, model_id): |
|
kg = script2kg(script['scenes'], idx, script['name'], api_key, model_id) |
|
return kg |
|
|
|
def build_kg_custom(scenes, idx, api_key, model_id): |
|
kg = script2kg(scenes, idx, "custom", api_key, model_id) |
|
return kg |
|
|
|
def build_kg_with_data(data, work_index, custom_scenes, api_key, model_id): |
|
if data and work_index is not None: |
|
script = data[int(work_index)] |
|
try: |
|
kg = script2kg(script['scenes'], int(work_index), script['name'], api_key, model_id) |
|
return kg, "Knowledge Graph built successfully!" |
|
except Exception as e: |
|
return None, f"Error building knowledge graph: {str(e)}" |
|
elif custom_scenes: |
|
try: |
|
scenes = parse_scenes(custom_scenes) |
|
if not isinstance(scenes, list): |
|
return None, "Invalid format. Please provide scenes as a list." |
|
kg = build_kg_custom(scenes, 0, api_key, model_id) |
|
return kg, "Knowledge Graph built successfully!" |
|
except (json.JSONDecodeError, SyntaxError, ValueError) as e: |
|
return None, f"Invalid format. Error: {str(e)}" |
|
except Exception as e: |
|
return None, f"Error building knowledge graph: {str(e)}" |
|
return None, "Please select a work or input custom scenes." |
|
|
|
def generate_summary(script, idx, api_key, model_id): |
|
_set_seed(42) |
|
scripty_summarizer = Summarizer( |
|
inference_mode="org", |
|
model_id=model_id, |
|
api_key=api_key, |
|
dtype="float16", |
|
seed=42 |
|
) |
|
scenes = [f"s#{i}\n{s}" for i, s in enumerate(script['scenes'])] |
|
script = "\n\n".join(scenes) |
|
script_chunks = chunk_script_gpt(script=script, model_id=model_id, chunk_size=2048) |
|
|
|
script_summaries = [] |
|
for chunk in tqdm(script_chunks): |
|
chunk = preprocess_script(chunk) |
|
prompt = build_summarizer_prompt( |
|
prompt_template="./templates/external_summary.txt", |
|
input_text_list=[chunk] |
|
) |
|
script_summ = scripty_summarizer.inference_with_gpt(prompt=prompt) |
|
script_summaries.append(script_summ.strip()) |
|
|
|
elem_dict_list = [] |
|
agg_dict = { |
|
'script': ' '.join(script_chunks), |
|
'summaries': ' '.join(script_summaries) |
|
} |
|
|
|
for i, (chunk, summary) in enumerate(zip(script_chunks, script_summaries)): |
|
elem_dict = { |
|
"chunk_index": i, |
|
"chunk": chunk.strip(), |
|
"summary": summary.strip() |
|
} |
|
elem_dict_list.append(elem_dict) |
|
|
|
processed_dataset = { |
|
"script": script, |
|
"scenes": scenes, |
|
"script_chunks": script_chunks, |
|
"script_summaries": script_summaries, |
|
} |
|
|
|
return {"summary_sep": elem_dict_list, "summary_agg": agg_dict, "processed_dataset": processed_dataset} |
|
|
|
def generate_summary_with_data(data, work_index, custom_scenes, api_key, model_id): |
|
if data and work_index is not None: |
|
script = data[int(work_index)] |
|
try: |
|
summary = generate_summary(script, int(work_index), api_key, model_id) |
|
return summary, extract_initial_summary(summary) |
|
except Exception as e: |
|
return None, f"Error generating summary: {str(e)}" |
|
elif custom_scenes: |
|
try: |
|
scenes = parse_scenes(custom_scenes) |
|
if not isinstance(scenes, list): |
|
return None, "Invalid format. Please provide scenes as a list." |
|
script = {"name": "custom", "scenes": scenes} |
|
summary = generate_summary(script, 0, api_key, model_id) |
|
return summary, extract_initial_summary(summary) |
|
except (json.JSONDecodeError, SyntaxError, ValueError) as e: |
|
return None, f"Invalid format. Error: {str(e)}" |
|
except Exception as e: |
|
return None, f"Error generating summary: {str(e)}" |
|
return None, "Please select a work or input custom scenes." |
|
|
|
def calculate_narrative_fact_score(summary, kg_raw, api_key, model_id): |
|
_set_seed(42) |
|
factscorer = NarrativeFactScore(split_type='gpt', model='gptscore', api_key=api_key, model_id=model_id) |
|
|
|
summary = summary['processed_dataset'] |
|
chunks, summaries = summary['script_chunks'], summary['script_summaries'] |
|
total_output = {'fact_score': 0, 'summary_feedback_pairs': []} |
|
partial_output = {'fact_score': 0, 'summary_feedback_pairs': []} |
|
total_score = 0 |
|
kg = [] |
|
for elem in kg_raw: |
|
if elem['subject'] == elem['object']: |
|
kg.append(f"{elem['subject']} {elem['predicate']}") |
|
else: |
|
kg.append(f"{elem['subject']} {elem['predicate']} {elem['object']}") |
|
|
|
scores, scores_per_sent, relevant_scenes, summary_chunks, feedbacks = factscorer.score_src_hyp_long(chunks, summaries, kg) |
|
for i, score in enumerate(scores): |
|
output_elem = { |
|
'src': chunks[i], |
|
'summary': summaries[i], |
|
'score': score, |
|
'scores_per_sent': scores_per_sent[i], |
|
'relevant_scenes': relevant_scenes[i], |
|
'summary_chunks': summary_chunks[i], |
|
'feedbacks': feedbacks[i], |
|
} |
|
output_elem_part = { |
|
'scores_per_sent': scores_per_sent[i], |
|
'summary_chunks': summary_chunks[i], |
|
'feedbacks': feedbacks[i], |
|
} |
|
total_output['summary_feedback_pairs'].append(output_elem) |
|
partial_output['summary_feedback_pairs'].append(output_elem_part) |
|
total_score += score |
|
|
|
total_output['fact_score'] = float(total_score / len(scores)) |
|
partial_output['fact_score'] = float(total_score / len(scores)) |
|
return total_output, partial_output |
|
|
|
def refine_summary(summary, fact_score, api_key, model_id): |
|
_set_seed(42) |
|
threshold = 0.9 |
|
summarizer = Summarizer( |
|
inference_mode="org", |
|
model_id=model_id, |
|
api_key=api_key, |
|
dtype="float16", |
|
seed=42 |
|
) |
|
|
|
processed_dataset = { |
|
"script": summary["script"], |
|
"scenes": summary["scenes"], |
|
"script_chunks": [], |
|
"script_summaries": [] |
|
} |
|
elem_dict_list = [] |
|
agg_dict = {} |
|
|
|
for factscore_chunk in tqdm(fact_score['summary_feedback_pairs']): |
|
src_chunk = factscore_chunk['src'] |
|
original_summary = factscore_chunk['summary'] |
|
|
|
if factscore_chunk['score'] >= threshold: |
|
processed_dataset["script_chunks"].append(src_chunk) |
|
processed_dataset["script_summaries"].append(original_summary.strip()) |
|
continue |
|
|
|
hallu_idxs = np.where(np.array(factscore_chunk['scores_per_sent']) == 0)[0] |
|
hallu_summary_parts = np.array(factscore_chunk['summary_chunks'])[hallu_idxs] |
|
feedbacks = np.array(factscore_chunk['feedbacks'])[hallu_idxs] |
|
|
|
prompt = build_summarizer_prompt( |
|
prompt_template="./templates/self_correction.txt", |
|
input_text_list=[src_chunk, original_summary] |
|
) |
|
|
|
for j, (hallu_summ, feedback) in enumerate(zip(hallu_summary_parts, feedbacks)): |
|
prompt += f"\n- Statement to Revise {j + 1}: {hallu_summ} (Reason for Revision: {feedback})" |
|
prompt += "\n- Revised Summary: " |
|
|
|
revised_summary = summarizer.inference_with_gpt(prompt=prompt) |
|
|
|
if len(revised_summary.strip()) == 0: |
|
revised_summary = original_summary |
|
|
|
processed_dataset["script_chunks"].append(src_chunk) |
|
processed_dataset["script_summaries"].append(revised_summary) |
|
|
|
elem_dict = { |
|
"chunk_index": len(processed_dataset["script_chunks"]) - 1, |
|
"chunk": src_chunk.strip(), |
|
"summary": revised_summary.strip(), |
|
"org_summary": original_summary.strip(), |
|
"hallu_in_summary": list(hallu_summary_parts), |
|
"feedbacks": list(feedbacks), |
|
} |
|
elem_dict_list.append(elem_dict) |
|
|
|
agg_dict['script'] = summary['script'] |
|
agg_dict['summaries'] = ' '.join(processed_dataset["script_summaries"]) |
|
|
|
return { |
|
"summary_sep": elem_dict_list, |
|
"summary_agg": agg_dict, |
|
"processed_dataset": processed_dataset |
|
} |
|
|
|
def refine_summary_and_return_diff(summary, fact_score, api_key, model_id): |
|
refined_summary = refine_summary(summary['processed_dataset'], fact_score, api_key, model_id) |
|
diff = HtmlDiff().make_file( |
|
summary['summary_agg']['summaries'].splitlines(), |
|
refined_summary['summary_agg']['summaries'].splitlines(), |
|
context=True |
|
) |
|
return diff |
|
|
|
def open_kg(kg_data): |
|
if kg_data is None: |
|
return "Please build the knowledge graph first." |
|
try: |
|
with open('refined_kg.html', 'r', encoding='utf-8') as f: |
|
html_content = f.read() |
|
return f''' |
|
<iframe |
|
srcdoc="{html_content.replace('"', '"')}" |
|
style="width: 100%; height: 500px; border: none;" |
|
></iframe> |
|
''' |
|
except Exception as e: |
|
return f'<div style="color: red;">Error reading KG file: {str(e)}</div>' |
|
|
|
def format_fact_score_output(fact_score_result): |
|
if not fact_score_result: |
|
return "No factuality analysis available" |
|
|
|
formatted_output = [] |
|
|
|
|
|
formatted_output.append(f"Overall Factuality Score: {fact_score_result['fact_score']*100:.1f}%\n") |
|
|
|
|
|
for i, chunk in enumerate(fact_score_result['summary_feedback_pairs'], 1): |
|
formatted_output.append(f"\nChunk {i} Analysis:") |
|
formatted_output.append("Original Text:") |
|
formatted_output.append(f"{' '.join(chunk['summary_chunks'])}\n") |
|
|
|
if chunk['feedbacks']: |
|
formatted_output.append("Feedback:") |
|
feedbacks = [f"• {feedback}" for feedback in chunk['feedbacks'] if feedback.strip()] |
|
formatted_output.extend(feedbacks) |
|
|
|
formatted_output.append("-" * 80) |
|
|
|
return "\n".join(formatted_output) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
# NarrativeFactScore: Script Factuality Evaluation |
|
Evaluate and refine script summaries using narrative factuality scoring. |
|
""" |
|
) |
|
|
|
with gr.Accordion("Model Settings", open=True): |
|
with gr.Row(): |
|
api_key_input = gr.Textbox( |
|
label="GPT API Key", |
|
placeholder="Enter your GPT API key", |
|
type="password", |
|
scale=2 |
|
) |
|
model_selector = gr.Dropdown( |
|
choices=[ |
|
"gpt-4o-mini", |
|
"gpt-4o", |
|
"gpt-4-turbo", |
|
"gpt-3.5-turbo-0125" |
|
], |
|
value="gpt-4o", |
|
label="Model Selection", |
|
scale=1 |
|
) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Dataset Selection"): |
|
with gr.Row(): |
|
dataset_selector = gr.Radio( |
|
choices=["MovieSum", "MENSA", "Custom"], |
|
label="Dataset", |
|
info="Choose the dataset or input custom script" |
|
) |
|
data_type_selector = gr.Radio( |
|
choices=["train", "validation", "test"], |
|
label="Split Type", |
|
info="Select data split", |
|
visible=True |
|
) |
|
name_list = gr.Dropdown( |
|
choices=[], |
|
label="Select Script", |
|
info="Choose a script to analyze", |
|
visible=True |
|
) |
|
custom_input = gr.Textbox( |
|
label="Custom Script Input", |
|
info="Enter scenes as a JSON list: ['scene1', 'scene2', ...]", |
|
lines=10, |
|
visible=False |
|
) |
|
narrative_output = gr.Textbox( |
|
label="Script Content", |
|
interactive=False, |
|
lines=10 |
|
) |
|
|
|
with gr.TabItem("Knowledge Graph"): |
|
with gr.Row(): |
|
generate_kg_button = gr.Button( |
|
"Generate Knowledge Graph", |
|
variant="primary" |
|
) |
|
open_kg_button = gr.Button("View Graph") |
|
kg_status = gr.Textbox( |
|
label="Status", |
|
interactive=False |
|
) |
|
kg_viewer = gr.HTML(label="Knowledge Graph Visualization") |
|
|
|
with gr.TabItem("Summary Generation"): |
|
generate_summary_button = gr.Button( |
|
"Generate Initial Summary", |
|
variant="primary" |
|
) |
|
summary_output = gr.Textbox( |
|
label="Generated Summary", |
|
interactive=False, |
|
lines=5 |
|
) |
|
calculate_score_button = gr.Button("Calculate Factuality Score") |
|
fact_score_display = gr.Textbox( |
|
label="Factuality Analysis", |
|
interactive=False, |
|
lines=10 |
|
) |
|
|
|
with gr.TabItem("Summary Refinement"): |
|
refine_button = gr.Button( |
|
"Refine Summary", |
|
variant="primary" |
|
) |
|
refined_output = gr.HTML(label="Refined Summary with Changes") |
|
|
|
|
|
work_index = gr.State() |
|
data_state = gr.State() |
|
kg_output = gr.State() |
|
summary_state = gr.State() |
|
fact_score_state = gr.State() |
|
|
|
|
|
dataset_selector.change( |
|
fn=lambda x: gr.update(visible=x in ["MovieSum", "MENSA"]), |
|
inputs=[dataset_selector], |
|
outputs=data_type_selector |
|
) |
|
|
|
dataset_selector.change( |
|
fn=update_name_list_interface, |
|
inputs=[dataset_selector, data_type_selector], |
|
outputs=[name_list, custom_input, narrative_output] |
|
) |
|
|
|
name_list.change( |
|
fn=get_narrative_content_with_index, |
|
inputs=[dataset_selector, data_type_selector, name_list], |
|
outputs=[narrative_output, work_index, data_state] |
|
) |
|
|
|
generate_kg_button.click( |
|
fn=build_kg_with_data, |
|
inputs=[ |
|
data_state, |
|
work_index, |
|
custom_input, |
|
api_key_input, |
|
model_selector |
|
], |
|
outputs=[kg_output, kg_status] |
|
) |
|
|
|
open_kg_button.click( |
|
fn=open_kg, |
|
inputs=[kg_output], |
|
outputs=kg_viewer |
|
) |
|
|
|
generate_summary_button.click( |
|
fn=generate_summary_with_data, |
|
inputs=[data_state, work_index, custom_input, api_key_input, model_selector], |
|
outputs=[summary_state, summary_output] |
|
) |
|
|
|
calculate_score_button.click( |
|
fn=lambda summary, kg, api_key, model: ( |
|
*calculate_narrative_fact_score(summary, kg, api_key, model), |
|
format_fact_score_output(calculate_narrative_fact_score(summary, kg, api_key, model)[0]) |
|
), |
|
inputs=[summary_state, kg_output, api_key_input, model_selector], |
|
outputs=[fact_score_state, fact_score_display] |
|
) |
|
|
|
refine_button.click( |
|
fn=refine_summary_and_return_diff, |
|
inputs=[summary_state, fact_score_state, api_key_input, model_selector], |
|
outputs=refined_output |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |