dejanseo's picture
Update app.py
2581e99 verified
import streamlit as st
import torch
from transformers import AlbertTokenizer, AlbertForSequenceClassification
import plotly.graph_objects as go
# URL of the logo
logo_url = "https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png"
# Display the logo at the top using st.logo
st.logo(logo_url, link="https://dejan.ai")
# Streamlit app title and description
st.title("Search Query Form Classifier")
st.write(
"Ambiguous search queries are candidates for query expansion. Our model identifies such queries with an 80 percent accuracy and is deployed in a batch processing pipeline directly connected with Google Search Console API. In this demo you can test the model capability by testing individual queries."
)
st.write("Enter a query to check if it's well-formed:")
# Load the model and tokenizer from the Hugging Face Model Hub
model_name = 'dejanseo/Query-Quality-Classifier'
tokenizer = AlbertTokenizer.from_pretrained(model_name)
model = AlbertForSequenceClassification.from_pretrained(model_name)
# Set the model to evaluation mode
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Create tabs for single and bulk queries
tab1, tab2 = st.tabs(["Single Query", "Bulk Query"])
with tab1:
user_input = st.text_input("Query:", "where can I book cheap flights to london")
#st.write("Developed by [Dejan AI](https://dejan.ai/blog/search-query-quality-classifier/)")
def classify_query(query):
# Tokenize input
inputs = tokenizer.encode_plus(
query,
add_special_tokens=True,
max_length=32,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
# Perform inference
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
softmax_scores = torch.softmax(logits, dim=1).cpu().numpy()[0]
confidence = softmax_scores[1] * 100 # Confidence for well-formed class
return confidence
# Function to determine color based on confidence
def get_color(confidence):
if confidence < 50:
return 'rgba(255, 51, 0, 0.8)' # Red
else:
return 'rgba(57, 172, 57, 0.8)' # Green
# Check and display classification for single query
if user_input:
confidence = classify_query(user_input)
# Plotly grey placeholder bar with dynamic color fill
fig = go.Figure()
# Placeholder grey bar
fig.add_trace(go.Bar(
x=[100],
y=['Well-formedness Factor'],
orientation='h',
marker=dict(
color='lightgrey'
),
width=0.8
))
# Colored bar based on confidence
fig.add_trace(go.Bar(
x=[confidence],
y=['Well-formedness Factor'],
orientation='h',
marker=dict(
color=get_color(confidence)
),
width=0.8
))
fig.update_layout(
xaxis=dict(range=[0, 100], title='Well-formedness Factor'),
yaxis=dict(showticklabels=False),
width=600,
height=250, # Increase height for better visibility
title_text='Well-formedness Factor',
plot_bgcolor='rgba(0,0,0,0)',
showlegend=False
)
st.plotly_chart(fig)
if confidence >= 50:
st.success(f"Query Score: {confidence:.2f}% Most likely doesn't require query expansion.")
st.subheader(f":sparkles: What's next?", divider="gray")
st.write("Connect with Google Search Console, Semrush, Ahrefs or any other search query source API and detect all queries which could benefit from expansion.")
st.write("[Engage our team](https://dejan.ai/call/) if you'd like us to do this for you.")
else:
st.error(f"The query is likely not well-formed with a score of {100 - confidence:.2f}% and most likely requires query expansion.")
st.subheader(f":sparkles: What's next?", divider="gray")
st.write("Connect with Google Search Console, Semrush, Ahrefs or any other search query source API and detect all queries which could benefit from expansion.")
st.write("[Engage our team](https://dejan.ai/call/) if you'd like us to do this for you.")
with tab2:
st.write("Paste multiple queries line-separated (no headers or extra data):")
bulk_input = st.text_area("Bulk Queries:", height=200)
if bulk_input:
bulk_queries = bulk_input.splitlines()
st.write("Processing queries...")
# Classify each query in bulk input
results = [(query, classify_query(query)) for query in bulk_queries]
# Display results in a table
for query, confidence in results:
st.write(f"Query: {query} - Score: {confidence:.2f}%")
if confidence >= 50:
st.success("Well-formed")
else:
st.error("Not well-formed")
st.subheader(f":sparkles: What's next?", divider="gray")
st.write("Connect with Google Search Console, Semrush, Ahrefs or any other search query source API and detect all queries which could benefit from expansion.")
st.write("[Engage our team](https://dejan.ai/call/) if you'd like us to do this for you.")