rasmodev's picture
Update app.py
38aadda verified
raw
history blame
3.31 kB
# Import the key libraries
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from scipy.special import softmax
import nltk
import re
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
# Download NLTK resources (if not already downloaded)
nltk.download('stopwords')
nltk.download('wordnet')
# Load the tokenizer and model from Hugging Face
model_path = "rasmodev/Covid-19_Sentiment_Analysis_RoBERTa_Model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
# Preprocess text (username and link placeholders, and text preprocessing)
def preprocess(text):
# Convert text to lowercase
text = text.lower()
# Remove special characters, numbers, and extra whitespaces
text = re.sub(r'[^a-zA-Z\s]', '', text)
# Remove stopwords (common words that don't carry much meaning)
stop_words = set(stopwords.words('english'))
words = text.split() # Split text into words
words = [word for word in words if word not in stop_words]
# Lemmatize words to their base form
lemmatizer = WordNetLemmatizer()
words = [lemmatizer.lemmatize(word) for word in words]
# Rejoin the preprocessed words into a single string
processed_text = ' '.join(words)
# Process placeholders
new_text = []
for t in processed_text.split(" "):
t = '@user' if t.startswith('@') and len(t) > 1 else t
t = 'http' if t.startswith('http') else t
new_text.append(t)
return " ".join(new_text)
# Perform sentiment analysis
def sentiment_analysis(text):
text = preprocess(text)
# Tokenize input text
inputs = tokenizer(text, return_tensors='pt')
# Forward pass through the model
with torch.no_grad():
outputs = model(**inputs)
# Get predicted probabilities
scores_ = outputs.logits[0].detach().numpy()
scores_ = softmax(scores_)
# Define labels and corresponding colors
labels = ['Negative', 'Neutral', 'Positive']
colors = ['red', 'yellow', 'green']
font_colors = ['white', 'black', 'white']
# Find the label with the highest percentage
max_label = labels[scores_.argmax()]
max_percentage = scores_.max() * 100
# Create HTML for the label with the specified style
label_html = f'<div style="display: flex; justify-content: center;"><button style="text-align: center; font-size: 16px; padding: 10px; border-radius: 15px; background-color: {colors[labels.index(max_label)]}; color: {font_colors[labels.index(max_label)]};">{max_label}({max_percentage:.2f}%)</button></div>'
return label_html
# Create a Gradio interface
interface = gr.Interface(
fn=sentiment_analysis,
inputs=gr.Textbox(placeholder="Write your tweet here..."),
outputs=gr.HTML(),
title="COVID-19 Sentiment Analysis App",
description="This App Analyzes the sentiment of COVID-19 related tweets. Negative: Indicates a negative sentiment, Neutral: Indicates a neutral sentiment, Positive: Indicates a positive sentiment.",
theme="default",
layout="horizontal",
examples=[
["Covid vaccines are irrelevant"],
["The Vaccine is Good I have had no issues!"]
]
)
# Launch the Gradio app
interface.launch()