Firefly777a's picture
create app.py
cd851c8
raw
history blame
7.23 kB
import os
from typing import Any, Callable, List, Optional, Tuple
import nltk
nltk.download('punkt')
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
# A folderpath for where the examples are stored
EXAMPLES_FOLDER_NAME = "examples"
# A List of repo names for the huggingface models available for inference
HF_MODELS = ["huggingface/facebook/bart-large-cnn",
"huggingface/sshleifer/distilbart-xsum-12-6",
"huggingface/google/pegasus-xsum",
"huggingface/philschmid/bart-large-cnn-samsum",
"huggingface/linydub/bart-large-samsum",
"huggingface/philschmid/distilbart-cnn-12-6-samsum",
"huggingface/knkarthick/MEETING-SUMMARY-BART-LARGE-XSUM-SAMSUM-DIALOGSUM-AMI",
]
################################################################################
# Functions: Document statistics
################################################################################
# Function that uses a huggingface tokenizer to count how many tokens are in a text
def count_tokens(input_text, model_path='sshleifer/distilbart-cnn-12-6'):
# Load a huggingface tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Tokenize the text
tokens = tokenizer(input_text)
# Count the number of tokens
return len(tokens['input_ids'])
# Function that uses nltk to count sentences in a text
def count_sentences(input_text):
# Use nltk to count sentences in the text
number_of_sentences = nltk.sent_tokenize(input_text)
# Return the number of sentences
return len(number_of_sentences)
# Function that counts the number of words in a text
def count_words(input_text):
# Use nltk to count words in the text
number_of_words = nltk.word_tokenize(input_text)
# Return the number of words
return len(number_of_words)
# Function that computes a few document statistics such as the number of tokens, sentences, and words
def compute_stats(input_text, models: Optional[List[str]] = None):
# Count the number of tokens
num_tokens = count_tokens(input_text)
# Count the number of sentences
num_sentences = count_sentences(input_text)
# Count the number of words
num_words = count_words(input_text)
# Return the document statistics formatted as a string
output_str = "| Tokens: {0} \n| Sentences: {1} \n| Words: {2}".format(num_tokens, num_sentences, num_words) + "\n"
output_str += "The max number of tokens for the model is: 1024" + "\n" # I manually set 1024 as the max. I don't intend to use any models that are smaller anyway.
# output_str += "Number of documents splits: 17.5"
return output_str
# # A function to loop through a list of strings
# # returning the last element in the filepath for each string
# def get_file_names(file_paths):
# # Create a list of file names
# file_names = []
# # Loop through the file paths
# for file_path in file_paths:
# # Get the last element in the file path
# file_name = file_path.split('/')[-2:]
# # Add the file name to the list
# file_names.append(file_name)
# # Loop through the file names and append to a string
# file_names_str = ""
# for file_name in file_names:
# breakpoint()
# file_names_str += file_name[0] + "\n"
# # Return the list of file names
# return file_names_str
################################################################################
# Functions: Huggingface Inference
################################################################################
# Function that uses a huggingface pipeline to predict a summary of a text
# input is a text string of a dialog conversation
def predict(dialog_text):
# Load a huggingface model
model = pipeline('summarization', model="philschmid/bart-large-cnn-samsum") #model='sshleifer/distilbart-cnn-12-6')
# Build tokenizer_kwargs to set a max length and truncate the data on inference
tokenizer_kwargs = {'truncation': True, 'max_length': 1024}
# Use the model to predict a summary of the text
summary = model(dialog_text, **tokenizer_kwargs)
# Return the summary w/ the model name
output = f"{hf_model_name} output: {summary[0]['summary_text']}"
return output, "output2"
def recursive_predict(dialog_text: str, hf_model_name: Tuple[str]):
breakpoint()
asdf = "asdf"
return output
################################################################################
# Functions: Gradio Utilities
################################################################################
# Function to build examples for gradio app
# Load text files from the examples folder as a list of strings for gradio
def get_examples(folder_path):
# Create a list of strings
examples = []
# Loop through the files in the folder
for file in os.listdir(folder_path):
# Load the file
with open(os.path.join(folder_path, file), 'r') as f:
# Add the file to the list
examples.append([f.read(), ["None"]])
# Return the list of strings
return examples
# A function that loops through a list of model paths, creates a gradio interface with the
# model name, and adds it to the list of interfaces
# It outputs a list of interfaces
def get_hf_interfaces(models_to_load):
# Create a list of interfaces
interfaces = []
# Loop through the HF_MODELS
for model in models_to_load:
# Create a gradio interface with the model name
interface = gr.Interface.load(model, title="this is a test TITLE", alias="this is an ALIAS")
# Add the interface to the list
interfaces.append(interface)
# Return the list of interfaces
return interfaces
################################################################################
# Build Gradio app
################################################################################
# print_details = gr.Interface(
# fn=lambda x: get_file_names(HF_MODELS),
# inputs="text",
# outputs="text",
# title="Statistics of the document"
# )
# Outputs a string of various document statistics
document_statistics = gr.Interface(
fn=compute_stats,
inputs="text",
outputs="text",
title="Statistics of the document"
)
maddie_mixer_summarization = gr.Interface(
fn=recursive_predict,
inputs="text",
outputs="text",
title="Statistics of the document"
)
# Build Examples to pass along to the gradio app
examples = get_examples(EXAMPLES_FOLDER_NAME)
# Build a list of huggingface interfaces from model paths,
# then add document statistics, and any custom interfaces
all_interfaces = get_hf_interfaces(HF_MODELS)
all_interfaces.insert(0, document_statistics) # Insert the statistics interface at the beginning
# all_interfaces.insert(0, print_details)
# all_interfaces.append(maddie_mixer_summarization) # Add the interface for the maddie mixer
# Build app
app = gr.Parallel(*all_interfaces,
title='Text Summarizer (Maddie Custom)',
description="Write a summary of a text",
examples=examples,
inputs=gr.inputs.Textbox(lines = 10, label="Text"),
)
# Launch
app.launch(inbrowser=True, show_error=True)