few-shot-demo / prediction.py
spdin
add csv file
f0ad92c
import os
import pickle
import streamlit as st
import model
def main():
st.title("Model Prediction")
# st.write(f"Session ID: {st.session_state.key}")
session_id = st.session_state.key
if not os.path.isdir(f"models/{session_id}"):
st.write("Model is not available")
st.stop()
model_options = [model_name for model_name in os.listdir(f"models/{session_id}")]
models = {
model_name: os.path.abspath(os.path.join(f"models/{session_id}", model_name))
for model_name in model_options
}
model_name = st.selectbox("Select a model", options=model_options)
# Text input
text = st.text_area("Enter some text here", height=200)
# Prediction button
if st.button("Predict"):
with open(f"{models[model_name]}/label.pkl", "rb") as f:
label_map = pickle.load(f)
classifier = model.create_classifier(models[model_name])
prediction = classifier([text])
prediction_class = prediction[0].item()
confidence_score = classifier.predict_proba([text])[0][prediction_class].item()
st.write(
"The predicted label is:",
label_map[prediction_class],
f"{round(confidence_score*100,2)}%",
)