Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
import streamlit as st | |
import model | |
def main(): | |
st.title("Model Prediction") | |
# st.write(f"Session ID: {st.session_state.key}") | |
session_id = st.session_state.key | |
if not os.path.isdir(f"models/{session_id}"): | |
st.write("Model is not available") | |
st.stop() | |
model_options = [model_name for model_name in os.listdir(f"models/{session_id}")] | |
models = { | |
model_name: os.path.abspath(os.path.join(f"models/{session_id}", model_name)) | |
for model_name in model_options | |
} | |
model_name = st.selectbox("Select a model", options=model_options) | |
# Text input | |
text = st.text_area("Enter some text here", height=200) | |
# Prediction button | |
if st.button("Predict"): | |
with open(f"{models[model_name]}/label.pkl", "rb") as f: | |
label_map = pickle.load(f) | |
classifier = model.create_classifier(models[model_name]) | |
prediction = classifier([text]) | |
prediction_class = prediction[0].item() | |
confidence_score = classifier.predict_proba([text])[0][prediction_class].item() | |
st.write( | |
"The predicted label is:", | |
label_map[prediction_class], | |
f"{round(confidence_score*100,2)}%", | |
) | |