File size: 3,298 Bytes
20f4093
245997e
bf2d247
f7c8641
1d618ed
 
 
 
 
245997e
 
bf2d247
720f1cb
 
bf2d247
1d618ed
 
720f1cb
1d618ed
bf2d247
245997e
 
1d618ed
bf2d247
1d618ed
245997e
1d618ed
 
bf2d247
245997e
1d618ed
 
 
bf2d247
1d618ed
bf2d247
 
1d618ed
 
 
 
245997e
f7c8641
1d618ed
 
 
 
 
 
 
20f4093
f7c8641
1d618ed
 
 
 
 
bf2d247
 
1d618ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf2d247
1d618ed
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import time
import sys
import traceback

# Global variables to store error information
error_message = ""

# Load the model and tokenizer from Hugging Face
model_name = "ambrosfitz/history-qa-flan-t5-large"
try:
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
except Exception as e:
    error_message = f"Error loading model or tokenizer: {str(e)}\n{traceback.format_exc()}"
print(error_message)

def generate_qa(text, max_length=512):
    try:
        input_text = f"Generate a history question and answer based on this text: {text}"
        input_ids = tokenizer(input_text, return_tensors="pt", max_length=max_length, truncation=True).input_ids.to(device)
        
        with torch.no_grad():
            outputs = model.generate(input_ids, max_length=max_length, num_return_sequences=1, do_sample=True, temperature=0.7)
            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Parse the generated text
        parts = generated_text.split("Question: ")
        if len(parts) > 1:
            qa_parts = parts[1].split("Answer: ")
            question = qa_parts[0].strip()
            answer = qa_parts[1].strip() if len(qa_parts) > 1 else "No answer provided."
            return f"Question: {question}\n\nAnswer: {answer}"
        else:
            return "Unable to generate a proper question and answer. Please try again with a different input."
    except Exception as e:
        return f"An error occurred: {str(e)}\n{traceback.format_exc()}"

def slow_qa(message, history):
    try:
        full_response = generate_qa(message)
        for i in range(len(full_response)):
            time.sleep(0.01)
            yield full_response[:i+1]
    except Exception as e:
        yield f"An error occurred: {str(e)}\n{traceback.format_exc()}"

# Create and launch the Gradio interface
try:
    iface = gr.ChatInterface(
        slow_qa,
        chatbot=gr.Chatbot(height=500),
        textbox=gr.Textbox(placeholder="Enter historical text here...", container=False, scale=7),
        title="History Q&A Generator (FLAN-T5)",
        description="Enter a piece of historical text, and the model will generate a related question and answer.",
        theme="soft",
        examples=[
            "The American Revolution was a colonial revolt that took place between 1765 and 1783.",
            "World War II was a global conflict that lasted from 1939 to 1945, involving many of the world's nations.",
            "The Renaissance was a period of cultural, artistic, political, and economic revival following the Middle Ages."
        ],
        cache_examples=False,
        retry_btn="Regenerate",
        undo_btn="Remove last",
        clear_btn="Clear",
    )
    
    if error_message:
        print("Launching interface with error message.")
    else:
        print("Launching interface normally.")
    iface.launch(debug=True)
except Exception as e:
    print(f"An error occurred while creating or launching the interface: {str(e)}\n{traceback.format_exc()}")