Spaces:
Runtime error
Runtime error
RMakushkin
commited on
Commit
·
d713893
1
Parent(s):
c8c7f92
Upload 4 files
Browse files- .gitattributes +1 -0
- app.py +87 -0
- embeddings_main.npy +3 -0
- faiss_index_main.index +3 -0
- func.py +53 -0
.gitattributes
CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
embs.txt filter=lfs diff=lfs merge=lfs -text
|
37 |
dataset.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
embs.txt filter=lfs diff=lfs merge=lfs -text
|
37 |
dataset.csv filter=lfs diff=lfs merge=lfs -text
|
38 |
+
faiss_index_main.index filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import ast
|
5 |
+
import faiss
|
6 |
+
|
7 |
+
from func import filter_by_ganre, embed_user
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
"""
|
12 |
+
# Умный поиск сериалов
|
13 |
+
"""
|
14 |
+
|
15 |
+
df = pd.read_csv('dataset.csv')
|
16 |
+
embeddings = np.load('embeddings_main.npy')
|
17 |
+
index = faiss.read_index('faiss_index_main.index')
|
18 |
+
|
19 |
+
df['ganres'] = df['ganres'].apply(lambda x: ast.literal_eval(x))
|
20 |
+
|
21 |
+
st.write(f'<p style="font-family: Arial, sans-serif; font-size: 24px; ">Количество сериалов, \
|
22 |
+
предоставляемых сервисом {len(df)}</p>', unsafe_allow_html=True)
|
23 |
+
|
24 |
+
ganres_lst = sorted(['драма', 'документальный', 'биография', 'комедия', 'фэнтези', 'приключения', 'для детей', 'мультсериалы',
|
25 |
+
'мелодрама', 'боевик', 'детектив', 'фантастика', 'триллер', 'семейный', 'криминал', 'исторический', 'музыкальные',
|
26 |
+
'мистика', 'аниме', 'ужасы', 'спорт', 'скетч-шоу', 'военный', 'для взрослых', 'вестерн'])
|
27 |
+
|
28 |
+
st.sidebar.header('Панель инструментов :gear:')
|
29 |
+
choice_g = st.sidebar.multiselect("Выберите жанры", options=ganres_lst)
|
30 |
+
n = st.sidebar.selectbox("Количество отображаемых элементов на странице", options=[5, 10, 15])
|
31 |
+
|
32 |
+
|
33 |
+
# col3, col4 = st.columns([5,2])
|
34 |
+
|
35 |
+
# with col3:
|
36 |
+
text = st.text_input('Введите описание для рекомендации')
|
37 |
+
|
38 |
+
# with col4:
|
39 |
+
|
40 |
+
button = st.button('Отправить запрос', type="primary")
|
41 |
+
|
42 |
+
if text and button:
|
43 |
+
if len(choice_g) == 0:
|
44 |
+
choice_g = ganres_lst
|
45 |
+
filt_ind = filter_by_ganre(df, choice_g)
|
46 |
+
user_emb = embed_user(filt_ind, embeddings, text, n)
|
47 |
+
_, sorted_indices = index.search(user_emb.reshape(1, -1), n)
|
48 |
+
st.write(f'<p style="font-family: Arial, sans-serif; font-size: 18px; text-align: center;"><strong>Всего подобранных \
|
49 |
+
рекомендаций {len(sorted_indices[0])}</strong></p>', unsafe_allow_html=True)
|
50 |
+
st.write('\n')
|
51 |
+
|
52 |
+
# Отображение изображений и названий
|
53 |
+
# for ind, sim in top_dict.items():
|
54 |
+
# col1, col2 = st.columns([3, 4])
|
55 |
+
# with col1:
|
56 |
+
# st.image(df['poster'][ind], width=300)
|
57 |
+
# with col2:
|
58 |
+
# st.write(f"***Название:*** {df['title'][ind]}")
|
59 |
+
# st.write(f"***Жанр:*** {', '.join(df['ganres'][ind])}")
|
60 |
+
# st.write(f"***Описание:*** {df['description'][ind]}")
|
61 |
+
# similarity = round(sim, 4)
|
62 |
+
# st.write(f"***Cosine Similarity : {similarity}***")
|
63 |
+
# st.write(f"***Ссылка на фильм : {df['url'][ind]}***")
|
64 |
+
|
65 |
+
# st.markdown(
|
66 |
+
# "<hr style='border: 2px solid #000; margin-top: 10px; margin-bottom: 10px;'>",
|
67 |
+
# unsafe_allow_html=True
|
68 |
+
# )
|
69 |
+
|
70 |
+
for ind in sorted_indices[0]:
|
71 |
+
col1, col2 = st.columns([3, 4])
|
72 |
+
with col1:
|
73 |
+
st.image(df['poster'][ind], width=300)
|
74 |
+
with col2:
|
75 |
+
st.write(f"***Название:*** {df['title'][ind]}")
|
76 |
+
st.write(f"***Жанр:*** {', '.join(df['ganres'][ind])}")
|
77 |
+
st.write(f"***Описание:*** {df['description'][ind]}")
|
78 |
+
# similarity = round(sim, 4)
|
79 |
+
# st.write(f"***Cosine Similarity : {similarity}***")
|
80 |
+
st.write(f"***Ссылка на фильм : {df['url'][ind]}***")
|
81 |
+
|
82 |
+
st.markdown(
|
83 |
+
"<hr style='border: 2px solid #000; margin-top: 10px; margin-bottom: 10px;'>",
|
84 |
+
unsafe_allow_html=True
|
85 |
+
)
|
86 |
+
|
87 |
+
# streamlit run app.py
|
embeddings_main.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b33d9e4726eff511c3f0f74dd9d1f22f863828aa0c03ff060c2983be3dce0115
|
3 |
+
size 45892736
|
faiss_index_main.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5fbaa50af8354c8a54372b1c763337f98792c351fa2e3aa266f448ec8266da2
|
3 |
+
size 45892653
|
func.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from transformers import BertModel, BertTokenizer
|
5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
+
|
7 |
+
|
8 |
+
tokenizer = BertTokenizer.from_pretrained("DeepPavlov/rubert-base-cased-sentence")
|
9 |
+
model = BertModel.from_pretrained("DeepPavlov/rubert-base-cased-sentence")
|
10 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
11 |
+
|
12 |
+
|
13 |
+
def filter_by_ganre(df: pd.DataFrame, ganre_list: list):
|
14 |
+
filtered_df = df[df['ganres'].apply(lambda x: any(g in ganre_list for g in(x)))]
|
15 |
+
filt_ind = filtered_df.index.to_list()
|
16 |
+
return filt_ind
|
17 |
+
|
18 |
+
# def mean_pooling(model_output, attention_mask):
|
19 |
+
# token_embeddings = model_output['last_hidden_state']
|
20 |
+
# input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
21 |
+
# sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
22 |
+
# sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
23 |
+
# return sum_embeddings / sum_mask
|
24 |
+
|
25 |
+
# def recommendation(filt_ind: list, embeddings: np.array, user_text: str, n=10):
|
26 |
+
# token_user_text = tokenizer(user_text, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
|
27 |
+
# user_embeddings = torch.Tensor().to(device)
|
28 |
+
# model.to(device)
|
29 |
+
# model.eval()
|
30 |
+
# with torch.no_grad():
|
31 |
+
# batch = {k: v.to(device) for k, v in token_user_text.items()}
|
32 |
+
# outputs = model(**batch)
|
33 |
+
# user_embeddings = torch.cat([user_embeddings, mean_pooling(outputs, batch['attention_mask'])])
|
34 |
+
# user_embeddings = user_embeddings.cpu().numpy()
|
35 |
+
# cosine_similarities = cosine_similarity(embeddings[filt_ind], user_embeddings.reshape(1, -1))
|
36 |
+
# df_res = pd.DataFrame(cosine_similarities.ravel(), columns=['cos_sim']).sort_values('cos_sim', ascending=False)
|
37 |
+
# dict_topn = df_res.iloc[:n, :].cos_sim.to_dict()
|
38 |
+
# return dict_topn
|
39 |
+
|
40 |
+
|
41 |
+
def embed_user(filt_ind: list, embeddings:np.array, user_text: str, n=10):
|
42 |
+
tokens = tokenizer(user_text, return_tensors="pt", padding=True, truncation=True).to(device)
|
43 |
+
model.to(device)
|
44 |
+
model.eval()
|
45 |
+
with torch.no_grad():
|
46 |
+
outputs = model(**tokens)
|
47 |
+
user_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().reshape(1, -1)
|
48 |
+
return user_embedding
|
49 |
+
|
50 |
+
# cosine_similarities = cosine_similarity(embeddings[filt_ind], user_embedding.reshape(1, -1))
|
51 |
+
# df_res = pd.DataFrame(cosine_similarities.ravel(), columns=['cos_sim']).sort_values('cos_sim', ascending=False)
|
52 |
+
# dict_topn = df_res.iloc[:n, :].cos_sim.to_dict()
|
53 |
+
# return dict_topn
|