Nanobit commited on
Commit
974dc00
·
1 Parent(s): 572d114

Fix set mem_id for inference and refactor

Browse files
scripts/finetune.py CHANGED
@@ -78,6 +78,9 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
78
  )
79
 
80
  if cfg.landmark_attention:
 
 
 
81
  model.set_mem_cache_args(
82
  max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
83
  )
 
78
  )
79
 
80
  if cfg.landmark_attention:
81
+ from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
82
+
83
+ set_model_mem_id(model, tokenizer)
84
  model.set_mem_cache_args(
85
  max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
86
  )
src/axolotl/monkeypatch/llama_landmark_attn.py CHANGED
@@ -29,6 +29,7 @@ import torch
29
  import torch.utils.checkpoint
30
  from torch import nn
31
  from torch.nn import CrossEntropyLoss
 
32
  from transformers.modeling_outputs import (
33
  BaseModelOutputWithPast,
34
  CausalLMOutputWithPast,
@@ -1237,3 +1238,12 @@ def patch_llama_with_landmark_attn():
1237
  transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
1238
  transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
1239
  transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
 
 
 
 
 
 
 
 
 
 
29
  import torch.utils.checkpoint
30
  from torch import nn
31
  from torch.nn import CrossEntropyLoss
32
+ from transformers import LlamaTokenizer
33
  from transformers.modeling_outputs import (
34
  BaseModelOutputWithPast,
35
  CausalLMOutputWithPast,
 
1238
  transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
1239
  transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
1240
  transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
1241
+
1242
+
1243
+ def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer):
1244
+ mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
1245
+ model.set_mem_id(mem_id)
1246
+
1247
+
1248
+ def get_mem_id(tokenizer: LlamaTokenizer):
1249
+ return tokenizer.convert_tokens_to_ids(MEM_TOKEN)
src/axolotl/utils/trainer.py CHANGED
@@ -239,16 +239,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
239
  if cfg.is_llama_derived_model and cfg.landmark_attention:
240
  from functools import partial
241
 
242
- from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
 
 
 
 
243
 
244
- mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
245
- model.set_mem_id(mem_id)
246
 
247
  logging.info("Adding landmark attention tokens to dataset")
248
 
249
  for dataset in [train_dataset, eval_dataset]:
250
  dataset = dataset.map(
251
- partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
252
  batched=False,
253
  num_proc=32,
254
  )
 
239
  if cfg.is_llama_derived_model and cfg.landmark_attention:
240
  from functools import partial
241
 
242
+ from axolotl.monkeypatch.llama_landmark_attn import (
243
+ add_mem_tokens,
244
+ get_mem_id,
245
+ set_model_mem_id,
246
+ )
247
 
248
+ set_model_mem_id(model, tokenizer)
 
249
 
250
  logging.info("Adding landmark attention tokens to dataset")
251
 
252
  for dataset in [train_dataset, eval_dataset]:
253
  dataset = dataset.map(
254
+ partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
255
  batched=False,
256
  num_proc=32,
257
  )