Spaces:
Running
Running
Use sparse matrices
Browse files- app.py +5 -1
- data_utils.py +24 -21
app.py
CHANGED
@@ -18,7 +18,6 @@ from components import (
|
|
18 |
get_upload_div,
|
19 |
)
|
20 |
from data_utils import (
|
21 |
-
build_embeddings_index,
|
22 |
build_formula_index,
|
23 |
get_crystal_plot,
|
24 |
get_dataset,
|
@@ -29,6 +28,11 @@ from data_utils import (
|
|
29 |
EMPTY_DATA = False
|
30 |
CACHE_PATH = None
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
dataset = get_dataset()
|
33 |
|
34 |
display_columns_query = [
|
|
|
18 |
get_upload_div,
|
19 |
)
|
20 |
from data_utils import (
|
|
|
21 |
build_formula_index,
|
22 |
get_crystal_plot,
|
23 |
get_dataset,
|
|
|
28 |
EMPTY_DATA = False
|
29 |
CACHE_PATH = None
|
30 |
|
31 |
+
if CACHE_PATH is not None:
|
32 |
+
import os
|
33 |
+
|
34 |
+
os.makedirs(CACHE_PATH, exist_ok=True)
|
35 |
+
|
36 |
dataset = get_dataset()
|
37 |
|
38 |
display_columns_query = [
|
data_utils.py
CHANGED
@@ -72,6 +72,7 @@ mapping_table_idx_dataset_idx = {}
|
|
72 |
|
73 |
|
74 |
def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False):
|
|
|
75 |
if empty_data:
|
76 |
return np.zeros((1, 1)), {}
|
77 |
|
@@ -80,40 +81,42 @@ def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=F
|
|
80 |
use_dataset = dataset.select(index_range)
|
81 |
|
82 |
# Preprocessing step to create an index for the dataset
|
83 |
-
|
84 |
-
train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb"))
|
85 |
|
86 |
-
|
|
|
|
|
87 |
else:
|
88 |
train_df = use_dataset.select_columns(
|
89 |
-
["
|
90 |
).to_pandas()
|
91 |
|
92 |
-
|
93 |
-
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
|
94 |
-
extracted["count"] = extracted["count"].replace("", "1").astype(int)
|
95 |
-
|
96 |
-
wide_df = (
|
97 |
-
extracted.reset_index().pivot_table( # Move index to columns for pivoting
|
98 |
-
index="level_0", # original row index
|
99 |
-
columns="element",
|
100 |
-
values="count",
|
101 |
-
aggfunc="sum",
|
102 |
-
fill_value=0,
|
103 |
-
)
|
104 |
-
)
|
105 |
|
106 |
-
all_elements =
|
107 |
-
|
|
|
|
|
108 |
|
109 |
-
|
|
|
|
|
110 |
|
111 |
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
|
112 |
dataset_index = (
|
113 |
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
|
114 |
) # Normalize vectors
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
immutable_id_to_idx = train_df["immutable_id"].to_dict()
|
|
|
117 |
immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()}
|
118 |
|
119 |
return dataset_index, immutable_id_to_idx
|
@@ -162,7 +165,7 @@ def search_materials(
|
|
162 |
numb = int(numb) if numb else 1
|
163 |
query_vector[map_periodic_table[el]] = numb
|
164 |
|
165 |
-
similarity =
|
166 |
indices = np.argsort(similarity)[::-1][:top_k]
|
167 |
|
168 |
options = [dataset[int(i)] for i in indices]
|
|
|
72 |
|
73 |
|
74 |
def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False):
|
75 |
+
print("Building formula index")
|
76 |
if empty_data:
|
77 |
return np.zeros((1, 1)), {}
|
78 |
|
|
|
81 |
use_dataset = dataset.select(index_range)
|
82 |
|
83 |
# Preprocessing step to create an index for the dataset
|
84 |
+
from scipy.sparse import load_npz
|
|
|
85 |
|
86 |
+
if cache_path is not None and os.path.exists(f"{cache_path}/train_df.pkl"):
|
87 |
+
train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb"))
|
88 |
+
dataset_index = load_npz(f"{cache_path}/dataset_index.npz")
|
89 |
else:
|
90 |
train_df = use_dataset.select_columns(
|
91 |
+
["species_at_sites", "immutable_id", "functional"]
|
92 |
).to_pandas()
|
93 |
|
94 |
+
import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
+
all_elements = {
|
97 |
+
str(el.symbol): i for i, el in enumerate(periodictable.elements)
|
98 |
+
} # full element list
|
99 |
+
dataset_index = np.zeros((len(train_df), len(all_elements)))
|
100 |
|
101 |
+
for idx, species in tqdm.tqdm(enumerate(train_df["species_at_sites"].values)):
|
102 |
+
for el in species:
|
103 |
+
dataset_index[idx, all_elements[el]] += 1
|
104 |
|
105 |
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
|
106 |
dataset_index = (
|
107 |
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
|
108 |
) # Normalize vectors
|
109 |
|
110 |
+
from scipy.sparse import csr_matrix, save_npz
|
111 |
+
|
112 |
+
dataset_index = csr_matrix(dataset_index)
|
113 |
+
|
114 |
+
if cache_path is not None:
|
115 |
+
pickle.dump(train_df, open(f"{cache_path}/train_df.pkl", "wb"))
|
116 |
+
save_npz(f"{cache_path}/dataset_index.npz", dataset_index)
|
117 |
+
|
118 |
immutable_id_to_idx = train_df["immutable_id"].to_dict()
|
119 |
+
del train_df
|
120 |
immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()}
|
121 |
|
122 |
return dataset_index, immutable_id_to_idx
|
|
|
165 |
numb = int(numb) if numb else 1
|
166 |
query_vector[map_periodic_table[el]] = numb
|
167 |
|
168 |
+
similarity = dataset_index.dot(query_vector) / (np.linalg.norm(query_vector))
|
169 |
indices = np.argsort(similarity)[::-1][:top_k]
|
170 |
|
171 |
options = [dataset[int(i)] for i in indices]
|