Spaces:
Runtime error
Runtime error
def main(): | |
""" | |
Creates a Streamlit web app that classifies a given body of text as either human-made or AI-generated, | |
using a pre-trained model. | |
""" | |
import streamlit as st | |
import numpy as np | |
import joblib | |
import string | |
import time | |
import scipy | |
import spacy | |
import re | |
from transformers import AutoTokenizer | |
import torch | |
from eli5.lime import TextExplainer | |
from eli5.lime.samplers import MaskingTextSampler | |
import eli5 | |
import shap | |
from custom_models import HF_DistilBertBasedModelAppDocs, HF_BertBasedModelAppDocs | |
# Initialize Spacy | |
nlp = spacy.load("en_core_web_sm") | |
# device to run DL model | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def format_text(text: str) -> str: | |
""" | |
This function takes a string as input and returns a formatted version of the string. | |
The function replaces specific substrings in the input string with empty strings, | |
converts the string to lowercase, removes any leading or trailing whitespace, | |
and removes any punctuation from the string. | |
""" | |
text = nlp(text) | |
text = " ".join([token.text for token in text if token.ent_type_ not in ["PERSON", "DATE"]]) | |
pattern = r"\b[A-Za-z]+\d+\b" | |
text = re.sub(pattern, "", text) | |
return text.replace("REDACTED", "").lower().replace("[Name]", "").replace("[your name]", "").\ | |
replace("dear admissions committee,", "").replace("sincerely,","").\ | |
replace("[university's name]","fordham").replace("dear sir/madam,","").\ | |
replace("โ statement of intent ","").\ | |
replace('program: master of science in data analytics name of applicant: ',"").\ | |
replace("data analytics", "data science").replace("| \u200b","").\ | |
replace("m.s. in data science at lincoln center ","").\ | |
translate(str.maketrans('', '', string.punctuation)).strip().lstrip() | |
# Define the function to classify text | |
def nb_lr(model, text: str) -> (int, float): | |
""" | |
This function takes a previously trained Sklearn Pipeline | |
model (NaiveBayes or Logistic Regression), then returns prediction probability, | |
and the final prediction as a tuple. | |
""" | |
# Clean and format the input text | |
text = format_text(text) | |
# Predict using either LR or NB and get prediction probability | |
prediction = model.predict([text]).item() | |
predict_proba = round(model.predict_proba([text]).squeeze()[prediction].item(),4) | |
return prediction, predict_proba | |
def torch_pred(tokenizer, model, text): | |
""" | |
This function takes a pre-trained tokenizer, a previously trained transformer-based model | |
model (DistilBert or Bert), then returns prediction probability, | |
and the final prediction as a tuple. | |
""" | |
# DL models (BERT/DistilBERT based models) | |
cleaned_text_tokens = tokenizer([text], padding='max_length', max_length=512, truncation=True) | |
with torch.inference_mode(): | |
input_ids, att = cleaned_text_tokens["input_ids"], cleaned_text_tokens["attention_mask"] | |
input_ids = torch.tensor(input_ids).to(device) | |
attention_mask = torch.tensor(att).to(device) | |
logits = model(input_ids=input_ids, attention_mask=attention_mask) | |
_, prediction = torch.max(logits, 1) | |
prediction = prediction.item() | |
predict_proba = round(torch.softmax(logits, 1).cpu().squeeze().tolist()[prediction],4) | |
return prediction, predict_proba | |
def pred_str(prediction:int) -> str: | |
""" | |
This function takes an integer value as input and returns a string representing the type of the input's source. | |
The input is expected to be a prediction from a classification model that distinguishes between human-made and AI-generated text. | |
""" | |
# Map the predicted class to string output | |
if prediction == 0: | |
return "Human-made ๐คทโโ๏ธ๐คทโโ๏ธ" | |
else: | |
return "Generated with AI ๐ฆพ" | |
def load_tokenizer(option): | |
""" | |
Load pre-trained tokenizer and and save in cache memory. | |
""" | |
if option == "BERT-based model": | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", padding='max_length', max_length=512, truncation=True) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", padding='max_length', max_length=512, truncation=True) | |
return tokenizer | |
def load_model(option): | |
""" | |
Load trained Transformer-based models and save in cache memory. | |
""" | |
if option == "BERT-based model": | |
model = HF_BertBasedModelAppDocs.from_pretrained("ferdmartin/HF_BertBasedModelAppDocs").to(device) | |
else: | |
model = HF_DistilBertBasedModelAppDocs.from_pretrained("ferdmartin/HF_DistilBertBasedModelAppDocs").to(device) | |
return model | |
# Streamlit app: | |
# List of models available | |
models_available = {"Logistic Regression":"models/baseline_model_lr2.joblib", | |
"Naive Bayes": "models/baseline_model_nb2.joblib", | |
"DistilBERT-based model (BERT light)": "ferdmartin/HF_DistilBertBasedModelAppDocs", | |
"BERT-based model": "ferdmartin/HF_BertBasedModelAppDocs" | |
} | |
st.set_page_config(page_title="AI/Human GradAppDocs", page_icon="๐ค", layout="wide") | |
st.title("Academic Application Document Classifier") | |
st.header("Is it human-made ๐ or Generated with AI ๐ค ? ") | |
# Check the model to use | |
def restore_prediction_state(): | |
"""Restore session_state variable to clear prediction after changing model""" | |
if "prediction" in st.session_state: | |
del st.session_state.prediction | |
option = st.selectbox("Select a model to use:", models_available, on_change=restore_prediction_state) | |
# Load the selected trained model | |
if option in ("BERT-based model", "DistilBERT-based model (BERT light)"): | |
tokenizer = load_tokenizer(option) | |
model = load_model(option) | |
else: | |
model = joblib.load(models_available[option]) | |
text = st.text_area("Enter either a statement of intent or a letter of recommendation:") | |
#Hide footer "made with streamlit" | |
hide_st_style = """ | |
<style> | |
footer {visibility: hidden;} | |
header {visibility: hidden;} | |
</style> | |
""" | |
st.markdown(hide_st_style, unsafe_allow_html=True) | |
# Use model | |
if st.button("Let's check this text!"): | |
if text.strip() == "": | |
# In case there is no input for the model | |
st.error("Please enter some text") | |
else: | |
with st.spinner("Wait for the magic ๐ช๐ฎ"): | |
# Use models | |
if option in ("Naive Bayes", "Logistic Regression"): # Use Sklearn pipeline models | |
prediction, predict_proba = nb_lr(model, text) | |
st.session_state["sklearn"] = True | |
else: | |
prediction, predict_proba = torch_pred(tokenizer, model, text) # Use transformers | |
st.session_state["torch"] = True | |
# Store the result in session state | |
st.session_state["color_pred"] = "blue" if prediction == 0 else "red" # Set color for prediction output string | |
prediction = pred_str(prediction) # Map predictions (int => str) | |
st.session_state["prediction"] = prediction | |
st.session_state["predict_proba"] = predict_proba | |
st.session_state["text"] = text | |
# Print result | |
st.markdown(f"I think this text is: **:{st.session_state['color_pred']}[{st.session_state['prediction']}]** (Prediction probability: {st.session_state['predict_proba'] * 100}%)") | |
elif "prediction" in st.session_state: | |
# Display the stored result if available | |
st.markdown(f"I think this text is: **:{st.session_state['color_pred']}[{st.session_state['prediction']}]** (Prediction probability: {st.session_state['predict_proba'] * 100}%)") | |
if st.button("Model Explanation"): | |
# Check if there's text in the session state | |
if "text" in st.session_state and "prediction" in st.session_state: | |
if option in ("Naive Bayes", "Logistic Regression"): | |
with st.spinner('Wait for it ๐ญ...'): | |
explainer = TextExplainer(sampler=MaskingTextSampler()) | |
explainer.fit(st.session_state["text"], model.predict_proba) | |
html = eli5.format_as_html(explainer.explain_prediction(target_names=["Human", "AI"])) | |
else: | |
with st.spinner('Wait for it ๐ญ... BERT-based model explanations take around 4-10 minutes. In case you want to abort, refresh the page.'): | |
def f(x): | |
"""TORCH EXPLAINER PRED FUNC (USES logits)""" | |
tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=512, truncation=True) for v in x])#.cuda() | |
outputs = model(tv).detach().cpu().numpy() | |
scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T | |
val = scipy.special.logit(scores[:,1]) # use one vs rest logit units | |
return val | |
explainer = shap.Explainer(f, tokenizer) # build explainer using masking tokens and selected transformer-based model | |
shap_values = explainer([st.session_state["text"]], fixed_context=1) | |
html = shap.plots.text(shap_values, display=False) | |
# Render HTML | |
st.components.v1.html(html, height=500, scrolling = True) | |
else: | |
st.error("Please enter some text and click 'Let's check!' before requesting an explanation.") | |
if __name__ == "__main__": | |
main() |