Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,759 Bytes
d59f323 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from mmengine.hooks import Hook
from xtuner.registry import BUILDER
class SpecialDatasetInfoHook(Hook):
def __init__(self, tokenizer, is_intern_repo_dataset=False, special_tokens=None):
self.tokenizer = BUILDER.build(tokenizer)
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.is_intern_repo_dataset = is_intern_repo_dataset
def log(self, runner, dataset, mode='train'):
def _log(input_ids, log_prefix=''):
if self.is_intern_repo_dataset:
input_ids = [abs(x) for x in input_ids]
text = self.tokenizer.decode(input_ids)
runner.logger.info(text)
runner.logger.info(f'Num {mode} samples {len(dataset)}')
runner.logger.info(f'{mode} example:')
if 'chosen_ids' in dataset[0]:
_log(dataset[0]['chosen_ids'], log_prefix='chosen: ')
_log(dataset[0]['rejected_ids'], log_prefix='rejected: ')
else:
_log(dataset[0]['input_ids'])
def before_train(self, runner) -> None:
do_train = runner.train_loop is not None
do_eval = runner.val_loop is not None
if do_train:
train_dataset = runner.train_dataloader.dataset
self.log(runner, train_dataset, mode='train')
if do_eval:
eval_dataset = runner.val_dataloader.dataset
self.log(runner, eval_dataset, mode='eval')
def before_val(self, runner) -> None:
eval_dataset = runner.val_dataloader.dataset
self.log(runner, eval_dataset, mode='eval')
def before_test(self, runner) -> None:
test_dataset = runner.test_dataloader.dataset
self.log(runner, test_dataset, mode='test')
|