|
from mteb import MTEB |
|
import torch |
|
import clip |
|
|
|
import numpy as np |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
MODEL, PREPROCESS = clip.load("RN50", device=DEVICE) |
|
|
|
|
|
TASK_LIST_CLASSIFICATION = [ |
|
"AmazonCounterfactualClassification", |
|
"AmazonPolarityClassification", |
|
"AmazonReviewsClassification", |
|
"Banking77Classification", |
|
"EmotionClassification", |
|
"ImdbClassification", |
|
"MassiveIntentClassification", |
|
"MassiveScenarioClassification", |
|
"MTOPDomainClassification", |
|
"MTOPIntentClassification", |
|
"ToxicConversationsClassification", |
|
"TweetSentimentExtractionClassification", |
|
] |
|
|
|
TASK_LIST_CLUSTERING = [ |
|
"ArxivClusteringP2P", |
|
"ArxivClusteringS2S", |
|
"BiorxivClusteringP2P", |
|
"BiorxivClusteringS2S", |
|
"MedrxivClusteringP2P", |
|
"MedrxivClusteringS2S", |
|
"RedditClustering", |
|
"RedditClusteringP2P", |
|
"StackExchangeClustering", |
|
"StackExchangeClusteringP2P", |
|
"TwentyNewsgroupsClustering", |
|
] |
|
|
|
TASK_LIST_PAIR_CLASSIFICATION = [ |
|
"SprintDuplicateQuestions", |
|
"TwitterSemEval2015", |
|
"TwitterURLCorpus", |
|
] |
|
|
|
TASK_LIST_RERANKING = [ |
|
"AskUbuntuDupQuestions", |
|
"MindSmallReranking", |
|
"SciDocsRR", |
|
"StackOverflowDupQuestions", |
|
] |
|
|
|
TASK_LIST_RETRIEVAL = [ |
|
"ArguAna", |
|
"ClimateFEVER", |
|
"CQADupstackAndroidRetrieval", |
|
"CQADupstackEnglishRetrieval", |
|
"CQADupstackGamingRetrieval", |
|
"CQADupstackGisRetrieval", |
|
"CQADupstackMathematicaRetrieval", |
|
"CQADupstackPhysicsRetrieval", |
|
"CQADupstackProgrammersRetrieval", |
|
"CQADupstackStatsRetrieval", |
|
"CQADupstackTexRetrieval", |
|
"CQADupstackUnixRetrieval", |
|
"CQADupstackWebmastersRetrieval", |
|
"CQADupstackWordpressRetrieval", |
|
"DBPedia", |
|
"FEVER", |
|
"FiQA2018", |
|
"HotpotQA", |
|
"MSMARCO", |
|
"NFCorpus", |
|
"NQ", |
|
"QuoraRetrieval", |
|
"SCIDOCS", |
|
"SciFact", |
|
"Touche2020", |
|
"TRECCOVID", |
|
] |
|
|
|
TASK_LIST_STS = [ |
|
"BIOSSES", |
|
"SICK-R", |
|
"STS12", |
|
"STS13", |
|
"STS14", |
|
"STS15", |
|
"STS16", |
|
"STS17", |
|
"STS22", |
|
"STSBenchmark", |
|
"SummEval", |
|
] |
|
|
|
TASK_LIST = TASK_LIST_CLASSIFICATION |
|
+ TASK_LIST_CLUSTERING |
|
+ TASK_LIST_PAIR_CLASSIFICATION |
|
+ TASK_LIST_RERANKING |
|
+ TASK_LIST_RETRIEVAL |
|
+ TASK_LIST_STS |
|
|
|
|
|
|
|
|
|
class ClipModel: |
|
""" |
|
This is an wrapper class for the clip embedding model. |
|
""" |
|
|
|
def encode(self, sentences, batch_size=1, **kwargs): |
|
"""Returns a list of embeddings for the given sentences. |
|
Args: |
|
sentences (`List[str]`): List of sentences to encode |
|
batch_size (`int`): Batch size for the encoding |
|
|
|
Returns: |
|
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences |
|
""" |
|
embeddings = [] |
|
for i in range(0, len(sentences)): |
|
batch = sentences[i] |
|
try: |
|
text = clip.tokenize(batch).to(DEVICE)[ |
|
:, :77 |
|
] |
|
|
|
with torch.no_grad(): |
|
text_features = MODEL.encode_text(text) |
|
|
|
except: |
|
print("too long token") |
|
text = clip.tokenize(batch[: (77 * 2)]).to(DEVICE)[ |
|
:, :77 |
|
] |
|
|
|
with torch.no_grad(): |
|
text_features = MODEL.encode_text(text) |
|
|
|
embeddings.append(text_features.cpu().numpy().squeeze()) |
|
|
|
return embeddings |
|
|
|
|
|
model = ClipModel() |
|
evaluation = MTEB(tasks=TASK_LIST, output_folder=f"results/clip/", task_langs=["en"]) |
|
evaluation.run(model) |
|
|