import gradio as gr import json from tqdm import tqdm import numpy as np import random import torch import ast from difflib import HtmlDiff from src.kg.main import script2kg from src.summary.summarizer import Summarizer from src.summary.utils import preprocess_script, chunk_script_gpt from src.summary.prompt import build_summarizer_prompt from src.fact.narrativefactscore import NarrativeFactScore def _set_seed(seed): np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def parse_scenes(scene_text): try: return json.loads(scene_text) except json.JSONDecodeError: return ast.literal_eval(scene_text) def set_name_list(dataset, data_type): if dataset == "MovieSum": if data_type == "train": return ['8MM_1999', 'The Iron Lady_2011', 'Adventureland_2009', 'Napoleon_2023', 'Kubo and the Two Strings_2016', 'The Woman King_2022', 'What They Had_2018', 'Synecdoche, New York_2008', 'Black Christmas_2006', 'Superbad_2007'] elif data_type == "validation": return ['The Boondock Saints_1999', 'The House with a Clock in Its Walls_2018', 'The Unbelievable Truth_1989', 'Insidious_2010', 'If Beale Street Could Talk_2018', 'The Battle of Shaker Heights_2003', '20th Century Women_2016', 'Captain Phillips_2013', 'Conspiracy Theory_1997', 'Domino_2005'] elif data_type == "test": # Return test dataset names (shortened for brevity) return ['A Nightmare on Elm Street 3: Dream Warriors_1987', 'Van Helsing_2004', 'Oppenheimer_2023', 'Armored_2009', 'The Martian_2015'] elif dataset == "MENSA": if data_type == "train": return ['The_Ides_of_March_(film)', 'An_American_Werewolf_in_Paris', 'Batman_&_Robin_(film)', 'Airplane_II:_The_Sequel', 'Krull_(film)'] elif data_type == "validation": return ['Pleasantville_(film)', 'V_for_Vendetta_(film)', 'Mary_Shelleys_Frankenstein_(film)', 'Rapture_(1965_film)', 'Get_Out'] elif data_type == "test": return ['Knives_Out', 'Black_Panther', 'Pet_Sematary_(film)', 'Panic_Room', 'The_Village_(2004_film)'] return [] def update_name_list_interface(dataset, data_type): if dataset in ["MovieSum", "MENSA"]: return ( gr.update(choices=set_name_list(dataset, data_type), value=None, visible=True), gr.update(visible=False), gr.update(value="") ) else: return ( gr.update(visible=False), gr.update(visible=True), gr.update(value="Click next 'Knowledge Graph' to continue") ) def read_data(dataset, data_type): file_path = f"dataset/{dataset}/{data_type}.jsonl" try: with open(file_path, 'r', encoding='utf8') as f: data = [json.loads(line) for line in f] return data except FileNotFoundError: return [] def find_work_index(data, work_name): for idx, entry in enumerate(data): if entry.get("name") == work_name: return idx, entry return None, "Work not found in the selected dataset." def get_narrative_content(dataset, data_type, work): data = read_data(dataset, data_type) for entry in data: if entry.get("name") == work: return entry['scenes'] return "Work not found in the selected dataset." def get_narrative_content_with_index(dataset, data_type, work): data = read_data(dataset, data_type) for idx, entry in enumerate(data): if entry.get("name") == work: # For MovieSum and MENSA datasets, only return scenes if dataset in ["MovieSum", "MENSA"]: return "\n".join(entry['scenes']), idx, data # For other datasets or custom input, return full content return entry, idx, data return "Work not found in the selected dataset.", None, None def show_diff(original, revised): d = HtmlDiff() original_lines = original.splitlines(keepends=True) revised_lines = revised.splitlines(keepends=True) diff_table = d.make_table(original_lines, revised_lines, fromdesc='Original Summary', todesc='Refined Summary', context=True, numlines=2) return diff_table def extract_initial_summary(summary_result): return summary_result['summary_agg']['summaries'] def extract_factuality_score_and_details(fact_score_result): factuality_score = fact_score_result['fact_score'] feedback_list = [] for i, feedback_data in enumerate(fact_score_result['summary_feedback_pairs']): feedbacks = [fb for fb in feedback_data['feedbacks'] if fb.strip()] if feedbacks: feedback_list.append(f"In chunk {i + 1}: {'; '.join(feedbacks)}") incorrect_details = "\n".join(feedback_list) return factuality_score, incorrect_details def build_kg(script, idx, api_key, model_id): kg = script2kg(script['scenes'], idx, script['name'], api_key, model_id) return kg def build_kg_custom(scenes, idx, api_key, model_id): kg = script2kg(scenes, idx, "custom", api_key, model_id) return kg def build_kg_with_data(data, work_index, custom_scenes, api_key, model_id): if data and work_index is not None: # Dataset mode script = data[int(work_index)] try: kg = script2kg(script['scenes'], int(work_index), script['name'], api_key, model_id) return kg, "Knowledge Graph built successfully!" except Exception as e: return None, f"Error building knowledge graph: {str(e)}" elif custom_scenes: # Custom script mode try: scenes = parse_scenes(custom_scenes) if not isinstance(scenes, list): return None, "Invalid format. Please provide scenes as a list." kg = build_kg_custom(scenes, 0, api_key, model_id) return kg, "Knowledge Graph built successfully!" except (json.JSONDecodeError, SyntaxError, ValueError) as e: return None, f"Invalid format. Error: {str(e)}" except Exception as e: return None, f"Error building knowledge graph: {str(e)}" return None, "Please select a work or input custom scenes." def generate_summary(script, idx, api_key, model_id): _set_seed(42) scripty_summarizer = Summarizer( inference_mode="org", model_id=model_id, api_key=api_key, dtype="float16", seed=42 ) scenes = [f"s#{i}\n{s}" for i, s in enumerate(script['scenes'])] script = "\n\n".join(scenes) script_chunks = chunk_script_gpt(script=script, model_id=model_id, chunk_size=2048) script_summaries = [] for chunk in tqdm(script_chunks): chunk = preprocess_script(chunk) prompt = build_summarizer_prompt( prompt_template="./templates/external_summary.txt", input_text_list=[chunk] ) script_summ = scripty_summarizer.inference_with_gpt(prompt=prompt) script_summaries.append(script_summ.strip()) elem_dict_list = [] agg_dict = { 'script': ' '.join(script_chunks), 'summaries': ' '.join(script_summaries) } for i, (chunk, summary) in enumerate(zip(script_chunks, script_summaries)): elem_dict = { "chunk_index": i, "chunk": chunk.strip(), "summary": summary.strip() } elem_dict_list.append(elem_dict) processed_dataset = { "script": script, "scenes": scenes, "script_chunks": script_chunks, "script_summaries": script_summaries, } return {"summary_sep": elem_dict_list, "summary_agg": agg_dict, "processed_dataset": processed_dataset} def generate_summary_with_data(data, work_index, custom_scenes, api_key, model_id): if data and work_index is not None: # Dataset mode script = data[int(work_index)] try: summary = generate_summary(script, int(work_index), api_key, model_id) return summary, extract_initial_summary(summary) except Exception as e: return None, f"Error generating summary: {str(e)}" elif custom_scenes: # Custom script mode try: scenes = parse_scenes(custom_scenes) if not isinstance(scenes, list): return None, "Invalid format. Please provide scenes as a list." script = {"name": "custom", "scenes": scenes} summary = generate_summary(script, 0, api_key, model_id) return summary, extract_initial_summary(summary) except (json.JSONDecodeError, SyntaxError, ValueError) as e: return None, f"Invalid format. Error: {str(e)}" except Exception as e: return None, f"Error generating summary: {str(e)}" return None, "Please select a work or input custom scenes." def calculate_narrative_fact_score(summary, kg_raw, api_key, model_id): _set_seed(42) factscorer = NarrativeFactScore(split_type='gpt', model='gptscore', api_key=api_key, model_id=model_id) summary = summary['processed_dataset'] chunks, summaries = summary['script_chunks'], summary['script_summaries'] total_output = {'fact_score': 0, 'summary_feedback_pairs': []} partial_output = {'fact_score': 0, 'summary_feedback_pairs': []} total_score = 0 kg = [] for elem in kg_raw: if elem['subject'] == elem['object']: kg.append(f"{elem['subject']} {elem['predicate']}") else: kg.append(f"{elem['subject']} {elem['predicate']} {elem['object']}") scores, scores_per_sent, relevant_scenes, summary_chunks, feedbacks = factscorer.score_src_hyp_long(chunks, summaries, kg) for i, score in enumerate(scores): output_elem = { 'src': chunks[i], 'summary': summaries[i], 'score': score, 'scores_per_sent': scores_per_sent[i], 'relevant_scenes': relevant_scenes[i], 'summary_chunks': summary_chunks[i], 'feedbacks': feedbacks[i], } output_elem_part = { 'scores_per_sent': scores_per_sent[i], 'summary_chunks': summary_chunks[i], 'feedbacks': feedbacks[i], } total_output['summary_feedback_pairs'].append(output_elem) partial_output['summary_feedback_pairs'].append(output_elem_part) total_score += score total_output['fact_score'] = float(total_score / len(scores)) partial_output['fact_score'] = float(total_score / len(scores)) return total_output, partial_output def refine_summary(summary, fact_score, api_key, model_id): _set_seed(42) threshold = 0.9 summarizer = Summarizer( inference_mode="org", model_id=model_id, api_key=api_key, dtype="float16", seed=42 ) processed_dataset = { "script": summary["script"], "scenes": summary["scenes"], "script_chunks": [], "script_summaries": [] } elem_dict_list = [] agg_dict = {} for factscore_chunk in tqdm(fact_score['summary_feedback_pairs']): src_chunk = factscore_chunk['src'] original_summary = factscore_chunk['summary'] if factscore_chunk['score'] >= threshold: processed_dataset["script_chunks"].append(src_chunk) processed_dataset["script_summaries"].append(original_summary.strip()) continue hallu_idxs = np.where(np.array(factscore_chunk['scores_per_sent']) == 0)[0] hallu_summary_parts = np.array(factscore_chunk['summary_chunks'])[hallu_idxs] feedbacks = np.array(factscore_chunk['feedbacks'])[hallu_idxs] prompt = build_summarizer_prompt( prompt_template="./templates/self_correction.txt", input_text_list=[src_chunk, original_summary] ) for j, (hallu_summ, feedback) in enumerate(zip(hallu_summary_parts, feedbacks)): prompt += f"\n- Statement to Revise {j + 1}: {hallu_summ} (Reason for Revision: {feedback})" prompt += "\n- Revised Summary: " revised_summary = summarizer.inference_with_gpt(prompt=prompt) if len(revised_summary.strip()) == 0: revised_summary = original_summary processed_dataset["script_chunks"].append(src_chunk) processed_dataset["script_summaries"].append(revised_summary) elem_dict = { "chunk_index": len(processed_dataset["script_chunks"]) - 1, "chunk": src_chunk.strip(), "summary": revised_summary.strip(), "org_summary": original_summary.strip(), "hallu_in_summary": list(hallu_summary_parts), "feedbacks": list(feedbacks), } elem_dict_list.append(elem_dict) agg_dict['script'] = summary['script'] agg_dict['summaries'] = ' '.join(processed_dataset["script_summaries"]) return { "summary_sep": elem_dict_list, "summary_agg": agg_dict, "processed_dataset": processed_dataset } def refine_summary_and_return_diff(summary, fact_score, api_key, model_id): refined_summary = refine_summary(summary['processed_dataset'], fact_score, api_key, model_id) diff = HtmlDiff().make_file( summary['summary_agg']['summaries'].splitlines(), refined_summary['summary_agg']['summaries'].splitlines(), context=True ) return diff def open_kg(kg_data): if kg_data is None: return "Please build the knowledge graph first." try: with open('refined_kg.html', 'r', encoding='utf-8') as f: html_content = f.read() return f''' ''' except Exception as e: return f'