JihyukKim's picture
Initial commit
eaa3d8a
import gradio as gr
import json
from tqdm import tqdm
import numpy as np
import random
import torch
import ast
from difflib import HtmlDiff
from src.kg.main import script2kg
from src.summary.summarizer import Summarizer
from src.summary.utils import preprocess_script, chunk_script_gpt
from src.summary.prompt import build_summarizer_prompt
from src.fact.narrativefactscore import NarrativeFactScore
def _set_seed(seed):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def parse_scenes(scene_text):
try:
return json.loads(scene_text)
except json.JSONDecodeError:
return ast.literal_eval(scene_text)
def set_name_list(dataset, data_type):
if dataset == "MovieSum":
if data_type == "train":
return ['8MM_1999', 'The Iron Lady_2011', 'Adventureland_2009', 'Napoleon_2023',
'Kubo and the Two Strings_2016', 'The Woman King_2022', 'What They Had_2018',
'Synecdoche, New York_2008', 'Black Christmas_2006', 'Superbad_2007']
elif data_type == "validation":
return ['The Boondock Saints_1999', 'The House with a Clock in Its Walls_2018',
'The Unbelievable Truth_1989', 'Insidious_2010', 'If Beale Street Could Talk_2018',
'The Battle of Shaker Heights_2003', '20th Century Women_2016',
'Captain Phillips_2013', 'Conspiracy Theory_1997', 'Domino_2005']
elif data_type == "test":
# Return test dataset names (shortened for brevity)
return ['A Nightmare on Elm Street 3: Dream Warriors_1987', 'Van Helsing_2004',
'Oppenheimer_2023', 'Armored_2009', 'The Martian_2015']
elif dataset == "MENSA":
if data_type == "train":
return ['The_Ides_of_March_(film)', 'An_American_Werewolf_in_Paris',
'Batman_&_Robin_(film)', 'Airplane_II:_The_Sequel', 'Krull_(film)']
elif data_type == "validation":
return ['Pleasantville_(film)', 'V_for_Vendetta_(film)',
'Mary_Shelleys_Frankenstein_(film)', 'Rapture_(1965_film)', 'Get_Out']
elif data_type == "test":
return ['Knives_Out', 'Black_Panther', 'Pet_Sematary_(film)',
'Panic_Room', 'The_Village_(2004_film)']
return []
def update_name_list_interface(dataset, data_type):
if dataset in ["MovieSum", "MENSA"]:
return (
gr.update(choices=set_name_list(dataset, data_type), value=None, visible=True),
gr.update(visible=False),
gr.update(value="")
)
else:
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(value="Click next 'Knowledge Graph' to continue")
)
def read_data(dataset, data_type):
file_path = f"dataset/{dataset}/{data_type}.jsonl"
try:
with open(file_path, 'r', encoding='utf8') as f:
data = [json.loads(line) for line in f]
return data
except FileNotFoundError:
return []
def find_work_index(data, work_name):
for idx, entry in enumerate(data):
if entry.get("name") == work_name:
return idx, entry
return None, "Work not found in the selected dataset."
def get_narrative_content(dataset, data_type, work):
data = read_data(dataset, data_type)
for entry in data:
if entry.get("name") == work:
return entry['scenes']
return "Work not found in the selected dataset."
def get_narrative_content_with_index(dataset, data_type, work):
data = read_data(dataset, data_type)
for idx, entry in enumerate(data):
if entry.get("name") == work:
# For MovieSum and MENSA datasets, only return scenes
if dataset in ["MovieSum", "MENSA"]:
return "\n".join(entry['scenes']), idx, data
# For other datasets or custom input, return full content
return entry, idx, data
return "Work not found in the selected dataset.", None, None
def show_diff(original, revised):
d = HtmlDiff()
original_lines = original.splitlines(keepends=True)
revised_lines = revised.splitlines(keepends=True)
diff_table = d.make_table(original_lines, revised_lines, fromdesc='Original Summary', todesc='Refined Summary', context=True, numlines=2)
return diff_table
def extract_initial_summary(summary_result):
return summary_result['summary_agg']['summaries']
def extract_factuality_score_and_details(fact_score_result):
factuality_score = fact_score_result['fact_score']
feedback_list = []
for i, feedback_data in enumerate(fact_score_result['summary_feedback_pairs']):
feedbacks = [fb for fb in feedback_data['feedbacks'] if fb.strip()]
if feedbacks:
feedback_list.append(f"In chunk {i + 1}: {'; '.join(feedbacks)}")
incorrect_details = "\n".join(feedback_list)
return factuality_score, incorrect_details
def build_kg(script, idx, api_key, model_id):
kg = script2kg(script['scenes'], idx, script['name'], api_key, model_id)
return kg
def build_kg_custom(scenes, idx, api_key, model_id):
kg = script2kg(scenes, idx, "custom", api_key, model_id)
return kg
def build_kg_with_data(data, work_index, custom_scenes, api_key, model_id):
if data and work_index is not None: # Dataset mode
script = data[int(work_index)]
try:
kg = script2kg(script['scenes'], int(work_index), script['name'], api_key, model_id)
return kg, "Knowledge Graph built successfully!"
except Exception as e:
return None, f"Error building knowledge graph: {str(e)}"
elif custom_scenes: # Custom script mode
try:
scenes = parse_scenes(custom_scenes)
if not isinstance(scenes, list):
return None, "Invalid format. Please provide scenes as a list."
kg = build_kg_custom(scenes, 0, api_key, model_id)
return kg, "Knowledge Graph built successfully!"
except (json.JSONDecodeError, SyntaxError, ValueError) as e:
return None, f"Invalid format. Error: {str(e)}"
except Exception as e:
return None, f"Error building knowledge graph: {str(e)}"
return None, "Please select a work or input custom scenes."
def generate_summary(script, idx, api_key, model_id):
_set_seed(42)
scripty_summarizer = Summarizer(
inference_mode="org",
model_id=model_id,
api_key=api_key,
dtype="float16",
seed=42
)
scenes = [f"s#{i}\n{s}" for i, s in enumerate(script['scenes'])]
script = "\n\n".join(scenes)
script_chunks = chunk_script_gpt(script=script, model_id=model_id, chunk_size=2048)
script_summaries = []
for chunk in tqdm(script_chunks):
chunk = preprocess_script(chunk)
prompt = build_summarizer_prompt(
prompt_template="./templates/external_summary.txt",
input_text_list=[chunk]
)
script_summ = scripty_summarizer.inference_with_gpt(prompt=prompt)
script_summaries.append(script_summ.strip())
elem_dict_list = []
agg_dict = {
'script': ' '.join(script_chunks),
'summaries': ' '.join(script_summaries)
}
for i, (chunk, summary) in enumerate(zip(script_chunks, script_summaries)):
elem_dict = {
"chunk_index": i,
"chunk": chunk.strip(),
"summary": summary.strip()
}
elem_dict_list.append(elem_dict)
processed_dataset = {
"script": script,
"scenes": scenes,
"script_chunks": script_chunks,
"script_summaries": script_summaries,
}
return {"summary_sep": elem_dict_list, "summary_agg": agg_dict, "processed_dataset": processed_dataset}
def generate_summary_with_data(data, work_index, custom_scenes, api_key, model_id):
if data and work_index is not None: # Dataset mode
script = data[int(work_index)]
try:
summary = generate_summary(script, int(work_index), api_key, model_id)
return summary, extract_initial_summary(summary)
except Exception as e:
return None, f"Error generating summary: {str(e)}"
elif custom_scenes: # Custom script mode
try:
scenes = parse_scenes(custom_scenes)
if not isinstance(scenes, list):
return None, "Invalid format. Please provide scenes as a list."
script = {"name": "custom", "scenes": scenes}
summary = generate_summary(script, 0, api_key, model_id)
return summary, extract_initial_summary(summary)
except (json.JSONDecodeError, SyntaxError, ValueError) as e:
return None, f"Invalid format. Error: {str(e)}"
except Exception as e:
return None, f"Error generating summary: {str(e)}"
return None, "Please select a work or input custom scenes."
def calculate_narrative_fact_score(summary, kg_raw, api_key, model_id):
_set_seed(42)
factscorer = NarrativeFactScore(split_type='gpt', model='gptscore', api_key=api_key, model_id=model_id)
summary = summary['processed_dataset']
chunks, summaries = summary['script_chunks'], summary['script_summaries']
total_output = {'fact_score': 0, 'summary_feedback_pairs': []}
partial_output = {'fact_score': 0, 'summary_feedback_pairs': []}
total_score = 0
kg = []
for elem in kg_raw:
if elem['subject'] == elem['object']:
kg.append(f"{elem['subject']} {elem['predicate']}")
else:
kg.append(f"{elem['subject']} {elem['predicate']} {elem['object']}")
scores, scores_per_sent, relevant_scenes, summary_chunks, feedbacks = factscorer.score_src_hyp_long(chunks, summaries, kg)
for i, score in enumerate(scores):
output_elem = {
'src': chunks[i],
'summary': summaries[i],
'score': score,
'scores_per_sent': scores_per_sent[i],
'relevant_scenes': relevant_scenes[i],
'summary_chunks': summary_chunks[i],
'feedbacks': feedbacks[i],
}
output_elem_part = {
'scores_per_sent': scores_per_sent[i],
'summary_chunks': summary_chunks[i],
'feedbacks': feedbacks[i],
}
total_output['summary_feedback_pairs'].append(output_elem)
partial_output['summary_feedback_pairs'].append(output_elem_part)
total_score += score
total_output['fact_score'] = float(total_score / len(scores))
partial_output['fact_score'] = float(total_score / len(scores))
return total_output, partial_output
def refine_summary(summary, fact_score, api_key, model_id):
_set_seed(42)
threshold = 0.9
summarizer = Summarizer(
inference_mode="org",
model_id=model_id,
api_key=api_key,
dtype="float16",
seed=42
)
processed_dataset = {
"script": summary["script"],
"scenes": summary["scenes"],
"script_chunks": [],
"script_summaries": []
}
elem_dict_list = []
agg_dict = {}
for factscore_chunk in tqdm(fact_score['summary_feedback_pairs']):
src_chunk = factscore_chunk['src']
original_summary = factscore_chunk['summary']
if factscore_chunk['score'] >= threshold:
processed_dataset["script_chunks"].append(src_chunk)
processed_dataset["script_summaries"].append(original_summary.strip())
continue
hallu_idxs = np.where(np.array(factscore_chunk['scores_per_sent']) == 0)[0]
hallu_summary_parts = np.array(factscore_chunk['summary_chunks'])[hallu_idxs]
feedbacks = np.array(factscore_chunk['feedbacks'])[hallu_idxs]
prompt = build_summarizer_prompt(
prompt_template="./templates/self_correction.txt",
input_text_list=[src_chunk, original_summary]
)
for j, (hallu_summ, feedback) in enumerate(zip(hallu_summary_parts, feedbacks)):
prompt += f"\n- Statement to Revise {j + 1}: {hallu_summ} (Reason for Revision: {feedback})"
prompt += "\n- Revised Summary: "
revised_summary = summarizer.inference_with_gpt(prompt=prompt)
if len(revised_summary.strip()) == 0:
revised_summary = original_summary
processed_dataset["script_chunks"].append(src_chunk)
processed_dataset["script_summaries"].append(revised_summary)
elem_dict = {
"chunk_index": len(processed_dataset["script_chunks"]) - 1,
"chunk": src_chunk.strip(),
"summary": revised_summary.strip(),
"org_summary": original_summary.strip(),
"hallu_in_summary": list(hallu_summary_parts),
"feedbacks": list(feedbacks),
}
elem_dict_list.append(elem_dict)
agg_dict['script'] = summary['script']
agg_dict['summaries'] = ' '.join(processed_dataset["script_summaries"])
return {
"summary_sep": elem_dict_list,
"summary_agg": agg_dict,
"processed_dataset": processed_dataset
}
def refine_summary_and_return_diff(summary, fact_score, api_key, model_id):
refined_summary = refine_summary(summary['processed_dataset'], fact_score, api_key, model_id)
diff = HtmlDiff().make_file(
summary['summary_agg']['summaries'].splitlines(),
refined_summary['summary_agg']['summaries'].splitlines(),
context=True
)
return diff
def open_kg(kg_data):
if kg_data is None:
return "Please build the knowledge graph first."
try:
with open('refined_kg.html', 'r', encoding='utf-8') as f:
html_content = f.read()
return f'''
<iframe
srcdoc="{html_content.replace('"', '&quot;')}"
style="width: 100%; height: 500px; border: none;"
></iframe>
'''
except Exception as e:
return f'<div style="color: red;">Error reading KG file: {str(e)}</div>'
def format_fact_score_output(fact_score_result):
if not fact_score_result:
return "No factuality analysis available"
formatted_output = []
# Overall score
formatted_output.append(f"Overall Factuality Score: {fact_score_result['fact_score']*100:.1f}%\n")
# Individual chunk analysis
for i, chunk in enumerate(fact_score_result['summary_feedback_pairs'], 1):
formatted_output.append(f"\nChunk {i} Analysis:")
formatted_output.append("Original Text:")
formatted_output.append(f"{' '.join(chunk['summary_chunks'])}\n")
if chunk['feedbacks']:
formatted_output.append("Feedback:")
feedbacks = [f"• {feedback}" for feedback in chunk['feedbacks'] if feedback.strip()]
formatted_output.extend(feedbacks)
formatted_output.append("-" * 80)
return "\n".join(formatted_output)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# NarrativeFactScore: Script Factuality Evaluation
Evaluate and refine script summaries using narrative factuality scoring.
"""
)
with gr.Accordion("Model Settings", open=True):
with gr.Row():
api_key_input = gr.Textbox(
label="GPT API Key",
placeholder="Enter your GPT API key",
type="password",
scale=2
)
model_selector = gr.Dropdown(
choices=[
"gpt-4o-mini",
"gpt-4o",
"gpt-4-turbo",
"gpt-3.5-turbo-0125"
],
value="gpt-4o",
label="Model Selection",
scale=1
)
with gr.Tabs():
with gr.TabItem("Dataset Selection"):
with gr.Row():
dataset_selector = gr.Radio(
choices=["MovieSum", "MENSA", "Custom"],
label="Dataset",
info="Choose the dataset or input custom script"
)
data_type_selector = gr.Radio(
choices=["train", "validation", "test"],
label="Split Type",
info="Select data split",
visible=True
)
name_list = gr.Dropdown(
choices=[],
label="Select Script",
info="Choose a script to analyze",
visible=True
)
custom_input = gr.Textbox(
label="Custom Script Input",
info="Enter scenes as a JSON list: ['scene1', 'scene2', ...]",
lines=10,
visible=False
)
narrative_output = gr.Textbox(
label="Script Content",
interactive=False,
lines=10
)
with gr.TabItem("Knowledge Graph"):
with gr.Row():
generate_kg_button = gr.Button(
"Generate Knowledge Graph",
variant="primary"
)
open_kg_button = gr.Button("View Graph")
kg_status = gr.Textbox(
label="Status",
interactive=False
)
kg_viewer = gr.HTML(label="Knowledge Graph Visualization")
with gr.TabItem("Summary Generation"):
generate_summary_button = gr.Button(
"Generate Initial Summary",
variant="primary"
)
summary_output = gr.Textbox(
label="Generated Summary",
interactive=False,
lines=5
)
calculate_score_button = gr.Button("Calculate Factuality Score")
fact_score_display = gr.Textbox(
label="Factuality Analysis",
interactive=False,
lines=10
)
with gr.TabItem("Summary Refinement"):
refine_button = gr.Button(
"Refine Summary",
variant="primary"
)
refined_output = gr.HTML(label="Refined Summary with Changes")
# Hidden states
work_index = gr.State()
data_state = gr.State()
kg_output = gr.State()
summary_state = gr.State()
fact_score_state = gr.State()
# Event handlers
dataset_selector.change(
fn=lambda x: gr.update(visible=x in ["MovieSum", "MENSA"]),
inputs=[dataset_selector],
outputs=data_type_selector
)
dataset_selector.change(
fn=update_name_list_interface,
inputs=[dataset_selector, data_type_selector],
outputs=[name_list, custom_input, narrative_output]
)
name_list.change(
fn=get_narrative_content_with_index,
inputs=[dataset_selector, data_type_selector, name_list],
outputs=[narrative_output, work_index, data_state]
)
generate_kg_button.click(
fn=build_kg_with_data,
inputs=[
data_state, # data
work_index, # work_index
custom_input, # custom_scenes
api_key_input, # api_key
model_selector # model_id
],
outputs=[kg_output, kg_status]
)
open_kg_button.click(
fn=open_kg,
inputs=[kg_output],
outputs=kg_viewer
)
generate_summary_button.click(
fn=generate_summary_with_data,
inputs=[data_state, work_index, custom_input, api_key_input, model_selector],
outputs=[summary_state, summary_output]
)
calculate_score_button.click(
fn=lambda summary, kg, api_key, model: (
*calculate_narrative_fact_score(summary, kg, api_key, model),
format_fact_score_output(calculate_narrative_fact_score(summary, kg, api_key, model)[0])
),
inputs=[summary_state, kg_output, api_key_input, model_selector],
outputs=[fact_score_state, fact_score_display]
)
refine_button.click(
fn=refine_summary_and_return_diff,
inputs=[summary_state, fact_score_state, api_key_input, model_selector],
outputs=refined_output
)
if __name__ == "__main__":
demo.launch()