Spaces:
Runtime error
Runtime error
Trent
commited on
Commit
·
73ee9f2
1
Parent(s):
de69128
Use distilbert
Browse files- backend/config.py +1 -1
- backend/inference.py +2 -2
- backend/utils.py +1 -1
- data/stackoverflow-titles-distilbert-emb.csv +3 -0
backend/config.py
CHANGED
@@ -10,5 +10,5 @@ QA_MODELS_ID = dict(
|
|
10 |
)
|
11 |
|
12 |
SEARCH_MODELS_ID = dict(
|
13 |
-
|
14 |
)
|
|
|
10 |
)
|
11 |
|
12 |
SEARCH_MODELS_ID = dict(
|
13 |
+
distilbert_qa = 'flax-sentence-embeddings/multi-qa_v1-distilbert-cls_dot'
|
14 |
)
|
backend/inference.py
CHANGED
@@ -47,7 +47,7 @@ def text_similarity(anchor: str, inputs: List[str], model_name: str, model_dict:
|
|
47 |
def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
|
48 |
# Proceeding with model
|
49 |
print(model_name)
|
50 |
-
assert model_name == "
|
51 |
model = load_model(model_name, model_dict)
|
52 |
|
53 |
# Creating embeddings
|
@@ -77,7 +77,7 @@ def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
|
|
77 |
def text_cluster(anchor: str, n_answers: int, model_name: str, model_dict: dict):
|
78 |
# Proceeding with model
|
79 |
print(model_name)
|
80 |
-
assert model_name == "
|
81 |
model = load_model(model_name, model_dict)
|
82 |
|
83 |
# Creating embeddings
|
|
|
47 |
def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
|
48 |
# Proceeding with model
|
49 |
print(model_name)
|
50 |
+
assert model_name == "distilbert_qa"
|
51 |
model = load_model(model_name, model_dict)
|
52 |
|
53 |
# Creating embeddings
|
|
|
77 |
def text_cluster(anchor: str, n_answers: int, model_name: str, model_dict: dict):
|
78 |
# Proceeding with model
|
79 |
print(model_name)
|
80 |
+
assert model_name == "distilbert_qa"
|
81 |
model = load_model(model_name, model_dict)
|
82 |
|
83 |
# Creating embeddings
|
backend/utils.py
CHANGED
@@ -23,7 +23,7 @@ def load_model(model_name, model_dict):
|
|
23 |
@st.cache(allow_output_mutation=True)
|
24 |
def load_embeddings():
|
25 |
# embedding pre-generated
|
26 |
-
corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-
|
27 |
return corpus_emb.float()
|
28 |
|
29 |
@st.cache(allow_output_mutation=True)
|
|
|
23 |
@st.cache(allow_output_mutation=True)
|
24 |
def load_embeddings():
|
25 |
# embedding pre-generated
|
26 |
+
corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-distilbert-emb.csv', max_rows=10000))
|
27 |
return corpus_emb.float()
|
28 |
|
29 |
@st.cache(allow_output_mutation=True)
|
data/stackoverflow-titles-distilbert-emb.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f54b58e7835fac510ef46b8ba38c58c9942d769cace977e42a3bb274344ee9f
|
3 |
+
size 3916646328
|