another fix for shard and train split
Browse files- src/axolotl/utils/data.py +22 -13
src/axolotl/utils/data.py
CHANGED
@@ -48,7 +48,7 @@ def load_tokenized_prepared_datasets(
|
|
48 |
(
|
49 |
str(cfg.sequence_len)
|
50 |
+ "@"
|
51 |
-
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
52 |
+ "|"
|
53 |
+ tokenizer_name
|
54 |
).encode("utf-8")
|
@@ -112,13 +112,22 @@ def load_tokenized_prepared_datasets(
|
|
112 |
raise Exception("unhandled dataset load")
|
113 |
# support for using a subset of the data
|
114 |
if d.shards:
|
|
|
115 |
ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
d_type = d.type
|
117 |
d_type_split = d_type.split(":")
|
118 |
d_base_type = d_type_split[0]
|
119 |
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
|
|
|
|
120 |
if ds_strategy := load(d.type, tokenizer, cfg):
|
121 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds
|
122 |
datasets.append(ds_wrapper)
|
123 |
elif d_base_type == "alpaca":
|
124 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
@@ -127,7 +136,7 @@ def load_tokenized_prepared_datasets(
|
|
127 |
cfg.train_on_inputs,
|
128 |
cfg.sequence_len,
|
129 |
)
|
130 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds
|
131 |
datasets.append(ds_wrapper)
|
132 |
elif d_base_type == "explainchoice":
|
133 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
@@ -136,7 +145,7 @@ def load_tokenized_prepared_datasets(
|
|
136 |
cfg.train_on_inputs,
|
137 |
cfg.sequence_len,
|
138 |
)
|
139 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds
|
140 |
datasets.append(ds_wrapper)
|
141 |
elif d_base_type == "concisechoice":
|
142 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
@@ -145,7 +154,7 @@ def load_tokenized_prepared_datasets(
|
|
145 |
cfg.train_on_inputs,
|
146 |
cfg.sequence_len,
|
147 |
)
|
148 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds
|
149 |
datasets.append(ds_wrapper)
|
150 |
elif d_base_type == "summarizetldr":
|
151 |
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
@@ -154,7 +163,7 @@ def load_tokenized_prepared_datasets(
|
|
154 |
cfg.train_on_inputs,
|
155 |
cfg.sequence_len,
|
156 |
)
|
157 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds
|
158 |
datasets.append(ds_wrapper)
|
159 |
elif d_base_type == "jeopardy":
|
160 |
ds_strategy = JeopardyPromptTokenizingStrategy(
|
@@ -163,7 +172,7 @@ def load_tokenized_prepared_datasets(
|
|
163 |
cfg.train_on_inputs,
|
164 |
cfg.sequence_len,
|
165 |
)
|
166 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds
|
167 |
datasets.append(ds_wrapper)
|
168 |
elif d_base_type == "oasst":
|
169 |
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
@@ -172,7 +181,7 @@ def load_tokenized_prepared_datasets(
|
|
172 |
cfg.train_on_inputs,
|
173 |
cfg.sequence_len,
|
174 |
)
|
175 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds
|
176 |
datasets.append(ds_wrapper)
|
177 |
elif d_base_type == "gpteacher":
|
178 |
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
@@ -181,7 +190,7 @@ def load_tokenized_prepared_datasets(
|
|
181 |
cfg.train_on_inputs,
|
182 |
cfg.sequence_len,
|
183 |
)
|
184 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds
|
185 |
datasets.append(ds_wrapper)
|
186 |
elif d_base_type == "reflection":
|
187 |
ds_strategy = AlpacaReflectionPTStrategy(
|
@@ -190,7 +199,7 @@ def load_tokenized_prepared_datasets(
|
|
190 |
cfg.train_on_inputs,
|
191 |
cfg.sequence_len,
|
192 |
)
|
193 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds
|
194 |
datasets.append(ds_wrapper)
|
195 |
elif d_base_type == "sharegpt":
|
196 |
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
@@ -199,7 +208,7 @@ def load_tokenized_prepared_datasets(
|
|
199 |
cfg.train_on_inputs,
|
200 |
cfg.sequence_len,
|
201 |
)
|
202 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds
|
203 |
datasets.append(ds_wrapper)
|
204 |
elif d_base_type == "completion":
|
205 |
ds_strategy = CompletionPromptTokenizingStrategy(
|
@@ -208,7 +217,7 @@ def load_tokenized_prepared_datasets(
|
|
208 |
cfg.train_on_inputs,
|
209 |
cfg.sequence_len,
|
210 |
)
|
211 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds
|
212 |
datasets.append(ds_wrapper)
|
213 |
else:
|
214 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
@@ -255,7 +264,7 @@ def load_prepare_datasets(
|
|
255 |
+ "@"
|
256 |
+ str(max_packed_sequence_len)
|
257 |
+ seed
|
258 |
-
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
259 |
+ "|"
|
260 |
+ tokenizer_name
|
261 |
).encode("utf-8")
|
|
|
48 |
(
|
49 |
str(cfg.sequence_len)
|
50 |
+ "@"
|
51 |
+
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
52 |
+ "|"
|
53 |
+ tokenizer_name
|
54 |
).encode("utf-8")
|
|
|
112 |
raise Exception("unhandled dataset load")
|
113 |
# support for using a subset of the data
|
114 |
if d.shards:
|
115 |
+
<<<<<<< Updated upstream
|
116 |
ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
117 |
+
=======
|
118 |
+
if "train" in ds:
|
119 |
+
ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
|
120 |
+
else:
|
121 |
+
ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
|
122 |
+
>>>>>>> Stashed changes
|
123 |
d_type = d.type
|
124 |
d_type_split = d_type.split(":")
|
125 |
d_base_type = d_type_split[0]
|
126 |
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
127 |
+
if "train" in ds:
|
128 |
+
ds = ds["train"]
|
129 |
if ds_strategy := load(d.type, tokenizer, cfg):
|
130 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
131 |
datasets.append(ds_wrapper)
|
132 |
elif d_base_type == "alpaca":
|
133 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
|
|
136 |
cfg.train_on_inputs,
|
137 |
cfg.sequence_len,
|
138 |
)
|
139 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
140 |
datasets.append(ds_wrapper)
|
141 |
elif d_base_type == "explainchoice":
|
142 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
|
|
145 |
cfg.train_on_inputs,
|
146 |
cfg.sequence_len,
|
147 |
)
|
148 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
149 |
datasets.append(ds_wrapper)
|
150 |
elif d_base_type == "concisechoice":
|
151 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
|
|
154 |
cfg.train_on_inputs,
|
155 |
cfg.sequence_len,
|
156 |
)
|
157 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
158 |
datasets.append(ds_wrapper)
|
159 |
elif d_base_type == "summarizetldr":
|
160 |
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
|
|
163 |
cfg.train_on_inputs,
|
164 |
cfg.sequence_len,
|
165 |
)
|
166 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
167 |
datasets.append(ds_wrapper)
|
168 |
elif d_base_type == "jeopardy":
|
169 |
ds_strategy = JeopardyPromptTokenizingStrategy(
|
|
|
172 |
cfg.train_on_inputs,
|
173 |
cfg.sequence_len,
|
174 |
)
|
175 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
176 |
datasets.append(ds_wrapper)
|
177 |
elif d_base_type == "oasst":
|
178 |
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
|
|
181 |
cfg.train_on_inputs,
|
182 |
cfg.sequence_len,
|
183 |
)
|
184 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
185 |
datasets.append(ds_wrapper)
|
186 |
elif d_base_type == "gpteacher":
|
187 |
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
|
|
190 |
cfg.train_on_inputs,
|
191 |
cfg.sequence_len,
|
192 |
)
|
193 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
194 |
datasets.append(ds_wrapper)
|
195 |
elif d_base_type == "reflection":
|
196 |
ds_strategy = AlpacaReflectionPTStrategy(
|
|
|
199 |
cfg.train_on_inputs,
|
200 |
cfg.sequence_len,
|
201 |
)
|
202 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
203 |
datasets.append(ds_wrapper)
|
204 |
elif d_base_type == "sharegpt":
|
205 |
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
|
|
208 |
cfg.train_on_inputs,
|
209 |
cfg.sequence_len,
|
210 |
)
|
211 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
212 |
datasets.append(ds_wrapper)
|
213 |
elif d_base_type == "completion":
|
214 |
ds_strategy = CompletionPromptTokenizingStrategy(
|
|
|
217 |
cfg.train_on_inputs,
|
218 |
cfg.sequence_len,
|
219 |
)
|
220 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
221 |
datasets.append(ds_wrapper)
|
222 |
else:
|
223 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
|
|
264 |
+ "@"
|
265 |
+ str(max_packed_sequence_len)
|
266 |
+ seed
|
267 |
+
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
268 |
+ "|"
|
269 |
+ tokenizer_name
|
270 |
).encode("utf-8")
|