Trent commited on
Commit
73ee9f2
·
1 Parent(s): de69128

Use distilbert

Browse files
backend/config.py CHANGED
@@ -10,5 +10,5 @@ QA_MODELS_ID = dict(
10
  )
11
 
12
  SEARCH_MODELS_ID = dict(
13
- mpnet_qa='flax-sentence-embeddings/mpnet_stackexchange_v1'
14
  )
 
10
  )
11
 
12
  SEARCH_MODELS_ID = dict(
13
+ distilbert_qa = 'flax-sentence-embeddings/multi-qa_v1-distilbert-cls_dot'
14
  )
backend/inference.py CHANGED
@@ -47,7 +47,7 @@ def text_similarity(anchor: str, inputs: List[str], model_name: str, model_dict:
47
  def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
48
  # Proceeding with model
49
  print(model_name)
50
- assert model_name == "mpnet_qa"
51
  model = load_model(model_name, model_dict)
52
 
53
  # Creating embeddings
@@ -77,7 +77,7 @@ def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
77
  def text_cluster(anchor: str, n_answers: int, model_name: str, model_dict: dict):
78
  # Proceeding with model
79
  print(model_name)
80
- assert model_name == "mpnet_qa"
81
  model = load_model(model_name, model_dict)
82
 
83
  # Creating embeddings
 
47
  def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
48
  # Proceeding with model
49
  print(model_name)
50
+ assert model_name == "distilbert_qa"
51
  model = load_model(model_name, model_dict)
52
 
53
  # Creating embeddings
 
77
  def text_cluster(anchor: str, n_answers: int, model_name: str, model_dict: dict):
78
  # Proceeding with model
79
  print(model_name)
80
+ assert model_name == "distilbert_qa"
81
  model = load_model(model_name, model_dict)
82
 
83
  # Creating embeddings
backend/utils.py CHANGED
@@ -23,7 +23,7 @@ def load_model(model_name, model_dict):
23
  @st.cache(allow_output_mutation=True)
24
  def load_embeddings():
25
  # embedding pre-generated
26
- corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-mpnet-emb.csv', max_rows=10000))
27
  return corpus_emb.float()
28
 
29
  @st.cache(allow_output_mutation=True)
 
23
  @st.cache(allow_output_mutation=True)
24
  def load_embeddings():
25
  # embedding pre-generated
26
+ corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-distilbert-emb.csv', max_rows=10000))
27
  return corpus_emb.float()
28
 
29
  @st.cache(allow_output_mutation=True)
data/stackoverflow-titles-distilbert-emb.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f54b58e7835fac510ef46b8ba38c58c9942d769cace977e42a3bb274344ee9f
3
+ size 3916646328