winglian commited on
Commit
2e56203
·
1 Parent(s): be3d396

another fix for shard and train split

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/data.py +22 -13
src/axolotl/utils/data.py CHANGED
@@ -48,7 +48,7 @@ def load_tokenized_prepared_datasets(
48
  (
49
  str(cfg.sequence_len)
50
  + "@"
51
- + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
52
  + "|"
53
  + tokenizer_name
54
  ).encode("utf-8")
@@ -112,13 +112,22 @@ def load_tokenized_prepared_datasets(
112
  raise Exception("unhandled dataset load")
113
  # support for using a subset of the data
114
  if d.shards:
 
115
  ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
 
 
 
 
 
 
116
  d_type = d.type
117
  d_type_split = d_type.split(":")
118
  d_base_type = d_type_split[0]
119
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
 
 
120
  if ds_strategy := load(d.type, tokenizer, cfg):
121
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
122
  datasets.append(ds_wrapper)
123
  elif d_base_type == "alpaca":
124
  ds_strategy = AlpacaPromptTokenizingStrategy(
@@ -127,7 +136,7 @@ def load_tokenized_prepared_datasets(
127
  cfg.train_on_inputs,
128
  cfg.sequence_len,
129
  )
130
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
131
  datasets.append(ds_wrapper)
132
  elif d_base_type == "explainchoice":
133
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
@@ -136,7 +145,7 @@ def load_tokenized_prepared_datasets(
136
  cfg.train_on_inputs,
137
  cfg.sequence_len,
138
  )
139
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
140
  datasets.append(ds_wrapper)
141
  elif d_base_type == "concisechoice":
142
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
@@ -145,7 +154,7 @@ def load_tokenized_prepared_datasets(
145
  cfg.train_on_inputs,
146
  cfg.sequence_len,
147
  )
148
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
149
  datasets.append(ds_wrapper)
150
  elif d_base_type == "summarizetldr":
151
  ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
@@ -154,7 +163,7 @@ def load_tokenized_prepared_datasets(
154
  cfg.train_on_inputs,
155
  cfg.sequence_len,
156
  )
157
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
158
  datasets.append(ds_wrapper)
159
  elif d_base_type == "jeopardy":
160
  ds_strategy = JeopardyPromptTokenizingStrategy(
@@ -163,7 +172,7 @@ def load_tokenized_prepared_datasets(
163
  cfg.train_on_inputs,
164
  cfg.sequence_len,
165
  )
166
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
167
  datasets.append(ds_wrapper)
168
  elif d_base_type == "oasst":
169
  ds_strategy = OpenAssistantPromptTokenizingStrategy(
@@ -172,7 +181,7 @@ def load_tokenized_prepared_datasets(
172
  cfg.train_on_inputs,
173
  cfg.sequence_len,
174
  )
175
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
176
  datasets.append(ds_wrapper)
177
  elif d_base_type == "gpteacher":
178
  ds_strategy = GPTeacherPromptTokenizingStrategy(
@@ -181,7 +190,7 @@ def load_tokenized_prepared_datasets(
181
  cfg.train_on_inputs,
182
  cfg.sequence_len,
183
  )
184
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
185
  datasets.append(ds_wrapper)
186
  elif d_base_type == "reflection":
187
  ds_strategy = AlpacaReflectionPTStrategy(
@@ -190,7 +199,7 @@ def load_tokenized_prepared_datasets(
190
  cfg.train_on_inputs,
191
  cfg.sequence_len,
192
  )
193
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
194
  datasets.append(ds_wrapper)
195
  elif d_base_type == "sharegpt":
196
  ds_strategy = ShareGPTPromptTokenizingStrategy(
@@ -199,7 +208,7 @@ def load_tokenized_prepared_datasets(
199
  cfg.train_on_inputs,
200
  cfg.sequence_len,
201
  )
202
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
203
  datasets.append(ds_wrapper)
204
  elif d_base_type == "completion":
205
  ds_strategy = CompletionPromptTokenizingStrategy(
@@ -208,7 +217,7 @@ def load_tokenized_prepared_datasets(
208
  cfg.train_on_inputs,
209
  cfg.sequence_len,
210
  )
211
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
212
  datasets.append(ds_wrapper)
213
  else:
214
  logging.error(f"unhandled prompt tokenization strategy: {d.type}")
@@ -255,7 +264,7 @@ def load_prepare_datasets(
255
  + "@"
256
  + str(max_packed_sequence_len)
257
  + seed
258
- + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
259
  + "|"
260
  + tokenizer_name
261
  ).encode("utf-8")
 
48
  (
49
  str(cfg.sequence_len)
50
  + "@"
51
+ + "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
52
  + "|"
53
  + tokenizer_name
54
  ).encode("utf-8")
 
112
  raise Exception("unhandled dataset load")
113
  # support for using a subset of the data
114
  if d.shards:
115
+ <<<<<<< Updated upstream
116
  ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
117
+ =======
118
+ if "train" in ds:
119
+ ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
120
+ else:
121
+ ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
122
+ >>>>>>> Stashed changes
123
  d_type = d.type
124
  d_type_split = d_type.split(":")
125
  d_base_type = d_type_split[0]
126
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
127
+ if "train" in ds:
128
+ ds = ds["train"]
129
  if ds_strategy := load(d.type, tokenizer, cfg):
130
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
131
  datasets.append(ds_wrapper)
132
  elif d_base_type == "alpaca":
133
  ds_strategy = AlpacaPromptTokenizingStrategy(
 
136
  cfg.train_on_inputs,
137
  cfg.sequence_len,
138
  )
139
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
140
  datasets.append(ds_wrapper)
141
  elif d_base_type == "explainchoice":
142
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
 
145
  cfg.train_on_inputs,
146
  cfg.sequence_len,
147
  )
148
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
149
  datasets.append(ds_wrapper)
150
  elif d_base_type == "concisechoice":
151
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
 
154
  cfg.train_on_inputs,
155
  cfg.sequence_len,
156
  )
157
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
158
  datasets.append(ds_wrapper)
159
  elif d_base_type == "summarizetldr":
160
  ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
 
163
  cfg.train_on_inputs,
164
  cfg.sequence_len,
165
  )
166
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
167
  datasets.append(ds_wrapper)
168
  elif d_base_type == "jeopardy":
169
  ds_strategy = JeopardyPromptTokenizingStrategy(
 
172
  cfg.train_on_inputs,
173
  cfg.sequence_len,
174
  )
175
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
176
  datasets.append(ds_wrapper)
177
  elif d_base_type == "oasst":
178
  ds_strategy = OpenAssistantPromptTokenizingStrategy(
 
181
  cfg.train_on_inputs,
182
  cfg.sequence_len,
183
  )
184
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
185
  datasets.append(ds_wrapper)
186
  elif d_base_type == "gpteacher":
187
  ds_strategy = GPTeacherPromptTokenizingStrategy(
 
190
  cfg.train_on_inputs,
191
  cfg.sequence_len,
192
  )
193
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
194
  datasets.append(ds_wrapper)
195
  elif d_base_type == "reflection":
196
  ds_strategy = AlpacaReflectionPTStrategy(
 
199
  cfg.train_on_inputs,
200
  cfg.sequence_len,
201
  )
202
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
203
  datasets.append(ds_wrapper)
204
  elif d_base_type == "sharegpt":
205
  ds_strategy = ShareGPTPromptTokenizingStrategy(
 
208
  cfg.train_on_inputs,
209
  cfg.sequence_len,
210
  )
211
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
212
  datasets.append(ds_wrapper)
213
  elif d_base_type == "completion":
214
  ds_strategy = CompletionPromptTokenizingStrategy(
 
217
  cfg.train_on_inputs,
218
  cfg.sequence_len,
219
  )
220
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
221
  datasets.append(ds_wrapper)
222
  else:
223
  logging.error(f"unhandled prompt tokenization strategy: {d.type}")
 
264
  + "@"
265
  + str(max_packed_sequence_len)
266
  + seed
267
+ + "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
268
  + "|"
269
  + tokenizer_name
270
  ).encode("utf-8")