mlsa-iai-msu-lab
commited on
Commit
·
b8669e9
1
Parent(s):
cb81375
Update README.md
Browse files
README.md
CHANGED
@@ -43,14 +43,11 @@ def get_sentence_embedding(title, abstract, model, tokenizer, max_length=None):
|
|
43 |
sentence = '</s>'.join([title, abstract])
|
44 |
encoded_input = tokenizer(
|
45 |
[sentence], padding=True, truncation=True, return_tensors='pt', max_length=max_length).to(model.device)
|
46 |
-
|
47 |
# Compute token embeddings
|
48 |
with torch.no_grad():
|
49 |
model_output = model(**encoded_input)
|
50 |
-
|
51 |
# Perform pooling
|
52 |
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
53 |
-
|
54 |
# Normalize embeddings
|
55 |
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
56 |
return sentence_embeddings.cpu().detach().numpy()[0]
|
|
|
43 |
sentence = '</s>'.join([title, abstract])
|
44 |
encoded_input = tokenizer(
|
45 |
[sentence], padding=True, truncation=True, return_tensors='pt', max_length=max_length).to(model.device)
|
|
|
46 |
# Compute token embeddings
|
47 |
with torch.no_grad():
|
48 |
model_output = model(**encoded_input)
|
|
|
49 |
# Perform pooling
|
50 |
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
|
|
51 |
# Normalize embeddings
|
52 |
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
53 |
return sentence_embeddings.cpu().detach().numpy()[0]
|