File size: 2,114 Bytes
3ab81fe
 
 
 
 
2f2ac8a
 
 
3ab81fe
 
 
 
 
 
 
 
 
 
a8dec23
 
3ab81fe
a8dec23
3ab81fe
 
 
e3c3ec7
3ab81fe
e3c3ec7
3ab81fe
 
 
 
2a309b2
 
 
3ab81fe
 
 
 
 
 
 
 
 
 
e09761e
 
3ab81fe
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
---
license: apache-2.0
---


This model is continually pre-trained from [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) with the structure proposed in [MemoryLLM](https://arxiv.org/abs/2402.04624).  
We equip Llama-3 with 12800 memory tokens in each layer, leading to a memory pool of 1.67B parameters. 


To use the model, please use the following code: 
```
git clone [email protected]:wangyu-ustc/MemoryLLM.git
cd MemoryLLM
```
Then simply use the following code to load the model:
```python
from modeling_memoryllm import MemoryLLM
from transformers import AutoTokenizer
# load chat model
model = MemoryLLM.from_pretrained("YuWangX/memoryllm-8b-chat", attn_implementation="flash_attention_2", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("YuWangX/memoryllm-8b-chat")
model = model.cuda()
```

### How to use the model

Inject a piece of context into the model using the following script:

```python
model = model.cuda()

# Self-Update with the new context
ctx = "Last week, John had a wonderful picnic with David. During their conversation, David mentioned multiple times that he likes eating apples. Though he didn't mention any other fruits, John says he can infer that David also like bananas."

# please make sure the context to inject into the memory is larger than 16 tokens, this is the hard minimum when training the model. The memory will be disturbed when less than 16 tokens are injected into the memory. 
model.inject_memory(tokenizer(ctx, return_tensors='pt', add_special_tokens=False).input_ids.cuda(), update_memory=True)

# Generation
messages = [{
    'role': 'user', "content": "What fruits does David like?",
}]

inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
inputs = inputs[:, 1:] # remove bos token

outputs = model.generate(input_ids=inputs.cuda(),
                         max_new_tokens=20)
response = tokenizer.decode(outputs[0])

outputs = model.generate(inputs=input_ids.cuda(), attention_mask=attention_mask.cuda(), max_new_tokens=10)
print(tokenizer.decode(outputs[0]))
```