import random import logging from datasets import load_dataset, Dataset, DatasetDict from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, SentenceTransformerModelCardData, ) from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers from sentence_transformers.evaluation import NanoBEIREvaluator from sentence_transformers.models.StaticEmbedding import StaticEmbedding from transformers import AutoTokenizer logging.basicConfig( format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO ) random.seed(12) def load_train_eval_datasets(): """ Either load the train and eval datasets from disk or load them from the datasets library & save them to disk. Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training. """ try: train_dataset = DatasetDict.load_from_disk("datasets/train_dataset") eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset") return train_dataset, eval_dataset except FileNotFoundError: print("Loading gooaq dataset...") gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train") gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12) gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"] gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"] print("Loaded gooaq dataset.") print("Loading msmarco dataset...") msmarco_dataset = load_dataset("sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", "triplet", split="train") msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12) msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"] msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"] print("Loaded msmarco dataset.") print("Loading squad dataset...") squad_dataset = load_dataset("sentence-transformers/squad", split="train") squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12) squad_train_dataset: Dataset = squad_dataset_dict["train"] squad_eval_dataset: Dataset = squad_dataset_dict["test"] print("Loaded squad dataset.") print("Loading s2orc dataset...") s2orc_dataset = load_dataset("sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]") s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12) s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"] s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"] print("Loaded s2orc dataset.") print("Loading allnli dataset...") allnli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train") allnli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev") print("Loaded allnli dataset.") print("Loading paq dataset...") paq_dataset = load_dataset("sentence-transformers/paq", split="train") paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12) paq_train_dataset: Dataset = paq_dataset_dict["train"] paq_eval_dataset: Dataset = paq_dataset_dict["test"] print("Loaded paq dataset.") print("Loading trivia_qa dataset...") trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train") trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12) trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"] trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"] print("Loaded trivia_qa dataset.") print("Loading msmarco_10m dataset...") msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train") msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(test_size=10_000, seed=12) msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"] msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"] print("Loaded msmarco_10m dataset.") print("Loading swim_ir dataset...") swim_ir_dataset = load_dataset("nthakur/swim-ir-monolingual", "en", split="train").select_columns(["query", "text"]) swim_ir_dataset_dict = swim_ir_dataset.train_test_split(test_size=10_000, seed=12) swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"] swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"] print("Loaded swim_ir dataset.") # NOTE: 20 negatives print("Loading pubmedqa dataset...") pubmedqa_dataset = load_dataset("sentence-transformers/pubmedqa", "triplet-20", split="train") pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12) pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"] pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"] print("Loaded pubmedqa dataset.") # NOTE: A lot of overlap with anchor/positives print("Loading miracl dataset...") miracl_dataset = load_dataset("sentence-transformers/miracl", "en-triplet-all", split="train") miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12) miracl_train_dataset: Dataset = miracl_dataset_dict["train"] miracl_eval_dataset: Dataset = miracl_dataset_dict["test"] print("Loaded miracl dataset.") # NOTE: A lot of overlap with anchor/positives print("Loading mldr dataset...") mldr_dataset = load_dataset("sentence-transformers/mldr", "en-triplet-all", split="train") mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12) mldr_train_dataset: Dataset = mldr_dataset_dict["train"] mldr_eval_dataset: Dataset = mldr_dataset_dict["test"] print("Loaded mldr dataset.") # NOTE: A lot of overlap with anchor/positives print("Loading mr_tydi dataset...") mr_tydi_dataset = load_dataset("sentence-transformers/mr-tydi", "en-triplet-all", split="train") mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12) mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"] mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"] print("Loaded mr_tydi dataset.") train_dataset = DatasetDict({ "gooaq": gooaq_train_dataset, "msmarco": msmarco_train_dataset, "squad": squad_train_dataset, "s2orc": s2orc_train_dataset, "allnli": allnli_train_dataset, "paq": paq_train_dataset, "trivia_qa": trivia_qa_train_dataset, "msmarco_10m": msmarco_10m_train_dataset, "swim_ir": swim_ir_train_dataset, "pubmedqa": pubmedqa_train_dataset, "miracl": miracl_train_dataset, "mldr": mldr_train_dataset, "mr_tydi": mr_tydi_train_dataset, }) eval_dataset = DatasetDict({ "gooaq": gooaq_eval_dataset, "msmarco": msmarco_eval_dataset, "squad": squad_eval_dataset, "s2orc": s2orc_eval_dataset, "allnli": allnli_eval_dataset, "paq": paq_eval_dataset, "trivia_qa": trivia_qa_eval_dataset, "msmarco_10m": msmarco_10m_eval_dataset, "swim_ir": swim_ir_eval_dataset, "pubmedqa": pubmedqa_eval_dataset, "miracl": miracl_eval_dataset, "mldr": mldr_eval_dataset, "mr_tydi": mr_tydi_eval_dataset, }) train_dataset.save_to_disk("datasets/train_dataset") eval_dataset.save_to_disk("datasets/eval_dataset") # The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk # So we're calling quit() here. Running the script again will load the datasets from disk. quit() def main(): # 1. Load a model to finetune with 2. (Optional) model card data static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-uncased"), embedding_dim=1024) model = SentenceTransformer( modules=[static_embedding], model_card_data=SentenceTransformerModelCardData( language="en", license="apache-2.0", model_name="Static Embeddings with BERT uncased tokenizer finetuned on various datasets", ), ) # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL) train_dataset, eval_dataset = load_train_eval_datasets() print(train_dataset) # 4. Define a loss function loss = MultipleNegativesRankingLoss(model) loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024]) # 5. (Optional) Specify training arguments run_name = "static-retrieval-mrl-en-v1" args = SentenceTransformerTrainingArguments( # Required parameter: output_dir=f"models/{run_name}", # Optional training parameters: num_train_epochs=1, per_device_train_batch_size=2048, per_device_eval_batch_size=2048, learning_rate=2e-1, warmup_ratio=0.1, fp16=False, # Set to False if you get an error that your GPU can't run on FP16 bf16=True, # Set to True if you have a GPU that supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, # Optional tracking/debugging parameters: eval_strategy="steps", eval_steps=250, save_strategy="steps", save_steps=250, save_total_limit=2, logging_steps=250, logging_first_step=True, run_name=run_name, # Will be used in W&B if `wandb` is installed ) # 6. (Optional) Create an evaluator & evaluate the base model evaluator = NanoBEIREvaluator() evaluator(model) # 7. Create a trainer & train trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, loss=loss, evaluator=evaluator, ) trainer.train() # (Optional) Evaluate the trained model on the evaluator after training evaluator(model) # 8. Save the trained model model.save_pretrained(f"models/{run_name}/final") # 9. (Optional) Push it to the Hugging Face Hub model.push_to_hub(run_name, private=True) if __name__ == "__main__": main()