Tom Aarsen commited on
Commit
dd9a98e
·
1 Parent(s): bc2993c

Update training script to separate dataset loading & training

Browse files
Files changed (1) hide show
  1. train.py +143 -123
train.py CHANGED
@@ -20,6 +20,148 @@ logging.basicConfig(
20
  random.seed(12)
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def main():
24
  # 1. Load a model to finetune with 2. (Optional) model card data
25
  static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-uncased"), embedding_dim=1024)
@@ -33,129 +175,7 @@ def main():
33
  )
34
 
35
  # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
36
- print("Loading gooaq dataset...")
37
- gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
38
- gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
39
- gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"]
40
- gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"]
41
- print("Loaded gooaq dataset.")
42
-
43
- print("Loading msmarco dataset...")
44
- msmarco_dataset = load_dataset("sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", "triplet", split="train")
45
- msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
46
- msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
47
- msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
48
- print("Loaded msmarco dataset.")
49
-
50
- print("Loading squad dataset...")
51
- squad_dataset = load_dataset("sentence-transformers/squad", split="train")
52
- squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12)
53
- squad_train_dataset: Dataset = squad_dataset_dict["train"]
54
- squad_eval_dataset: Dataset = squad_dataset_dict["test"]
55
- print("Loaded squad dataset.")
56
-
57
- print("Loading s2orc dataset...")
58
- s2orc_dataset = load_dataset("sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]")
59
- s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
60
- s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
61
- s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
62
- print("Loaded s2orc dataset.")
63
-
64
- print("Loading allnli dataset...")
65
- allnli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
66
- allnli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
67
- print("Loaded allnli dataset.")
68
-
69
- print("Loading paq dataset...")
70
- paq_dataset = load_dataset("sentence-transformers/paq", split="train")
71
- paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12)
72
- paq_train_dataset: Dataset = paq_dataset_dict["train"]
73
- paq_eval_dataset: Dataset = paq_dataset_dict["test"]
74
- print("Loaded paq dataset.")
75
-
76
- print("Loading trivia_qa dataset...")
77
- trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train")
78
- trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12)
79
- trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"]
80
- trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"]
81
- print("Loaded trivia_qa dataset.")
82
-
83
- print("Loading msmarco_10m dataset...")
84
- msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
85
- msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(test_size=10_000, seed=12)
86
- msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
87
- msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
88
- print("Loaded msmarco_10m dataset.")
89
-
90
- print("Loading swim_ir dataset...")
91
- swim_ir_dataset = load_dataset("nthakur/swim-ir-monolingual", "en", split="train").select_columns(["query", "text"])
92
- swim_ir_dataset_dict = swim_ir_dataset.train_test_split(test_size=10_000, seed=12)
93
- swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
94
- swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
95
- print("Loaded swim_ir dataset.")
96
-
97
- # NOTE: 20 negatives
98
- print("Loading pubmedqa dataset...")
99
- pubmedqa_dataset = load_dataset("sentence-transformers/pubmedqa", "triplet-20", split="train")
100
- pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
101
- pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
102
- pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
103
- print("Loaded pubmedqa dataset.")
104
-
105
- # NOTE: A lot of overlap with anchor/positives
106
- print("Loading miracl dataset...")
107
- miracl_dataset = load_dataset("sentence-transformers/miracl", "en-triplet-all", split="train")
108
- miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
109
- miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
110
- miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
111
- print("Loaded miracl dataset.")
112
-
113
- # NOTE: A lot of overlap with anchor/positives
114
- print("Loading mldr dataset...")
115
- mldr_dataset = load_dataset("sentence-transformers/mldr", "en-triplet-all", split="train")
116
- mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
117
- mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
118
- mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
119
- print("Loaded mldr dataset.")
120
-
121
- # NOTE: A lot of overlap with anchor/positives
122
- print("Loading mr_tydi dataset...")
123
- mr_tydi_dataset = load_dataset("sentence-transformers/mr-tydi", "en-triplet-all", split="train")
124
- mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
125
- mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
126
- mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
127
- print("Loaded mr_tydi dataset.")
128
-
129
- train_dataset = DatasetDict({
130
- "gooaq": gooaq_train_dataset,
131
- "msmarco": msmarco_train_dataset,
132
- "squad": squad_train_dataset,
133
- "s2orc": s2orc_train_dataset,
134
- "allnli": allnli_train_dataset,
135
- "paq": paq_train_dataset,
136
- "trivia_qa": trivia_qa_train_dataset,
137
- "msmarco_10m": msmarco_10m_train_dataset,
138
- "swim_ir": swim_ir_train_dataset,
139
- "pubmedqa": pubmedqa_train_dataset,
140
- "miracl": miracl_train_dataset,
141
- "mldr": mldr_train_dataset,
142
- "mr_tydi": mr_tydi_train_dataset,
143
- })
144
- eval_dataset = {
145
- "gooaq": gooaq_eval_dataset,
146
- "msmarco": msmarco_eval_dataset,
147
- "squad": squad_eval_dataset,
148
- "s2orc": s2orc_eval_dataset,
149
- "allnli": allnli_eval_dataset,
150
- "paq": paq_eval_dataset,
151
- "trivia_qa": trivia_qa_eval_dataset,
152
- "msmarco_10m": msmarco_10m_eval_dataset,
153
- "swim_ir": swim_ir_eval_dataset,
154
- "pubmedqa": pubmedqa_eval_dataset,
155
- "miracl": miracl_eval_dataset,
156
- "mldr": mldr_eval_dataset,
157
- "mr_tydi": mr_tydi_eval_dataset,
158
- }
159
  print(train_dataset)
160
 
161
  # 4. Define a loss function
 
20
  random.seed(12)
21
 
22
 
23
+ def load_train_eval_datasets():
24
+ """
25
+ Either load the train and eval datasets from disk or load them from the datasets library & save them to disk.
26
+
27
+ Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training.
28
+ """
29
+ try:
30
+ train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
31
+ eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
32
+ return train_dataset, eval_dataset
33
+ except FileNotFoundError:
34
+ print("Loading gooaq dataset...")
35
+ gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
36
+ gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
37
+ gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"]
38
+ gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"]
39
+ print("Loaded gooaq dataset.")
40
+
41
+ print("Loading msmarco dataset...")
42
+ msmarco_dataset = load_dataset("sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", "triplet", split="train")
43
+ msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
44
+ msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
45
+ msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
46
+ print("Loaded msmarco dataset.")
47
+
48
+ print("Loading squad dataset...")
49
+ squad_dataset = load_dataset("sentence-transformers/squad", split="train")
50
+ squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12)
51
+ squad_train_dataset: Dataset = squad_dataset_dict["train"]
52
+ squad_eval_dataset: Dataset = squad_dataset_dict["test"]
53
+ print("Loaded squad dataset.")
54
+
55
+ print("Loading s2orc dataset...")
56
+ s2orc_dataset = load_dataset("sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]")
57
+ s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
58
+ s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
59
+ s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
60
+ print("Loaded s2orc dataset.")
61
+
62
+ print("Loading allnli dataset...")
63
+ allnli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
64
+ allnli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
65
+ print("Loaded allnli dataset.")
66
+
67
+ print("Loading paq dataset...")
68
+ paq_dataset = load_dataset("sentence-transformers/paq", split="train")
69
+ paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12)
70
+ paq_train_dataset: Dataset = paq_dataset_dict["train"]
71
+ paq_eval_dataset: Dataset = paq_dataset_dict["test"]
72
+ print("Loaded paq dataset.")
73
+
74
+ print("Loading trivia_qa dataset...")
75
+ trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train")
76
+ trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12)
77
+ trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"]
78
+ trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"]
79
+ print("Loaded trivia_qa dataset.")
80
+
81
+ print("Loading msmarco_10m dataset...")
82
+ msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
83
+ msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(test_size=10_000, seed=12)
84
+ msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
85
+ msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
86
+ print("Loaded msmarco_10m dataset.")
87
+
88
+ print("Loading swim_ir dataset...")
89
+ swim_ir_dataset = load_dataset("nthakur/swim-ir-monolingual", "en", split="train").select_columns(["query", "text"])
90
+ swim_ir_dataset_dict = swim_ir_dataset.train_test_split(test_size=10_000, seed=12)
91
+ swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
92
+ swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
93
+ print("Loaded swim_ir dataset.")
94
+
95
+ # NOTE: 20 negatives
96
+ print("Loading pubmedqa dataset...")
97
+ pubmedqa_dataset = load_dataset("sentence-transformers/pubmedqa", "triplet-20", split="train")
98
+ pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
99
+ pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
100
+ pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
101
+ print("Loaded pubmedqa dataset.")
102
+
103
+ # NOTE: A lot of overlap with anchor/positives
104
+ print("Loading miracl dataset...")
105
+ miracl_dataset = load_dataset("sentence-transformers/miracl", "en-triplet-all", split="train")
106
+ miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
107
+ miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
108
+ miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
109
+ print("Loaded miracl dataset.")
110
+
111
+ # NOTE: A lot of overlap with anchor/positives
112
+ print("Loading mldr dataset...")
113
+ mldr_dataset = load_dataset("sentence-transformers/mldr", "en-triplet-all", split="train")
114
+ mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
115
+ mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
116
+ mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
117
+ print("Loaded mldr dataset.")
118
+
119
+ # NOTE: A lot of overlap with anchor/positives
120
+ print("Loading mr_tydi dataset...")
121
+ mr_tydi_dataset = load_dataset("sentence-transformers/mr-tydi", "en-triplet-all", split="train")
122
+ mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
123
+ mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
124
+ mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
125
+ print("Loaded mr_tydi dataset.")
126
+
127
+ train_dataset = DatasetDict({
128
+ "gooaq": gooaq_train_dataset,
129
+ "msmarco": msmarco_train_dataset,
130
+ "squad": squad_train_dataset,
131
+ "s2orc": s2orc_train_dataset,
132
+ "allnli": allnli_train_dataset,
133
+ "paq": paq_train_dataset,
134
+ "trivia_qa": trivia_qa_train_dataset,
135
+ "msmarco_10m": msmarco_10m_train_dataset,
136
+ "swim_ir": swim_ir_train_dataset,
137
+ "pubmedqa": pubmedqa_train_dataset,
138
+ "miracl": miracl_train_dataset,
139
+ "mldr": mldr_train_dataset,
140
+ "mr_tydi": mr_tydi_train_dataset,
141
+ })
142
+ eval_dataset = DatasetDict({
143
+ "gooaq": gooaq_eval_dataset,
144
+ "msmarco": msmarco_eval_dataset,
145
+ "squad": squad_eval_dataset,
146
+ "s2orc": s2orc_eval_dataset,
147
+ "allnli": allnli_eval_dataset,
148
+ "paq": paq_eval_dataset,
149
+ "trivia_qa": trivia_qa_eval_dataset,
150
+ "msmarco_10m": msmarco_10m_eval_dataset,
151
+ "swim_ir": swim_ir_eval_dataset,
152
+ "pubmedqa": pubmedqa_eval_dataset,
153
+ "miracl": miracl_eval_dataset,
154
+ "mldr": mldr_eval_dataset,
155
+ "mr_tydi": mr_tydi_eval_dataset,
156
+ })
157
+
158
+ train_dataset.save_to_disk("datasets/train_dataset")
159
+ eval_dataset.save_to_disk("datasets/eval_dataset")
160
+
161
+ # The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk
162
+ quit()
163
+
164
+
165
  def main():
166
  # 1. Load a model to finetune with 2. (Optional) model card data
167
  static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-uncased"), embedding_dim=1024)
 
175
  )
176
 
177
  # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
178
+ train_dataset, eval_dataset = load_train_eval_datasets()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  print(train_dataset)
180
 
181
  # 4. Define a loss function