Lint and format
Browse files- .gitignore +1 -1
- docker/Dockerfile-base +0 -1
- examples/falcon/config-7b-lora.yml +0 -1
- examples/falcon/config-7b.yml +0 -1
- scripts/alpaca_json_to_jsonl.py +21 -5
- scripts/finetune.py +19 -15
- src/axolotl/datasets.py +4 -7
- src/axolotl/utils/data.py +31 -23
- tests/test_prompters.py +6 -4
.gitignore
CHANGED
@@ -160,4 +160,4 @@ cython_debug/
|
|
160 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
-
.idea/
|
|
|
160 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
+
.idea/
|
docker/Dockerfile-base
CHANGED
@@ -99,4 +99,3 @@ RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
|
|
99 |
pip3 install awscli && \
|
100 |
# The base image ships with `pydantic==1.8.2` which is not working
|
101 |
pip3 install -U --no-cache-dir pydantic
|
102 |
-
|
|
|
99 |
pip3 install awscli && \
|
100 |
# The base image ships with `pydantic==1.8.2` which is not working
|
101 |
pip3 install -U --no-cache-dir pydantic
|
|
examples/falcon/config-7b-lora.yml
CHANGED
@@ -61,4 +61,3 @@ special_tokens:
|
|
61 |
pad_token: "<|endoftext|>"
|
62 |
bos_token: ">>ABSTRACT<<"
|
63 |
eos_token: "<|endoftext|>"
|
64 |
-
|
|
|
61 |
pad_token: "<|endoftext|>"
|
62 |
bos_token: ">>ABSTRACT<<"
|
63 |
eos_token: "<|endoftext|>"
|
|
examples/falcon/config-7b.yml
CHANGED
@@ -61,4 +61,3 @@ special_tokens:
|
|
61 |
pad_token: "<|endoftext|>"
|
62 |
bos_token: ">>ABSTRACT<<"
|
63 |
eos_token: "<|endoftext|>"
|
64 |
-
|
|
|
61 |
pad_token: "<|endoftext|>"
|
62 |
bos_token: ">>ABSTRACT<<"
|
63 |
eos_token: "<|endoftext|>"
|
|
scripts/alpaca_json_to_jsonl.py
CHANGED
@@ -1,23 +1,39 @@
|
|
|
|
|
|
1 |
import os
|
2 |
import sys
|
|
|
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
import fire
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# add src to the pythonpath so we don't need to pip install this
|
9 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
10 |
src_dir = os.path.join(project_root, "src")
|
11 |
sys.path.insert(0, src_dir)
|
12 |
|
13 |
-
from axolotl.convert import *
|
14 |
-
|
15 |
|
16 |
def main(
|
17 |
-
|
18 |
output: Optional[Path] = None,
|
19 |
to_stdout: Optional[bool] = False,
|
20 |
):
|
|
|
|
|
|
|
|
|
21 |
file_reader = FileReader()
|
22 |
if to_stdout or output is None:
|
23 |
writer = StdoutWriter()
|
@@ -28,7 +44,7 @@ def main(
|
|
28 |
|
29 |
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
30 |
|
31 |
-
converter.convert(
|
32 |
|
33 |
|
34 |
if __name__ == "__main__":
|
|
|
1 |
+
"""Module to convert json file to jsonl"""
|
2 |
+
|
3 |
import os
|
4 |
import sys
|
5 |
+
|
6 |
+
from typing import Optional
|
7 |
from pathlib import Path
|
8 |
|
9 |
import fire
|
10 |
+
|
11 |
+
|
12 |
+
from axolotl.convert import (
|
13 |
+
FileReader,
|
14 |
+
StdoutWriter,
|
15 |
+
FileWriter,
|
16 |
+
JsonlSerializer,
|
17 |
+
JsonParser,
|
18 |
+
JsonToJsonlConverter,
|
19 |
+
)
|
20 |
+
|
21 |
|
22 |
# add src to the pythonpath so we don't need to pip install this
|
23 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
24 |
src_dir = os.path.join(project_root, "src")
|
25 |
sys.path.insert(0, src_dir)
|
26 |
|
|
|
|
|
27 |
|
28 |
def main(
|
29 |
+
file: Path,
|
30 |
output: Optional[Path] = None,
|
31 |
to_stdout: Optional[bool] = False,
|
32 |
):
|
33 |
+
"""
|
34 |
+
Convert a json file to jsonl
|
35 |
+
"""
|
36 |
+
|
37 |
file_reader = FileReader()
|
38 |
if to_stdout or output is None:
|
39 |
writer = StdoutWriter()
|
|
|
44 |
|
45 |
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
46 |
|
47 |
+
converter.convert(file, output)
|
48 |
|
49 |
|
50 |
if __name__ == "__main__":
|
scripts/finetune.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import importlib
|
2 |
import logging
|
3 |
import os
|
@@ -16,15 +18,16 @@ from axolotl.utils.tokenization import check_dataset_labels
|
|
16 |
from axolotl.utils.validation import validate_config
|
17 |
from axolotl.utils.dict import DictDefault
|
18 |
|
19 |
-
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
20 |
-
src_dir = os.path.join(project_root, "src")
|
21 |
-
sys.path.insert(0, src_dir)
|
22 |
-
|
23 |
from axolotl.utils.data import load_prepare_datasets
|
24 |
from axolotl.utils.models import load_model, load_tokenizer
|
25 |
from axolotl.utils.trainer import setup_trainer
|
26 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
29 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
30 |
|
@@ -37,7 +40,7 @@ def choose_device(cfg):
|
|
37 |
try:
|
38 |
if torch.backends.mps.is_available():
|
39 |
return "mps"
|
40 |
-
except:
|
41 |
return "cpu"
|
42 |
|
43 |
cfg.device = get_device()
|
@@ -73,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
73 |
|
74 |
model.eval()
|
75 |
with torch.no_grad():
|
76 |
-
# gc = GenerationConfig() # TODO swap out and use this
|
77 |
generated = model.generate(
|
78 |
inputs=batch["input_ids"].to(cfg.device),
|
79 |
do_sample=True,
|
@@ -130,12 +133,12 @@ def train(
|
|
130 |
config = choose_config(config)
|
131 |
|
132 |
# load the config from the yaml file
|
133 |
-
with open(config, "
|
134 |
-
cfg: DictDefault = DictDefault(yaml.load(
|
135 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
136 |
# then overwrite the value
|
137 |
cfg_keys = cfg.keys()
|
138 |
-
for k in kwargs:
|
139 |
# if not strict, allow writing to cfg even if it's not in the yml already
|
140 |
if k in cfg_keys or cfg.strict is False:
|
141 |
# handle booleans
|
@@ -167,13 +170,11 @@ def train(
|
|
167 |
|
168 |
# load the tokenizer first
|
169 |
logging.info("loading tokenizer...")
|
170 |
-
tokenizer = load_tokenizer(
|
171 |
-
cfg.base_model_config,
|
172 |
-
cfg.tokenizer_type,
|
173 |
-
cfg
|
174 |
-
)
|
175 |
|
176 |
-
if check_not_in(
|
|
|
|
|
177 |
train_dataset, eval_dataset = load_prepare_datasets(
|
178 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
179 |
)
|
@@ -262,10 +263,13 @@ def train(
|
|
262 |
|
263 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
264 |
|
|
|
265 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
266 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
267 |
if cfg.local_rank == 0:
|
268 |
model.save_pretrained(cfg.output_dir)
|
|
|
|
|
269 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
270 |
|
271 |
|
|
|
1 |
+
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
2 |
+
|
3 |
import importlib
|
4 |
import logging
|
5 |
import os
|
|
|
18 |
from axolotl.utils.validation import validate_config
|
19 |
from axolotl.utils.dict import DictDefault
|
20 |
|
|
|
|
|
|
|
|
|
21 |
from axolotl.utils.data import load_prepare_datasets
|
22 |
from axolotl.utils.models import load_model, load_tokenizer
|
23 |
from axolotl.utils.trainer import setup_trainer
|
24 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
25 |
|
26 |
+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
27 |
+
src_dir = os.path.join(project_root, "src")
|
28 |
+
sys.path.insert(0, src_dir)
|
29 |
+
|
30 |
+
|
31 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
32 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
33 |
|
|
|
40 |
try:
|
41 |
if torch.backends.mps.is_available():
|
42 |
return "mps"
|
43 |
+
except Exception: # pylint: disable=broad-exception-caught
|
44 |
return "cpu"
|
45 |
|
46 |
cfg.device = get_device()
|
|
|
76 |
|
77 |
model.eval()
|
78 |
with torch.no_grad():
|
79 |
+
# gc = GenerationConfig() # TODO swap out and use this # pylint: disable=fixme
|
80 |
generated = model.generate(
|
81 |
inputs=batch["input_ids"].to(cfg.device),
|
82 |
do_sample=True,
|
|
|
133 |
config = choose_config(config)
|
134 |
|
135 |
# load the config from the yaml file
|
136 |
+
with open(config, encoding="utf-8") as file:
|
137 |
+
cfg: DictDefault = DictDefault(yaml.load(file, Loader=yaml.Loader))
|
138 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
139 |
# then overwrite the value
|
140 |
cfg_keys = cfg.keys()
|
141 |
+
for k, _ in kwargs.items():
|
142 |
# if not strict, allow writing to cfg even if it's not in the yml already
|
143 |
if k in cfg_keys or cfg.strict is False:
|
144 |
# handle booleans
|
|
|
170 |
|
171 |
# load the tokenizer first
|
172 |
logging.info("loading tokenizer...")
|
173 |
+
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
if check_not_in(
|
176 |
+
["inference", "shard", "merge_lora"], kwargs
|
177 |
+
): # don't need to load dataset for these
|
178 |
train_dataset, eval_dataset = load_prepare_datasets(
|
179 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
180 |
)
|
|
|
263 |
|
264 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
265 |
|
266 |
+
# pylint: disable=fixme
|
267 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
268 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
269 |
if cfg.local_rank == 0:
|
270 |
model.save_pretrained(cfg.output_dir)
|
271 |
+
|
272 |
+
# pylint: disable=fixme
|
273 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
274 |
|
275 |
|
src/axolotl/datasets.py
CHANGED
@@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset):
|
|
82 |
else:
|
83 |
example_len = 0
|
84 |
|
85 |
-
if (
|
86 |
-
|
87 |
-
or buffer_len + int(add_concat_token) + example_len
|
88 |
-
> self.seq_length
|
89 |
):
|
90 |
if buffer["input_ids"]:
|
91 |
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
@@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset):
|
|
95 |
: self.seq_length
|
96 |
]
|
97 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
98 |
-
if (
|
99 |
-
|
100 |
-
and attention_mask.size() == input_ids.size()
|
101 |
):
|
102 |
yield {
|
103 |
"input_ids": input_ids,
|
|
|
82 |
else:
|
83 |
example_len = 0
|
84 |
|
85 |
+
if not example_len or (
|
86 |
+
buffer_len + int(add_concat_token) + example_len > self.seq_length
|
|
|
|
|
87 |
):
|
88 |
if buffer["input_ids"]:
|
89 |
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
|
|
93 |
: self.seq_length
|
94 |
]
|
95 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
96 |
+
if labels.size() == input_ids.size() and (
|
97 |
+
attention_mask.size() == input_ids.size()
|
|
|
98 |
):
|
99 |
yield {
|
100 |
"input_ids": input_ids,
|
src/axolotl/utils/data.py
CHANGED
@@ -1,14 +1,12 @@
|
|
1 |
import logging
|
2 |
from hashlib import md5
|
3 |
from pathlib import Path
|
4 |
-
from typing import Union
|
5 |
|
6 |
from datasets import (
|
7 |
load_from_disk,
|
8 |
load_dataset,
|
9 |
-
IterableDataset,
|
10 |
Dataset,
|
11 |
-
concatenate_datasets,
|
12 |
DatasetDict,
|
13 |
)
|
14 |
from huggingface_hub import hf_hub_download
|
@@ -48,10 +46,12 @@ def load_tokenized_prepared_datasets(
|
|
48 |
md5(
|
49 |
(
|
50 |
str(cfg.sequence_len)
|
51 |
-
+ "@"
|
52 |
-
+ "|".join(
|
53 |
-
|
54 |
-
|
|
|
|
|
55 |
).encode("utf-8")
|
56 |
).hexdigest()
|
57 |
)
|
@@ -68,7 +68,7 @@ def load_tokenized_prepared_datasets(
|
|
68 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
69 |
)
|
70 |
dataset = dataset["train"]
|
71 |
-
except:
|
72 |
pass
|
73 |
|
74 |
if dataset:
|
@@ -109,15 +109,21 @@ def load_tokenized_prepared_datasets(
|
|
109 |
fp = hf_hub_download(
|
110 |
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
111 |
)
|
112 |
-
ds: Dataset = load_dataset(
|
|
|
|
|
113 |
if not ds:
|
114 |
-
raise
|
115 |
# support for using a subset of the data
|
116 |
if d.shards:
|
117 |
if "train" in ds:
|
118 |
-
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(
|
|
|
|
|
119 |
else:
|
120 |
-
ds: Dataset = ds.shuffle(seed=42).shard(
|
|
|
|
|
121 |
d_type = d.type
|
122 |
d_type_split = d_type.split(":")
|
123 |
d_base_type = d_type_split[0]
|
@@ -243,7 +249,7 @@ def load_tokenized_prepared_datasets(
|
|
243 |
|
244 |
def load_prepare_datasets(
|
245 |
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
246 |
-
) ->
|
247 |
max_packed_sequence_len = (
|
248 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
249 |
)
|
@@ -259,12 +265,14 @@ def load_prepare_datasets(
|
|
259 |
md5(
|
260 |
(
|
261 |
str(cfg.sequence_len)
|
262 |
-
+ "@"
|
263 |
-
+ str(max_packed_sequence_len)
|
264 |
-
+ seed
|
265 |
-
+ "|".join(
|
266 |
-
|
267 |
-
|
|
|
|
|
268 |
).encode("utf-8")
|
269 |
).hexdigest()
|
270 |
)
|
@@ -285,7 +293,7 @@ def load_prepare_datasets(
|
|
285 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
286 |
)
|
287 |
dataset = dataset["train"]
|
288 |
-
except:
|
289 |
pass
|
290 |
|
291 |
if dataset:
|
@@ -327,9 +335,9 @@ def load_prepare_datasets(
|
|
327 |
d
|
328 |
for d in dataset
|
329 |
if len(d["input_ids"]) < cfg.sequence_len
|
330 |
-
and len(d["input_ids"]) > 0
|
331 |
-
and len(d["input_ids"]) == len(d["attention_mask"])
|
332 |
-
and len(d["input_ids"]) == len(d["labels"])
|
333 |
]
|
334 |
)
|
335 |
|
|
|
1 |
import logging
|
2 |
from hashlib import md5
|
3 |
from pathlib import Path
|
4 |
+
from typing import Tuple, Union
|
5 |
|
6 |
from datasets import (
|
7 |
load_from_disk,
|
8 |
load_dataset,
|
|
|
9 |
Dataset,
|
|
|
10 |
DatasetDict,
|
11 |
)
|
12 |
from huggingface_hub import hf_hub_download
|
|
|
46 |
md5(
|
47 |
(
|
48 |
str(cfg.sequence_len)
|
49 |
+
+ "@" # noqa: W503
|
50 |
+
+ "|".join( # noqa: W503
|
51 |
+
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
52 |
+
)
|
53 |
+
+ "|" # noqa: W503
|
54 |
+
+ tokenizer_name # noqa: W503
|
55 |
).encode("utf-8")
|
56 |
).hexdigest()
|
57 |
)
|
|
|
68 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
69 |
)
|
70 |
dataset = dataset["train"]
|
71 |
+
except Exception: # pylint: disable=broad-except
|
72 |
pass
|
73 |
|
74 |
if dataset:
|
|
|
109 |
fp = hf_hub_download(
|
110 |
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
111 |
)
|
112 |
+
ds: Dataset = load_dataset(
|
113 |
+
"json", data_files=fp, streaming=False, split=None
|
114 |
+
)
|
115 |
if not ds:
|
116 |
+
raise ValueError("unhandled dataset load")
|
117 |
# support for using a subset of the data
|
118 |
if d.shards:
|
119 |
if "train" in ds:
|
120 |
+
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(
|
121 |
+
num_shards=d.shards, index=0
|
122 |
+
)
|
123 |
else:
|
124 |
+
ds: Dataset = ds.shuffle(seed=42).shard(
|
125 |
+
num_shards=d.shards, index=0
|
126 |
+
)
|
127 |
d_type = d.type
|
128 |
d_type_split = d_type.split(":")
|
129 |
d_base_type = d_type_split[0]
|
|
|
249 |
|
250 |
def load_prepare_datasets(
|
251 |
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
252 |
+
) -> Tuple[Dataset, Dataset]:
|
253 |
max_packed_sequence_len = (
|
254 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
255 |
)
|
|
|
265 |
md5(
|
266 |
(
|
267 |
str(cfg.sequence_len)
|
268 |
+
+ "@" # noqa: W503
|
269 |
+
+ str(max_packed_sequence_len) # noqa: W503
|
270 |
+
+ seed # noqa: W503
|
271 |
+
+ "|".join( # noqa: W503
|
272 |
+
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
273 |
+
)
|
274 |
+
+ "|" # noqa: W503
|
275 |
+
+ tokenizer_name # noqa: W503
|
276 |
).encode("utf-8")
|
277 |
).hexdigest()
|
278 |
)
|
|
|
293 |
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
294 |
)
|
295 |
dataset = dataset["train"]
|
296 |
+
except Exception: # pylint: disable=broad-except
|
297 |
pass
|
298 |
|
299 |
if dataset:
|
|
|
335 |
d
|
336 |
for d in dataset
|
337 |
if len(d["input_ids"]) < cfg.sequence_len
|
338 |
+
and len(d["input_ids"]) > 0 # noqa: W503
|
339 |
+
and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
|
340 |
+
and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
|
341 |
]
|
342 |
)
|
343 |
|
tests/test_prompters.py
CHANGED
@@ -12,7 +12,9 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|
12 |
|
13 |
def test_prompt_style_w_instruct(self):
|
14 |
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
|
15 |
-
res = next(
|
|
|
|
|
16 |
assert "Below is an instruction" in res
|
17 |
assert "### Instruction:" in res
|
18 |
assert "### Input:" in res
|
@@ -30,7 +32,9 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|
30 |
|
31 |
def test_prompt_style_w_chat(self):
|
32 |
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
|
33 |
-
res = next(
|
|
|
|
|
34 |
assert "Below is an instruction" in res
|
35 |
assert "### Instruction:" not in res
|
36 |
assert "### Input:" not in res
|
@@ -45,5 +49,3 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|
45 |
assert "### Response:" not in res
|
46 |
assert "USER:" in res
|
47 |
assert "ASSISTANT:" in res
|
48 |
-
|
49 |
-
|
|
|
12 |
|
13 |
def test_prompt_style_w_instruct(self):
|
14 |
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
|
15 |
+
res = next(
|
16 |
+
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
17 |
+
)
|
18 |
assert "Below is an instruction" in res
|
19 |
assert "### Instruction:" in res
|
20 |
assert "### Input:" in res
|
|
|
32 |
|
33 |
def test_prompt_style_w_chat(self):
|
34 |
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
|
35 |
+
res = next(
|
36 |
+
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
37 |
+
)
|
38 |
assert "Below is an instruction" in res
|
39 |
assert "### Instruction:" not in res
|
40 |
assert "### Input:" not in res
|
|
|
49 |
assert "### Response:" not in res
|
50 |
assert "USER:" in res
|
51 |
assert "ASSISTANT:" in res
|
|
|
|