my_space / app.py
daniild71r's picture
minor adjustments
bf8e5ea
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
from tokenizers import Tokenizer
def fake_hash(x):
return 0
@st.cache(hash_funcs={Tokenizer: fake_hash}, suppress_st_warning=True, allow_output_mutation=True)
def initialize():
model_name = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained('./final_model')
the_pipeline = TextClassificationPipeline(
model=model,
tokenizer=tokenizer,
return_all_scores=True,
device=-1
)
cat_mapping_file = open('cat_mapping.json', 'r')
cat_name_mapping_file = open('cat_name_mapping.json', 'r')
cat_mapping = json.load(cat_mapping_file)
cat_name_mapping = json.load(cat_name_mapping_file)
return the_pipeline, cat_mapping, cat_name_mapping
def get_top(the_pipeline, cat_mapping, title, summary, thresh=0.95):
if title == '' or summary == '':
return 'Not enough data to compute.'
question = title + ' || ' + summary
if len(question) > 4000:
return 'Your input is supsiciously long, try something shorter.'
try:
result = the_pipeline(question)[0]
result.sort(key=lambda x: -x['score'])
current_sum = 0
scores = []
for score in result:
scores.append(score)
current_sum += score['score']
if current_sum >= thresh:
break
for i in range(len(result)):
result[i]['label'] = cat_mapping[result[i]['label'][6:]]
return scores
except BaseException:
return 'Something unexpected happened, I\'m sorry. Try again.'
st.markdown('## Welcome to the CS article classification page!')
st.markdown('### What\'s below is pretty much self-explanatory.')
img_source = 'https://sun9-55.userapi.com/impg/azBQ_VTvbgEVonbL9hhFEpwyKAhjAtpVl4H2GQ/I4Vq0H6c3UM.jpg'
img_params = 'size=1200x900&quality=96&sign=f42419d9cdbf6fe55016fb002e4e85ae&type=album'
st.markdown(
f'<img src="{img_source}?{img_params}" width="70%"><br>',
unsafe_allow_html=True
)
title = st.text_input(
'Please, insert the title of the CS article you are interested in.',
placeholder='The title (e. g. Incorporating alien technologies in CV)'
)
summary = st.text_area(
'Now, please, insert the summary of the CS article you are interested in.',
height=240, placeholder='The summary itself.'
)
the_pipeline, cat_mapping, cat_name_mapping = initialize()
scores = get_top(the_pipeline, cat_mapping, title, summary)
if isinstance(scores, str):
st.markdown(scores)
else:
for score in scores:
percent = round(score['score'] * 100, 2)
category_short = score['label']
category_full = cat_name_mapping[category_short]
st.markdown(f'I\'m {percent}\% certain that the article is from the {category_short} category, which is "{category_full}"')