|
import torch |
|
import gradio as gr |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
|
|
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model_path = "thenHung/question_decomposer_t5" |
|
tokenizer = T5Tokenizer.from_pretrained(model_path) |
|
model = T5ForConditionalGeneration.from_pretrained(model_path) |
|
model.to(device) |
|
model.eval() |
|
|
|
def decompose_question(question): |
|
""" |
|
Decompose a complex question into sub-questions |
|
|
|
Args: |
|
question (str): Input complex question |
|
|
|
Returns: |
|
list: List of decomposed sub-questions |
|
""" |
|
try: |
|
|
|
input_text = f"decompose question: {question}" |
|
input_ids = tokenizer( |
|
input_text, |
|
max_length=128, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
).input_ids.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids, |
|
max_length=128, |
|
num_beams=4, |
|
early_stopping=True |
|
) |
|
|
|
|
|
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
sub_questions = decoded_output.split(" [SEP] ") |
|
|
|
return sub_questions |
|
except Exception as e: |
|
return [f"Error: {str(e)}"] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=decompose_question, |
|
inputs=gr.Textbox(label="Enter your complex question"), |
|
outputs=gr.JSON(label="Decomposed Sub-Questions"), |
|
title="Question Decomposer", |
|
description="Breaks down complex questions into simpler sub-questions using a T5 model", |
|
examples=[ |
|
"Who is taller between John and Mary?", |
|
"What is the capital of Vietnam and the largest city in Vietnam?", |
|
] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |