ACMC commited on
Commit
bd73a7b
·
1 Parent(s): bf9e30f
Files changed (3) hide show
  1. app.py +99 -46
  2. utils.py +52 -49
  3. validation.py +40 -27
app.py CHANGED
@@ -1,33 +1,41 @@
1
  # %%
 
 
 
2
  from uuid import uuid4
3
- import gradio as gr
4
  import datasets
5
- import json
6
- import io
7
- from utils import (
8
- process_chat_file,
9
- transform_conversations_dataset_into_training_examples,
10
- )
11
- from validation import (
12
- check_format_errors,
13
- estimate_cost,
14
- get_distributions,
15
- )
16
  import matplotlib.pyplot as plt
17
 
 
 
 
18
 
19
- def convert_to_dataset(files, do_spelling_correction, progress):
 
 
 
 
20
  modified_dataset = None
21
  for file in progress.tqdm(files, desc="Processing files"):
22
  if modified_dataset is None:
23
  # First file
24
  modified_dataset = process_chat_file(
25
- file, do_spelling_correction=do_spelling_correction
 
 
 
 
26
  )
27
  else:
28
  # Concatenate the datasets
29
  this_file_dataset = process_chat_file(
30
- file, do_spelling_correction=do_spelling_correction
 
 
 
 
31
  )
32
  modified_dataset = datasets.concatenate_datasets(
33
  [modified_dataset, this_file_dataset]
@@ -43,25 +51,41 @@ def file_upload_callback(
43
  user_role,
44
  model_role,
45
  whatsapp_name,
 
 
46
  progress=gr.Progress(),
47
  ):
48
- print(f"Processing {files}")
49
- full_system_prompt = f"""You are a chatbot. Your goal is to simulate realistic, natural chat conversations as if you were me.
50
- # Task
51
  The {model_role} and the {user_role} can send multiple messages in a row, as a JSON list of strings. Your answer always needs to be JSON compliant. The strings are delimited by double quotes ("). The strings are separated by a comma (,). The list is delimited by square brackets ([, ]). Always start your answer with [", and close it with "]. Do not write anything else in your answer after "].
52
  # Information about me
53
- You should use the following information about me to answer:
54
  {system_prompt}"""
55
  # Example
56
  # [{{\"role\":\"user\",\"content\":\"[\"Hello!\",\"How are you?\"]\"}},{{\"role\":\"assistant\",\"content\":\"[\"Hi!\",\"I'm doing great.\",\"What about you?\"]\"}},{{\"role\":\"user\",\"content\":\"[\"I'm doing well.\",\"Have you been travelling?\"]\"}}]
57
  # Response:
58
  # [{{\"role\":\"assistant\",\"content\":\"[\"Yes, I've been to many places.\",\"I love travelling.\"]\"}}]"""
59
 
 
 
 
 
 
 
 
 
60
  # # Avoid using the full system prompt for now, as it is too long and increases the cost of the training
61
  # full_system_prompt = system_prompt
62
  dataset = convert_to_dataset(
63
- files=files, progress=progress, do_spelling_correction=do_spelling_correction
 
 
 
 
 
64
  )
 
 
65
  training_examples_ds = transform_conversations_dataset_into_training_examples(
66
  conversations_ds=dataset,
67
  system_prompt=full_system_prompt,
@@ -69,6 +93,7 @@ You should use the following information about me to answer:
69
  model_role=model_role,
70
  whatsapp_name=whatsapp_name,
71
  )
 
72
 
73
  # Split into training and validation datasets (80% and 20%)
74
  training_examples_ds = training_examples_ds.train_test_split(
@@ -78,9 +103,9 @@ You should use the following information about me to answer:
78
  training_examples_ds["train"],
79
  training_examples_ds["test"],
80
  )
81
- training_examples_ds = training_examples_ds#.select(
82
  # range(min(250, len(training_examples_ds)))
83
- #)
84
  validation_examples_ds = validation_examples_ds.select(
85
  range(min(200, len(validation_examples_ds)))
86
  )
@@ -124,6 +149,12 @@ You should use the following information about me to answer:
124
  file_path_validation = f"validation_examples_{uuid}.jsonl"
125
  validation_examples_ds.to_json(path_or_buf=file_path_validation, force_ascii=False)
126
 
 
 
 
 
 
 
127
  return (
128
  file_path,
129
  gr.update(visible=True),
@@ -142,7 +173,7 @@ def remove_file_and_hide_button(file_path):
142
  try:
143
  os.remove(file_path)
144
  except Exception as e:
145
- print(f"Error removing file {file_path}: {e}")
146
 
147
  return gr.update(visible=False)
148
 
@@ -190,32 +221,52 @@ with gr.Blocks(theme=theme) as demo:
190
  info="Enter your WhatsApp name as it appears in your profile. It needs to match exactly your name. If you're unsure, you can check the chat messages to see it.",
191
  )
192
 
193
- user_role = gr.Textbox(
194
- label="Role for User",
195
- info="This is a technical parameter. If you don't know what to write, just type 'user'.",
196
- value="user",
197
- )
 
 
198
 
199
- model_role = gr.Textbox(
200
- label="Role for Model",
201
- info="This is a technical parameter. If you don't know what to write, just type 'model'.",
202
- value="model",
203
- )
204
 
205
- do_spelling_correction = gr.Checkbox(
206
- label="Do Spelling Correction (English)",
207
- info="Check this box if you want to perform spelling correction on the chat messages before generating the training examples.",
208
- )
 
209
 
210
- # Allow the user to choose the validation split size
211
- validation_split = gr.Slider(
212
- minimum=0.0,
213
- maximum=0.5,
214
- value=0.2,
215
- interactive=True,
216
- label="Validation Split",
217
- info="Choose the percentage of the dataset to be used for validation. For example, if you choose 0.2, 20% of the dataset will be used for validation and 80% for training.",
218
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  submit = gr.Button(value="Submit", variant="primary")
221
 
@@ -253,6 +304,8 @@ with gr.Blocks(theme=theme) as demo:
253
  user_role,
254
  model_role,
255
  whatsapp_name,
 
 
256
  ],
257
  outputs=[
258
  output_file,
 
1
  # %%
2
+ import io
3
+ import json
4
+ import logging
5
  from uuid import uuid4
6
+
7
  import datasets
8
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
9
  import matplotlib.pyplot as plt
10
 
11
+ from utils import (process_chat_file,
12
+ transform_conversations_dataset_into_training_examples)
13
+ from validation import check_format_errors, estimate_cost, get_distributions
14
 
15
+ logger = logging.getLogger(__name__)
16
+ logger.setLevel(logging.INFO)
17
+
18
+
19
+ def convert_to_dataset(files, do_spelling_correction, progress, whatsapp_name, datetime_dayfirst, message_line_format):
20
  modified_dataset = None
21
  for file in progress.tqdm(files, desc="Processing files"):
22
  if modified_dataset is None:
23
  # First file
24
  modified_dataset = process_chat_file(
25
+ file,
26
+ do_spelling_correction=do_spelling_correction,
27
+ whatsapp_name=whatsapp_name,
28
+ datetime_dayfirst=datetime_dayfirst,
29
+ message_line_format=message_line_format,
30
  )
31
  else:
32
  # Concatenate the datasets
33
  this_file_dataset = process_chat_file(
34
+ file,
35
+ do_spelling_correction=do_spelling_correction,
36
+ whatsapp_name=whatsapp_name,
37
+ datetime_dayfirst=datetime_dayfirst,
38
+ message_line_format=message_line_format,
39
  )
40
  modified_dataset = datasets.concatenate_datasets(
41
  [modified_dataset, this_file_dataset]
 
51
  user_role,
52
  model_role,
53
  whatsapp_name,
54
+ datetime_dayfirst,
55
+ message_line_format,
56
  progress=gr.Progress(),
57
  ):
58
+ logger.info(f"Processing {files}")
59
+ full_system_prompt = f"""# Task
60
+ You are a chatbot. Your goal is to simulate realistic, natural chat conversations as if you were me.
61
  The {model_role} and the {user_role} can send multiple messages in a row, as a JSON list of strings. Your answer always needs to be JSON compliant. The strings are delimited by double quotes ("). The strings are separated by a comma (,). The list is delimited by square brackets ([, ]). Always start your answer with [", and close it with "]. Do not write anything else in your answer after "].
62
  # Information about me
 
63
  {system_prompt}"""
64
  # Example
65
  # [{{\"role\":\"user\",\"content\":\"[\"Hello!\",\"How are you?\"]\"}},{{\"role\":\"assistant\",\"content\":\"[\"Hi!\",\"I'm doing great.\",\"What about you?\"]\"}},{{\"role\":\"user\",\"content\":\"[\"I'm doing well.\",\"Have you been travelling?\"]\"}}]
66
  # Response:
67
  # [{{\"role\":\"assistant\",\"content\":\"[\"Yes, I've been to many places.\",\"I love travelling.\"]\"}}]"""
68
 
69
+ # Check if the user has not chosen any files
70
+ if not files or len(files) == 0:
71
+ raise gr.Error("Please upload at least one file.")
72
+
73
+ # Check if the user has not entered their whatsapp name
74
+ if not whatsapp_name or len(whatsapp_name) == 0:
75
+ raise gr.Error("Please enter your WhatsApp name.")
76
+
77
  # # Avoid using the full system prompt for now, as it is too long and increases the cost of the training
78
  # full_system_prompt = system_prompt
79
  dataset = convert_to_dataset(
80
+ files=files,
81
+ progress=progress,
82
+ do_spelling_correction=do_spelling_correction,
83
+ whatsapp_name=whatsapp_name,
84
+ datetime_dayfirst=datetime_dayfirst,
85
+ message_line_format=message_line_format,
86
  )
87
+ logger.info(f"Number of conversations of dataset before being transformed: {len(dataset)}")
88
+
89
  training_examples_ds = transform_conversations_dataset_into_training_examples(
90
  conversations_ds=dataset,
91
  system_prompt=full_system_prompt,
 
93
  model_role=model_role,
94
  whatsapp_name=whatsapp_name,
95
  )
96
+ logger.info(f"Number of training examples: {len(training_examples_ds)}")
97
 
98
  # Split into training and validation datasets (80% and 20%)
99
  training_examples_ds = training_examples_ds.train_test_split(
 
103
  training_examples_ds["train"],
104
  training_examples_ds["test"],
105
  )
106
+ training_examples_ds = training_examples_ds # .select(
107
  # range(min(250, len(training_examples_ds)))
108
+ # )
109
  validation_examples_ds = validation_examples_ds.select(
110
  range(min(200, len(validation_examples_ds)))
111
  )
 
149
  file_path_validation = f"validation_examples_{uuid}.jsonl"
150
  validation_examples_ds.to_json(path_or_buf=file_path_validation, force_ascii=False)
151
 
152
+ # If there's less than 50 training examples, show a warning message
153
+ if len(training_examples_ds) < 50:
154
+ gr.Warning(
155
+ "Warning: There are less than 50 training examples. The model may not perform well with such a small dataset. Consider adding more chat files to increase the number of training examples."
156
+ )
157
+
158
  return (
159
  file_path,
160
  gr.update(visible=True),
 
173
  try:
174
  os.remove(file_path)
175
  except Exception as e:
176
+ logger.info(f"Error removing file {file_path}: {e}")
177
 
178
  return gr.update(visible=False)
179
 
 
221
  info="Enter your WhatsApp name as it appears in your profile. It needs to match exactly your name. If you're unsure, you can check the chat messages to see it.",
222
  )
223
 
224
+ # Advanced parameters section, collapsed by default
225
+ with gr.Accordion(label="Advanced Parameters", open=False):
226
+ gr.Markdown(
227
+ """
228
+ These are advanced parameters that you can change if you know what you're doing. If you're unsure, you can leave them as they are.
229
+ """
230
+ )
231
 
232
+ user_role = gr.Textbox(
233
+ label="Role for User",
234
+ info="This is a technical parameter. If you don't know what to write, just type 'user'.",
235
+ value="user",
236
+ )
237
 
238
+ model_role = gr.Textbox(
239
+ label="Role for Model",
240
+ info="This is a technical parameter. Usual values are 'model' or 'assistant'.",
241
+ value="model",
242
+ )
243
 
244
+ message_line_format = gr.Textbox(
245
+ label="Message Line Format",
246
+ info="Format of each message line in the chat file, as a regular expression. The default value should work for most cases.",
247
+ value=r"\[?(?P<msg_datetime>\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P<contact_name>.+): (?P<message>.+)",
248
+ )
249
+
250
+ datetime_dayfirst = gr.Checkbox(
251
+ label="Date format: Day first",
252
+ info="Check this box if the date time format in the chat messages is in the format 'DD/MM/YYYY'. You can check your phone settings to see the date format. Otherwise, it will be assumed that the date time format is 'MM/DD/YYYY'.",
253
+ value=True,
254
+ )
255
+
256
+ do_spelling_correction = gr.Checkbox(
257
+ label="Do Spelling Correction (English)",
258
+ info="Check this box if you want to perform spelling correction on the chat messages before generating the training examples.",
259
+ )
260
+
261
+ # Allow the user to choose the validation split size
262
+ validation_split = gr.Slider(
263
+ minimum=0.0,
264
+ maximum=0.5,
265
+ value=0.2,
266
+ interactive=True,
267
+ label="Validation Split",
268
+ info="Choose the percentage of the dataset to be used for validation. For example, if you choose 0.2, 20% of the dataset will be used for validation and 80% for training.",
269
+ )
270
 
271
  submit = gr.Button(value="Submit", variant="primary")
272
 
 
304
  user_role,
305
  model_role,
306
  whatsapp_name,
307
+ datetime_dayfirst,
308
+ message_line_format,
309
  ],
310
  outputs=[
311
  output_file,
utils.py CHANGED
@@ -1,36 +1,13 @@
1
- import datasets
2
  import datetime
3
- import os
4
  import json
5
-
 
6
  import re
 
 
7
 
8
- exp = re.compile(
9
- r"(?P<month>\d+)/(?P<day>\d+)/(?P<year>\d+), (?P<hour>\d+):(?P<minute>\d+) - (?P<contact_name>.+): (?P<message>.+)"
10
- )
11
-
12
-
13
- def process_line(example):
14
- # The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
15
- try:
16
- groups = exp.match(example["text"]).groupdict()
17
- timestamp = datetime.datetime(
18
- int(groups["year"]),
19
- int(groups["month"]),
20
- int(groups["day"]),
21
- int(groups["hour"]),
22
- int(groups["minute"]),
23
- ).timestamp()
24
- return {
25
- "message": groups["message"],
26
- "contact_name": groups["contact_name"],
27
- "timestamp": timestamp,
28
- }
29
- except Exception as e:
30
- print(e)
31
- print(example["text"])
32
- raise e
33
-
34
 
35
  # %%
36
  # Now, create message groups ('conversations')
@@ -63,10 +40,11 @@ def printable_conversation(conversation):
63
  )
64
 
65
 
 
 
66
  # %%
67
  # Use spacy to spell check the messages
68
  import spacy
69
- import contextualSpellCheck
70
  from spellchecker import SpellChecker
71
 
72
  spell = SpellChecker()
@@ -78,17 +56,17 @@ def spell_check_conversation(conversation):
78
  for i, message in enumerate(conversation["conversations"]):
79
  # Use SpaCy to get the words
80
  words = spell.split_words(message["message"])
81
- print(f"Words: {words}")
82
  corrected_message = []
83
  for word in words:
84
  correction = spell.correction(word)
85
  if (correction != None) and (correction != word):
86
- print(f"Spell check: {word} -> {correction}")
87
  corrected_message.append(correction)
88
  else:
89
  corrected_message.append(word)
90
 
91
- print(f"Corrected message: {corrected_message}")
92
  joined_message = " ".join(corrected_message)
93
  conversation["conversations"][i]["message"] = joined_message
94
 
@@ -107,7 +85,7 @@ def spell_check_conversation_spacy(conversation):
107
  docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]]))
108
  for i, doc in enumerate(docs):
109
  if doc._.performed_spellCheck:
110
- print(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}")
111
  conversation["conversations"][i]["message"] = doc._.outcome_spellCheck
112
 
113
  return conversation
@@ -144,8 +122,8 @@ A: I'm fine too
144
  To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages.
145
  """
146
 
147
- from transformers import AutoTokenizer, AutoModelForNextSentencePrediction
148
  import torch
 
149
 
150
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
151
  model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased")
@@ -186,10 +164,12 @@ def swap_messages_if_needed(message1, message2):
186
  swap = logits[0, 0] - logits[1, 0] < -0.2
187
  if swap:
188
  # Swap the messages
189
- print(f"YES Swapping messages: {message1['message']} <-> {message2['message']}")
 
 
190
  return message2, message1
191
  else:
192
- # print(f"NOT swapping messages: {message1['message']} <-> {message2['message']}")
193
  return message1, message2
194
 
195
 
@@ -208,8 +188,8 @@ def swap_messages_if_needed_in_conversation(conversation):
208
  new_conversation[-1] = message1
209
  new_conversation.append(message2)
210
 
211
- # print(f"\nOriginal conversation:\n{printable_conversation(conversation)}")
212
- # print(f"\nNew conversation:\n{printable_conversation(new_conversation)}")
213
  return new_conversation
214
 
215
 
@@ -226,26 +206,38 @@ test_conversation = [
226
  "timestamp": 3,
227
  },
228
  ]
229
- # print(swap_messages_if_needed_in_conversation(test_conversation))
230
 
231
  # %%
232
  # Now, we'll train an mT5 model to generate the next message in a conversation
233
  import os
234
 
235
 
236
- # For the contact_name, rewrite everything that is not 'Aldi' to 'Other'
237
- def rewrite_contact_name(conversation):
238
- for message in conversation["conversations"]:
239
- if message["contact_name"] != "Aldi":
240
- message["contact_name"] = "Other"
241
- return conversation
242
-
243
-
244
  # %%
245
- def process_chat_file(file, do_spelling_correction, do_reordering=False):
246
  """
247
  Process a chat file and return a dataset with the conversations.
248
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  ds = (
250
  datasets.load_dataset("text", data_files=[file])["train"]
251
  .filter(
@@ -288,6 +280,13 @@ def process_chat_file(file, do_spelling_correction, do_reordering=False):
288
  else:
289
  reordered_conversations_ds = spell_checked_conversations_ds
290
 
 
 
 
 
 
 
 
291
  changed_contact_name_ds = reordered_conversations_ds.map(
292
  rewrite_contact_name
293
  ) # , num_proc=os.cpu_count() - 1)
@@ -372,6 +371,10 @@ def transform_conversations_dataset_into_training_examples(
372
  ]
373
  }
374
  )
 
 
 
 
375
  # Before returning, flatten the list of dictionaries into a dictionary of lists
376
  flattened_examples = {}
377
  for key in processed_examples[0].keys():
 
 
1
  import datetime
 
2
  import json
3
+ import logging
4
+ import os
5
  import re
6
+ import datasets
7
+ import dateutil.parser
8
 
9
+ logger = logging.getLogger(__name__)
10
+ logger.setLevel(logging.INFO)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # %%
13
  # Now, create message groups ('conversations')
 
40
  )
41
 
42
 
43
+ import contextualSpellCheck
44
+
45
  # %%
46
  # Use spacy to spell check the messages
47
  import spacy
 
48
  from spellchecker import SpellChecker
49
 
50
  spell = SpellChecker()
 
56
  for i, message in enumerate(conversation["conversations"]):
57
  # Use SpaCy to get the words
58
  words = spell.split_words(message["message"])
59
+ logger.info(f"Words: {words}")
60
  corrected_message = []
61
  for word in words:
62
  correction = spell.correction(word)
63
  if (correction != None) and (correction != word):
64
+ logger.info(f"Spell check: {word} -> {correction}")
65
  corrected_message.append(correction)
66
  else:
67
  corrected_message.append(word)
68
 
69
+ logger.info(f"Corrected message: {corrected_message}")
70
  joined_message = " ".join(corrected_message)
71
  conversation["conversations"][i]["message"] = joined_message
72
 
 
85
  docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]]))
86
  for i, doc in enumerate(docs):
87
  if doc._.performed_spellCheck:
88
+ logger.info(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}")
89
  conversation["conversations"][i]["message"] = doc._.outcome_spellCheck
90
 
91
  return conversation
 
122
  To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages.
123
  """
124
 
 
125
  import torch
126
+ from transformers import AutoModelForNextSentencePrediction, AutoTokenizer
127
 
128
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
129
  model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased")
 
164
  swap = logits[0, 0] - logits[1, 0] < -0.2
165
  if swap:
166
  # Swap the messages
167
+ logger.info(
168
+ f"Swapping messages: {message1['message']} <-> {message2['message']}"
169
+ )
170
  return message2, message1
171
  else:
172
+ # logger.info(f"NOT swapping messages: {message1['message']} <-> {message2['message']}")
173
  return message1, message2
174
 
175
 
 
188
  new_conversation[-1] = message1
189
  new_conversation.append(message2)
190
 
191
+ # logger.info(f"\nOriginal conversation:\n{printable_conversation(conversation)}")
192
+ # logger.info(f"\nNew conversation:\n{printable_conversation(new_conversation)}")
193
  return new_conversation
194
 
195
 
 
206
  "timestamp": 3,
207
  },
208
  ]
209
+ # logger.info(swap_messages_if_needed_in_conversation(test_conversation))
210
 
211
  # %%
212
  # Now, we'll train an mT5 model to generate the next message in a conversation
213
  import os
214
 
215
 
 
 
 
 
 
 
 
 
216
  # %%
217
+ def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayfirst, message_line_format, do_reordering=False):
218
  """
219
  Process a chat file and return a dataset with the conversations.
220
  """
221
+ exp = re.compile(
222
+ # r"(?P<msg_datetime>.+?) - (?P<contact_name>.+): (?P<message>.+)"
223
+ # r"\[?(?P<msg_datetime>\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P<contact_name>.+): (?P<message>.+)"
224
+ message_line_format
225
+ )
226
+
227
+ def process_line(example):
228
+ # The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
229
+ try:
230
+ groups = exp.match(example["text"]).groupdict()
231
+ timestamp = dateutil.parser.parse(groups['msg_datetime'], dayfirst=datetime_dayfirst).timestamp()
232
+ return {
233
+ "message": groups["message"],
234
+ "contact_name": groups["contact_name"],
235
+ "timestamp": timestamp,
236
+ }
237
+ except Exception as e:
238
+ logger.exception(example["text"])
239
+ raise e
240
+
241
  ds = (
242
  datasets.load_dataset("text", data_files=[file])["train"]
243
  .filter(
 
280
  else:
281
  reordered_conversations_ds = spell_checked_conversations_ds
282
 
283
+ # For the contact_name, rewrite everything that is not 'my_whatsapp_name' to 'Other'
284
+ def rewrite_contact_name(conversation):
285
+ for message in conversation["conversations"]:
286
+ if message["contact_name"] != whatsapp_name:
287
+ message["contact_name"] = "Other"
288
+ return conversation
289
+
290
  changed_contact_name_ds = reordered_conversations_ds.map(
291
  rewrite_contact_name
292
  ) # , num_proc=os.cpu_count() - 1)
 
371
  ]
372
  }
373
  )
374
+ else:
375
+ logger.warning(
376
+ f"Discarding conversation because the length is not at least {MIN_MESSAGES_THRESHOLD}: {messages}"
377
+ )
378
  # Before returning, flatten the list of dictionaries into a dictionary of lists
379
  flattened_examples = {}
380
  for key in processed_examples[0].keys():
validation.py CHANGED
@@ -1,7 +1,12 @@
1
- import numpy as np
2
  from collections import defaultdict
 
 
3
  import tiktoken
4
 
 
 
 
5
 
6
  def check_format_errors(train_dataset, user_role, model_role):
7
  """
@@ -24,7 +29,10 @@ def check_format_errors(train_dataset, user_role, model_role):
24
  if "role" not in message or "content" not in message:
25
  format_errors["message_missing_key"] += 1
26
 
27
- if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
 
 
 
28
  format_errors["message_unrecognized_key"] += 1
29
 
30
  if message.get("role", None) not in ["system", user_role, model_role]:
@@ -40,14 +48,15 @@ def check_format_errors(train_dataset, user_role, model_role):
40
  format_errors["example_missing_assistant_message"] += 1
41
 
42
  if format_errors:
43
- print("Found errors:")
44
  for k, v in format_errors.items():
45
- print(f"{k}: {v}")
46
  else:
47
- print("No errors found")
48
 
49
  return format_errors if format_errors else {}
50
 
 
51
  def get_distributions(train_dataset, user_role, model_role):
52
  """
53
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
@@ -76,7 +85,6 @@ def get_distributions(train_dataset, user_role, model_role):
76
  num_tokens += len(encoding.encode(message["content"]))
77
  return num_tokens
78
 
79
-
80
  n_missing_system = 0
81
  n_missing_user = 0
82
  n_messages = []
@@ -92,13 +100,13 @@ def get_distributions(train_dataset, user_role, model_role):
92
  n_messages.append(len(messages))
93
  convo_lens.append(num_tokens_from_messages(messages))
94
  assistant_message_lens.append(num_assistant_tokens_from_messages(messages))
95
-
96
  return {
97
  "n_missing_system": n_missing_system,
98
  "n_missing_user": n_missing_user,
99
  "n_messages": n_messages,
100
  "convo_lens": convo_lens,
101
- "assistant_message_lens": assistant_message_lens
102
  }
103
 
104
 
@@ -106,48 +114,49 @@ def check_token_counts(train_dataset, user_role, model_role):
106
  """
107
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
108
  """
109
- def print_distribution(values, name):
110
- print(f"\n#### Distribution of {name}:")
111
- print(f"min / max: {min(values)}, {max(values)}")
112
- print(f"mean / median: {np.mean(values)}, {np.median(values)}")
113
- print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")
114
-
115
 
 
 
 
 
 
116
 
117
  # Warnings and tokens counts
118
- distributions = get_distributions(train_dataset, user_role=user_role, model_role=model_role)
 
 
119
  n_missing_system = distributions["n_missing_system"]
120
  n_missing_user = distributions["n_missing_user"]
121
  n_messages = distributions["n_messages"]
122
  convo_lens = distributions["convo_lens"]
123
  assistant_message_lens = distributions["assistant_message_lens"]
124
 
125
- print("Num examples missing system message:", n_missing_system)
126
- print("Num examples missing user message:", n_missing_user)
127
  print_distribution(n_messages, "num_messages_per_example")
128
  print_distribution(convo_lens, "num_total_tokens_per_example")
129
  print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
130
  n_too_long = sum(l > 4096 for l in convo_lens)
131
- print(
132
  f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning"
133
  )
134
 
135
- return
136
 
137
 
138
  def estimate_cost(train_dataset, user_role, model_role):
139
  """
140
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
141
  """
142
- distributions = get_distributions(train_dataset, user_role=user_role, model_role=model_role)
 
 
143
  n_missing_system = distributions["n_missing_system"]
144
  n_missing_user = distributions["n_missing_user"]
145
  n_messages = distributions["n_messages"]
146
  convo_lens = distributions["convo_lens"]
147
  assistant_message_lens = distributions["assistant_message_lens"]
148
 
149
-
150
-
151
  # Pricing and default n_epochs estimate
152
  MAX_TOKENS_PER_EXAMPLE = 4096
153
 
@@ -159,10 +168,13 @@ def estimate_cost(train_dataset, user_role, model_role):
159
 
160
  n_epochs = TARGET_EPOCHS
161
  n_train_examples = len(train_dataset)
162
- if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
163
- n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
164
- elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
165
- n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)
 
 
 
166
 
167
  n_billing_tokens_in_dataset = sum(
168
  min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens
@@ -170,5 +182,6 @@ def estimate_cost(train_dataset, user_role, model_role):
170
 
171
  return {
172
  "Estimated number of tokens in dataset": n_billing_tokens_in_dataset,
173
- f"Estimated number of tokens that will be billed (assuming {n_epochs} training epochs)": n_epochs * n_billing_tokens_in_dataset
 
174
  }
 
1
+ import logging
2
  from collections import defaultdict
3
+
4
+ import numpy as np
5
  import tiktoken
6
 
7
+ logger = logging.getLogger(__name__)
8
+ logger.setLevel(logging.INFO)
9
+
10
 
11
  def check_format_errors(train_dataset, user_role, model_role):
12
  """
 
29
  if "role" not in message or "content" not in message:
30
  format_errors["message_missing_key"] += 1
31
 
32
+ if any(
33
+ k not in ("role", "content", "name", "function_call", "weight")
34
+ for k in message
35
+ ):
36
  format_errors["message_unrecognized_key"] += 1
37
 
38
  if message.get("role", None) not in ["system", user_role, model_role]:
 
48
  format_errors["example_missing_assistant_message"] += 1
49
 
50
  if format_errors:
51
+ logger.warning("Found errors:")
52
  for k, v in format_errors.items():
53
+ logger.warning(f"{k}: {v}")
54
  else:
55
+ logger.info("No errors found")
56
 
57
  return format_errors if format_errors else {}
58
 
59
+
60
  def get_distributions(train_dataset, user_role, model_role):
61
  """
62
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
 
85
  num_tokens += len(encoding.encode(message["content"]))
86
  return num_tokens
87
 
 
88
  n_missing_system = 0
89
  n_missing_user = 0
90
  n_messages = []
 
100
  n_messages.append(len(messages))
101
  convo_lens.append(num_tokens_from_messages(messages))
102
  assistant_message_lens.append(num_assistant_tokens_from_messages(messages))
103
+
104
  return {
105
  "n_missing_system": n_missing_system,
106
  "n_missing_user": n_missing_user,
107
  "n_messages": n_messages,
108
  "convo_lens": convo_lens,
109
+ "assistant_message_lens": assistant_message_lens,
110
  }
111
 
112
 
 
114
  """
115
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
116
  """
 
 
 
 
 
 
117
 
118
+ def print_distribution(values, name):
119
+ logger.info(f"\n#### Distribution of {name}:")
120
+ logger.info(f"min / max: {min(values)}, {max(values)}")
121
+ logger.info(f"mean / median: {np.mean(values)}, {np.median(values)}")
122
+ logger.info(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")
123
 
124
  # Warnings and tokens counts
125
+ distributions = get_distributions(
126
+ train_dataset, user_role=user_role, model_role=model_role
127
+ )
128
  n_missing_system = distributions["n_missing_system"]
129
  n_missing_user = distributions["n_missing_user"]
130
  n_messages = distributions["n_messages"]
131
  convo_lens = distributions["convo_lens"]
132
  assistant_message_lens = distributions["assistant_message_lens"]
133
 
134
+ logger.info("Num examples missing system message:", n_missing_system)
135
+ logger.info("Num examples missing user message:", n_missing_user)
136
  print_distribution(n_messages, "num_messages_per_example")
137
  print_distribution(convo_lens, "num_total_tokens_per_example")
138
  print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
139
  n_too_long = sum(l > 4096 for l in convo_lens)
140
+ logger.info(
141
  f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning"
142
  )
143
 
144
+ return
145
 
146
 
147
  def estimate_cost(train_dataset, user_role, model_role):
148
  """
149
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
150
  """
151
+ distributions = get_distributions(
152
+ train_dataset, user_role=user_role, model_role=model_role
153
+ )
154
  n_missing_system = distributions["n_missing_system"]
155
  n_missing_user = distributions["n_missing_user"]
156
  n_messages = distributions["n_messages"]
157
  convo_lens = distributions["convo_lens"]
158
  assistant_message_lens = distributions["assistant_message_lens"]
159
 
 
 
160
  # Pricing and default n_epochs estimate
161
  MAX_TOKENS_PER_EXAMPLE = 4096
162
 
 
168
 
169
  n_epochs = TARGET_EPOCHS
170
  n_train_examples = len(train_dataset)
171
+ try:
172
+ if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
173
+ n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
174
+ elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
175
+ n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)
176
+ except:
177
+ n_epochs = TARGET_EPOCHS
178
 
179
  n_billing_tokens_in_dataset = sum(
180
  min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens
 
182
 
183
  return {
184
  "Estimated number of tokens in dataset": n_billing_tokens_in_dataset,
185
+ f"Estimated number of tokens that will be billed (assuming {n_epochs} training epochs)": n_epochs
186
+ * n_billing_tokens_in_dataset,
187
  }