Nanobit commited on
Commit
01c8a33
·
1 Parent(s): 7eb33a7

Lint pygmalion

Browse files
src/axolotl/prompt_strategies/pygmalion.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import copy
2
  import logging
3
  from collections import defaultdict
@@ -9,10 +11,14 @@ IGNORE_TOKEN_ID = -100
9
 
10
 
11
  class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
 
 
 
 
12
  bot_prefix_token_ids = []
13
 
14
  def __init__(self, prompter, tokenizer, *args, **kwargs):
15
- super().__init__(prompter, tokenizer)
16
  res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
17
  self.bot_prefix_token_ids = res["input_ids"]
18
 
@@ -23,7 +29,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
23
  "labels": [],
24
  }
25
  current_len = 0
26
- for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
27
  role, message = part
28
  if role == "system":
29
  prefix = "<|system|>"
@@ -96,10 +102,16 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
96
 
97
 
98
  class PygmalionPrompter:
 
 
 
 
99
  def __init__(self, *args, **kwargs):
100
  pass
101
 
102
- def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
 
 
103
  for msg in source:
104
  yield msg["role"], msg["value"]
105
 
 
1
+ """Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
2
+
3
  import copy
4
  import logging
5
  from collections import defaultdict
 
11
 
12
 
13
  class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
14
+ """
15
+ Tokenizing strategy for Pygmalion.
16
+ """
17
+
18
  bot_prefix_token_ids = []
19
 
20
  def __init__(self, prompter, tokenizer, *args, **kwargs):
21
+ super().__init__(prompter, tokenizer, *args, **kwargs)
22
  res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
23
  self.bot_prefix_token_ids = res["input_ids"]
24
 
 
29
  "labels": [],
30
  }
31
  current_len = 0
32
+ for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
33
  role, message = part
34
  if role == "system":
35
  prefix = "<|system|>"
 
102
 
103
 
104
  class PygmalionPrompter:
105
+ """
106
+ Prompter for Pygmalion.
107
+ """
108
+
109
  def __init__(self, *args, **kwargs):
110
  pass
111
 
112
+ def build_prompt(
113
+ self, source, *args, **kwargs # pylint: disable=unused-argument
114
+ ) -> Generator[str, None, None]:
115
  for msg in source:
116
  yield msg["role"], msg["value"]
117