File size: 33,168 Bytes
6ca641e
1
2
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"scrolled":true,"trusted":true},"outputs":[],"source":["import argparse\n","import glob\n","import logging\n","import os\n","import pickle\n","import random\n","import re\n","import shutil\n","from typing import Dict, List, Tuple\n","from copy import deepcopy\n","from multiprocessing import Pool\n","\n","import numpy as np\n","import torch\n","from torch.nn.utils.rnn import pad_sequence\n","from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\n","from torch.utils.data.distributed import DistributedSampler\n","from tqdm import tqdm, trange\n","from transformers import AutoTokenizer, AutoModel\n","from transformers import (\n","    WEIGHTS_NAME,\n","    AdamW,\n","    BertConfig,\n","    BertForMaskedLM,\n","    BertTokenizer,\n","    CamembertConfig,\n","    CamembertForMaskedLM,\n","    CamembertTokenizer,\n","    DistilBertConfig,\n","    DistilBertForMaskedLM,\n","    DistilBertTokenizer,\n","    GPT2Config,\n","    GPT2LMHeadModel,\n","    GPT2Tokenizer,\n","    OpenAIGPTConfig,\n","    OpenAIGPTLMHeadModel,\n","    OpenAIGPTTokenizer,\n","    PreTrainedModel,\n","    PreTrainedTokenizer,\n","    RobertaConfig,\n","    RobertaForMaskedLM,\n","    RobertaTokenizer,\n","    get_linear_schedule_with_warmup,\n","    get_cosine_with_hard_restarts_schedule_with_warmup\n",")\n","\n","\n","try:\n","    from torch.utils.tensorboard import SummaryWriter\n","except ImportError:\n","    from tensorboardX import SummaryWriter\n","\n","\n","logger = logging.getLogger(__name__)\n","\n","DNATokenizer = AutoTokenizer.from_pretrained(\"zhihan1996/DNA_bert_6\", trust_remote_code=True)\n","\n","\n","MODEL_CLASSES = {\n","    \"gpt2\": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),\n","    \"openai-gpt\": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),\n","    \"dna\": (BertConfig, BertForMaskedLM, DNATokenizer),\n","    \"bert\": (BertConfig, BertForMaskedLM, BertTokenizer),\n","    \"roberta\": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),\n","    \"distilbert\": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),\n","    \"camembert\": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),\n","}\n","\n","MASK_LIST = {\n","    \"3\": [-1, 1],\n","    \"4\": [-1, 1, 2],\n","    \"5\": [-2, -1, 1, 2],\n","    \"6\": [-2, -1, 1, 2, 3]\n","}\n","\n","\n","class TextDataset(Dataset):\n","    def __init__(self, tokenizer: PreTrainedTokenizer, config, file_path: str, block_size=512):\n","        assert os.path.isfile(file_path)\n","\n","\n","        directory, filename = os.path.split(file_path)\n","        cached_features_file = os.path.join(\n","            directory, dna + \"_cached_lm_\" + str(block_size) + \"_\" + filename\n","        )\n","\n","        if os.path.exists(cached_features_file) and not config['overwrite_cache']:\n","            logger.info(\"Loading features from cached file %s\", cached_features_file)\n","            with open(cached_features_file, \"rb\") as handle:\n","                self.examples = pickle.load(handle)\n","        else:\n","            logger.info(\"Creating features from dataset file at %s\", directory)\n","\n","            self.examples = []\n","            with open(file_path, encoding=\"utf-8\") as f:\n","                text = f.read()\n","\n","            tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))\n","\n","            for i in range(0, len(tokenized_text) - block_size + 1, block_size):  # Truncate in block of block_size\n","                self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size]))\n","            logger.info(\"Saving features into cached file %s\", cached_features_file)\n","            with open(cached_features_file, \"wb\") as handle:\n","                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)\n","\n","    def __len__(self):\n","        return len(self.examples)\n","\n","    def __getitem__(self, item):\n","        return torch.tensor(self.examples[item], dtype=torch.long)\n","\n","def convert_line_to_example(tokenizer, lines, max_length, add_special_tokens=True):\n","    examples = tokenizer.batch_encode_plus(lines, add_special_tokens=add_special_tokens, max_length=max_length)[\"input_ids\"]\n","    return examples\n","\n","class LineByLineTextDataset(Dataset):\n","    def __init__(self, tokenizer: PreTrainedTokenizer, config, file_path: str, block_size=512):\n","        assert os.path.isfile(file_path)\n","        # Here, we do not cache the features, operating under the assumption\n","        # that we will soon use fast multithreaded tokenizers from the\n","        # `tokenizers` repo everywhere =)\n","        directory, filename = os.path.split(file_path)\n","        cached_features_file = os.path.join(\n","            '/kaggle/working/', 'dna' + \"_cached_lm_\" + str(block_size) + \"_\" + filename\n","        )\n","\n","        if os.path.exists(cached_features_file) and not config['overwrite_cache']:\n","            logger.info(\"Loading features from cached file %s\", cached_features_file)\n","            with open(cached_features_file, \"rb\") as handle:\n","                self.examples = pickle.load(handle)\n","        else:\n","            logger.info(\"Creating features from dataset file at %s\", file_path)\n","\n","            with open(file_path, encoding=\"utf-8\") as f:\n","                lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]\n","            \n","            if config['n_process'] == 1:\n","                self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)[\"input_ids\"]\n","            else:\n","                n_proc = config['n_process']\n","                p = Pool(n_proc)\n","                indexes = [0]\n","                len_slice = int(len(lines)/n_proc)\n","                for i in range(1, n_proc+1):\n","                    if i != n_proc:\n","                        indexes.append(len_slice*(i))\n","                    else:\n","                        indexes.append(len(lines))\n","                results = []\n","                for i in range(n_proc):\n","                    results.append(p.apply_async(convert_line_to_example,[tokenizer, lines[indexes[i]:indexes[i+1]], block_size,]))\n","                    print(str(i) + \" start\")\n","                p.close() \n","                p.join()\n","\n","                self.examples = []\n","                for result in results:\n","                    ids = result.get()\n","                    self.examples.extend(ids)\n","\n","            logger.info(\"Saving features into cached file %s\", cached_features_file)\n","            with open(cached_features_file, \"wb\") as handle:\n","                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)\n","\n","    def __len__(self):\n","        return len(self.examples)\n","\n","    def __getitem__(self, i):\n","        return torch.tensor(self.examples[i], dtype=torch.long)\n","\n","\n","def load_and_cache_examples(config, tokenizer, evaluate=False):\n","    file_path = r\"/kaggle/input/random-dna-sequences-for-transfomer-pretraining/6_12k.txt\" if evaluate else r'/kaggle/input/random-dna-sequences-for-transfomer-pretraining/6_12k.txt'\n","    if config['line_by_line']:\n","        return LineByLineTextDataset(tokenizer, config, file_path=file_path, block_size=config['block_size'])\n","    else:\n","        return TextDataset(tokenizer, config, file_path=file_path, block_size=config['block_size'])\n","\n","\n","def set_seed(config):\n","    random.seed(config['seed'])\n","    np.random.seed(config['seed'])\n","    torch.manual_seed(config['seed'])\n","    if config['n_gpu'] > 0:\n","        torch.cuda.manual_seed_all(config['seed'])\n","\n","\n","def _sorted_checkpoints(config, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> List[str]:\n","    ordering_and_checkpoint_path = []\n","    st = r\"/kaggle/working/output\"\n","    \n","    glob_checkpoints = glob.glob(os.path.join(st, \"{}-*\".format(checkpoint_prefix)))\n","\n","    for path in glob_checkpoints:\n","        if use_mtime:\n","            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n","        else:\n","            regex_match = re.match(\".*{}-([0-9]+)\".format(checkpoint_prefix), path)\n","            if regex_match and regex_match.groups():\n","                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n","\n","    checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n","    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n","    return checkpoints_sorted\n","\n","\n","def _rotate_checkpoints(config, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> None:\n","    if not config['save_total_limit']:\n","        return\n","    if config['save_total_limit'] <= 0:\n","        return\n","\n","    # Check if we should delete older checkpoint(s)\n","    checkpoints_sorted = _sorted_checkpoints(config, checkpoint_prefix, use_mtime)\n","    if len(checkpoints_sorted) <= config['save_total_limit']:\n","        return\n","\n","    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - config['save_total_limit'])\n","    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n","    for checkpoint in checkpoints_to_be_deleted:\n","        logger.info(\"Deleting older checkpoint [{}] due to config['save_total_limit']\".format(checkpoint))\n","        shutil.rmtree(checkpoint)\n","\n","\n","\n","\n","def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, config) -> Tuple[torch.Tensor, torch.Tensor]:\n","    \"\"\"Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.\"\"\"\n","\n","    mask_list = MASK_LIST['6']\n","\n","    if tokenizer.mask_token is None:\n","        raise ValueError(\n","            \"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer.\"\n","        )\n","\n","    labels = inputs.clone()\n","    # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)\n","    probability_matrix = torch.full(labels.shape, config['mlm_probability'])\n","    special_tokens_mask = [\n","        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()\n","    ]\n","    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)\n","    if tokenizer.pad_token is not None:\n","        padding_mask = labels.eq(tokenizer.pad_token_id)\n","        probability_matrix.masked_fill_(padding_mask, value=0.0)\n","\n","    masked_indices = torch.bernoulli(probability_matrix).bool()\n","\n","    # Ensure masked_indices and probability_matrix are the same shape\n","    masks = deepcopy(masked_indices)\n","    for i, masked_index in enumerate(masks):\n","        # Ensure there are non-zero elements to avoid IndexError\n","        non_zero_indices = torch.where(probability_matrix[i] != 0)[0]\n","        if non_zero_indices.numel() == 0:\n","            # If no non-zero elements, skip this sequence\n","            continue\n","\n","        end = non_zero_indices.tolist()[-1]\n","        mask_centers = set(torch.where(masked_index == 1)[0].tolist())\n","        new_centers = deepcopy(mask_centers)\n","        for center in mask_centers:\n","            for mask_number in mask_list:\n","                current_index = center + mask_number\n","                if current_index <= end and current_index >= 1:\n","                    new_centers.add(current_index)\n","        new_centers = list(new_centers)\n","        masked_indices[i][new_centers] = True\n","\n","    labels[~masked_indices] = -100  # We only compute loss on masked tokens\n","\n","    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])\n","    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices\n","    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)\n","\n","    # 10% of the time, we replace masked input tokens with random word\n","    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced\n","    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)\n","    inputs[indices_random] = random_words[indices_random]\n","\n","    # The rest of the time (10% of the time) we keep the masked input tokens unchanged\n","    return inputs, labels\n","\n","import os\n","import torch\n","from torch.nn.utils.rnn import pad_sequence\n","from torch.utils.data import DataLoader, RandomSampler, DistributedSampler\n","from tqdm import tqdm, trange\n","from transformers import PreTrainedModel, PreTrainedTokenizer, AdamW, get_linear_schedule_with_warmup\n","from typing import List, Dict, Tuple\n","import wandb\n","import time\n","\n","def train(config, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:\n","    \"\"\" Train the model \"\"\"\n","    if config['local_rank'] in [-1, 0]:\n","        tb_writer = SummaryWriter()\n","\n","    config['train_batch_size'] = config['per_gpu_train_batch_size'] * max(1, config['n_gpu'])\n","\n","    def collate(examples: List[torch.Tensor]):\n","        if tokenizer._pad_token is None:\n","            return pad_sequence(examples, batch_first=True)\n","        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n","\n","    train_sampler = RandomSampler(train_dataset) if config['local_rank'] == -1 else DistributedSampler(train_dataset)\n","    train_dataloader = DataLoader(\n","        train_dataset, sampler=train_sampler, batch_size=config['train_batch_size'], collate_fn=collate\n","    )\n","\n","    if config['max_steps'] > 0:\n","        t_total = config['max_steps']\n","        config['num_train_epochs'] = config['max_steps'] // (len(train_dataloader) // config['gradient_accumulation_steps']) + 1\n","    else:\n","        t_total = len(train_dataloader) // config['gradient_accumulation_steps'] * config['num_train_epochs']\n","\n","    # Prepare optimizer and schedule (linear warmup and decay)\n","    no_decay = [\"bias\", \"LayerNorm.weight\"]\n","    optimizer_grouped_parameters = [\n","        {\n","            \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n","            \"weight_decay\": config['weight_decay'],\n","        },\n","        {\"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], \"weight_decay\": 0.0},\n","    ]\n","    optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate'], eps=config['adam_epsilon'], betas=(config['beta1'],config['beta2']))\n","    scheduler = get_linear_schedule_with_warmup(\n","        optimizer, num_warmup_steps=2000, num_training_steps=t_total\n","    )\n","\n","    # Train!\n","    logger.info(\"***** Running training *****\")\n","    logger.info(\"  Num examples = %d\", len(train_dataset))\n","    logger.info(\"  Num Epochs = %d\", config['num_train_epochs'])\n","    logger.info(\"  Instantaneous batch size per GPU = %d\", config['per_gpu_train_batch_size'])\n","    logger.info(\n","        \"  Total train batch size (w. parallel, distributed & accumulation) = %d\",\n","        config['train_batch_size']\n","        * config['gradient_accumulation_steps']\n","        * (torch.distributed.get_world_size() if config['local_rank'] != -1 else 1),\n","    )\n","    logger.info(\"  Gradient Accumulation steps = %d\", config['gradient_accumulation_steps'])\n","    logger.info(\"  Total optimization steps = %d\", t_total)\n","\n","    global_step = 0\n","    epochs_trained = 0\n","    steps_trained_in_current_epoch = 0\n","\n","    tr_loss, logging_loss = 0.0, 0.0\n","\n","    model_to_resize = model.module if hasattr(model, \"module\") else model  # Take care of distributed/parallel training\n","    model_to_resize.resize_token_embeddings(len(tokenizer))\n","\n","    model.zero_grad()\n","    train_iterator = trange(\n","        epochs_trained, int(config['num_train_epochs']), desc=\"Epoch\", disable=config['local_rank'] not in [-1, 0]\n","    )\n","    set_seed(config)  # Added here for reproducibility\n","\n","    for epoch in train_iterator:\n","        epoch_start_time = time.time()\n","        epoch_iterator = tqdm(train_dataloader, desc=\"Iteration\", disable=config['local_rank'] not in [-1, 0])\n","        for step, batch in enumerate(epoch_iterator):\n","\n","            # Skip past any already trained steps if resuming training\n","            if steps_trained_in_current_epoch > 0:\n","                steps_trained_in_current_epoch -= 1\n","                continue\n","\n","            inputs, labels = mask_tokens(batch, tokenizer, config) if config['mlm'] else (batch, batch)\n","\n","            inputs = inputs.to(config['device'])\n","            labels = labels.to(config['device'])\n","            model.train()\n","            outputs = model(inputs, labels=labels) if config['mlm'] else model(inputs, labels=labels)\n","            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)\n","\n","            if config['n_gpu'] > 1:\n","                loss = loss.mean()  # mean() to average on multi-gpu parallel training\n","            if config['gradient_accumulation_steps'] > 1:\n","                loss = loss / config['gradient_accumulation_steps']\n","\n","            loss.backward()\n","\n","            tr_loss += loss.item()\n","            if (step + 1) % config['gradient_accumulation_steps'] == 0:\n","                torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])\n","                optimizer.step()\n","                scheduler.step()  # Update learning rate schedule\n","                model.zero_grad()\n","                global_step += 1\n","\n","                # Log metrics to wandb\n","                wandb.log({\"learning_rate\": scheduler.get_last_lr()[0], \"loss\": loss.item(), \"global_step\": global_step})\n","\n","                if config['local_rank'] in [-1, 0] and config['logging_steps'] > 0 and global_step % config['logging_steps'] == 0:\n","                    # Log metrics\n","                    if (\n","                        config['local_rank'] == -1 and config['evaluate_during_training']\n","                    ):  # Only evaluate when single GPU otherwise metrics may not average well\n","                        results = evaluate(config, model, tokenizer)\n","                        for key, value in results.items():\n","                            tb_writer.add_scalar(\"eval_{}\".format(key), value, global_step)\n","                            wandb.log({f\"eval_{key}\": value, \"global_step\": global_step})\n","                    tb_writer.add_scalar(\"lr\", scheduler.get_lr()[0], global_step)\n","                    tb_writer.add_scalar(\"loss\", (tr_loss - logging_loss) / config['logging_steps'], global_step)\n","                    logging_loss = tr_loss\n","\n","                if config['local_rank'] in [-1, 0] and config['save_steps'] > 0 and global_step % config['save_steps'] == 0:\n","                    checkpoint_prefix = \"checkpoint\"\n","                    # Save model checkpoint\n","                    st = r\"/kaggle/working/output\"\n","                    output_dir = os.path.join(st, \"{}-{}\".format(checkpoint_prefix, global_step))\n","                    os.makedirs(output_dir, exist_ok=True)\n","                    model_to_save = (\n","                        model.module if hasattr(model, \"module\") else model\n","                    )  # Take care of distributed/parallel training\n","                    model_to_save.save_pretrained(output_dir)\n","                    tokenizer.save_pretrained(output_dir)\n","\n","                    torch.save(config, os.path.join(output_dir, \"training_args.bin\"))\n","                    logger.info(\"Saving model checkpoint to %s\", output_dir)\n","\n","                    _rotate_checkpoints(config, checkpoint_prefix)\n","\n","                    torch.save(optimizer.state_dict(), os.path.join(output_dir, \"optimizer.pt\"))\n","                    torch.save(scheduler.state_dict(), os.path.join(output_dir, \"scheduler.pt\"))\n","                    logger.info(\"Saving optimizer and scheduler states to %s\", output_dir)\n","\n","            if config['max_steps'] > 0 and global_step > config['max_steps']:\n","                epoch_iterator.close()\n","                break\n","        if config['max_steps'] > 0 and global_step > config['max_steps']:\n","            train_iterator.close()\n","            break\n","        epoch_end_time = time.time()\n","        epoch_time = epoch_end_time - epoch_start_time\n","        # Log epoch time\n","        output_dir = r\"/kaggle/working/output\"\n","        logging.info(f'Epoch {epoch + 1}: Time {epoch_time:.4f}s')\n","        log_dir = os.path.join(output_dir, 'training_logs')\n","        os.makedirs(log_dir, exist_ok=True)\n","        file = os.path.join(log_dir,'log.txt')\n","        with open(file, 'a') as f:\n","            f.write(f\"Epoch {epoch + 1}/{config['num_train_epochs']}:\\n\")\n","            f.write(f\"  Epoch Time: {epoch_time}\\n\")\n","\n","        # Log epoch time to wandb\n","        wandb.log({\"epoch_time\": epoch_time, \"epoch\": epoch + 1})\n","\n","    if config['local_rank'] in [-1, 0]:\n","        tb_writer.close()\n","\n","    return global_step, tr_loss / global_step\n","\n","\n","import os\n","import torch\n","from torch.nn.utils.rnn import pad_sequence\n","from torch.utils.data import DataLoader, SequentialSampler\n","from tqdm import tqdm\n","from transformers import PreTrainedModel, PreTrainedTokenizer\n","from typing import List, Dict\n","import wandb\n","\n","def evaluate(config, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix=\"\") -> Dict:\n","    # Loop to handle MNLI double evaluation (matched, mis-matched)\n","    eval_output_dir = config['output_dir']\n","\n","    eval_dataset = load_and_cache_examples(config, tokenizer, evaluate=True)\n","\n","    if config['local_rank'] in [-1, 0]:\n","        os.makedirs(eval_output_dir, exist_ok=True)\n","\n","    config['eval_batch_size'] = config['per_gpu_eval_batch_size'] * max(1, config['n_gpu'])\n","    # Note that DistributedSampler samples randomly\n","\n","    def collate(examples: List[torch.Tensor]):\n","        if tokenizer._pad_token is None:\n","            return pad_sequence(examples, batch_first=True)\n","        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n","\n","    eval_sampler = SequentialSampler(eval_dataset)\n","    eval_dataloader = DataLoader(\n","        eval_dataset, sampler=eval_sampler, batch_size=config['eval_batch_size'], collate_fn=collate\n","    )\n","\n","    # multi-gpu evaluate\n","    if config['n_gpu'] > 1 and not isinstance(model, torch.nn.DataParallel):\n","        model = torch.nn.DataParallel(model)\n","\n","    # Eval!\n","    logger.info(\"***** Running evaluation {} *****\".format(prefix))\n","    logger.info(\"  Num examples = %d\", len(eval_dataset))\n","    logger.info(\"  Batch size = %d\", config['eval_batch_size'])\n","    eval_loss = 0.0\n","    nb_eval_steps = 0\n","    model.eval()\n","\n","    for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n","        inputs, labels = mask_tokens(batch, tokenizer, config) if config['mlm'] else (batch, batch)\n","        inputs = inputs.to(config['device'])\n","        labels = labels.to(config['device'])\n","\n","        with torch.no_grad():\n","            outputs = model(inputs, labels=labels) if config['mlm'] else model(inputs, labels=labels)\n","            lm_loss = outputs[0]\n","            eval_loss += lm_loss.mean().item()\n","        nb_eval_steps += 1\n","\n","    eval_loss = eval_loss / nb_eval_steps\n","    perplexity = torch.exp(torch.tensor(eval_loss))\n","\n","    result = {\"perplexity\": perplexity.item()}\n","\n","    # Log metrics to wandb\n","    wandb.log({\"eval perplexity\" : result})\n","\n","    output_eval_file = os.path.join(eval_output_dir, prefix, \"eval_results.txt\")\n","    with open(output_eval_file, \"a\") as writer:\n","        logger.info(\"***** Eval results {} *****\".format(prefix))\n","        for key in sorted(result.keys()):\n","            logger.info(\"  %s = %s\", key, str(result[key]))\n","            writer.write(\"%s = %s\\n\" % (key, str(result[key])))\n","\n","    return result\n","\n","\n","import argparse\n","import os\n","import logging\n","import torch\n","import os\n","import logging\n","import torch\n","\n","def main(config):\n","    # Handle checkpoint continuation\n","    if config['should_continue']:\n","        sorted_checkpoints = _sorted_checkpoints(config)\n","        if len(sorted_checkpoints) == 0:\n","            raise ValueError(\"Used --should_continue but no checkpoint was found in --output_dir.\")\n","        else:\n","            config['model_name_or_path'] = sorted_checkpoints[-1]\n","\n","    output_dir = config.get('output_dir', './output')\n","    if (\n","        os.path.exists(output_dir)\n","        and os.listdir(output_dir)\n","        and config['do_train']\n","        and not config.get('overwrite_output_dir', False)\n","    ):\n","        raise ValueError(\n","            \"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.\".format(\n","                output_dir\n","            )\n","        )\n","\n","    # Setup CUDA, GPU & distributed training\n","    if config.get('local_rank', -1) == -1 or config.get('no_cuda', False):\n","        device = torch.device(\"cuda:0\" if torch.cuda.is_available() and not config.get('no_cuda', False) else \"cpu\")\n","        config['n_gpu'] = torch.cuda.device_count()\n","    else:\n","        torch.cuda.set_device(config.get('local_rank', 0))\n","        device = torch.device(\"cuda\", config.get('local_rank', 0))\n","        torch.distributed.init_process_group(backend=\"nccl\")\n","        config['n_gpu'] = 1\n","    config['device'] = device\n","\n","    # Setup logging\n","    logging.basicConfig(\n","        format=\"%(asctime)s - %(levelname)s - %(name)s -   %(message)s\",\n","        datefmt=\"%m/%d/%Y %H:%M:%S\",\n","        level=logging.INFO if config.get('local_rank', -1) in [-1, 0] else logging.WARN,\n","        filename = 'app.log'\n","    )\n","    logger = logging.getLogger(__name__)\n","    logger.warning(\n","        \"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s\",\n","        config.get('local_rank', -1),\n","        device,\n","        config['n_gpu'],\n","        bool(config.get('local_rank', -1) != -1),\n","        config.get('fp16', False),\n","    )\n","\n","    # Set seed\n","    set_seed(config)\n","\n","    # Load pretrained model and tokenizer\n","    if config.get('local_rank', -1) not in [-1, 0]:\n","        torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training download model & vocab\n","\n","    config_class, model_class, tokenizer_class = MODEL_CLASSES['dna']\n","    config_obj = config_class.from_pretrained('prajjwal1/bert-tiny', cache_dir=config.get('cache_dir', None))\n","\n","    tokenizer = tokenizer_class.from_pretrained('zhihan1996/DNA_bert_6', cache_dir=config.get('cache_dir', None))\n","\n","    if config.get('block_size', 512) <= 0:\n","        config['block_size'] = 512\n","    else:\n","        config['block_size'] = min(config['block_size'], 512)\n","\n","    if config.get('model_name_or_path'):\n","#         model = model_class.from_pretrained(\n","#             config['model_name_or_path'],\n","#             from_tf=bool(\".ckpt\" in config['model_name_or_path']),\n","#             config=config_obj,\n","#             cache_dir=config.get('cache_dir', None),\n","        pass\n","    else:\n","        logger.info(\"Training new model from scratch\")\n","        model = model_class(config=config_obj)\n","\n","    model.to(config['device'])\n","\n","    if config.get('local_rank', -1) == 0:\n","        torch.distributed.barrier()\n","\n","    logger.info(\"Training/evaluation parameters %s\", config)\n","\n","    # Training\n","    if config.get('do_train', False):\n","        if config.get('local_rank', -1) not in [-1, 0]:\n","            torch.distributed.barrier()\n","\n","        train_dataset = load_and_cache_examples(config, tokenizer, evaluate=False)\n","\n","        if config.get('local_rank', -1) == 0:\n","            torch.distributed.barrier()\n","\n","        global_step, tr_loss = train(config, train_dataset, model, tokenizer)\n","        logger.info(\" global_step = %s, average loss = %s\", global_step, tr_loss)\n","\n","    # Save and reload model\n","    if config.get('do_train', False) and (config.get('local_rank', -1) == -1 or torch.distributed.get_rank() == 0):\n","        if config.get('local_rank', -1) in [-1, 0]:\n","            os.makedirs(output_dir, exist_ok=True)\n","\n","        logger.info(\"Saving model checkpoint to %s\", output_dir)\n","        model_to_save = (\n","            model.module if hasattr(model, \"module\") else model\n","        )\n","        model_to_save.save_pretrained(output_dir)\n","        tokenizer.save_pretrained(output_dir)\n","        torch.save(config, os.path.join(output_dir, \"training_args.bin\"))\n","\n","        model = model_class.from_pretrained(output_dir)\n","        tokenizer = tokenizer_class.from_pretrained(output_dir)\n","        model.to(config['device'])\n","\n","    # Evaluation\n","    results = {}\n","    if config.get('do_eval', False) and config.get('local_rank', -1) in [-1, 0]:\n","        checkpoints = [output_dir]\n","        if config.get('eval_all_checkpoints', False):\n","            checkpoints = list(\n","                os.path.dirname(c) for c in sorted(glob.glob(output_dir + \"/**/\" + WEIGHTS_NAME, recursive=True))\n","            )\n","            logging.getLogger(\"transformers.modeling_utils\").setLevel(logging.WARN)\n","        logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n","        for checkpoint in checkpoints:\n","            global_step = checkpoint.split(\"-\")[-1] if len(checkpoints) > 1 else \"\"\n","            prefix = checkpoint.split(\"/\")[-1] if checkpoint.find(\"checkpoint\") != -1 else \"\"\n","\n","            model = model_class.from_pretrained(checkpoint)\n","            model.to(config['device'])\n","            result = evaluate(config, model, tokenizer, prefix=prefix)\n","            result = dict((k + \"_{}\".format(global_step), v) for k, v in result.items())\n","            results.update(result)\n","\n","    return results\n","\n","# Example configuration dictionary\n","config = {\n","    'line_by_line': True,\n","    'should_continue': False,#use if you have a checkpoint present or it will throw error\n","    'mlm': True,\n","    'mlm_probability': 0.15,\n","    'config_name': None,\n","    'tokenizer_name': None,\n","    'cache_dir': None,\n","    'block_size': 512,\n","    'do_train': True,\n","    'do_eval': True,\n","    'evaluate_during_training': True,\n","    'per_gpu_train_batch_size': 175,\n","    'per_gpu_eval_batch_size': 25,\n","    'gradient_accumulation_steps': 1,\n","    'learning_rate': 4e-4,\n","    'weight_decay': 0.01,\n","    'adam_epsilon': 1e-6,\n","    'beta1': 0.9,\n","    'beta2': 0.98,\n","    'max_grad_norm': 1.0,\n","    'num_train_epochs': 2000,\n","    'max_steps': -1,\n","    'warmup_steps': 100,\n","    'logging_steps': 200,\n","    'save_steps': 1000,\n","    'save_total_limit': 10,\n","    'eval_all_checkpoints': False,\n","    'no_cuda': False,\n","    'overwrite_output_dir': True,\n","    'overwrite_cache': False,\n","    'seed': 42,\n","    'n_process': 1,\n","    'fp16': False,\n","    'fp16_opt_level': 'O1',\n","    'local_rank': -1,\n","    'server_ip': '',\n","    'server_port': '',\n","    'output_dir': './output',\n","    'device':'cuda'\n","}\n","\n","if __name__ == \"__main__\":\n","    main(config)\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"kaggle":{"accelerator":"gpu","dataSources":[{"datasetId":5477436,"sourceId":9095316,"sourceType":"datasetVersion"}],"dockerImageVersionId":30732,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":4}