Spaces:
Runtime error
Runtime error
SyrWin
commited on
Commit
·
95f97c5
1
Parent(s):
a281fc1
init
Browse files- .gitattributes +1 -0
- .gitignore +18 -0
- README.md +49 -13
- all_checkpoints/.gitignore +2 -0
- app.py +309 -0
- average_ckpt.py +33 -0
- convert.py +14 -0
- data_provider/__init__.py +0 -0
- data_provider/caption_dataset.py +93 -0
- data_provider/chebi_dataset.py +42 -0
- data_provider/context_gen.py +207 -0
- data_provider/data_utils.py +144 -0
- data_provider/molecule_abstract_dataset.py +222 -0
- data_provider/pretrain_dm.py +309 -0
- data_provider/r_smiles.py +449 -0
- data_provider/reaction_action_dataset.py +100 -0
- data_provider/synthesis_dataset.py +160 -0
- data_provider/tune_dm.py +312 -0
- demo.json +7 -0
- demo.py +224 -0
- environment.yml +489 -0
- figures/frameworks.jpg +3 -0
- gin_pretrained/graphcl_80.pth +3 -0
- graph_gen.ipynb +190 -0
- lora_config.json +14 -0
- main.py +157 -0
- model/allowed_words.json +118 -0
- model/blip2.py +126 -0
- model/blip2_llama.py +266 -0
- model/blip2_model.py +381 -0
- model/blip2_opt.py +417 -0
- model/blip2_t5.py +305 -0
- model/blip2qformer.py +603 -0
- model/dist_funs.py +83 -0
- model/gin_model.py +397 -0
- model/help_funcs.py +86 -0
- model/modeling_llama.py +888 -0
- model/modeling_opt.py +1223 -0
- model/opt_flash_attention.py +331 -0
- read_results/baselines.py +141 -0
- read_results/read_results.py +28 -0
- read_results/score.py +358 -0
- read_results/t_test.py +48 -0
- read_results/utils.py +256 -0
- visualize_context_gen.py +164 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
figures/frameworks.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
kvplm_pretrained
|
2 |
+
__pycache__
|
3 |
+
test*
|
4 |
+
log*
|
5 |
+
*.sh.e*
|
6 |
+
*.sh.o*
|
7 |
+
.d*
|
8 |
+
llms
|
9 |
+
Text2graph
|
10 |
+
fig/
|
11 |
+
results/
|
12 |
+
*.out
|
13 |
+
*.err
|
14 |
+
debug*
|
15 |
+
data
|
16 |
+
scripts
|
17 |
+
conda_env
|
18 |
+
tmp*
|
README.md
CHANGED
@@ -1,13 +1,49 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ReactXT: Understanding Molecular “Reaction-ship” via Reaction-Contextualized Molecule-Text Pretraining
|
2 |
+
|
3 |
+
## Comparison to previous molecule-text generative modeling methods
|
4 |
+
|
5 |
+
![fig1](./figures/comparison.pdf)
|
6 |
+
|
7 |
+
|
8 |
+
## Framework of ReactXT
|
9 |
+
|
10 |
+
![fig1](./figures/frameworks.pdf)
|
11 |
+
|
12 |
+
|
13 |
+
## Requirements
|
14 |
+
|
15 |
+
Our environment is detailed in `environment.yml`. To create a new environment `reactxt`, run the following command:
|
16 |
+
|
17 |
+
```bash
|
18 |
+
conda env create -f environment.yml
|
19 |
+
```
|
20 |
+
|
21 |
+
|
22 |
+
## Reproduce the results
|
23 |
+
|
24 |
+
### Reaction-Contextualized Molecule-Text Pretraining
|
25 |
+
|
26 |
+
```bash
|
27 |
+
bash scripts/run_pretrain.sh
|
28 |
+
```
|
29 |
+
|
30 |
+
### Finetuning on downstream tasks
|
31 |
+
|
32 |
+
1. Experimental Procedure Prediction on OpenExp
|
33 |
+
|
34 |
+
```bash
|
35 |
+
bash scripts/run_action.sh
|
36 |
+
```
|
37 |
+
|
38 |
+
2. Molecule Captioning on PubChem324k and CheBI-20
|
39 |
+
|
40 |
+
```bash
|
41 |
+
bash scripts/run_caption.sh
|
42 |
+
bash scripts/run_chebi.sh
|
43 |
+
```
|
44 |
+
|
45 |
+
3. Retro-synthesis Prediction on USPTO-50k
|
46 |
+
|
47 |
+
```bash
|
48 |
+
bash scripts/run_retro.sh
|
49 |
+
```
|
all_checkpoints/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
app.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import warnings
|
5 |
+
from rdkit import Chem
|
6 |
+
from rdkit.Chem import CanonSmiles
|
7 |
+
from rdkit.Chem import MolFromSmiles, MolToSmiles
|
8 |
+
from data_provider.pretrain_dm import PretrainDM
|
9 |
+
from data_provider.tune_dm import *
|
10 |
+
from model.opt_flash_attention import replace_opt_attn_with_flash_attn
|
11 |
+
from model.blip2_model import Blip2Model
|
12 |
+
from data_provider.data_utils import json_read, json_write
|
13 |
+
from data_provider.data_utils import smiles2data, reformat_smiles
|
14 |
+
import gradio as gr
|
15 |
+
from datetime import datetime
|
16 |
+
|
17 |
+
## for pyg bug
|
18 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
19 |
+
## for A5000 gpus
|
20 |
+
torch.set_float32_matmul_precision('medium') # can be medium (bfloat16), high (tensorfloat32), highest (float32)
|
21 |
+
|
22 |
+
def smiles_split(string, separator='.'):
|
23 |
+
string = str(string)
|
24 |
+
mols = []
|
25 |
+
for smi in string.split(separator):
|
26 |
+
mol = MolFromSmiles(smi)
|
27 |
+
if mol is None:
|
28 |
+
continue # Skip invalid SMILES strings
|
29 |
+
mols.append(mol)
|
30 |
+
|
31 |
+
parts = []
|
32 |
+
current_part = []
|
33 |
+
charge_count = 0
|
34 |
+
|
35 |
+
for mol in mols:
|
36 |
+
charge = Chem.GetFormalCharge(mol)
|
37 |
+
if charge==0:
|
38 |
+
if current_part:
|
39 |
+
smiles = '.'.join([MolToSmiles(m) for m in current_part])
|
40 |
+
smiles = CanonSmiles(smiles)
|
41 |
+
parts.append(smiles)
|
42 |
+
current_part = []
|
43 |
+
charge_count = 0
|
44 |
+
parts.append(MolToSmiles(mol))
|
45 |
+
else:
|
46 |
+
charge_count += charge
|
47 |
+
current_part.append(mol)
|
48 |
+
if charge_count == 0:
|
49 |
+
smiles = '.'.join([MolToSmiles(m) for m in current_part])
|
50 |
+
smiles = CanonSmiles(smiles)
|
51 |
+
parts.append(smiles)
|
52 |
+
current_part = []
|
53 |
+
charge_count = 0
|
54 |
+
if current_part:
|
55 |
+
smiles = '.'.join([MolToSmiles(m) for m in current_part])
|
56 |
+
smiles = CanonSmiles(smiles)
|
57 |
+
parts.append(smiles)
|
58 |
+
|
59 |
+
return parts
|
60 |
+
|
61 |
+
def get_args():
|
62 |
+
parser = argparse.ArgumentParser()
|
63 |
+
parser.add_argument('--filename', type=str, default="main")
|
64 |
+
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
65 |
+
# MM settings
|
66 |
+
parser.add_argument('--mode', type=str, default='pretrain', choices=['pretrain', 'ft', 'eval', 'pretrain_eval'])
|
67 |
+
parser.add_argument('--strategy_name', type=str, default='mydeepspeed')
|
68 |
+
parser.add_argument('--iupac_prediction', action='store_true', default=False)
|
69 |
+
parser.add_argument('--ckpt_path', type=str, default=None)
|
70 |
+
# parser = Trainer.add_argparse_args(parser)
|
71 |
+
parser = Blip2Model.add_model_specific_args(parser) # add model args
|
72 |
+
parser = PretrainDM.add_model_specific_args(parser)
|
73 |
+
parser.add_argument('--accelerator', type=str, default='gpu')
|
74 |
+
parser.add_argument('--devices', type=str, default='0,1,2,3')
|
75 |
+
parser.add_argument('--precision', type=str, default='bf16-mixed')
|
76 |
+
parser.add_argument('--downstream_task', type=str, default='action', choices=['action', 'synthesis', 'caption', 'chebi'])
|
77 |
+
parser.add_argument('--max_epochs', type=int, default=10)
|
78 |
+
parser.add_argument('--enable_flash', action='store_true', default=False)
|
79 |
+
parser.add_argument('--disable_graph_cache', action='store_true', default=False)
|
80 |
+
parser.add_argument('--generate_restrict_tokens', action='store_true', default=False)
|
81 |
+
parser.add_argument('--train_restrict_tokens', action='store_true', default=False)
|
82 |
+
parser.add_argument('--smiles_type', type=str, default='default', choices=['default', 'canonical', 'restricted', 'unrestricted', 'r_smiles'])
|
83 |
+
parser.add_argument('--accumulate_grad_batches', type=int, default=1)
|
84 |
+
parser.add_argument('--tqdm_interval', type=int, default=50)
|
85 |
+
parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
|
86 |
+
args = parser.parse_args()
|
87 |
+
|
88 |
+
if args.enable_flash:
|
89 |
+
replace_opt_attn_with_flash_attn()
|
90 |
+
return args
|
91 |
+
|
92 |
+
app_config = {
|
93 |
+
"init_checkpoint": "all_checkpoints/ckpt_tune_hybridFeb11_May31/last_converted.ckpt",
|
94 |
+
"filename": "app",
|
95 |
+
"opt_model": "facebook/galactica-1.3b",
|
96 |
+
"num_workers": 4,
|
97 |
+
"rxn_max_len": 512,
|
98 |
+
"text_max_len": 512,
|
99 |
+
"precision": "bf16-mixed",
|
100 |
+
"max_inference_len": 512,
|
101 |
+
}
|
102 |
+
|
103 |
+
class InferenceRunner:
|
104 |
+
def __init__(self, model, tokenizer, rxn_max_len, smi_max_len,
|
105 |
+
smiles_type='default', device='cuda', args=None):
|
106 |
+
self.model = model
|
107 |
+
self.rxn_max_len = rxn_max_len
|
108 |
+
self.smi_max_len = smi_max_len
|
109 |
+
self.tokenizer = tokenizer
|
110 |
+
self.collater = Collater([], [])
|
111 |
+
self.mol_ph = '<mol>' * args.num_query_token
|
112 |
+
self.mol_token_id = tokenizer.mol_token_id
|
113 |
+
self.is_gal = args.opt_model.find('galactica') >= 0
|
114 |
+
self.collater = Collater([], [])
|
115 |
+
self.device = device
|
116 |
+
self.smiles_type = smiles_type
|
117 |
+
self.args = args
|
118 |
+
time_stamp = datetime.now().strftime("%Y.%m.%d-%H:%M")
|
119 |
+
self.cache_dir = f'results/{self.args.filename}/{time_stamp}'
|
120 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
121 |
+
|
122 |
+
def make_query_dict(self, rxn_string):
|
123 |
+
try:
|
124 |
+
reactant, solvent, product = rxn_string.split('>')
|
125 |
+
reactant = smiles_split(reactant)
|
126 |
+
product = smiles_split(product)
|
127 |
+
solvent = smiles_split(solvent) if solvent else []
|
128 |
+
assert reactant and product
|
129 |
+
except:
|
130 |
+
raise KeyError('Please input a valid reaction string')
|
131 |
+
|
132 |
+
extracted_molecules = {product[0]: "$-1$"}
|
133 |
+
for mol in reactant+solvent:
|
134 |
+
extracted_molecules[mol] = f"${len(extracted_molecules)}$"
|
135 |
+
|
136 |
+
result_dict = {}
|
137 |
+
result_dict['time_stamp'] = datetime.now().strftime("%Y.%m.%d %H:%M:%S.%f")[:-3]
|
138 |
+
result_dict['reaction_string'] = rxn_string
|
139 |
+
result_dict['REACTANT'] = reactant
|
140 |
+
result_dict['SOLVENT'] = solvent
|
141 |
+
result_dict['CATALYST'] = []
|
142 |
+
result_dict['PRODUCT'] = product
|
143 |
+
result_dict['extracted_molecules'] = extracted_molecules
|
144 |
+
return result_dict
|
145 |
+
|
146 |
+
def save_prediction(self, result_dict):
|
147 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
148 |
+
result_id = result_dict['time_stamp']
|
149 |
+
result_path = os.path.join(self.cache_dir, f'{result_id}.json')
|
150 |
+
json_write(result_path, result_dict)
|
151 |
+
|
152 |
+
def make_prompt(self, param_dict, smi_max_len=128):
|
153 |
+
smiles_list = []
|
154 |
+
prompt = ''
|
155 |
+
prompt += 'Reactants: '
|
156 |
+
smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len]
|
157 |
+
for smi in param_dict['REACTANT']:
|
158 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
159 |
+
smiles_list.append(smi)
|
160 |
+
|
161 |
+
prompt += 'Product: '
|
162 |
+
for smi in param_dict['PRODUCT']:
|
163 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
164 |
+
smiles_list.append(smi)
|
165 |
+
|
166 |
+
if param_dict['CATALYST']:
|
167 |
+
prompt += 'Catalysts: '
|
168 |
+
for smi in param_dict['CATALYST']:
|
169 |
+
if smi in param_dict["extracted_molecules"]:
|
170 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
171 |
+
else:
|
172 |
+
prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
173 |
+
smiles_list.append(smi)
|
174 |
+
|
175 |
+
if param_dict['SOLVENT']:
|
176 |
+
prompt += 'Solvents: '
|
177 |
+
for smi in param_dict['SOLVENT']:
|
178 |
+
if smi in param_dict["extracted_molecules"]:
|
179 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
180 |
+
else:
|
181 |
+
prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
182 |
+
smiles_list.append(smi)
|
183 |
+
|
184 |
+
prompt += 'Action Squence: '
|
185 |
+
return prompt, smiles_list
|
186 |
+
|
187 |
+
def get_action_elements(self, rxn_dict):
|
188 |
+
input_text, smiles_list = self.make_prompt(rxn_dict, self.smi_max_len)
|
189 |
+
|
190 |
+
graph_list = []
|
191 |
+
for smiles in smiles_list:
|
192 |
+
graph_item = smiles2data(smiles)
|
193 |
+
graph_list.append(graph_item)
|
194 |
+
return graph_list, input_text
|
195 |
+
|
196 |
+
@torch.no_grad()
|
197 |
+
def predict(self, rxn_dict, temperature=1):
|
198 |
+
graphs, prompt_tokens = self.tokenize(rxn_dict)
|
199 |
+
result_dict = rxn_dict
|
200 |
+
samples = {'graphs': graphs, 'prompt_tokens': prompt_tokens}
|
201 |
+
prediction = self.model.blip2opt.generate(
|
202 |
+
samples,
|
203 |
+
do_sample=self.args.do_sample,
|
204 |
+
num_beams=self.args.num_beams,
|
205 |
+
max_length=self.args.max_inference_len,
|
206 |
+
min_length=self.args.min_inference_len,
|
207 |
+
num_captions=self.args.num_generate_captions,
|
208 |
+
temperature=temperature,
|
209 |
+
use_graph=True
|
210 |
+
)[0]
|
211 |
+
for k, v in result_dict['extracted_molecules'].items():
|
212 |
+
prediction = prediction.replace(v, k)
|
213 |
+
result_dict['prediction'] = prediction
|
214 |
+
return result_dict
|
215 |
+
|
216 |
+
def tokenize(self, rxn_dict):
|
217 |
+
graph_list, input_text = self.get_action_elements(rxn_dict)
|
218 |
+
if graph_list:
|
219 |
+
graphs = self.collater(graph_list).to(self.device)
|
220 |
+
input_prompt = smiles_handler(input_text, self.mol_ph, self.is_gal)[0]
|
221 |
+
|
222 |
+
## deal with prompt
|
223 |
+
self.tokenizer.padding_side = 'left'
|
224 |
+
input_prompt_tokens = self.tokenizer(input_prompt,
|
225 |
+
truncation=True,
|
226 |
+
padding='max_length',
|
227 |
+
add_special_tokens=True,
|
228 |
+
max_length=self.rxn_max_len,
|
229 |
+
return_tensors='pt',
|
230 |
+
return_attention_mask=True).to(self.device)
|
231 |
+
is_mol_token = input_prompt_tokens.input_ids == self.mol_token_id
|
232 |
+
input_prompt_tokens['is_mol_token'] = is_mol_token
|
233 |
+
return graphs, input_prompt_tokens
|
234 |
+
|
235 |
+
def main(args):
|
236 |
+
device = torch.device('cuda')
|
237 |
+
# model
|
238 |
+
if args.init_checkpoint:
|
239 |
+
model = Blip2Model(args).to(device)
|
240 |
+
ckpt = torch.load(args.init_checkpoint, map_location='cpu')
|
241 |
+
model.load_state_dict(ckpt['state_dict'], strict=False)
|
242 |
+
print(f"loaded model from {args.init_checkpoint}")
|
243 |
+
else:
|
244 |
+
model = Blip2Model(args).to(device)
|
245 |
+
model.eval()
|
246 |
+
|
247 |
+
print('total params:', sum(p.numel() for p in model.parameters()))
|
248 |
+
|
249 |
+
if args.opt_model.find('galactica') >= 0 or args.opt_model.find('t5') >= 0:
|
250 |
+
tokenizer = model.blip2opt.opt_tokenizer
|
251 |
+
elif args.opt_model.find('llama') >= 0 or args.opt_model.find('vicuna') >= 0:
|
252 |
+
tokenizer = model.blip2opt.llm_tokenizer
|
253 |
+
else:
|
254 |
+
raise NotImplementedError
|
255 |
+
|
256 |
+
infer_runner = InferenceRunner(
|
257 |
+
model=model,
|
258 |
+
tokenizer=tokenizer,
|
259 |
+
rxn_max_len=args.rxn_max_len,
|
260 |
+
smi_max_len=args.smi_max_len,
|
261 |
+
device=device,
|
262 |
+
args=args
|
263 |
+
)
|
264 |
+
example_inputs = json_read('demo.json')
|
265 |
+
example_inputs = [[e] for e in example_inputs]
|
266 |
+
|
267 |
+
def online_chat(reaction_string, temperature=1):
|
268 |
+
data_item = infer_runner.make_query_dict(reaction_string)
|
269 |
+
result = infer_runner.predict(data_item, temperature=temperature)
|
270 |
+
infer_runner.save_prediction(result)
|
271 |
+
prediction = result['prediction'].replace(' ; ', ' ;\n')
|
272 |
+
return prediction
|
273 |
+
|
274 |
+
with gr.Blocks(css="""
|
275 |
+
.center { display: flex; justify-content: center; }
|
276 |
+
""") as demo:
|
277 |
+
gr.HTML(
|
278 |
+
"""
|
279 |
+
<center><h1><b>ReactXT</b></h1></center>
|
280 |
+
<p style="font-size:20px; font-weight:bold;">This is the demo page of our ACL 2024 paper
|
281 |
+
<i>ReactXT: Understanding Molecular “Reaction-ship” via Reaction-Contextualized Molecule-Text Pretraining.</i></p>
|
282 |
+
""")
|
283 |
+
with gr.Row(elem_classes="center"):
|
284 |
+
gr.Image(value="./figures/frameworks.jpg", elem_classes="center", width=800, label="Framework of ReactXT")
|
285 |
+
gr.HTML(
|
286 |
+
"""
|
287 |
+
<p style="font-size:16px;"> Please input one chemical reaction below, and we will generate the predicted experimental procedure.</p>
|
288 |
+
<p style="font-size:16px;"> The reaction should be in form of <b>Reactants>Reagents>Product</b>.</p>
|
289 |
+
""")
|
290 |
+
|
291 |
+
reaction_string = gr.Textbox(placeholder="Input one reaction", label='Input Reaction')
|
292 |
+
gr.Examples(example_inputs, [reaction_string,], fn=online_chat, label='Example Reactions')
|
293 |
+
with gr.Row():
|
294 |
+
btn = gr.Button("Submit")
|
295 |
+
clear_btn = gr.Button("Clear")
|
296 |
+
temperature = gr.Slider(0.1, 1, value=1, label='Temperature')
|
297 |
+
with gr.Row():
|
298 |
+
out = gr.Textbox(label="ReactXT's Output", placeholder="Predicted experimental procedure")
|
299 |
+
btn.click(fn=online_chat, inputs=[reaction_string, temperature], outputs=[out])
|
300 |
+
clear_btn.click(fn=lambda:("", ""), inputs=[], outputs=[reaction_string, out])
|
301 |
+
|
302 |
+
demo.launch(share=True)
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
if __name__ == '__main__':
|
307 |
+
args = get_args()
|
308 |
+
vars(args).update(app_config)
|
309 |
+
main(args)
|
average_ckpt.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def average_checkpoints(checkpoint_paths):
|
5 |
+
averaged_ckpt = torch.load(checkpoint_paths[-1], map_location=torch.device('cpu'))
|
6 |
+
param_sum_dict = {}
|
7 |
+
for key, value in averaged_ckpt['state_dict'].items():
|
8 |
+
param_sum_dict[key] = value.clone()
|
9 |
+
|
10 |
+
num_checkpoints = len(checkpoint_paths)
|
11 |
+
for ckpt_path in checkpoint_paths[:-1]:
|
12 |
+
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
13 |
+
for key, value in checkpoint['state_dict'].items():
|
14 |
+
param_sum_dict[key] += value
|
15 |
+
|
16 |
+
for key in param_sum_dict.keys():
|
17 |
+
param_sum_dict[key] = param_sum_dict[key] / num_checkpoints
|
18 |
+
averaged_ckpt['state_dict'] = param_sum_dict
|
19 |
+
|
20 |
+
return averaged_ckpt
|
21 |
+
|
22 |
+
def parse_arguments():
|
23 |
+
parser = argparse.ArgumentParser(description="Averages the weights of multiple transformer model checkpoints.")
|
24 |
+
parser.add_argument('--checkpoint_paths', nargs='+', required=True,
|
25 |
+
help='List of paths to the checkpoints to be averaged. Example: --checkpoint_paths path1 path2 path3')
|
26 |
+
parser.add_argument('--output_path', type=str, required=True,)
|
27 |
+
return parser.parse_args()
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
args = parse_arguments()
|
31 |
+
averaged_state_dict = average_checkpoints(args.checkpoint_paths)
|
32 |
+
torch.save(averaged_state_dict, args.output_path)
|
33 |
+
print(f"Averaged checkpoint saved to {args.output_path}")
|
convert.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
## read a path using argparse and pass it to convert_zero_checkpoint_to_fp32_state_dict
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('--input', type=str, default=None, help='path to the desired checkpoint folder')
|
9 |
+
parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file')
|
10 |
+
# parser.add_argument('--tag', type=str, help='checkpoint tag used as a unique identifier for checkpoint')
|
11 |
+
args = parser.parse_args()
|
12 |
+
if args.output is None:
|
13 |
+
args.output = Path(args.input) / 'converted.ckpt'
|
14 |
+
convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output)
|
data_provider/__init__.py
ADDED
File without changes
|
data_provider/caption_dataset.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch_geometric.data import Dataset
|
3 |
+
import os
|
4 |
+
from torch_geometric.data import InMemoryDataset
|
5 |
+
from .data_utils import reformat_smiles
|
6 |
+
import random
|
7 |
+
import json
|
8 |
+
|
9 |
+
class PubChemDataset(InMemoryDataset):
|
10 |
+
def __init__(self, path):
|
11 |
+
super(PubChemDataset, self).__init__()
|
12 |
+
self.data, self.slices = torch.load(path)
|
13 |
+
|
14 |
+
def __getitem__(self, idx):
|
15 |
+
return self.get(idx)
|
16 |
+
|
17 |
+
class CaptionDataset(Dataset):
|
18 |
+
def __init__(self, root, mode, smi_max_len=128, use_graph=True, disable_graph_cache=False, smiles_type='default'):
|
19 |
+
super(CaptionDataset, self).__init__(root)
|
20 |
+
self.root = root
|
21 |
+
self.file_path = os.path.join(root, f'{mode}.pt')
|
22 |
+
self.smi_max_len = smi_max_len
|
23 |
+
self.tokenizer = None
|
24 |
+
self.use_graph = use_graph
|
25 |
+
self.smiles_type = smiles_type
|
26 |
+
|
27 |
+
self.data = PubChemDataset(self.file_path)
|
28 |
+
|
29 |
+
def get(self, index):
|
30 |
+
return self.__getitem__(index)
|
31 |
+
|
32 |
+
def len(self):
|
33 |
+
return len(self)
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return len(self.data)
|
37 |
+
|
38 |
+
def __getitem__(self, index):
|
39 |
+
data = self.data[index]
|
40 |
+
smiles = reformat_smiles(data.smiles, smiles_type=self.smiles_type)
|
41 |
+
smiles_prompt = f'[START_I_SMILES]{smiles[:self.smi_max_len]}[END_I_SMILES]. '
|
42 |
+
|
43 |
+
text_list = []
|
44 |
+
count = 0
|
45 |
+
for line in data.text.split('\n'):
|
46 |
+
count += 1
|
47 |
+
text_list.append(line.strip())
|
48 |
+
if count > 100:
|
49 |
+
break
|
50 |
+
text = ' '.join(text_list) + '\n'
|
51 |
+
graph_list = [data] if self.use_graph else []
|
52 |
+
|
53 |
+
return index, graph_list, text, smiles_prompt
|
54 |
+
|
55 |
+
class PretrainCaptionDataset(Dataset):
|
56 |
+
def __init__(self, root, smi_max_len=128, use_graph=True, disable_graph_cache=False):
|
57 |
+
super(PretrainCaptionDataset, self).__init__(root)
|
58 |
+
self.pre_train_data = CaptionDataset(
|
59 |
+
root,
|
60 |
+
'pretrain',
|
61 |
+
smi_max_len=smi_max_len,
|
62 |
+
use_graph=use_graph,
|
63 |
+
)
|
64 |
+
self.train_data = CaptionDataset(
|
65 |
+
root,
|
66 |
+
'train',
|
67 |
+
smi_max_len=smi_max_len,
|
68 |
+
use_graph=use_graph,
|
69 |
+
)
|
70 |
+
|
71 |
+
def get(self, index):
|
72 |
+
return self.__getitem__(index)
|
73 |
+
|
74 |
+
def len(self):
|
75 |
+
return len(self)
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return len(self.pre_train_data) + len(self.train_data)
|
79 |
+
|
80 |
+
def __getitem__(self, index):
|
81 |
+
if index < len(self.pre_train_data):
|
82 |
+
index, graph_list, text, smiles_prompt = self.pre_train_data[index]
|
83 |
+
else:
|
84 |
+
index, graph_list, text, smiles_prompt = self.train_data[index - len(self.pre_train_data)]
|
85 |
+
graph_item = graph_list[0]
|
86 |
+
if hasattr(graph_item, 'iupac'):
|
87 |
+
del graph_item.iupac
|
88 |
+
if hasattr(graph_item, 'cid'):
|
89 |
+
del graph_item.cid
|
90 |
+
del graph_item.text
|
91 |
+
del graph_item.smiles
|
92 |
+
|
93 |
+
return graph_item, text, smiles_prompt
|
data_provider/chebi_dataset.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch_geometric.data import Dataset
|
3 |
+
import os
|
4 |
+
from torch_geometric.data import InMemoryDataset
|
5 |
+
import random
|
6 |
+
import json
|
7 |
+
from .data_utils import reformat_smiles
|
8 |
+
|
9 |
+
class ChEBI_dataset(Dataset):
|
10 |
+
def __init__(self, root, mode, smi_max_len=128, use_graph=True, disable_graph_cache=False, smiles_type='default'):
|
11 |
+
super(ChEBI_dataset, self).__init__(root)
|
12 |
+
self.root = root
|
13 |
+
self.file_path = os.path.join(root, f'{mode}.txt')
|
14 |
+
self.smi_max_len = smi_max_len
|
15 |
+
self.tokenizer = None
|
16 |
+
self.use_graph = use_graph
|
17 |
+
self.smiles_type = smiles_type
|
18 |
+
if self.use_graph:
|
19 |
+
self.idx_graph_map = torch.load(os.path.join(root, 'cid_graph_map.pt'))
|
20 |
+
with open(self.file_path) as f:
|
21 |
+
lines = f.readlines()
|
22 |
+
self.data = [line.split('\t', maxsplit=2) for line in lines[1:]]
|
23 |
+
|
24 |
+
|
25 |
+
def get(self, index):
|
26 |
+
return self.__getitem__(index)
|
27 |
+
|
28 |
+
def len(self):
|
29 |
+
return len(self)
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.data)
|
33 |
+
|
34 |
+
def __getitem__(self, index):
|
35 |
+
cid, smiles, text = self.data[index]
|
36 |
+
smiles = reformat_smiles(smiles, smiles_type=self.smiles_type)
|
37 |
+
smiles_prompt = f'[START_I_SMILES]{smiles[:self.smi_max_len]}[END_I_SMILES]. '
|
38 |
+
text = text.strip() + '\n'
|
39 |
+
if self.use_graph:
|
40 |
+
graph_list = [self.idx_graph_map[cid]]
|
41 |
+
|
42 |
+
return index, graph_list, text, smiles_prompt
|
data_provider/context_gen.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
from collections import defaultdict
|
7 |
+
from matplotlib import pyplot as plt
|
8 |
+
from collections import Counter
|
9 |
+
from .data_utils import json_read
|
10 |
+
|
11 |
+
def set_random_seed(seed):
|
12 |
+
random.seed(seed)
|
13 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
14 |
+
np.random.seed(seed)
|
15 |
+
|
16 |
+
class Reaction_Cluster:
|
17 |
+
def __init__(self, root, reaction_filename, reverse_ratio=0.5):
|
18 |
+
self.root = root
|
19 |
+
self.reaction_data = json_read(os.path.join(self.root, reaction_filename))
|
20 |
+
self.property_data = json_read(os.path.join(self.root, 'Abstract_property.json'))
|
21 |
+
self.mol_property_map = {d['canon_smiles']: d for d in self.property_data}
|
22 |
+
self.reverse_ratio = reverse_ratio
|
23 |
+
self.rxn_mols_attr = defaultdict(lambda:{
|
24 |
+
'freq': 0,
|
25 |
+
'occurrence': 0,
|
26 |
+
'in_caption': False,
|
27 |
+
})
|
28 |
+
|
29 |
+
self._read_reaction_mols() # add `valid_mols` in each rxn_dict
|
30 |
+
self.mol_counter = Counter(mol for rxn_dict in self.reaction_data for mol in rxn_dict['valid_mols'])
|
31 |
+
self._calculate_Pr() # calculate P(r), add `weight` in each rxn_dict
|
32 |
+
self._calculate_Pir() # calculate P(i|r), add `mol_weight` in each rxn_dict
|
33 |
+
|
34 |
+
def _read_reaction_mols(self):
|
35 |
+
self.valid_rxn_indices = []
|
36 |
+
for rxn_id, rxn_dict in enumerate(self.reaction_data):
|
37 |
+
mol_role_map = {}
|
38 |
+
for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']:
|
39 |
+
for m in rxn_dict[key]:
|
40 |
+
if m in mol_role_map:
|
41 |
+
continue
|
42 |
+
if m in self.mol_property_map:
|
43 |
+
mol_role_map[m] = key
|
44 |
+
valid_mols = []
|
45 |
+
for mol in mol_role_map:
|
46 |
+
assert mol in self.mol_property_map # this is garanteed by the above if statement
|
47 |
+
if 'abstract' not in self.mol_property_map[mol]:
|
48 |
+
continue
|
49 |
+
valid_mols.append(mol) # here the molecules should be in the R, C, S, P order.
|
50 |
+
if len(valid_mols) > 0:
|
51 |
+
self.valid_rxn_indices.append(rxn_id)
|
52 |
+
rxn_dict['valid_mols'] = valid_mols
|
53 |
+
rxn_dict['mol_role_map'] = mol_role_map
|
54 |
+
|
55 |
+
def _calculate_Pr(self):
|
56 |
+
total_weights = 0
|
57 |
+
for rxn_dict in self.reaction_data:
|
58 |
+
rxn_weight = sum([1/self.mol_counter[mol] for mol in rxn_dict['valid_mols']])
|
59 |
+
rxn_dict['weight'] = rxn_weight
|
60 |
+
total_weights += rxn_weight
|
61 |
+
for rxn_dict in self.reaction_data:
|
62 |
+
rxn_dict['weight'] = rxn_dict['weight'] / total_weights
|
63 |
+
|
64 |
+
def _calculate_Pir(self):
|
65 |
+
for rxn_dict in self.reaction_data:
|
66 |
+
mol_weight = {}
|
67 |
+
for mol in rxn_dict['valid_mols']:
|
68 |
+
mol_weight[mol] = 1/self.mol_counter[mol]
|
69 |
+
total_weight = sum(mol_weight.values())
|
70 |
+
rxn_dict['mol_weight'] = {m:w/total_weight for m, w in mol_weight.items()}
|
71 |
+
|
72 |
+
def choose_mol(self, valid_mols, k=4, weights=None):
|
73 |
+
if k>=len(valid_mols):
|
74 |
+
sampled_indices = list(range(len(valid_mols)))
|
75 |
+
else:
|
76 |
+
sampled_indices = np.random.choice(len(valid_mols), k, replace=False, p=weights)
|
77 |
+
sampled_indices = list(sampled_indices)
|
78 |
+
sampled_indices = sorted(sampled_indices)
|
79 |
+
if random.random() < self.reverse_ratio: # reverse the indices with reverse_ratio chance.
|
80 |
+
sampled_indices.reverse()
|
81 |
+
sampled_mols = [valid_mols[i] for i in sampled_indices]
|
82 |
+
return sampled_mols
|
83 |
+
|
84 |
+
def sample_mol_batch(self, index=None, k=4):
|
85 |
+
if index is None:
|
86 |
+
index = self.sample_rxn_index(1)[0]
|
87 |
+
assert index < len(self.reaction_data)
|
88 |
+
rxn = self.reaction_data[index]
|
89 |
+
valid_mols, weights = zip(*rxn['mol_weight'].items())
|
90 |
+
|
91 |
+
sampled_mols = self.choose_mol(valid_mols, k=k, weights=weights)
|
92 |
+
mol_property_batch = []
|
93 |
+
for mol in sampled_mols:
|
94 |
+
mol_property = self.mol_property_map[mol]
|
95 |
+
mol_role = rxn['mol_role_map'][mol]
|
96 |
+
mol_property['role'] = mol_role
|
97 |
+
mol_property_batch.append(mol_property)
|
98 |
+
if 'rsmiles_map' in rxn:
|
99 |
+
rsmiles_map = random.choice(rxn['rsmiles_map'])
|
100 |
+
for mol_property in mol_property_batch:
|
101 |
+
canon_smiles = mol_property['canon_smiles']
|
102 |
+
if canon_smiles in rsmiles_map:
|
103 |
+
mol_property['r_smiles'] = rsmiles_map[canon_smiles]
|
104 |
+
return mol_property_batch
|
105 |
+
|
106 |
+
def sample_rxn_index(self, num_samples):
|
107 |
+
indices = range(len(self.reaction_data))
|
108 |
+
weights = [d['weight'] for d in self.reaction_data]
|
109 |
+
return np.random.choice(indices, num_samples, replace=False, p=weights)
|
110 |
+
|
111 |
+
def __call__(self, rxn_num=1000, k=4):
|
112 |
+
sampled_indices = self.sample_rxn_index(rxn_num)
|
113 |
+
sampled_batch = [self.sample_mol_batch(idx, k=k) for idx in sampled_indices]
|
114 |
+
return sampled_batch
|
115 |
+
|
116 |
+
def generate_batch_uniform_rxn(self, rxn_num=1000, k=4):
|
117 |
+
assert rxn_num <= len(self.valid_rxn_indices)
|
118 |
+
sampled_rxn_indices = random.sample(self.valid_rxn_indices, rxn_num)
|
119 |
+
sampled_batch = []
|
120 |
+
for rxn_id in sampled_rxn_indices:
|
121 |
+
rxn = self.reaction_data[rxn_id]
|
122 |
+
sampled_mols = self.choose_mol(rxn['valid_mols'], k=k, weights=None)
|
123 |
+
mol_property_batch = []
|
124 |
+
for mol in sampled_mols:
|
125 |
+
mol_property = self.mol_property_map[mol]
|
126 |
+
mol_role = rxn['mol_role_map'][mol]
|
127 |
+
mol_property['role'] = mol_role
|
128 |
+
mol_property_batch.append(mol_property)
|
129 |
+
sampled_batch.append(mol_property_batch)
|
130 |
+
return sampled_batch
|
131 |
+
|
132 |
+
def generate_batch_uniform_mol(self, rxn_num=1000, k=4):
|
133 |
+
valid_mols = list(self.mol_counter.elements())
|
134 |
+
assert rxn_num*k <= len(valid_mols)
|
135 |
+
sampled_batch = []
|
136 |
+
sampled_mol_ids = random.sample(range(len(valid_mols)), rxn_num*k)
|
137 |
+
for i in range(rxn_num):
|
138 |
+
sampled_batch.append([self.mol_property_map[valid_mols[mol_id]] for mol_id in sampled_mol_ids[i*k:(i+1)*k]])
|
139 |
+
return sampled_batch
|
140 |
+
|
141 |
+
def generate_batch_single(self, rxn_num=1000):
|
142 |
+
valid_mols = list(self.mol_counter.elements())
|
143 |
+
sampled_mols = random.sample(valid_mols, rxn_num)
|
144 |
+
total_valid_mols = [[self.mol_property_map[mol]] for mol in sampled_mols]
|
145 |
+
return total_valid_mols
|
146 |
+
|
147 |
+
# visaulize probability for molecules in caption dataset.
|
148 |
+
def visualize_mol_distribution(self):
|
149 |
+
prob_dict = {mol:0.0 for mol in self.mol_property_map.keys()}
|
150 |
+
N = len(prob_dict)
|
151 |
+
M = len(self.reaction_data)
|
152 |
+
assert N == len(self.mol_property_map)
|
153 |
+
print(f'Number of molecules in Caption Dataset: {N}')
|
154 |
+
print(f'Number of Reactions in Reaction Dataset: {M}')
|
155 |
+
|
156 |
+
# prob distribution for molecules
|
157 |
+
for rxn_dict in self.reaction_data:
|
158 |
+
for mol, weight in rxn_dict['mol_weight'].items():
|
159 |
+
prob_dict[mol] += weight * rxn_dict['weight']
|
160 |
+
# sum of prob_dict.values() should already be 1.
|
161 |
+
prob_values = np.array(list(prob_dict.values()))
|
162 |
+
prob_values *= N
|
163 |
+
|
164 |
+
# prob distribution for reactions
|
165 |
+
rxn_weights = np.array([d['weight'] for d in self.reaction_data])
|
166 |
+
# sum of rxn_weights should already be 1.
|
167 |
+
rxn_weights *= M
|
168 |
+
|
169 |
+
return prob_values, rxn_weights
|
170 |
+
|
171 |
+
# visaulize the frequency for molecules in caption dataset.
|
172 |
+
def visualize_mol_frequency(self, rxn_num=1000, k=4, epochs=100):
|
173 |
+
sampled_mols_counter = Counter()
|
174 |
+
sampled_rxns_counter = Counter()
|
175 |
+
for _ in range(epochs):
|
176 |
+
rxn_indices = self.sample_rxn_index(rxn_num)
|
177 |
+
sampled_rxns_counter.update(rxn_indices)
|
178 |
+
for index in rxn_indices:
|
179 |
+
rxn = self.reaction_data[index]
|
180 |
+
if len(rxn['valid_mols']) ==0:
|
181 |
+
continue
|
182 |
+
valid_mols, weights = zip(*rxn['mol_weight'].items())
|
183 |
+
mol_batch = self.choose_mol(valid_mols, k=k, weights=weights)
|
184 |
+
sampled_mols_counter.update(mol_batch)
|
185 |
+
sampled_mols_count = np.array([c for _, c in sorted(sampled_mols_counter.items())])
|
186 |
+
sampled_rxns_count = np.array([c for _, c in sorted(sampled_rxns_counter.items())])
|
187 |
+
return sampled_mols_count, sampled_rxns_count
|
188 |
+
|
189 |
+
def _randomly(self, func, *args, **kwargs):
|
190 |
+
# make fake weights and backup the weights
|
191 |
+
for rxn_dict in self.reaction_data:
|
192 |
+
rxn_dict['weight_bak'] = rxn_dict['weight']
|
193 |
+
rxn_dict['weight'] = 1/len(self.reaction_data)
|
194 |
+
rxn_dict['mol_weight_bak'] = rxn_dict['mol_weight']
|
195 |
+
rxn_dict['mol_weight'] = {m:1/len(rxn_dict['mol_weight']) for m in rxn_dict['mol_weight']}
|
196 |
+
|
197 |
+
# run the function
|
198 |
+
result = func(*args, **kwargs)
|
199 |
+
|
200 |
+
# weights recovery
|
201 |
+
for rxn_dict in self.reaction_data:
|
202 |
+
rxn_dict['weight'] = rxn_dict['weight_bak']
|
203 |
+
del rxn_dict['weight_bak']
|
204 |
+
rxn_dict['mol_weight'] = rxn_dict['mol_weight_bak']
|
205 |
+
del rxn_dict['mol_weight_bak']
|
206 |
+
|
207 |
+
return result
|
data_provider/data_utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch_geometric.data import Data
|
3 |
+
from ogb.utils import smiles2graph
|
4 |
+
from rdkit import Chem
|
5 |
+
import random
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
from rdkit import RDLogger
|
9 |
+
RDLogger.DisableLog('rdApp.*')
|
10 |
+
from .r_smiles import multi_process
|
11 |
+
import multiprocessing
|
12 |
+
|
13 |
+
def reformat_smiles(smiles, smiles_type='default'):
|
14 |
+
if not smiles:
|
15 |
+
return None
|
16 |
+
if smiles_type == 'default':
|
17 |
+
return smiles
|
18 |
+
elif smiles_type=='canonical':
|
19 |
+
mol = Chem.MolFromSmiles(smiles)
|
20 |
+
return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
|
21 |
+
elif smiles_type=='restricted':
|
22 |
+
mol = Chem.MolFromSmiles(smiles)
|
23 |
+
new_atom_order = list(range(mol.GetNumAtoms()))
|
24 |
+
random.shuffle(new_atom_order)
|
25 |
+
random_mol = Chem.RenumberAtoms(mol, newOrder=new_atom_order)
|
26 |
+
return Chem.MolToSmiles(random_mol, canonical=False, isomericSmiles=False)
|
27 |
+
elif smiles_type=='unrestricted':
|
28 |
+
mol = Chem.MolFromSmiles(smiles)
|
29 |
+
return Chem.MolToSmiles(mol, canonical=False, doRandom=True, isomericSmiles=False)
|
30 |
+
elif smiles_type=='r_smiles':
|
31 |
+
# the implementation of root-aligned smiles is in r_smiles.py
|
32 |
+
return smiles
|
33 |
+
else:
|
34 |
+
raise NotImplementedError(f"smiles_type {smiles_type} not implemented")
|
35 |
+
|
36 |
+
def json_read(path):
|
37 |
+
with open(path, 'r') as f:
|
38 |
+
data = json.load(f)
|
39 |
+
return data
|
40 |
+
|
41 |
+
def json_write(path, data):
|
42 |
+
with open(path, 'w') as f:
|
43 |
+
json.dump(data, f, indent=4, ensure_ascii=False)
|
44 |
+
|
45 |
+
def format_float_from_string(s):
|
46 |
+
try:
|
47 |
+
float_value = float(s)
|
48 |
+
return f'{float_value:.2f}'
|
49 |
+
except ValueError:
|
50 |
+
return s
|
51 |
+
|
52 |
+
def make_abstract(mol_dict, abstract_max_len=256, property_max_len=256):
|
53 |
+
prompt = ''
|
54 |
+
if 'abstract' in mol_dict:
|
55 |
+
abstract_string = mol_dict['abstract'][:abstract_max_len]
|
56 |
+
prompt += f'[Abstract] {abstract_string} '
|
57 |
+
|
58 |
+
property_string = ''
|
59 |
+
property_dict = mol_dict['property'] if 'property' in mol_dict else {}
|
60 |
+
for property_key in ['Experimental Properties', 'Computed Properties']:
|
61 |
+
if not property_key in property_dict:
|
62 |
+
continue
|
63 |
+
for key, value in property_dict[property_key].items():
|
64 |
+
if isinstance(value, float):
|
65 |
+
key_value_string = f'{key}: {value:.2f}; '
|
66 |
+
elif isinstance(value, str):
|
67 |
+
float_value = format_float_from_string(value)
|
68 |
+
key_value_string = f'{key}: {float_value}; '
|
69 |
+
else:
|
70 |
+
key_value_string = f'{key}: {value}; '
|
71 |
+
if len(property_string+key_value_string) > property_max_len:
|
72 |
+
break
|
73 |
+
property_string += key_value_string
|
74 |
+
if property_string:
|
75 |
+
property_string = property_string[:property_max_len]
|
76 |
+
prompt += f'[Properties] {property_string}. '
|
77 |
+
return prompt
|
78 |
+
|
79 |
+
def smiles2data(smiles):
|
80 |
+
graph = smiles2graph(smiles)
|
81 |
+
x = torch.from_numpy(graph['node_feat'])
|
82 |
+
edge_index = torch.from_numpy(graph['edge_index'], )
|
83 |
+
edge_attr = torch.from_numpy(graph['edge_feat'])
|
84 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
85 |
+
return data
|
86 |
+
|
87 |
+
import re
|
88 |
+
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
|
89 |
+
|
90 |
+
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
|
91 |
+
|
92 |
+
|
93 |
+
def _insert_split_marker(m: re.Match):
|
94 |
+
"""
|
95 |
+
Applies split marker based on a regex match of special tokens such as
|
96 |
+
[START_DNA].
|
97 |
+
|
98 |
+
Parameters
|
99 |
+
----------
|
100 |
+
n : str
|
101 |
+
Input text to split
|
102 |
+
|
103 |
+
Returns
|
104 |
+
----------
|
105 |
+
str - the text with the split token added
|
106 |
+
"""
|
107 |
+
start_token, _, sequence, end_token = m.groups()
|
108 |
+
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
|
109 |
+
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
|
110 |
+
|
111 |
+
def escape_custom_split_sequence(text):
|
112 |
+
"""
|
113 |
+
Applies custom splitting to the text for GALILEO's tokenization
|
114 |
+
|
115 |
+
Parameters
|
116 |
+
----------
|
117 |
+
text : str
|
118 |
+
Input text to split
|
119 |
+
|
120 |
+
Returns
|
121 |
+
----------
|
122 |
+
str - the text with the split token added
|
123 |
+
"""
|
124 |
+
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
|
125 |
+
|
126 |
+
def generate_rsmiles(reactants, products, augmentation=20):
|
127 |
+
"""
|
128 |
+
reactants: list of N, reactant smiles
|
129 |
+
products: list of N, product smiles
|
130 |
+
augmentation: int, number of augmentations
|
131 |
+
|
132 |
+
return: list of N x augmentation
|
133 |
+
"""
|
134 |
+
data = [{
|
135 |
+
'reactant': r.strip().replace(' ', ''),
|
136 |
+
'product': p.strip().replace(' ', ''),
|
137 |
+
'augmentation': augmentation,
|
138 |
+
'root_aligned': True,
|
139 |
+
} for r, p in zip(reactants, products)]
|
140 |
+
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
|
141 |
+
results = pool.map(func=multi_process,iterable=data)
|
142 |
+
product_smiles = [smi for r in results for smi in r['src_data']]
|
143 |
+
reactant_smiles = [smi for r in results for smi in r['tgt_data']]
|
144 |
+
return reactant_smiles, product_smiles
|
data_provider/molecule_abstract_dataset.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch_geometric.data import Dataset
|
3 |
+
import os
|
4 |
+
from .context_gen import Reaction_Cluster
|
5 |
+
import json
|
6 |
+
from .data_utils import smiles2data, reformat_smiles
|
7 |
+
from collections import defaultdict
|
8 |
+
import random
|
9 |
+
from data_provider.caption_dataset import PretrainCaptionDataset
|
10 |
+
from data_provider.synthesis_dataset import SynthesisDataset
|
11 |
+
|
12 |
+
def format_float_from_string(s):
|
13 |
+
try:
|
14 |
+
float_value = float(s)
|
15 |
+
return f'{float_value:.2f}'
|
16 |
+
except ValueError:
|
17 |
+
return s
|
18 |
+
|
19 |
+
class MoleculeAbstract(Dataset):
|
20 |
+
def __init__(self,
|
21 |
+
root,
|
22 |
+
rxn_num=1000,
|
23 |
+
rxn_batch_size=4,
|
24 |
+
smi_max_len=128,
|
25 |
+
prompt=None,
|
26 |
+
disable_graph_cache=False,
|
27 |
+
disable_graphs=False,
|
28 |
+
context_style='weighted_rxn',
|
29 |
+
use_caption_dataset=False,
|
30 |
+
caption_batch_num=10000,
|
31 |
+
synthesis_datasetpath=None,
|
32 |
+
synthesis_batch_num=10000,
|
33 |
+
reverse_ratio=0.5,
|
34 |
+
enable_abstract=True,
|
35 |
+
enable_property=True,
|
36 |
+
smiles_type='default',
|
37 |
+
mode='train'
|
38 |
+
):
|
39 |
+
super(MoleculeAbstract, self).__init__(root)
|
40 |
+
self.root = root
|
41 |
+
self.rxn_num = rxn_num
|
42 |
+
self.rxn_batch_size = rxn_batch_size
|
43 |
+
self.smi_max_len = smi_max_len
|
44 |
+
self.context_style = context_style
|
45 |
+
self.tokenizer = None
|
46 |
+
self.disable_graph_cache = disable_graph_cache
|
47 |
+
self.disable_graphs = disable_graphs
|
48 |
+
self.use_caption_dataset = use_caption_dataset
|
49 |
+
self.smiles_type = smiles_type
|
50 |
+
if use_caption_dataset:
|
51 |
+
self.caption_dataset = PretrainCaptionDataset(
|
52 |
+
os.path.join(root, '../caption_data'),
|
53 |
+
smi_max_len=smi_max_len,
|
54 |
+
use_graph=not self.disable_graphs,
|
55 |
+
disable_graph_cache=disable_graph_cache,
|
56 |
+
smiles_type=smiles_type,
|
57 |
+
)
|
58 |
+
self.caption_batch_num = caption_batch_num
|
59 |
+
self.use_synthesis_dataset = bool(synthesis_datasetpath)
|
60 |
+
if self.use_synthesis_dataset:
|
61 |
+
self.synthesis_dataset = SynthesisDataset(
|
62 |
+
synthesis_datasetpath,
|
63 |
+
'train',
|
64 |
+
smi_max_len,
|
65 |
+
roundrobin_train=True,
|
66 |
+
use_graph=not disable_graphs,
|
67 |
+
disable_graph_cache=disable_graph_cache,
|
68 |
+
smiles_type='default',
|
69 |
+
)
|
70 |
+
self.synthesis_batch_num = synthesis_batch_num
|
71 |
+
if not self.disable_graphs:
|
72 |
+
self.mol_graph_map = torch.load(os.path.join(self.root, 'mol_graph_map.pt'))
|
73 |
+
reaction_filename = 'reactions/reactions_test.json' if (mode=='test') else 'reactions/reactions.json'
|
74 |
+
if smiles_type=='r_smiles':
|
75 |
+
reaction_filename = 'reactions/reactions_wRSMILES.json'
|
76 |
+
self.cluster = Reaction_Cluster(self.root, reaction_filename=reaction_filename, reverse_ratio=reverse_ratio)
|
77 |
+
self.reload_data_list()
|
78 |
+
self.abstract_max_len = 10240
|
79 |
+
self.property_max_len = 10240
|
80 |
+
self.enable_abstract = enable_abstract
|
81 |
+
self.enable_property = enable_property
|
82 |
+
|
83 |
+
def get(self, index):
|
84 |
+
return self.__getitem__(index)
|
85 |
+
|
86 |
+
def len(self):
|
87 |
+
return len(self)
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
data_len = len(self.data_list)
|
91 |
+
if self.use_caption_dataset:
|
92 |
+
data_len += len(self.caption_index_list)
|
93 |
+
if self.use_synthesis_dataset:
|
94 |
+
data_len += len(self.synthesis_index_list)
|
95 |
+
return data_len
|
96 |
+
|
97 |
+
def reload_data_list(self):
|
98 |
+
k = self.rxn_batch_size
|
99 |
+
if self.context_style == 'weighted_rxn':
|
100 |
+
self.data_list = self.cluster(self.rxn_num, k=k)
|
101 |
+
elif self.context_style == 'uniform_rxn':
|
102 |
+
self.data_list = self.cluster.generate_batch_uniform_rxn(self.rxn_num, k=k)
|
103 |
+
elif self.context_style == 'uniform_mol':
|
104 |
+
self.data_list = self.cluster.generate_batch_uniform_mol(self.rxn_num, k=k)
|
105 |
+
elif self.context_style == 'single_mol':
|
106 |
+
self.data_list = self.cluster.generate_batch_single(self.rxn_num)
|
107 |
+
elif self.context_style == 'hybrid':
|
108 |
+
self.data_list = self.cluster(self.rxn_num//2, k=k)
|
109 |
+
self.data_list += self.cluster.generate_batch_uniform_mol(self.rxn_num//2, k=k)
|
110 |
+
else:
|
111 |
+
raise NotImplementedError
|
112 |
+
if self.use_caption_dataset:
|
113 |
+
assert self.caption_batch_num*k <= len(self.caption_dataset)
|
114 |
+
caption_index_list = random.sample(range(len(self.caption_dataset)), self.caption_batch_num*k)
|
115 |
+
self.caption_index_list = [caption_index_list[i*k:(i+1)*k] for i in range(self.caption_batch_num)]
|
116 |
+
else:
|
117 |
+
self.caption_index_list = []
|
118 |
+
if self.use_synthesis_dataset:
|
119 |
+
if self.synthesis_dataset.roundrobin_train:
|
120 |
+
self.synthesis_dataset.reload_data()
|
121 |
+
assert self.synthesis_batch_num <= len(self.synthesis_dataset)
|
122 |
+
self.synthesis_index_list = random.sample(range(len(self.synthesis_dataset)), self.synthesis_batch_num)
|
123 |
+
else:
|
124 |
+
self.synthesis_index_list = []
|
125 |
+
|
126 |
+
def make_prompt(self, mol_batch, smi_max_len=128):
|
127 |
+
mol_prompt_list, text_prompt_list = [], []
|
128 |
+
last_role = None
|
129 |
+
for mol_dict in mol_batch:
|
130 |
+
smiles = mol_dict['canon_smiles']
|
131 |
+
if self.smiles_type=='r_smiles':
|
132 |
+
if 'r_smiles' in mol_dict:
|
133 |
+
smiles = mol_dict['r_smiles']
|
134 |
+
# else:
|
135 |
+
# smiles = reformat_smiles(smiles, smiles_type='restricted')
|
136 |
+
else:
|
137 |
+
smiles = reformat_smiles(smiles, smiles_type=self.smiles_type)
|
138 |
+
mol_prompt = f'[START_SMILES]{smiles[:smi_max_len]}[END_SMILES]. '
|
139 |
+
if 'role' in mol_dict:
|
140 |
+
role = {
|
141 |
+
'REACTANT': 'Reactant',
|
142 |
+
'CATALYST': 'Catalyst',
|
143 |
+
'SOLVENT': 'Solvent',
|
144 |
+
'PRODUCT': 'Product',
|
145 |
+
}[mol_dict['role']]
|
146 |
+
if last_role != role:
|
147 |
+
mol_prompt = f'{role}: {mol_prompt}'
|
148 |
+
last_role = role
|
149 |
+
text_prompt = self.make_abstract(mol_dict)
|
150 |
+
mol_prompt_list.append(mol_prompt)
|
151 |
+
text_prompt_list.append(text_prompt)
|
152 |
+
return mol_prompt_list, text_prompt_list
|
153 |
+
|
154 |
+
def make_abstract(self, mol_dict):
|
155 |
+
prompt = ''
|
156 |
+
if self.enable_abstract and 'abstract' in mol_dict:
|
157 |
+
abstract_string = mol_dict['abstract'][:self.abstract_max_len]
|
158 |
+
prompt += f'[Abstract] {abstract_string} '
|
159 |
+
|
160 |
+
if self.enable_property:
|
161 |
+
property_string = ''
|
162 |
+
property_dict = mol_dict['property'] if 'property' in mol_dict else {}
|
163 |
+
for property_key in ['Experimental Properties', 'Computed Properties']:
|
164 |
+
if not property_key in property_dict:
|
165 |
+
continue
|
166 |
+
for key, value in property_dict[property_key].items():
|
167 |
+
if isinstance(value, float):
|
168 |
+
key_value_string = f'{key}: {value:.2f}; '
|
169 |
+
elif isinstance(value, str):
|
170 |
+
float_value = format_float_from_string(value)
|
171 |
+
key_value_string = f'{key}: {float_value}; '
|
172 |
+
else:
|
173 |
+
key_value_string = f'{key}: {value}; '
|
174 |
+
if len(property_string+key_value_string) > self.property_max_len:
|
175 |
+
break
|
176 |
+
property_string += key_value_string
|
177 |
+
if property_string:
|
178 |
+
property_string = property_string[:self.property_max_len]
|
179 |
+
prompt += f'[Properties] {property_string}. '
|
180 |
+
return prompt
|
181 |
+
|
182 |
+
def get_caption_data(self, index):
|
183 |
+
caption_index = self.caption_index_list[index]
|
184 |
+
graph_list, mol_prompt_list, text_prompt_list = [], [], []
|
185 |
+
for idx in caption_index:
|
186 |
+
graph_item, text, smiles_prompt = self.caption_dataset[idx]
|
187 |
+
graph_list.append(graph_item)
|
188 |
+
mol_prompt_list.append(smiles_prompt)
|
189 |
+
text_prompt_list.append(text)
|
190 |
+
|
191 |
+
return graph_list, mol_prompt_list, text_prompt_list
|
192 |
+
|
193 |
+
def get_synthesis_data(self, index):
|
194 |
+
synthesis_index = self.synthesis_index_list[index]
|
195 |
+
_, graph_list, output_text, input_text = self.synthesis_dataset[synthesis_index]
|
196 |
+
return graph_list, [input_text], [output_text]
|
197 |
+
|
198 |
+
def __getitem__(self, index):
|
199 |
+
if index < len(self.data_list):
|
200 |
+
mol_batch = self.data_list[index]
|
201 |
+
elif index < len(self.data_list)+len(self.caption_index_list):
|
202 |
+
assert self.use_caption_dataset
|
203 |
+
return self.get_caption_data(index-len(self.data_list))
|
204 |
+
else:
|
205 |
+
assert self.use_synthesis_dataset
|
206 |
+
return self.get_synthesis_data(index-(len(self.data_list)+len(self.caption_index_list)))
|
207 |
+
|
208 |
+
graph_list = []
|
209 |
+
for mol_dict in mol_batch:
|
210 |
+
smiles = mol_dict['canon_smiles']
|
211 |
+
if self.disable_graphs:
|
212 |
+
graph_item = None
|
213 |
+
else:
|
214 |
+
if self.disable_graph_cache:
|
215 |
+
graph_item = smiles2data(smiles)
|
216 |
+
else:
|
217 |
+
assert smiles in self.mol_graph_map
|
218 |
+
graph_item = self.mol_graph_map[smiles]
|
219 |
+
graph_list.append(graph_item)
|
220 |
+
mol_prompt_list, text_prompt_list = self.make_prompt(mol_batch, smi_max_len=self.smi_max_len)
|
221 |
+
|
222 |
+
return graph_list, mol_prompt_list, text_prompt_list
|
data_provider/pretrain_dm.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
import torch
|
4 |
+
from pytorch_lightning import LightningDataModule
|
5 |
+
import torch_geometric
|
6 |
+
# from torch_geometric.loader import DataLoader
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch_geometric.loader.dataloader import Collater
|
9 |
+
from data_provider.molecule_abstract_dataset import MoleculeAbstract
|
10 |
+
import re
|
11 |
+
from transformers import BatchEncoding
|
12 |
+
|
13 |
+
# we split individual characters inside special tokens like [START_DNA]
|
14 |
+
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
|
15 |
+
|
16 |
+
# token added to implement a custom sequence tokenization. This token is added at
|
17 |
+
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
|
18 |
+
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
|
19 |
+
# literally in the source code in case we ever include it in the training data.
|
20 |
+
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
|
21 |
+
|
22 |
+
def _insert_split_marker(m: re.Match):
|
23 |
+
"""
|
24 |
+
Applies split marker based on a regex match of special tokens such as
|
25 |
+
[START_DNA].
|
26 |
+
|
27 |
+
Parameters
|
28 |
+
----------
|
29 |
+
n : str
|
30 |
+
Input text to split
|
31 |
+
|
32 |
+
Returns
|
33 |
+
----------
|
34 |
+
str - the text with the split token added
|
35 |
+
"""
|
36 |
+
start_token, _, sequence, end_token = m.groups()
|
37 |
+
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
|
38 |
+
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
|
39 |
+
|
40 |
+
|
41 |
+
def smiles_handler(text, mol_ph, is_gal=True):
|
42 |
+
smiles_list = []
|
43 |
+
for match in CUSTOM_SEQ_RE.finditer(text):
|
44 |
+
smiles = match.group(3)
|
45 |
+
smiles_list.append(smiles)
|
46 |
+
if is_gal:
|
47 |
+
text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text)
|
48 |
+
text = escape_custom_split_sequence(text)
|
49 |
+
return text, smiles_list
|
50 |
+
else:
|
51 |
+
text = CUSTOM_SEQ_RE.sub(r'\3%s' % (mol_ph), text)
|
52 |
+
return text, smiles_list
|
53 |
+
|
54 |
+
|
55 |
+
def escape_custom_split_sequence(text):
|
56 |
+
"""
|
57 |
+
Applies custom splitting to the text for GALILEO's tokenization
|
58 |
+
|
59 |
+
Parameters
|
60 |
+
----------
|
61 |
+
text : str
|
62 |
+
Input text to split
|
63 |
+
|
64 |
+
Returns
|
65 |
+
----------
|
66 |
+
str - the text with the split token added
|
67 |
+
"""
|
68 |
+
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
|
69 |
+
|
70 |
+
|
71 |
+
def tokenize_and_merge_batched_qa_pairs(tokenizer, qa_pairs_list, max_length):
|
72 |
+
tokenized_batches = {
|
73 |
+
'input_ids': [],
|
74 |
+
'attention_mask': []
|
75 |
+
}
|
76 |
+
for qa_pairs in qa_pairs_list:
|
77 |
+
max_length_per_qa = max_length // len(qa_pairs)
|
78 |
+
batch_input_ids = []
|
79 |
+
batch_attention_mask = []
|
80 |
+
for qa in qa_pairs:
|
81 |
+
# here qa should be string
|
82 |
+
tokens = tokenizer(qa,
|
83 |
+
truncation=True,
|
84 |
+
padding=False,
|
85 |
+
add_special_tokens=False,
|
86 |
+
max_length=max_length_per_qa,
|
87 |
+
return_tensors='pt',
|
88 |
+
return_attention_mask=True)
|
89 |
+
batch_input_ids.extend(tokens['input_ids'].squeeze().tolist())
|
90 |
+
batch_attention_mask.extend(tokens['attention_mask'].squeeze().tolist())
|
91 |
+
|
92 |
+
# Pad the batch to max_length
|
93 |
+
padding_length = max_length - len(batch_input_ids)
|
94 |
+
batch_input_ids.extend([tokenizer.pad_token_id] * padding_length)
|
95 |
+
batch_attention_mask.extend([0] * padding_length)
|
96 |
+
|
97 |
+
tokenized_batches['input_ids'].append(torch.tensor(batch_input_ids).unsqueeze(0))
|
98 |
+
tokenized_batches['attention_mask'].append(torch.tensor(batch_attention_mask).unsqueeze(0))
|
99 |
+
|
100 |
+
tokenized_batches['input_ids'] = torch.cat(tokenized_batches['input_ids'], dim=0)
|
101 |
+
tokenized_batches['attention_mask'] = torch.cat(tokenized_batches['attention_mask'], dim=0)
|
102 |
+
|
103 |
+
tokenized_batch = BatchEncoding(data=tokenized_batches, tensor_type='pt')
|
104 |
+
return tokenized_batch
|
105 |
+
|
106 |
+
class TrainCollater:
|
107 |
+
def __init__(self, tokenizer, text_max_len, mol_ph, mol_token_id, is_gal=True, disable_graphs=False):
|
108 |
+
self.text_max_len = text_max_len
|
109 |
+
self.tokenizer = tokenizer
|
110 |
+
self.collater = Collater([], [])
|
111 |
+
self.mol_ph = mol_ph
|
112 |
+
self.mol_token_id = mol_token_id
|
113 |
+
self.is_gal = is_gal
|
114 |
+
self.disable_graphs = disable_graphs
|
115 |
+
|
116 |
+
def __call__(self, batch):
|
117 |
+
graphs, mol_prompt, text_prompt = zip(*batch)
|
118 |
+
if not self.disable_graphs:
|
119 |
+
graphs = [graph for graph_batch in graphs for graph in graph_batch]
|
120 |
+
graphs = self.collater(graphs)
|
121 |
+
|
122 |
+
qa_pairs = []
|
123 |
+
for mol_batch, text_batch in zip(mol_prompt, text_prompt):
|
124 |
+
qa_list = []
|
125 |
+
for mol_prompt, text_prompt in zip(mol_batch, text_batch):
|
126 |
+
smiles_prompt = smiles_handler(mol_prompt, self.mol_ph, self.is_gal)[0]
|
127 |
+
qa_list.append(f'{smiles_prompt} {text_prompt}')
|
128 |
+
qa_pairs.append(qa_list)
|
129 |
+
|
130 |
+
self.tokenizer.padding_side = 'right'
|
131 |
+
qa_batch = tokenize_and_merge_batched_qa_pairs(self.tokenizer, qa_pairs, self.text_max_len)
|
132 |
+
|
133 |
+
is_mol_token = qa_batch.input_ids == self.mol_token_id
|
134 |
+
qa_batch['is_mol_token'] = is_mol_token
|
135 |
+
|
136 |
+
return graphs, qa_batch
|
137 |
+
|
138 |
+
class InferenceCollater:
|
139 |
+
def __init__(self, tokenizer, text_max_len, mol_ph, mol_token_id, is_gal=True, disable_graphs=False, last_only=False):
|
140 |
+
self.text_max_len = text_max_len
|
141 |
+
self.tokenizer = tokenizer
|
142 |
+
self.collater = Collater([], [])
|
143 |
+
self.mol_ph = mol_ph
|
144 |
+
self.mol_token_id = mol_token_id
|
145 |
+
self.is_gal = is_gal
|
146 |
+
self.disable_graphs = disable_graphs
|
147 |
+
self.last_only = last_only
|
148 |
+
|
149 |
+
def __call__(self, batch):
|
150 |
+
graphs, mol_prompt, text_prompt = zip(*batch)
|
151 |
+
rxn_ids = [0 for i in range(len(mol_prompt))]
|
152 |
+
if self.last_only:
|
153 |
+
mol_prompt = [[mol_batch[-1]] for mol_batch in mol_prompt]
|
154 |
+
text_prompt = [[text_batch[-1]] for text_batch in text_prompt]
|
155 |
+
graphs = [[graph_batch[-1]] for graph_batch in graphs]
|
156 |
+
if not self.disable_graphs:
|
157 |
+
graphs = [graph for graph_batch in graphs for graph in graph_batch]
|
158 |
+
graphs = self.collater(graphs)
|
159 |
+
|
160 |
+
input_text, output_text = [], []
|
161 |
+
for mol_batch, text_batch in zip(mol_prompt, text_prompt):
|
162 |
+
qa_list = []
|
163 |
+
for mol_prompt, text_prompt in list(zip(mol_batch, text_batch))[:-1]:
|
164 |
+
smiles_prompt = smiles_handler(mol_prompt, self.mol_ph, self.is_gal)[0]
|
165 |
+
qa_list.append(f'{smiles_prompt} {text_prompt}')
|
166 |
+
qa_list.append(f'{smiles_handler(mol_batch[-1], self.mol_ph, self.is_gal)[0]} ')
|
167 |
+
output_text.append(text_batch[-1])
|
168 |
+
input_text.append(qa_list)
|
169 |
+
|
170 |
+
self.tokenizer.padding_side = 'right'
|
171 |
+
input_batch = tokenize_and_merge_batched_qa_pairs(self.tokenizer, input_text, self.text_max_len)
|
172 |
+
|
173 |
+
is_mol_token = input_batch.input_ids == self.mol_token_id
|
174 |
+
input_batch['is_mol_token'] = is_mol_token
|
175 |
+
|
176 |
+
return rxn_ids, graphs, input_batch, output_text, input_text
|
177 |
+
|
178 |
+
|
179 |
+
class PretrainDM(LightningDataModule):
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
num_workers: int = 0,
|
183 |
+
batch_size: int = 256,
|
184 |
+
root: str = 'data/',
|
185 |
+
text_max_len: int = 128,
|
186 |
+
rxn_max_len: int = 128,
|
187 |
+
smi_max_len: int = 128,
|
188 |
+
tokenizer=None,
|
189 |
+
args=None,
|
190 |
+
):
|
191 |
+
super().__init__()
|
192 |
+
self.args = args
|
193 |
+
self.batch_size = batch_size
|
194 |
+
self.inference_batch_size = args.inference_batch_size
|
195 |
+
self.num_workers = num_workers
|
196 |
+
self.text_max_len = text_max_len
|
197 |
+
self.rxn_max_len = rxn_max_len
|
198 |
+
self.pretrain_dataset = MoleculeAbstract(
|
199 |
+
root,
|
200 |
+
rxn_num=args.pretrain_rxn_num,
|
201 |
+
rxn_batch_size=args.rxn_batch_size,
|
202 |
+
smi_max_len=smi_max_len,
|
203 |
+
disable_graph_cache=args.disable_graph_cache,
|
204 |
+
context_style=args.context_style,
|
205 |
+
disable_graphs=args.disable_graphs,
|
206 |
+
use_caption_dataset=args.pretrain_use_caption,
|
207 |
+
caption_batch_num=args.caption_batch_num,
|
208 |
+
synthesis_datasetpath=args.pretrain_synthesis_path,
|
209 |
+
synthesis_batch_num=args.synthesis_batch_num,
|
210 |
+
reverse_ratio=args.reverse_ratio,
|
211 |
+
enable_abstract=not args.disable_abstract,
|
212 |
+
enable_property=not args.disable_property,
|
213 |
+
smiles_type=args.smiles_type,
|
214 |
+
)
|
215 |
+
self.test_dataset = MoleculeAbstract(
|
216 |
+
root,
|
217 |
+
rxn_num=args.pretrain_rxn_num,
|
218 |
+
rxn_batch_size=args.rxn_batch_size,
|
219 |
+
smi_max_len=smi_max_len,
|
220 |
+
disable_graph_cache=args.disable_graph_cache,
|
221 |
+
context_style=args.context_style,
|
222 |
+
disable_graphs=args.disable_graphs,
|
223 |
+
use_caption_dataset=args.pretrain_use_caption,
|
224 |
+
caption_batch_num=args.caption_batch_num,
|
225 |
+
reverse_ratio=args.reverse_ratio,
|
226 |
+
enable_abstract=not args.disable_abstract,
|
227 |
+
enable_property=not args.disable_property,
|
228 |
+
smiles_type=args.smiles_type,
|
229 |
+
mode='test',
|
230 |
+
)
|
231 |
+
self.init_tokenizer(tokenizer)
|
232 |
+
self.mol_ph_token = '<mol>' * self.args.num_query_token
|
233 |
+
self.is_gal = args.opt_model.find('galactica') >= 0
|
234 |
+
self.disable_graphs = args.disable_graphs
|
235 |
+
self.last_only = args.pretrain_eval_last_only
|
236 |
+
|
237 |
+
def init_tokenizer(self, tokenizer):
|
238 |
+
self.tokenizer = tokenizer
|
239 |
+
self.pretrain_dataset.tokenizer = tokenizer
|
240 |
+
self.test_dataset.tokenizer = tokenizer
|
241 |
+
self.mol_token_id = self.tokenizer.mol_token_id
|
242 |
+
# self.tokenizer.mol_token_id = tokenizer("<mol>", add_special_tokens=False).input_ids[0]
|
243 |
+
|
244 |
+
def train_dataloader(self):
|
245 |
+
self.pretrain_dataset.reload_data_list()
|
246 |
+
loader = DataLoader(
|
247 |
+
self.pretrain_dataset,
|
248 |
+
batch_size=self.batch_size,
|
249 |
+
shuffle=True,
|
250 |
+
num_workers=self.num_workers,
|
251 |
+
pin_memory=False,
|
252 |
+
drop_last=True,
|
253 |
+
persistent_workers=True,
|
254 |
+
collate_fn=TrainCollater(
|
255 |
+
tokenizer=self.tokenizer,
|
256 |
+
text_max_len=self.text_max_len,
|
257 |
+
mol_ph=self.mol_ph_token,
|
258 |
+
mol_token_id=self.mol_token_id,
|
259 |
+
is_gal=self.is_gal,
|
260 |
+
disable_graphs=self.disable_graphs,
|
261 |
+
),
|
262 |
+
)
|
263 |
+
return loader
|
264 |
+
def val_dataloader(self):
|
265 |
+
test_loader = DataLoader(
|
266 |
+
self.test_dataset,
|
267 |
+
batch_size=self.inference_batch_size,
|
268 |
+
shuffle=False,
|
269 |
+
num_workers=self.num_workers,
|
270 |
+
pin_memory=False,
|
271 |
+
drop_last=False,
|
272 |
+
persistent_workers=True,
|
273 |
+
collate_fn=InferenceCollater(
|
274 |
+
tokenizer=self.tokenizer,
|
275 |
+
text_max_len=self.text_max_len,
|
276 |
+
mol_ph=self.mol_ph_token,
|
277 |
+
mol_token_id=self.mol_token_id,
|
278 |
+
is_gal=self.is_gal,
|
279 |
+
disable_graphs=self.disable_graphs,
|
280 |
+
last_only=self.last_only,
|
281 |
+
),
|
282 |
+
)
|
283 |
+
return [test_loader]
|
284 |
+
|
285 |
+
def add_model_specific_args(parent_parser):
|
286 |
+
parser = parent_parser.add_argument_group("Data module")
|
287 |
+
parser.add_argument('--num_workers', type=int, default=2)
|
288 |
+
parser.add_argument('--batch_size', type=int, default=4)
|
289 |
+
parser.add_argument('--inference_batch_size', type=int, default=4)
|
290 |
+
parser.add_argument('--use_smiles', action='store_true', default=False)
|
291 |
+
parser.add_argument('--root', type=str, default='data/action_data')
|
292 |
+
parser.add_argument('--context_style', type=str, default='weighted_rxn', choices=['weighted_rxn', 'uniform_rxn', 'uniform_mol', 'single_mol', 'hybrid'])
|
293 |
+
parser.add_argument('--rxn_max_len', type=int, default=512)
|
294 |
+
parser.add_argument('--text_max_len', type=int, default=512)
|
295 |
+
parser.add_argument('--smi_max_len', type=int, default=128)
|
296 |
+
parser.add_argument('--pretrain_rxn_num', type=int, default=50000)
|
297 |
+
parser.add_argument('--reverse_ratio', type=float, default=0.5, help='ratio of reversed reactions (retro reactions)')
|
298 |
+
parser.add_argument('--disable_abstract', action='store_true', default=False)
|
299 |
+
parser.add_argument('--disable_property', action='store_true', default=False)
|
300 |
+
parser.add_argument('--pretrain_use_caption', action='store_true', default=False)
|
301 |
+
parser.add_argument('--caption_batch_num', type=int, default=5000)
|
302 |
+
parser.add_argument('--pretrain_synthesis_path', type=str, default=None)
|
303 |
+
parser.add_argument('--synthesis_batch_num', type=int, default=5000)
|
304 |
+
parser.add_argument('--rxn_batch_size', type=int, default=4)
|
305 |
+
parser.add_argument('--roundrobin_train', action='store_true', default=False)
|
306 |
+
parser.add_argument('--test_subset', type=int, default=-1)
|
307 |
+
parser.add_argument('--pretrain_eval_last_only', default=False, action='store_true')
|
308 |
+
parser.add_argument('--prompt', type=str, default=None)
|
309 |
+
return parent_parser
|
data_provider/r_smiles.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import argparse
|
3 |
+
import re
|
4 |
+
import random
|
5 |
+
import textdistance
|
6 |
+
|
7 |
+
from rdkit import Chem
|
8 |
+
|
9 |
+
|
10 |
+
from rdkit import RDLogger
|
11 |
+
RDLogger.DisableLog('rdApp.*')
|
12 |
+
|
13 |
+
|
14 |
+
def smi_tokenizer(smi):
|
15 |
+
pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
|
16 |
+
regex = re.compile(pattern)
|
17 |
+
tokens = [token for token in regex.findall(smi)]
|
18 |
+
assert smi == ''.join(tokens)
|
19 |
+
return ' '.join(tokens)
|
20 |
+
|
21 |
+
|
22 |
+
def clear_map_canonical_smiles(smi, canonical=True, root=-1):
|
23 |
+
mol = Chem.MolFromSmiles(smi)
|
24 |
+
if mol is not None:
|
25 |
+
for atom in mol.GetAtoms():
|
26 |
+
if atom.HasProp('molAtomMapNumber'):
|
27 |
+
atom.ClearProp('molAtomMapNumber')
|
28 |
+
return Chem.MolToSmiles(mol, isomericSmiles=True, rootedAtAtom=root, canonical=canonical)
|
29 |
+
else:
|
30 |
+
return smi
|
31 |
+
|
32 |
+
|
33 |
+
def get_cano_map_number(smi,root=-1):
|
34 |
+
atommap_mol = Chem.MolFromSmiles(smi)
|
35 |
+
canonical_mol = Chem.MolFromSmiles(clear_map_canonical_smiles(smi,root=root))
|
36 |
+
cano2atommapIdx = atommap_mol.GetSubstructMatch(canonical_mol)
|
37 |
+
correct_mapped = [canonical_mol.GetAtomWithIdx(i).GetSymbol() == atommap_mol.GetAtomWithIdx(index).GetSymbol() for i,index in enumerate(cano2atommapIdx)]
|
38 |
+
atom_number = len(canonical_mol.GetAtoms())
|
39 |
+
if np.sum(correct_mapped) < atom_number or len(cano2atommapIdx) < atom_number:
|
40 |
+
cano2atommapIdx = [0] * atom_number
|
41 |
+
atommap2canoIdx = canonical_mol.GetSubstructMatch(atommap_mol)
|
42 |
+
if len(atommap2canoIdx) != atom_number:
|
43 |
+
return None
|
44 |
+
for i, index in enumerate(atommap2canoIdx):
|
45 |
+
cano2atommapIdx[index] = i
|
46 |
+
id2atommap = [atom.GetAtomMapNum() for atom in atommap_mol.GetAtoms()]
|
47 |
+
|
48 |
+
return [id2atommap[cano2atommapIdx[i]] for i in range(atom_number)]
|
49 |
+
|
50 |
+
|
51 |
+
def get_root_id(mol,root_map_number):
|
52 |
+
root = -1
|
53 |
+
for i, atom in enumerate(mol.GetAtoms()):
|
54 |
+
if atom.GetAtomMapNum() == root_map_number:
|
55 |
+
root = i
|
56 |
+
break
|
57 |
+
return root
|
58 |
+
# root = -1
|
59 |
+
# for i, atom in enumerate(mol.GetAtoms()):
|
60 |
+
# if atom.GetAtomMapNum() == root_map_number:
|
61 |
+
# return i
|
62 |
+
|
63 |
+
|
64 |
+
def get_forward_rsmiles(data):
|
65 |
+
pt = re.compile(r':(\d+)]')
|
66 |
+
product = data['product']
|
67 |
+
reactant = data['reactant']
|
68 |
+
augmentation = data['augmentation']
|
69 |
+
separated = data['separated']
|
70 |
+
pro_mol = Chem.MolFromSmiles(product)
|
71 |
+
rea_mol = Chem.MolFromSmiles(reactant)
|
72 |
+
"""checking data quality"""
|
73 |
+
rids = sorted(re.findall(pt, reactant))
|
74 |
+
pids = sorted(re.findall(pt, product))
|
75 |
+
return_status = {
|
76 |
+
"status":0,
|
77 |
+
"src_data":[],
|
78 |
+
"tgt_data":[],
|
79 |
+
"edit_distance":0,
|
80 |
+
}
|
81 |
+
reactant = reactant.split(".")
|
82 |
+
product = product.split(".")
|
83 |
+
rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant]
|
84 |
+
max_times = np.prod([len(map_numbers) for map_numbers in rea_atom_map_numbers])
|
85 |
+
times = min(augmentation, max_times)
|
86 |
+
reactant_roots = [[-1 for _ in reactant]]
|
87 |
+
j = 0
|
88 |
+
while j < times:
|
89 |
+
reactant_roots.append([random.sample(rea_atom_map_numbers[k], 1)[0] for k in range(len(reactant))])
|
90 |
+
if reactant_roots[-1] in reactant_roots[:-1]:
|
91 |
+
reactant_roots.pop()
|
92 |
+
else:
|
93 |
+
j += 1
|
94 |
+
if j < augmentation:
|
95 |
+
reactant_roots.extend(random.choices(reactant_roots, k=augmentation - times))
|
96 |
+
times = augmentation
|
97 |
+
reversable = False # no reverse
|
98 |
+
assert times == augmentation
|
99 |
+
if reversable:
|
100 |
+
times = int(times / 2)
|
101 |
+
|
102 |
+
pro_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", pro))) for pro in product]
|
103 |
+
full_pro_atom_map_numbers = set(map(int, re.findall(r"(?<=:)\d+", ".".join(product))))
|
104 |
+
for k in range(times):
|
105 |
+
tmp = list(zip(reactant, reactant_roots[k],rea_atom_map_numbers))
|
106 |
+
random.shuffle(tmp)
|
107 |
+
reactant_k, reactant_roots_k,rea_atom_map_numbers_k = [i[0] for i in tmp], [i[1] for i in tmp], [i[2] for i in tmp]
|
108 |
+
aligned_reactants = []
|
109 |
+
aligned_products = []
|
110 |
+
aligned_products_order = []
|
111 |
+
all_atom_map = []
|
112 |
+
for i, rea in enumerate(reactant_k):
|
113 |
+
rea_root_atom_map = reactant_roots_k[i]
|
114 |
+
rea_root = get_root_id(Chem.MolFromSmiles(rea), root_map_number=rea_root_atom_map)
|
115 |
+
cano_atom_map = get_cano_map_number(rea, rea_root)
|
116 |
+
if cano_atom_map is None:
|
117 |
+
print(f"Reactant Failed to find Canonical Mol with Atom MapNumber")
|
118 |
+
continue
|
119 |
+
rea_smi = clear_map_canonical_smiles(rea, canonical=True, root=rea_root)
|
120 |
+
aligned_reactants.append(rea_smi)
|
121 |
+
all_atom_map.extend(cano_atom_map)
|
122 |
+
|
123 |
+
for i, pro_map_number in enumerate(pro_atom_map_numbers):
|
124 |
+
reactant_candidates = []
|
125 |
+
selected_reactant = []
|
126 |
+
for j, map_number in enumerate(all_atom_map):
|
127 |
+
if map_number in pro_map_number:
|
128 |
+
for rea_index, rea_atom_map_number in enumerate(rea_atom_map_numbers_k):
|
129 |
+
if map_number in rea_atom_map_number and rea_index not in selected_reactant:
|
130 |
+
selected_reactant.append(rea_index)
|
131 |
+
reactant_candidates.append((map_number, j, len(rea_atom_map_number)))
|
132 |
+
|
133 |
+
# select maximal reactant
|
134 |
+
reactant_candidates.sort(key=lambda x: x[2], reverse=True)
|
135 |
+
map_number = reactant_candidates[0][0]
|
136 |
+
j = reactant_candidates[0][1]
|
137 |
+
pro_root = get_root_id(Chem.MolFromSmiles(product[i]), root_map_number=map_number)
|
138 |
+
pro_smi = clear_map_canonical_smiles(product[i], canonical=True, root=pro_root)
|
139 |
+
aligned_products.append(pro_smi)
|
140 |
+
aligned_products_order.append(j)
|
141 |
+
|
142 |
+
sorted_products = sorted(list(zip(aligned_products, aligned_products_order)), key=lambda x: x[1])
|
143 |
+
aligned_products = [item[0] for item in sorted_products]
|
144 |
+
pro_smi = ".".join(aligned_products)
|
145 |
+
if separated:
|
146 |
+
reactants = []
|
147 |
+
reagents = []
|
148 |
+
for i,cano_atom_map in enumerate(rea_atom_map_numbers_k):
|
149 |
+
if len(set(cano_atom_map) & full_pro_atom_map_numbers) > 0:
|
150 |
+
reactants.append(aligned_reactants[i])
|
151 |
+
else:
|
152 |
+
reagents.append(aligned_reactants[i])
|
153 |
+
rea_smi = ".".join(reactants)
|
154 |
+
reactant_tokens = smi_tokenizer(rea_smi)
|
155 |
+
if len(reagents) > 0 :
|
156 |
+
reactant_tokens += " <separated> " + smi_tokenizer(".".join(reagents))
|
157 |
+
else:
|
158 |
+
rea_smi = ".".join(aligned_reactants)
|
159 |
+
reactant_tokens = smi_tokenizer(rea_smi)
|
160 |
+
product_tokens = smi_tokenizer(pro_smi)
|
161 |
+
return_status['src_data'].append(reactant_tokens)
|
162 |
+
return_status['tgt_data'].append(product_tokens)
|
163 |
+
if reversable:
|
164 |
+
aligned_reactants.reverse()
|
165 |
+
aligned_products.reverse()
|
166 |
+
pro_smi = ".".join(aligned_products)
|
167 |
+
rea_smi = ".".join(aligned_reactants)
|
168 |
+
product_tokens = smi_tokenizer(pro_smi)
|
169 |
+
reactant_tokens = smi_tokenizer(rea_smi)
|
170 |
+
return_status['src_data'].append(reactant_tokens)
|
171 |
+
return_status['tgt_data'].append(product_tokens)
|
172 |
+
edit_distances = []
|
173 |
+
for src,tgt in zip(return_status['src_data'],return_status['tgt_data']):
|
174 |
+
edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split()))
|
175 |
+
return_status['edit_distance'] = np.mean(edit_distances)
|
176 |
+
return return_status
|
177 |
+
|
178 |
+
|
179 |
+
def get_retro_rsmiles(data):
|
180 |
+
pt = re.compile(r':(\d+)]')
|
181 |
+
product = data['product']
|
182 |
+
reactant = data['reactant']
|
183 |
+
augmentation = data['augmentation']
|
184 |
+
pro_mol = Chem.MolFromSmiles(product)
|
185 |
+
rea_mol = Chem.MolFromSmiles(reactant)
|
186 |
+
"""checking data quality"""
|
187 |
+
rids = sorted(re.findall(pt, reactant))
|
188 |
+
pids = sorted(re.findall(pt, product))
|
189 |
+
return_status = {
|
190 |
+
"status":0,
|
191 |
+
"src_data":[],
|
192 |
+
"tgt_data":[],
|
193 |
+
"edit_distance":0,
|
194 |
+
}
|
195 |
+
pro_atom_map_numbers = list(map(int, re.findall(r"(?<=:)\d+", product)))
|
196 |
+
reactant = reactant.split(".")
|
197 |
+
reversable = False # no shuffle
|
198 |
+
# augmentation = 100
|
199 |
+
if augmentation == 999:
|
200 |
+
product_roots = pro_atom_map_numbers
|
201 |
+
times = len(product_roots)
|
202 |
+
else:
|
203 |
+
product_roots = [-1]
|
204 |
+
# reversable = len(reactant) > 1
|
205 |
+
|
206 |
+
max_times = len(pro_atom_map_numbers)
|
207 |
+
times = min(augmentation, max_times)
|
208 |
+
if times < augmentation: # times = max_times
|
209 |
+
product_roots.extend(pro_atom_map_numbers)
|
210 |
+
product_roots.extend(random.choices(product_roots, k=augmentation - len(product_roots)))
|
211 |
+
else: # times = augmentation
|
212 |
+
while len(product_roots) < times:
|
213 |
+
product_roots.append(random.sample(pro_atom_map_numbers, 1)[0])
|
214 |
+
# pro_atom_map_numbers.remove(product_roots[-1])
|
215 |
+
if product_roots[-1] in product_roots[:-1]:
|
216 |
+
product_roots.pop()
|
217 |
+
times = len(product_roots)
|
218 |
+
assert times == augmentation
|
219 |
+
if reversable:
|
220 |
+
times = int(times / 2)
|
221 |
+
# candidates = []
|
222 |
+
for k in range(times):
|
223 |
+
pro_root_atom_map = product_roots[k]
|
224 |
+
pro_root = get_root_id(pro_mol, root_map_number=pro_root_atom_map)
|
225 |
+
cano_atom_map = get_cano_map_number(product, root=pro_root)
|
226 |
+
if cano_atom_map is None:
|
227 |
+
return_status["status"] = "error_mapping"
|
228 |
+
return return_status
|
229 |
+
pro_smi = clear_map_canonical_smiles(product, canonical=True, root=pro_root)
|
230 |
+
aligned_reactants = []
|
231 |
+
aligned_reactants_order = []
|
232 |
+
rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant]
|
233 |
+
used_indices = []
|
234 |
+
for i, rea_map_number in enumerate(rea_atom_map_numbers):
|
235 |
+
for j, map_number in enumerate(cano_atom_map):
|
236 |
+
# select mapping reactans
|
237 |
+
if map_number in rea_map_number:
|
238 |
+
rea_root = get_root_id(Chem.MolFromSmiles(reactant[i]), root_map_number=map_number)
|
239 |
+
rea_smi = clear_map_canonical_smiles(reactant[i], canonical=True, root=rea_root)
|
240 |
+
aligned_reactants.append(rea_smi)
|
241 |
+
aligned_reactants_order.append(j)
|
242 |
+
used_indices.append(i)
|
243 |
+
break
|
244 |
+
sorted_reactants = sorted(list(zip(aligned_reactants, aligned_reactants_order)), key=lambda x: x[1])
|
245 |
+
aligned_reactants = [item[0] for item in sorted_reactants]
|
246 |
+
reactant_smi = ".".join(aligned_reactants)
|
247 |
+
product_tokens = smi_tokenizer(pro_smi)
|
248 |
+
reactant_tokens = smi_tokenizer(reactant_smi)
|
249 |
+
|
250 |
+
return_status['src_data'].append(product_tokens)
|
251 |
+
return_status['tgt_data'].append(reactant_tokens)
|
252 |
+
|
253 |
+
if reversable:
|
254 |
+
aligned_reactants.reverse()
|
255 |
+
reactant_smi = ".".join(aligned_reactants)
|
256 |
+
product_tokens = smi_tokenizer(pro_smi)
|
257 |
+
reactant_tokens = smi_tokenizer(reactant_smi)
|
258 |
+
return_status['src_data'].append(product_tokens)
|
259 |
+
return_status['tgt_data'].append(reactant_tokens)
|
260 |
+
assert len(return_status['src_data']) == data['augmentation']
|
261 |
+
edit_distances = []
|
262 |
+
for src,tgt in zip(return_status['src_data'],return_status['tgt_data']):
|
263 |
+
edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split()))
|
264 |
+
return_status['edit_distance'] = np.mean(edit_distances)
|
265 |
+
return return_status
|
266 |
+
|
267 |
+
|
268 |
+
def multi_process(data):
|
269 |
+
pt = re.compile(r':(\d+)]')
|
270 |
+
product = data['product']
|
271 |
+
reactant = data['reactant']
|
272 |
+
augmentation = data['augmentation']
|
273 |
+
pro_mol = Chem.MolFromSmiles(product)
|
274 |
+
rea_mol = Chem.MolFromSmiles(reactant)
|
275 |
+
"""checking data quality"""
|
276 |
+
rids = sorted(re.findall(pt, reactant))
|
277 |
+
pids = sorted(re.findall(pt, product))
|
278 |
+
return_status = {
|
279 |
+
"status":0,
|
280 |
+
"src_data":[],
|
281 |
+
"tgt_data":[],
|
282 |
+
"edit_distance":0,
|
283 |
+
}
|
284 |
+
# if ",".join(rids) != ",".join(pids): # mapping is not 1:1
|
285 |
+
# return_status["status"] = "error_mapping"
|
286 |
+
# if len(set(rids)) != len(rids): # mapping is not 1:1
|
287 |
+
# return_status["status"] = "error_mapping"
|
288 |
+
# if len(set(pids)) != len(pids): # mapping is not 1:1
|
289 |
+
# return_status["status"] = "error_mapping"
|
290 |
+
if "" == product:
|
291 |
+
return_status["status"] = "empty_p"
|
292 |
+
if "" == reactant:
|
293 |
+
return_status["status"] = "empty_r"
|
294 |
+
if rea_mol is None:
|
295 |
+
return_status["status"] = "invalid_r"
|
296 |
+
if len(rea_mol.GetAtoms()) < 5:
|
297 |
+
return_status["status"] = "small_r"
|
298 |
+
if pro_mol is None:
|
299 |
+
return_status["status"] = "invalid_p"
|
300 |
+
if len(pro_mol.GetAtoms()) == 1:
|
301 |
+
return_status["status"] = "small_p"
|
302 |
+
if not all([a.HasProp('molAtomMapNumber') for a in pro_mol.GetAtoms()]):
|
303 |
+
return_status["status"] = "error_mapping_p"
|
304 |
+
"""finishing checking data quality"""
|
305 |
+
|
306 |
+
if return_status['status'] == 0:
|
307 |
+
pro_atom_map_numbers = list(map(int, re.findall(r"(?<=:)\d+", product)))
|
308 |
+
reactant = reactant.split(".")
|
309 |
+
if data['root_aligned']:
|
310 |
+
reversable = False # no shuffle
|
311 |
+
# augmentation = 100
|
312 |
+
if augmentation == 999:
|
313 |
+
product_roots = pro_atom_map_numbers
|
314 |
+
times = len(product_roots)
|
315 |
+
else:
|
316 |
+
product_roots = [-1]
|
317 |
+
# reversable = len(reactant) > 1
|
318 |
+
|
319 |
+
max_times = len(pro_atom_map_numbers)
|
320 |
+
times = min(augmentation, max_times)
|
321 |
+
if times < augmentation: # times = max_times
|
322 |
+
product_roots.extend(pro_atom_map_numbers)
|
323 |
+
product_roots.extend(random.choices(product_roots, k=augmentation - len(product_roots)))
|
324 |
+
else: # times = augmentation
|
325 |
+
while len(product_roots) < times:
|
326 |
+
product_roots.append(random.sample(pro_atom_map_numbers, 1)[0])
|
327 |
+
# pro_atom_map_numbers.remove(product_roots[-1])
|
328 |
+
if product_roots[-1] in product_roots[:-1]:
|
329 |
+
product_roots.pop()
|
330 |
+
times = len(product_roots)
|
331 |
+
assert times == augmentation
|
332 |
+
if reversable:
|
333 |
+
times = int(times / 2)
|
334 |
+
# candidates = []
|
335 |
+
for k in range(times):
|
336 |
+
pro_root_atom_map = product_roots[k]
|
337 |
+
pro_root = get_root_id(pro_mol, root_map_number=pro_root_atom_map)
|
338 |
+
cano_atom_map = get_cano_map_number(product, root=pro_root)
|
339 |
+
if cano_atom_map is None:
|
340 |
+
return_status["status"] = "error_mapping"
|
341 |
+
return return_status
|
342 |
+
pro_smi = clear_map_canonical_smiles(product, canonical=True, root=pro_root)
|
343 |
+
aligned_reactants = []
|
344 |
+
aligned_reactants_order = []
|
345 |
+
rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant]
|
346 |
+
used_indices = []
|
347 |
+
for i, rea_map_number in enumerate(rea_atom_map_numbers):
|
348 |
+
for j, map_number in enumerate(cano_atom_map):
|
349 |
+
# select mapping reactans
|
350 |
+
if map_number in rea_map_number:
|
351 |
+
rea_root = get_root_id(Chem.MolFromSmiles(reactant[i]), root_map_number=map_number)
|
352 |
+
rea_smi = clear_map_canonical_smiles(reactant[i], canonical=True, root=rea_root)
|
353 |
+
aligned_reactants.append(rea_smi)
|
354 |
+
aligned_reactants_order.append(j)
|
355 |
+
used_indices.append(i)
|
356 |
+
break
|
357 |
+
sorted_reactants = sorted(list(zip(aligned_reactants, aligned_reactants_order)), key=lambda x: x[1])
|
358 |
+
aligned_reactants = [item[0] for item in sorted_reactants]
|
359 |
+
reactant_smi = ".".join(aligned_reactants)
|
360 |
+
product_tokens = smi_tokenizer(pro_smi)
|
361 |
+
reactant_tokens = smi_tokenizer(reactant_smi)
|
362 |
+
|
363 |
+
return_status['src_data'].append(product_tokens)
|
364 |
+
return_status['tgt_data'].append(reactant_tokens)
|
365 |
+
|
366 |
+
if reversable:
|
367 |
+
aligned_reactants.reverse()
|
368 |
+
reactant_smi = ".".join(aligned_reactants)
|
369 |
+
product_tokens = smi_tokenizer(pro_smi)
|
370 |
+
reactant_tokens = smi_tokenizer(reactant_smi)
|
371 |
+
return_status['src_data'].append(product_tokens)
|
372 |
+
return_status['tgt_data'].append(reactant_tokens)
|
373 |
+
assert len(return_status['src_data']) == data['augmentation']
|
374 |
+
else:
|
375 |
+
cano_product = clear_map_canonical_smiles(product)
|
376 |
+
cano_reactanct = ".".join([clear_map_canonical_smiles(rea) for rea in reactant if len(set(map(int, re.findall(r"(?<=:)\d+", rea))) & set(pro_atom_map_numbers)) > 0 ])
|
377 |
+
return_status['src_data'].append(smi_tokenizer(cano_product))
|
378 |
+
return_status['tgt_data'].append(smi_tokenizer(cano_reactanct))
|
379 |
+
pro_mol = Chem.MolFromSmiles(cano_product)
|
380 |
+
rea_mols = [Chem.MolFromSmiles(rea) for rea in cano_reactanct.split(".")]
|
381 |
+
for i in range(int(augmentation-1)):
|
382 |
+
pro_smi = Chem.MolToSmiles(pro_mol,doRandom=True)
|
383 |
+
rea_smi = [Chem.MolToSmiles(rea_mol,doRandom=True) for rea_mol in rea_mols]
|
384 |
+
rea_smi = ".".join(rea_smi)
|
385 |
+
return_status['src_data'].append(smi_tokenizer(pro_smi))
|
386 |
+
return_status['tgt_data'].append(smi_tokenizer(rea_smi))
|
387 |
+
edit_distances = []
|
388 |
+
for src,tgt in zip(return_status['src_data'],return_status['tgt_data']):
|
389 |
+
edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split()))
|
390 |
+
return_status['edit_distance'] = np.mean(edit_distances)
|
391 |
+
return return_status
|
392 |
+
|
393 |
+
if __name__ == '__main__':
|
394 |
+
parser = argparse.ArgumentParser()
|
395 |
+
parser.add_argument('-rxn',type=str,required=True)
|
396 |
+
parser.add_argument('-mode',type=str,default="retro",)
|
397 |
+
parser.add_argument('-forward_mode',type=str,default="separated",)
|
398 |
+
parser.add_argument("-augmentation",type=int,default=1)
|
399 |
+
parser.add_argument("-seed",type=int,default=33)
|
400 |
+
args = parser.parse_args()
|
401 |
+
print(args)
|
402 |
+
reactant,reagent,product = args.rxn.split(">")
|
403 |
+
pt = re.compile(r':(\d+)]')
|
404 |
+
rids = sorted(re.findall(pt, reactant))
|
405 |
+
pids = sorted(re.findall(pt, product))
|
406 |
+
if len(rids) == 0 or len(pids) == 0:
|
407 |
+
print("No atom mapping found!")
|
408 |
+
exit(1)
|
409 |
+
if args.mode == "retro":
|
410 |
+
args.input = product
|
411 |
+
args.output = reactant
|
412 |
+
else:
|
413 |
+
args.input = reactant
|
414 |
+
args.output = product
|
415 |
+
|
416 |
+
print("Original input:", args.input)
|
417 |
+
print("Original output:",args.output)
|
418 |
+
src_smi = clear_map_canonical_smiles(args.input)
|
419 |
+
tgt_smi = clear_map_canonical_smiles(args.output)
|
420 |
+
if src_smi == "" or tgt_smi == "":
|
421 |
+
print("Invalid SMILES!")
|
422 |
+
exit(1)
|
423 |
+
print("Canonical input:", src_smi)
|
424 |
+
print("Canonical output:",tgt_smi)
|
425 |
+
|
426 |
+
mapping_check = True
|
427 |
+
if ",".join(rids) != ",".join(pids): # mapping is not 1:1
|
428 |
+
mapping_check = False
|
429 |
+
if len(set(rids)) != len(rids): # mapping is not 1:1
|
430 |
+
mapping_check = False
|
431 |
+
if len(set(pids)) != len(pids): # mapping is not 1:1
|
432 |
+
mapping_check = False
|
433 |
+
if not mapping_check:
|
434 |
+
print("The quality of the atom mapping may not be good enough, which can affect the effect of root alignment.")
|
435 |
+
data = {
|
436 |
+
'product':product,
|
437 |
+
'reactant':reactant,
|
438 |
+
'augmentation':args.augmentation,
|
439 |
+
'separated':args.forward_mode == "separated"
|
440 |
+
}
|
441 |
+
if args.mode == "retro":
|
442 |
+
res = get_retro_rsmiles(data)
|
443 |
+
else:
|
444 |
+
res = get_forward_rsmiles(data)
|
445 |
+
for index,(src,tgt) in enumerate(zip(res['src_data'], res['tgt_data'])):
|
446 |
+
print(f"ID:{index}")
|
447 |
+
print(f"R-SMILES input:{''.join(src.split())}")
|
448 |
+
print(f"R-SMILES output:{''.join(tgt.split())}")
|
449 |
+
print("Avg. edit distance:", res['edit_distance'])
|
data_provider/reaction_action_dataset.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch_geometric.data import Dataset
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import json
|
6 |
+
from .data_utils import smiles2data, reformat_smiles
|
7 |
+
|
8 |
+
class ActionDataset(Dataset):
|
9 |
+
def __init__(self, root, mode, smi_max_len, use_graph=True, disable_graph_cache=False, predict_rxn_condition=False, smiles_type='default'):
|
10 |
+
super(ActionDataset, self).__init__(root)
|
11 |
+
self.root = root
|
12 |
+
self.smi_max_len = smi_max_len
|
13 |
+
self.tokenizer = None
|
14 |
+
self.use_graph = use_graph
|
15 |
+
self.disable_graph_cache = disable_graph_cache
|
16 |
+
self.predict_rxn_condition = predict_rxn_condition
|
17 |
+
self.smiles_type = smiles_type
|
18 |
+
|
19 |
+
with open(os.path.join(self.root, f'{mode}.json'), encoding='utf-8') as f:
|
20 |
+
self.data_list = json.load(f)
|
21 |
+
if self.use_graph:
|
22 |
+
self.mol_graph_map = torch.load(os.path.join(self.root, 'mol_graph_map.pt'))
|
23 |
+
# self.data_list = self.data_list[:100]
|
24 |
+
|
25 |
+
def get(self, index):
|
26 |
+
return self.__getitem__(index)
|
27 |
+
|
28 |
+
def len(self):
|
29 |
+
return len(self)
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.data_list)
|
33 |
+
|
34 |
+
def make_prompt(self, param_dict, smi_max_len=128, predict_rxn_condition=False):
|
35 |
+
action_sequence = param_dict['actions']
|
36 |
+
smiles_list = []
|
37 |
+
prompt = ''
|
38 |
+
prompt += 'Reactants: '
|
39 |
+
smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len]
|
40 |
+
for smi in param_dict['REACTANT']:
|
41 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
42 |
+
smiles_list.append(smi)
|
43 |
+
|
44 |
+
prompt += 'Product: '
|
45 |
+
for smi in param_dict['PRODUCT']:
|
46 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
47 |
+
smiles_list.append(smi)
|
48 |
+
|
49 |
+
if param_dict['CATALYST']:
|
50 |
+
prompt += 'Catalysts: '
|
51 |
+
for smi in param_dict['CATALYST']:
|
52 |
+
if smi in param_dict["extracted_molecules"]:
|
53 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
54 |
+
else:
|
55 |
+
prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
56 |
+
smiles_list.append(smi)
|
57 |
+
|
58 |
+
if param_dict['SOLVENT']:
|
59 |
+
prompt += 'Solvents: '
|
60 |
+
for smi in param_dict['SOLVENT']:
|
61 |
+
if smi in param_dict["extracted_molecules"]:
|
62 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
63 |
+
else:
|
64 |
+
prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
65 |
+
smiles_list.append(smi)
|
66 |
+
|
67 |
+
if predict_rxn_condition:
|
68 |
+
for value, token in param_dict['extracted_duration'].items():
|
69 |
+
action_sequence = action_sequence.replace(token, value)
|
70 |
+
for value, token in param_dict['extracted_temperature'].items():
|
71 |
+
action_sequence = action_sequence.replace(token, value)
|
72 |
+
else:
|
73 |
+
prompt += 'Temperatures: '
|
74 |
+
for value, token in param_dict['extracted_temperature'].items():
|
75 |
+
prompt += f'{token}: {value} '
|
76 |
+
|
77 |
+
prompt += 'Durations: '
|
78 |
+
for value, token in param_dict['extracted_duration'].items():
|
79 |
+
prompt += f'{token}: {value} '
|
80 |
+
|
81 |
+
prompt += 'Action Squence: '
|
82 |
+
return prompt, smiles_list, action_sequence
|
83 |
+
|
84 |
+
def __getitem__(self, index):
|
85 |
+
rxn_dict = self.data_list[index]
|
86 |
+
rxn_id = rxn_dict['index']
|
87 |
+
input_text, smiles_list, output_text = self.make_prompt(rxn_dict, self.smi_max_len, self.predict_rxn_condition)
|
88 |
+
output_text = output_text.strip() + '\n'
|
89 |
+
|
90 |
+
graph_list = []
|
91 |
+
if self.use_graph:
|
92 |
+
for smiles in smiles_list:
|
93 |
+
if self.disable_graph_cache:
|
94 |
+
graph_item = smiles2data(smiles)
|
95 |
+
else:
|
96 |
+
assert smiles in self.mol_graph_map
|
97 |
+
graph_item = self.mol_graph_map[smiles]
|
98 |
+
graph_list.append(graph_item)
|
99 |
+
return rxn_id, graph_list, output_text, input_text
|
100 |
+
|
data_provider/synthesis_dataset.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch_geometric.data import Dataset
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import json
|
6 |
+
from .data_utils import smiles2data, escape_custom_split_sequence, reformat_smiles, generate_rsmiles
|
7 |
+
|
8 |
+
class SynthesisDataset(Dataset):
|
9 |
+
def __init__(self,
|
10 |
+
root,
|
11 |
+
mode,
|
12 |
+
smi_max_len=128,
|
13 |
+
use_graph=True,
|
14 |
+
disable_graph_cache=False,
|
15 |
+
smiles_type='default',
|
16 |
+
roundrobin_train=False,
|
17 |
+
test_subset=-1
|
18 |
+
):
|
19 |
+
super(SynthesisDataset, self).__init__(root)
|
20 |
+
self.root = root
|
21 |
+
if 'PtoR' in root:
|
22 |
+
self.task = 'retro'
|
23 |
+
elif 'pretrain' in root:
|
24 |
+
self.task = 'pretrain'
|
25 |
+
elif 'RtoP' in root:
|
26 |
+
self.task = 'forward'
|
27 |
+
else:
|
28 |
+
raise NotImplementedError(f'Invalid task: {root}')
|
29 |
+
if mode=='valid':
|
30 |
+
mode='val'
|
31 |
+
self.mode = mode
|
32 |
+
self.smi_max_len = smi_max_len
|
33 |
+
self.tokenizer = None
|
34 |
+
self.use_graph = use_graph
|
35 |
+
self.disable_graph_cache = disable_graph_cache
|
36 |
+
self.smiles_type = smiles_type
|
37 |
+
self.roundrobin_train = roundrobin_train
|
38 |
+
with open(os.path.join(root, 'mol_graphid_map.json')) as f:
|
39 |
+
self.mol_idx_map = json.load(f)
|
40 |
+
if self.use_graph:
|
41 |
+
self.idx_graph_map = torch.load(os.path.join(root, 'idx_graph_map.pt'))
|
42 |
+
|
43 |
+
if self.roundrobin_train and mode=='train':
|
44 |
+
self.reload_counter=-2
|
45 |
+
self.reload_data()
|
46 |
+
else:
|
47 |
+
with open(os.path.join(root, mode, f'src-{mode}.txt')) as f:
|
48 |
+
self.input_list = f.readlines()
|
49 |
+
with open(os.path.join(root, mode, f'tgt-{mode}.txt')) as f:
|
50 |
+
self.output_list = f.readlines()
|
51 |
+
assert len(self.input_list) == len(self.output_list)
|
52 |
+
self.renew_r_smiles()
|
53 |
+
self.input_list = [smi.strip().replace(' ','') for smi in self.input_list]
|
54 |
+
self.output_list = [smi.strip().replace(' ','') for smi in self.output_list]
|
55 |
+
if test_subset>0 and mode=='test':
|
56 |
+
assert test_subset<=len(self.input_list)
|
57 |
+
self.input_list = self.input_list[:test_subset]
|
58 |
+
self.input_list = self.input_list[:test_subset]
|
59 |
+
|
60 |
+
def reload_data(self):
|
61 |
+
if not self.roundrobin_train:
|
62 |
+
return
|
63 |
+
self.reload_counter = (self.reload_counter+1)%10
|
64 |
+
if hasattr(self, 'input_list'):
|
65 |
+
del self.input_list
|
66 |
+
if hasattr(self, 'output_list'):
|
67 |
+
del self.output_list
|
68 |
+
with open(os.path.join(self.root, f'train/src-train_{self.reload_counter}.txt')) as f:
|
69 |
+
self.input_list = f.readlines()
|
70 |
+
with open(os.path.join(self.root, f'train/tgt-train_{self.reload_counter}.txt')) as f:
|
71 |
+
self.output_list = f.readlines()
|
72 |
+
assert len(self.input_list) == len(self.output_list)
|
73 |
+
self.renew_r_smiles()
|
74 |
+
self.input_list = [smi.strip().replace(' ','') for smi in self.input_list]
|
75 |
+
self.output_list = [smi.strip().replace(' ','') for smi in self.output_list]
|
76 |
+
input_list, output_list = [], []
|
77 |
+
for input_smiles, output_smiles in zip(self.input_list, self.output_list):
|
78 |
+
if input_smiles.count('.') != output_smiles.count('.'):
|
79 |
+
continue
|
80 |
+
input_list.append(input_smiles)
|
81 |
+
output_list.append(output_smiles)
|
82 |
+
print(f'Reloaded data from {self.root}/train/src-train_{self.reload_counter}.txt, filtered len={len(self.input_list)}', flush=True)
|
83 |
+
self.input_list = input_list
|
84 |
+
self.output_list = output_list
|
85 |
+
|
86 |
+
def renew_r_smiles(self):
|
87 |
+
if self.smiles_type == 'r_smiles' and self.mode == 'train':
|
88 |
+
# only renew r_smiles for training set
|
89 |
+
if not hasattr(self, 'input_list_mapped'):
|
90 |
+
# here we back up the original input_list and output_list
|
91 |
+
self.input_list_mapped = self.input_list
|
92 |
+
self.output_list_mapped = self.output_list
|
93 |
+
self.output_list, self.input_list = generate_rsmiles(self.output_list_mapped, self.input_list_mapped)
|
94 |
+
self.input_list = [smi.strip().replace(' ','') for smi in self.input_list]
|
95 |
+
self.output_list = [smi.strip().replace(' ','') for smi in self.output_list]
|
96 |
+
|
97 |
+
def get(self, index):
|
98 |
+
return self.__getitem__(index)
|
99 |
+
|
100 |
+
def len(self):
|
101 |
+
return len(self)
|
102 |
+
|
103 |
+
def __len__(self):
|
104 |
+
return len(self.input_list)
|
105 |
+
|
106 |
+
def make_prompt(self, input_smiles, output_smiles, smi_max_len=512):
|
107 |
+
FORWARD_PROMPT = 'Question: Given the following reactant molecules: {}, what are the expected products? Answer: The product molecules are '
|
108 |
+
FORWARD_CATALYST_PROMPT = '{}, and the following catalyst molecules: {}'
|
109 |
+
RETRO_PROMPT = 'Question: Given the following product molecules: {}, what are the reactants that produce them? Answer: The reactant molecules are '
|
110 |
+
# RETRO_PROMPT = 'Predict the reaction that produces the following product: {} '
|
111 |
+
PRETRAIN_PROMPT = 'Reconstruct the masked molecule: {}. Answer: '
|
112 |
+
smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len]
|
113 |
+
if self.task=='retro':
|
114 |
+
assert '<separated>' not in input_smiles
|
115 |
+
smiles_list = input_smiles.split('.')
|
116 |
+
in_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in smiles_list])
|
117 |
+
input_prompt = RETRO_PROMPT.format(in_prompt)
|
118 |
+
elif self.task=='forward':
|
119 |
+
if '<separated>' in input_smiles:
|
120 |
+
reactant_smiles, reagent_smiles = input_smiles.split('<separated>')
|
121 |
+
reactant_smiles = reactant_smiles.split('.')
|
122 |
+
reagent_smiles = reagent_smiles.split('.')
|
123 |
+
reactant_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in reactant_smiles])
|
124 |
+
reagent_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in reagent_smiles])
|
125 |
+
smiles_list = reactant_smiles+reagent_smiles
|
126 |
+
input_prompt = FORWARD_CATALYST_PROMPT.format(reactant_prompt, reagent_prompt)
|
127 |
+
else:
|
128 |
+
smiles_list = input_smiles.split('.')
|
129 |
+
reactant_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in smiles_list])
|
130 |
+
input_prompt = reactant_prompt
|
131 |
+
input_prompt = FORWARD_PROMPT.format(input_prompt)
|
132 |
+
elif self.task=='pretrain':
|
133 |
+
in_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in input_smiles.split('.')])
|
134 |
+
input_prompt = PRETRAIN_PROMPT.format(in_prompt)
|
135 |
+
smiles_list = output_smiles.split('.')
|
136 |
+
# output_smiles = ' '.join([f'[START_SMILES]{smi[:smi_max_len]}[END_SMILES]' for smi in output_smiles.split('.')])
|
137 |
+
output_smiles = f'[START_SMILES]{output_smiles}[END_SMILES]'
|
138 |
+
output_smiles = escape_custom_split_sequence(output_smiles)
|
139 |
+
|
140 |
+
return input_prompt, smiles_list, output_smiles
|
141 |
+
|
142 |
+
def __getitem__(self, index):
|
143 |
+
input_smiles = self.input_list[index]
|
144 |
+
output_smiles = self.output_list[index]
|
145 |
+
input_text, smiles_list, output_text = self.make_prompt(input_smiles, output_smiles, smi_max_len=self.smi_max_len)
|
146 |
+
output_text = output_text.strip()+'\n'
|
147 |
+
|
148 |
+
graph_list = []
|
149 |
+
if self.use_graph:
|
150 |
+
for smiles in smiles_list:
|
151 |
+
if self.disable_graph_cache:
|
152 |
+
graph_item = smiles2data(smiles)
|
153 |
+
else:
|
154 |
+
assert smiles in self.mol_idx_map
|
155 |
+
idx = self.mol_idx_map[smiles]
|
156 |
+
assert idx in self.idx_graph_map
|
157 |
+
graph_item = self.idx_graph_map[idx]
|
158 |
+
graph_list.append(graph_item)
|
159 |
+
|
160 |
+
return index, graph_list, output_text, input_text
|
data_provider/tune_dm.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
import torch
|
4 |
+
from pytorch_lightning import LightningDataModule
|
5 |
+
import torch_geometric
|
6 |
+
# from torch_geometric.loader import DataLoader
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch_geometric.loader.dataloader import Collater
|
9 |
+
from data_provider.reaction_action_dataset import ActionDataset
|
10 |
+
from data_provider.synthesis_dataset import SynthesisDataset
|
11 |
+
from data_provider.caption_dataset import CaptionDataset
|
12 |
+
from data_provider.chebi_dataset import ChEBI_dataset
|
13 |
+
import re
|
14 |
+
|
15 |
+
# we split individual characters inside special tokens like [START_DNA]
|
16 |
+
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
|
17 |
+
|
18 |
+
# token added to implement a custom sequence tokenization. This token is added at
|
19 |
+
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
|
20 |
+
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
|
21 |
+
# literally in the source code in case we ever include it in the training data.
|
22 |
+
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
|
23 |
+
|
24 |
+
def _insert_split_marker(m: re.Match):
|
25 |
+
"""
|
26 |
+
Applies split marker based on a regex match of special tokens such as
|
27 |
+
[START_DNA].
|
28 |
+
|
29 |
+
Parameters
|
30 |
+
----------
|
31 |
+
n : str
|
32 |
+
Input text to split
|
33 |
+
|
34 |
+
Returns
|
35 |
+
----------
|
36 |
+
str - the text with the split token added
|
37 |
+
"""
|
38 |
+
start_token, _, sequence, end_token = m.groups()
|
39 |
+
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
|
40 |
+
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
|
41 |
+
|
42 |
+
def smiles_handler(text, mol_ph, is_gal=True):
|
43 |
+
smiles_list = []
|
44 |
+
for match in CUSTOM_SEQ_RE.finditer(text):
|
45 |
+
smiles = match.group(3)
|
46 |
+
smiles_list.append(smiles)
|
47 |
+
if is_gal:
|
48 |
+
text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text)
|
49 |
+
text = escape_custom_split_sequence(text)
|
50 |
+
return text, smiles_list
|
51 |
+
else:
|
52 |
+
text = CUSTOM_SEQ_RE.sub(r'\3%s' % (mol_ph), text)
|
53 |
+
return text, smiles_list
|
54 |
+
|
55 |
+
def escape_custom_split_sequence(text):
|
56 |
+
"""
|
57 |
+
Applies custom splitting to the text for GALILEO's tokenization
|
58 |
+
|
59 |
+
Parameters
|
60 |
+
----------
|
61 |
+
text : str
|
62 |
+
Input text to split
|
63 |
+
|
64 |
+
Returns
|
65 |
+
----------
|
66 |
+
str - the text with the split token added
|
67 |
+
"""
|
68 |
+
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
|
69 |
+
|
70 |
+
class TrainCollater:
|
71 |
+
def __init__(self, tokenizer, text_max_len, rxn_max_len, mol_ph, mol_token_id, is_gal=True, use_graph=True, use_qa_pair=True):
|
72 |
+
self.rxn_max_len = rxn_max_len
|
73 |
+
self.text_max_len = text_max_len
|
74 |
+
self.tokenizer = tokenizer
|
75 |
+
self.collater = Collater([], [])
|
76 |
+
self.mol_ph = mol_ph
|
77 |
+
self.mol_token_id = mol_token_id
|
78 |
+
self.is_gal = is_gal
|
79 |
+
self.use_graph = use_graph
|
80 |
+
self.use_qa_pair = use_qa_pair
|
81 |
+
|
82 |
+
def __call__(self, batch):
|
83 |
+
return self.collate_qa(batch) if self.use_qa_pair else self.collate(batch)
|
84 |
+
|
85 |
+
def collate(self, batch):
|
86 |
+
rxn_ids, graphs, texts, smiles_prompt = zip(*batch)
|
87 |
+
if graphs:
|
88 |
+
graphs = self.collater(graphs)
|
89 |
+
|
90 |
+
## deal with prompt
|
91 |
+
if self.use_graph:
|
92 |
+
smiles_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in smiles_prompt]
|
93 |
+
else:
|
94 |
+
smiles_prompt = [escape_custom_split_sequence(p) for p in smiles_prompt]
|
95 |
+
|
96 |
+
self.tokenizer.padding_side = 'left'
|
97 |
+
smiles_prompt_tokens = self.tokenizer(text=smiles_prompt,
|
98 |
+
truncation=False,
|
99 |
+
padding='longest',
|
100 |
+
add_special_tokens=True,
|
101 |
+
return_tensors='pt',
|
102 |
+
return_attention_mask=True)
|
103 |
+
|
104 |
+
is_mol_token = smiles_prompt_tokens.input_ids == self.mol_token_id
|
105 |
+
smiles_prompt_tokens['is_mol_token'] = is_mol_token
|
106 |
+
|
107 |
+
self.tokenizer.padding_side = 'right'
|
108 |
+
text_tokens = self.tokenizer(text=texts,
|
109 |
+
truncation=True,
|
110 |
+
padding='longest',
|
111 |
+
add_special_tokens=True,
|
112 |
+
max_length=self.text_max_len,
|
113 |
+
return_tensors='pt',
|
114 |
+
return_attention_mask=True)
|
115 |
+
return rxn_ids, graphs, smiles_prompt_tokens, text_tokens
|
116 |
+
|
117 |
+
def collate_qa(self, batch):
|
118 |
+
rxn_ids, graphs, texts, input_prompt = zip(*batch)
|
119 |
+
graphs = [graph for graph_batch in graphs for graph in graph_batch]
|
120 |
+
if graphs:
|
121 |
+
graphs = self.collater(graphs)
|
122 |
+
|
123 |
+
## deal with prompt
|
124 |
+
if self.use_graph:
|
125 |
+
input_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in input_prompt]
|
126 |
+
else:
|
127 |
+
input_prompt = [escape_custom_split_sequence(p) for p in input_prompt]
|
128 |
+
|
129 |
+
self.tokenizer.padding_side = 'right'
|
130 |
+
qa_pair = [[q, a] for q, a in zip(input_prompt, texts)]
|
131 |
+
qa_batch = self.tokenizer(qa_pair,
|
132 |
+
truncation=True,
|
133 |
+
padding='longest',
|
134 |
+
add_special_tokens=True,
|
135 |
+
max_length=self.rxn_max_len + self.text_max_len,
|
136 |
+
return_tensors='pt',
|
137 |
+
return_attention_mask=True,
|
138 |
+
return_token_type_ids=True)
|
139 |
+
is_mol_token = qa_batch.input_ids == self.mol_token_id
|
140 |
+
qa_batch['is_mol_token'] = is_mol_token
|
141 |
+
return rxn_ids, graphs, qa_batch
|
142 |
+
|
143 |
+
class InferenceCollater:
|
144 |
+
def __init__(self, tokenizer, text_max_len, rxn_max_len, mol_ph, mol_token_id, is_gal=True):
|
145 |
+
self.text_max_len = text_max_len
|
146 |
+
self.rxn_max_len = rxn_max_len
|
147 |
+
self.tokenizer = tokenizer
|
148 |
+
self.collater = Collater([], [])
|
149 |
+
self.mol_ph = mol_ph
|
150 |
+
self.mol_token_id = mol_token_id
|
151 |
+
self.is_gal = is_gal
|
152 |
+
|
153 |
+
def __call__(self, batch):
|
154 |
+
rxn_ids, graphs, texts, input_prompt = zip(*batch)
|
155 |
+
inputs = input_prompt
|
156 |
+
graphs = [graph for graph_batch in graphs for graph in graph_batch]
|
157 |
+
if graphs:
|
158 |
+
graphs = self.collater(graphs)
|
159 |
+
input_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in input_prompt]
|
160 |
+
|
161 |
+
## deal with prompt
|
162 |
+
self.tokenizer.padding_side = 'left'
|
163 |
+
input_prompt_tokens = self.tokenizer(input_prompt,
|
164 |
+
truncation=True,
|
165 |
+
padding='longest',
|
166 |
+
add_special_tokens=True,
|
167 |
+
max_length=self.rxn_max_len,
|
168 |
+
return_tensors='pt',
|
169 |
+
return_attention_mask=True)
|
170 |
+
|
171 |
+
is_mol_token = input_prompt_tokens.input_ids == self.mol_token_id
|
172 |
+
input_prompt_tokens['is_mol_token'] = is_mol_token
|
173 |
+
return rxn_ids, graphs, input_prompt_tokens, texts, inputs
|
174 |
+
|
175 |
+
class TuneDM(LightningDataModule):
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
num_workers: int = 0,
|
179 |
+
batch_size: int = 256,
|
180 |
+
root: str = 'data/',
|
181 |
+
text_max_len: int = 128,
|
182 |
+
smi_max_len: int = 128,
|
183 |
+
rxn_max_len: int = 128,
|
184 |
+
tokenizer=None,
|
185 |
+
downstream_task='action',
|
186 |
+
args=None,
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
self.args = args
|
190 |
+
self.batch_size = batch_size
|
191 |
+
self.inference_batch_size = args.inference_batch_size
|
192 |
+
self.num_workers = num_workers
|
193 |
+
self.rxn_max_len = rxn_max_len
|
194 |
+
self.text_max_len = text_max_len
|
195 |
+
self.prompt = args.prompt
|
196 |
+
DownstreamDataset = {
|
197 |
+
'action': ActionDataset,
|
198 |
+
'synthesis': SynthesisDataset,
|
199 |
+
'caption': CaptionDataset,
|
200 |
+
'chebi': ChEBI_dataset,
|
201 |
+
}[downstream_task]
|
202 |
+
ds_args = {
|
203 |
+
'use_graph': not args.disable_graphs,
|
204 |
+
'disable_graph_cache': args.disable_graph_cache,
|
205 |
+
'smiles_type': args.smiles_type,
|
206 |
+
}
|
207 |
+
if downstream_task == 'action':
|
208 |
+
ds_args['predict_rxn_condition'] = args.predict_rxn_condition
|
209 |
+
if downstream_task == 'synthesis':
|
210 |
+
ds_args['roundrobin_train'] = args.roundrobin_train
|
211 |
+
ds_args['test_subset'] = args.test_subset
|
212 |
+
self.train_dataset = DownstreamDataset(root, 'train', smi_max_len, **ds_args)
|
213 |
+
self.val_dataset = DownstreamDataset(root, 'valid', smi_max_len, **ds_args)
|
214 |
+
self.test_dataset = DownstreamDataset(root, 'test', smi_max_len, **ds_args)
|
215 |
+
self.init_tokenizer(tokenizer)
|
216 |
+
self.mol_ph_token = '<mol>' * self.args.num_query_token
|
217 |
+
self.is_gal = args.opt_model.find('galactica') >= 0
|
218 |
+
self.use_graph = not args.disable_graphs
|
219 |
+
self.is_t5 = args.opt_model.find('t5') >= 0
|
220 |
+
|
221 |
+
def init_tokenizer(self, tokenizer):
|
222 |
+
self.tokenizer = tokenizer
|
223 |
+
self.train_dataset.tokenizer = tokenizer
|
224 |
+
self.val_dataset.tokenizer = tokenizer
|
225 |
+
self.test_dataset.tokenizer = tokenizer
|
226 |
+
self.mol_token_id = self.tokenizer.mol_token_id
|
227 |
+
# self.tokenizer.mol_token_id = tokenizer("<mol>", add_special_tokens=False).input_ids[0]
|
228 |
+
|
229 |
+
def train_dataloader(self):
|
230 |
+
if self.args.roundrobin_train:
|
231 |
+
self.train_dataset.reload_data()
|
232 |
+
if hasattr(self.train_dataset, 'renew_r_smiles'):
|
233 |
+
self.train_dataset.renew_r_smiles()
|
234 |
+
loader = DataLoader(
|
235 |
+
self.train_dataset,
|
236 |
+
batch_size=self.batch_size,
|
237 |
+
shuffle=True,
|
238 |
+
num_workers=self.num_workers,
|
239 |
+
pin_memory=False,
|
240 |
+
drop_last=True,
|
241 |
+
persistent_workers=True,
|
242 |
+
collate_fn=TrainCollater(
|
243 |
+
tokenizer=self.tokenizer,
|
244 |
+
text_max_len=self.text_max_len,
|
245 |
+
rxn_max_len=self.rxn_max_len,
|
246 |
+
mol_ph=self.mol_ph_token,
|
247 |
+
mol_token_id=self.mol_token_id,
|
248 |
+
is_gal=self.is_gal,
|
249 |
+
use_graph=self.use_graph,
|
250 |
+
use_qa_pair=not self.is_t5,
|
251 |
+
),
|
252 |
+
)
|
253 |
+
return loader
|
254 |
+
|
255 |
+
def val_dataloader(self):
|
256 |
+
test_loader = DataLoader(
|
257 |
+
self.test_dataset,
|
258 |
+
batch_size=self.inference_batch_size,
|
259 |
+
shuffle=False,
|
260 |
+
num_workers=self.num_workers,
|
261 |
+
pin_memory=False,
|
262 |
+
drop_last=False,
|
263 |
+
persistent_workers=True,
|
264 |
+
collate_fn=InferenceCollater(
|
265 |
+
tokenizer=self.tokenizer,
|
266 |
+
text_max_len=self.text_max_len,
|
267 |
+
rxn_max_len=self.rxn_max_len,
|
268 |
+
mol_ph=self.mol_ph_token,
|
269 |
+
mol_token_id=self.mol_token_id,
|
270 |
+
is_gal=self.is_gal
|
271 |
+
),
|
272 |
+
)
|
273 |
+
return [test_loader]
|
274 |
+
val_loader = DataLoader(
|
275 |
+
self.val_dataset,
|
276 |
+
batch_size=self.batch_size,
|
277 |
+
shuffle=False,
|
278 |
+
num_workers=self.num_workers,
|
279 |
+
pin_memory=False,
|
280 |
+
drop_last=False,
|
281 |
+
persistent_workers=True,
|
282 |
+
collate_fn=InferenceCollater(
|
283 |
+
tokenizer=self.tokenizer,
|
284 |
+
text_max_len=self.text_max_len,
|
285 |
+
rxn_max_len=self.rxn_max_len,
|
286 |
+
mol_ph=self.mol_ph_token,
|
287 |
+
mol_token_id=self.mol_token_id,
|
288 |
+
is_gal=self.is_gal
|
289 |
+
),
|
290 |
+
)
|
291 |
+
return [val_loader, test_loader]
|
292 |
+
|
293 |
+
def test_dataloader(self):
|
294 |
+
loader = DataLoader(
|
295 |
+
self.test_dataset,
|
296 |
+
batch_size=self.inference_batch_size,
|
297 |
+
shuffle=False,
|
298 |
+
num_workers=self.num_workers,
|
299 |
+
pin_memory=False,
|
300 |
+
drop_last=False,
|
301 |
+
persistent_workers=True,
|
302 |
+
collate_fn=InferenceCollater(
|
303 |
+
tokenizer=self.tokenizer,
|
304 |
+
text_max_len=self.text_max_len,
|
305 |
+
rxn_max_len=self.rxn_max_len,
|
306 |
+
mol_ph=self.mol_ph_token,
|
307 |
+
mol_token_id=self.mol_token_id,
|
308 |
+
is_gal=self.is_gal
|
309 |
+
),
|
310 |
+
)
|
311 |
+
return loader
|
312 |
+
|
demo.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
"[BH4-].[Na+].COC(=O)c1nc2ccccn2c1C>C1CCOC1.CO>Cc1c(CO)nc2ccccn12",
|
3 |
+
"NCCN.O=C(O)c1cc2nc(-c3ccc(-c4ccccc4)cc3)c(Cl)cc2n1CO>CN(C)C=O>O=C(O)c1cc2nc(-c3ccc(-c4ccccc4)cc3)c(Cl)cc2[nH]1",
|
4 |
+
"CC[O-].[Na+].CC(C)[N+](=O)[O-].Cc1cc(C)nc(NC(=O)NS(=O)(=O)c2ccccc2CCl)n1>CCO>Cc1cc(C)nc(NC(=O)NS(=O)(=O)c2ccccc2C=O)n1",
|
5 |
+
"COC(=O)c1ccc2c(c1)nc(C(C)(C)C)n2CC1CCC(F)(F)CC1.O=S(=O)([O-])O.[K+]>[Li+].[OH-].C1COCCO1>CC(C)(C)c1nc2cc(C(=O)O)ccc2n1CC1CCC(F)(F)CC1",
|
6 |
+
"FC(F)(F)c1cccc2c(Br)c(Cc3ccccc3)cnc12.CC1(C)OB(c2cncc(C=O)c2)OC1(C)C>O=C([O-])[O-].[Na+].[Na+].Cc1ccccc1.CCO>O=Cc1cncc(-c2c(Cc3ccccc3)cnc3c(C(F)(F)F)cccc23)c1"
|
7 |
+
]
|
demo.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import warnings
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from pytorch_lightning import Trainer, strategies
|
7 |
+
import pytorch_lightning.callbacks as plc
|
8 |
+
from pytorch_lightning.loggers import CSVLogger
|
9 |
+
from pytorch_lightning.callbacks import TQDMProgressBar
|
10 |
+
from data_provider.pretrain_dm import PretrainDM
|
11 |
+
from data_provider.tune_dm import *
|
12 |
+
from model.opt_flash_attention import replace_opt_attn_with_flash_attn
|
13 |
+
from model.blip2_model import Blip2Model
|
14 |
+
from model.dist_funs import MyDeepSpeedStrategy
|
15 |
+
from data_provider.reaction_action_dataset import ActionDataset
|
16 |
+
from data_provider.data_utils import json_read, json_write
|
17 |
+
from data_provider.data_utils import smiles2data, reformat_smiles
|
18 |
+
|
19 |
+
## for pyg bug
|
20 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
21 |
+
## for A5000 gpus
|
22 |
+
torch.set_float32_matmul_precision('medium') # can be medium (bfloat16), high (tensorfloat32), highest (float32)
|
23 |
+
|
24 |
+
|
25 |
+
class InferenceRunner:
|
26 |
+
def __init__(self, model, tokenizer, rxn_max_len, smi_max_len,
|
27 |
+
smiles_type='default', device='cuda', predict_rxn_condition=True, args=None):
|
28 |
+
self.model = model
|
29 |
+
self.rxn_max_len = rxn_max_len
|
30 |
+
self.smi_max_len = smi_max_len
|
31 |
+
self.tokenizer = tokenizer
|
32 |
+
self.collater = Collater([], [])
|
33 |
+
self.mol_ph = '<mol>' * args.num_query_token
|
34 |
+
self.mol_token_id = tokenizer.mol_token_id
|
35 |
+
self.is_gal = args.opt_model.find('galactica') >= 0
|
36 |
+
self.collater = Collater([], [])
|
37 |
+
self.device = device
|
38 |
+
self.smiles_type = smiles_type
|
39 |
+
self.predict_rxn_condition = predict_rxn_condition
|
40 |
+
self.args = args
|
41 |
+
|
42 |
+
def make_prompt(self, param_dict, smi_max_len=128, predict_rxn_condition=False):
|
43 |
+
action_sequence = param_dict['actions']
|
44 |
+
smiles_list = []
|
45 |
+
prompt = ''
|
46 |
+
prompt += 'Reactants: '
|
47 |
+
smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len]
|
48 |
+
for smi in param_dict['REACTANT']:
|
49 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
50 |
+
smiles_list.append(smi)
|
51 |
+
|
52 |
+
prompt += 'Product: '
|
53 |
+
for smi in param_dict['PRODUCT']:
|
54 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
55 |
+
smiles_list.append(smi)
|
56 |
+
|
57 |
+
if param_dict['CATALYST']:
|
58 |
+
prompt += 'Catalysts: '
|
59 |
+
for smi in param_dict['CATALYST']:
|
60 |
+
if smi in param_dict["extracted_molecules"]:
|
61 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
62 |
+
else:
|
63 |
+
prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
64 |
+
smiles_list.append(smi)
|
65 |
+
|
66 |
+
if param_dict['SOLVENT']:
|
67 |
+
prompt += 'Solvents: '
|
68 |
+
for smi in param_dict['SOLVENT']:
|
69 |
+
if smi in param_dict["extracted_molecules"]:
|
70 |
+
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
71 |
+
else:
|
72 |
+
prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
|
73 |
+
smiles_list.append(smi)
|
74 |
+
|
75 |
+
if predict_rxn_condition:
|
76 |
+
for value, token in param_dict['extracted_duration'].items():
|
77 |
+
action_sequence = action_sequence.replace(token, value)
|
78 |
+
for value, token in param_dict['extracted_temperature'].items():
|
79 |
+
action_sequence = action_sequence.replace(token, value)
|
80 |
+
else:
|
81 |
+
prompt += 'Temperatures: '
|
82 |
+
for value, token in param_dict['extracted_temperature'].items():
|
83 |
+
prompt += f'{token}: {value} '
|
84 |
+
|
85 |
+
prompt += 'Durations: '
|
86 |
+
for value, token in param_dict['extracted_duration'].items():
|
87 |
+
prompt += f'{token}: {value} '
|
88 |
+
|
89 |
+
prompt += 'Action Squence: '
|
90 |
+
return prompt, smiles_list, action_sequence
|
91 |
+
|
92 |
+
def get_action_elements(self, rxn_dict):
|
93 |
+
rxn_id = rxn_dict['index']
|
94 |
+
input_text, smiles_list, output_text = self.make_prompt(rxn_dict, self.smi_max_len, self.predict_rxn_condition)
|
95 |
+
output_text = output_text.strip() + '\n'
|
96 |
+
|
97 |
+
graph_list = []
|
98 |
+
for smiles in smiles_list:
|
99 |
+
graph_item = smiles2data(smiles)
|
100 |
+
graph_list.append(graph_item)
|
101 |
+
return rxn_id, graph_list, output_text, input_text
|
102 |
+
|
103 |
+
@torch.no_grad()
|
104 |
+
def predict(self, rxn_dict):
|
105 |
+
rxn_id, graphs, prompt_tokens, output_text, input_text = self.tokenize(rxn_dict)
|
106 |
+
result_dict = {
|
107 |
+
'raw': rxn_dict,
|
108 |
+
'index': rxn_id,
|
109 |
+
'input': input_text,
|
110 |
+
'target': output_text
|
111 |
+
}
|
112 |
+
samples = {'graphs': graphs, 'prompt_tokens': prompt_tokens}
|
113 |
+
with torch.no_grad():
|
114 |
+
result_dict['prediction'] = self.model.blip2opt.generate(
|
115 |
+
samples,
|
116 |
+
do_sample=self.args.do_sample,
|
117 |
+
num_beams=self.args.num_beams,
|
118 |
+
max_length=self.args.max_inference_len,
|
119 |
+
min_length=self.args.min_inference_len,
|
120 |
+
num_captions=self.args.num_generate_captions,
|
121 |
+
use_graph=True
|
122 |
+
)
|
123 |
+
return result_dict
|
124 |
+
|
125 |
+
def tokenize(self, rxn_dict):
|
126 |
+
rxn_id, graph_list, output_text, input_text = self.get_action_elements(rxn_dict)
|
127 |
+
if graph_list:
|
128 |
+
graphs = self.collater(graph_list).to(self.device)
|
129 |
+
input_prompt = smiles_handler(input_text, self.mol_ph, self.is_gal)[0]
|
130 |
+
|
131 |
+
## deal with prompt
|
132 |
+
self.tokenizer.padding_side = 'left'
|
133 |
+
input_prompt_tokens = self.tokenizer(input_prompt,
|
134 |
+
truncation=True,
|
135 |
+
padding='max_length',
|
136 |
+
add_special_tokens=True,
|
137 |
+
max_length=self.rxn_max_len,
|
138 |
+
return_tensors='pt',
|
139 |
+
return_attention_mask=True).to(self.device)
|
140 |
+
is_mol_token = input_prompt_tokens.input_ids == self.mol_token_id
|
141 |
+
input_prompt_tokens['is_mol_token'] = is_mol_token
|
142 |
+
return rxn_id, graphs, input_prompt_tokens, output_text, input_text
|
143 |
+
|
144 |
+
|
145 |
+
def main(args):
|
146 |
+
device = torch.device('cuda')
|
147 |
+
data_list = json_read('demo.json')
|
148 |
+
pl.seed_everything(args.seed)
|
149 |
+
# model
|
150 |
+
if args.init_checkpoint:
|
151 |
+
model = Blip2Model(args).to(device)
|
152 |
+
ckpt = torch.load(args.init_checkpoint, map_location='cpu')
|
153 |
+
model.load_state_dict(ckpt['state_dict'], strict=False)
|
154 |
+
print(f"loaded model from {args.init_checkpoint}")
|
155 |
+
else:
|
156 |
+
model = Blip2Model(args).to(device)
|
157 |
+
model.eval()
|
158 |
+
|
159 |
+
print('total params:', sum(p.numel() for p in model.parameters()))
|
160 |
+
|
161 |
+
if args.opt_model.find('galactica') >= 0 or args.opt_model.find('t5') >= 0:
|
162 |
+
tokenizer = model.blip2opt.opt_tokenizer
|
163 |
+
elif args.opt_model.find('llama') >= 0 or args.opt_model.find('vicuna') >= 0:
|
164 |
+
tokenizer = model.blip2opt.llm_tokenizer
|
165 |
+
else:
|
166 |
+
raise NotImplementedError
|
167 |
+
|
168 |
+
infer_runner = InferenceRunner(
|
169 |
+
model=model,
|
170 |
+
tokenizer=tokenizer,
|
171 |
+
rxn_max_len=args.rxn_max_len,
|
172 |
+
smi_max_len=args.smi_max_len,
|
173 |
+
device=device,
|
174 |
+
predict_rxn_condition=args.predict_rxn_condition,
|
175 |
+
args=args
|
176 |
+
)
|
177 |
+
|
178 |
+
import time
|
179 |
+
for data_item in data_list:
|
180 |
+
t1 = time.time()
|
181 |
+
result = infer_runner.predict(data_item)
|
182 |
+
print(result)
|
183 |
+
print(f"Time: {time.time() - t1:.2f}s")
|
184 |
+
|
185 |
+
|
186 |
+
def get_args():
|
187 |
+
parser = argparse.ArgumentParser()
|
188 |
+
parser.add_argument('--filename', type=str, default="main")
|
189 |
+
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
190 |
+
# MM settings
|
191 |
+
parser.add_argument('--mode', type=str, default='pretrain', choices=['pretrain', 'ft', 'eval', 'pretrain_eval'])
|
192 |
+
parser.add_argument('--strategy_name', type=str, default='mydeepspeed')
|
193 |
+
parser.add_argument('--iupac_prediction', action='store_true', default=False)
|
194 |
+
parser.add_argument('--ckpt_path', type=str, default=None)
|
195 |
+
# parser = Trainer.add_argparse_args(parser)
|
196 |
+
parser = Blip2Model.add_model_specific_args(parser) # add model args
|
197 |
+
parser = PretrainDM.add_model_specific_args(parser)
|
198 |
+
parser.add_argument('--accelerator', type=str, default='gpu')
|
199 |
+
parser.add_argument('--devices', type=str, default='0,1,2,3')
|
200 |
+
parser.add_argument('--precision', type=str, default='bf16-mixed')
|
201 |
+
parser.add_argument('--downstream_task', type=str, default='action', choices=['action', 'synthesis', 'caption', 'chebi'])
|
202 |
+
parser.add_argument('--max_epochs', type=int, default=10)
|
203 |
+
parser.add_argument('--enable_flash', action='store_true', default=False)
|
204 |
+
parser.add_argument('--disable_graph_cache', action='store_true', default=False)
|
205 |
+
parser.add_argument('--predict_rxn_condition', action='store_true', default=False)
|
206 |
+
parser.add_argument('--generate_restrict_tokens', action='store_true', default=False)
|
207 |
+
parser.add_argument('--train_restrict_tokens', action='store_true', default=False)
|
208 |
+
parser.add_argument('--smiles_type', type=str, default='default', choices=['default', 'canonical', 'restricted', 'unrestricted', 'r_smiles'])
|
209 |
+
parser.add_argument('--accumulate_grad_batches', type=int, default=1)
|
210 |
+
parser.add_argument('--tqdm_interval', type=int, default=50)
|
211 |
+
parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
|
212 |
+
args = parser.parse_args()
|
213 |
+
|
214 |
+
if args.enable_flash:
|
215 |
+
replace_opt_attn_with_flash_attn()
|
216 |
+
print("=========================================")
|
217 |
+
for k, v in sorted(vars(args).items()):
|
218 |
+
print(k, '=', v)
|
219 |
+
print("=========================================")
|
220 |
+
return args
|
221 |
+
|
222 |
+
if __name__ == '__main__':
|
223 |
+
main(get_args())
|
224 |
+
|
environment.yml
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: reactxt
|
2 |
+
channels:
|
3 |
+
- pyg
|
4 |
+
- tmap
|
5 |
+
- pytorch
|
6 |
+
- nvidia
|
7 |
+
- nvidia/label/cuda-11.7.0
|
8 |
+
- conda-forge
|
9 |
+
- defaults
|
10 |
+
dependencies:
|
11 |
+
- _libgcc_mutex=0.1=main
|
12 |
+
- _openmp_mutex=5.1=1_gnu
|
13 |
+
- appdirs=1.4.4=pyhd3eb1b0_0
|
14 |
+
- asttokens=2.2.1=pyhd8ed1ab_0
|
15 |
+
- backcall=0.2.0=pyh9f0ad1d_0
|
16 |
+
- backports=1.0=pyhd8ed1ab_3
|
17 |
+
- backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
|
18 |
+
- blas=1.0=mkl
|
19 |
+
- brotlipy=0.7.0=py38h27cfd23_1003
|
20 |
+
- bzip2=1.0.8=h7b6447c_0
|
21 |
+
- ca-certificates=2023.08.22=h06a4308_0
|
22 |
+
- certifi=2023.11.17=py38h06a4308_0
|
23 |
+
- cffi=1.15.1=py38h5eee18b_3
|
24 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
25 |
+
- cryptography=39.0.1=py38h9ce1e76_2
|
26 |
+
- cuda-cccl=11.7.58=hc415cf5_0
|
27 |
+
- cuda-cudart=11.7.99=0
|
28 |
+
- cuda-cudart-dev=11.7.60=h6a7c232_0
|
29 |
+
- cuda-cupti=11.7.101=0
|
30 |
+
- cuda-driver-dev=11.7.60=0
|
31 |
+
- cuda-libraries=11.7.1=0
|
32 |
+
- cuda-libraries-dev=11.7.0=0
|
33 |
+
- cuda-nvcc=11.7.64=0
|
34 |
+
- cuda-nvrtc=11.7.99=0
|
35 |
+
- cuda-nvrtc-dev=11.7.50=heada363_0
|
36 |
+
- cuda-nvtx=11.7.91=0
|
37 |
+
- cuda-runtime=11.7.0=0
|
38 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
39 |
+
- debugpy=1.5.1=py38h295c915_0
|
40 |
+
- decorator=5.1.1=pyhd8ed1ab_0
|
41 |
+
- entrypoints=0.4=pyhd8ed1ab_0
|
42 |
+
- executing=1.2.0=pyhd8ed1ab_0
|
43 |
+
- ffmpeg=4.3=hf484d3e_0
|
44 |
+
- freetype=2.12.1=h4a9f257_0
|
45 |
+
- giflib=5.2.1=h5eee18b_3
|
46 |
+
- gmp=6.2.1=h295c915_3
|
47 |
+
- gmpy2=2.1.2=py38heeb90bb_0
|
48 |
+
- gnutls=3.6.15=he1e5248_0
|
49 |
+
- idna=3.4=py38h06a4308_0
|
50 |
+
- intel-openmp=2023.1.0=hdb19cb5_46305
|
51 |
+
- ipykernel=6.15.0=pyh210e3f2_0
|
52 |
+
- jedi=0.18.2=pyhd8ed1ab_0
|
53 |
+
- jinja2=3.1.2=py38h06a4308_0
|
54 |
+
- joblib=1.2.0=py38h06a4308_0
|
55 |
+
- jpeg=9e=h5eee18b_1
|
56 |
+
- jupyter_client=7.0.6=pyhd8ed1ab_0
|
57 |
+
- jupyter_core=4.12.0=py38h578d9bd_0
|
58 |
+
- lame=3.100=h7b6447c_0
|
59 |
+
- lcms2=2.12=h3be6417_0
|
60 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
61 |
+
- lerc=3.0=h295c915_0
|
62 |
+
- libcublas=11.10.3.66=0
|
63 |
+
- libcublas-dev=11.10.1.25=h0c8ac2b_0
|
64 |
+
- libcufft=10.7.2.124=h4fbf590_0
|
65 |
+
- libcufft-dev=10.7.2.50=h59a5ac8_0
|
66 |
+
- libcufile=1.7.0.149=0
|
67 |
+
- libcufile-dev=1.3.0.44=0
|
68 |
+
- libcurand=10.3.3.53=0
|
69 |
+
- libcurand-dev=10.2.10.50=hd49a9cd_0
|
70 |
+
- libcusolver=11.4.0.1=0
|
71 |
+
- libcusolver-dev=11.3.5.50=hc6eba6f_0
|
72 |
+
- libcusparse=11.7.4.91=0
|
73 |
+
- libcusparse-dev=11.7.3.50=hc644b96_0
|
74 |
+
- libdeflate=1.17=h5eee18b_0
|
75 |
+
- libffi=3.4.4=h6a678d5_0
|
76 |
+
- libgcc-ng=11.2.0=h1234567_1
|
77 |
+
- libgfortran-ng=11.2.0=h00389a5_1
|
78 |
+
- libgfortran5=11.2.0=h1234567_1
|
79 |
+
- libgomp=11.2.0=h1234567_1
|
80 |
+
- libiconv=1.16=h7f8727e_2
|
81 |
+
- libidn2=2.3.4=h5eee18b_0
|
82 |
+
- libnpp=11.7.4.75=0
|
83 |
+
- libnpp-dev=11.7.3.21=hb6476a9_0
|
84 |
+
- libnvjpeg=11.8.0.2=0
|
85 |
+
- libnvjpeg-dev=11.7.2.34=h2e48410_0
|
86 |
+
- libpng=1.6.39=h5eee18b_0
|
87 |
+
- libsodium=1.0.18=h36c2ea0_1
|
88 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
89 |
+
- libtasn1=4.19.0=h5eee18b_0
|
90 |
+
- libtiff=4.5.0=h6a678d5_2
|
91 |
+
- libunistring=0.9.10=h27cfd23_0
|
92 |
+
- libwebp=1.2.4=h11a3e52_1
|
93 |
+
- libwebp-base=1.2.4=h5eee18b_1
|
94 |
+
- lz4-c=1.9.4=h6a678d5_0
|
95 |
+
- markupsafe=2.1.1=py38h7f8727e_0
|
96 |
+
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
|
97 |
+
- mkl=2023.1.0=h6d00ec8_46342
|
98 |
+
- mkl-service=2.4.0=py38h5eee18b_1
|
99 |
+
- mkl_fft=1.3.6=py38h417a72b_1
|
100 |
+
- mkl_random=1.2.2=py38h417a72b_1
|
101 |
+
- mpc=1.1.0=h10f8cd9_1
|
102 |
+
- mpfr=4.0.2=hb69a4c5_1
|
103 |
+
- mpmath=1.2.1=py38h06a4308_0
|
104 |
+
- ncurses=6.4=h6a678d5_0
|
105 |
+
- nest-asyncio=1.5.6=pyhd8ed1ab_0
|
106 |
+
- nettle=3.7.3=hbbd107a_1
|
107 |
+
- networkx=2.8.4=py38h06a4308_1
|
108 |
+
- numpy-base=1.24.3=py38h060ed82_1
|
109 |
+
- ogdf=1.2.0=h2bc3f7f_0
|
110 |
+
- openh264=2.1.1=h4ff587b_0
|
111 |
+
- openssl=3.0.12=h7f8727e_0
|
112 |
+
- parso=0.8.3=pyhd8ed1ab_0
|
113 |
+
- pexpect=4.8.0=pyh1a96a4e_2
|
114 |
+
- pickleshare=0.7.5=py_1003
|
115 |
+
- pillow=9.4.0=py38h6a678d5_0
|
116 |
+
- pooch=1.4.0=pyhd3eb1b0_0
|
117 |
+
- prompt-toolkit=3.0.39=pyha770c72_0
|
118 |
+
- prompt_toolkit=3.0.39=hd8ed1ab_0
|
119 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
120 |
+
- pure_eval=0.2.2=pyhd8ed1ab_0
|
121 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
122 |
+
- pyg=2.3.0=py38_torch_2.0.0_cu117
|
123 |
+
- pygments=2.15.1=pyhd8ed1ab_0
|
124 |
+
- pyopenssl=23.0.0=py38h06a4308_0
|
125 |
+
- pyparsing=3.0.9=py38h06a4308_0
|
126 |
+
- pysocks=1.7.1=py38h06a4308_0
|
127 |
+
- python=3.8.17=h955ad1f_0
|
128 |
+
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
129 |
+
- python_abi=3.8=2_cp38
|
130 |
+
- pytorch=2.0.1=py3.8_cuda11.7_cudnn8.5.0_0
|
131 |
+
- pytorch-cuda=11.7=h778d358_5
|
132 |
+
- pytorch-mutex=1.0=cuda
|
133 |
+
- readline=8.2=h5eee18b_0
|
134 |
+
- setuptools=67.8.0=py38h06a4308_0
|
135 |
+
- six=1.16.0=pyh6c4a22f_0
|
136 |
+
- sqlite=3.41.2=h5eee18b_0
|
137 |
+
- stack_data=0.6.2=pyhd8ed1ab_0
|
138 |
+
- sympy=1.11.1=py38h06a4308_0
|
139 |
+
- tbb=2021.8.0=hdb19cb5_0
|
140 |
+
- threadpoolctl=2.2.0=pyh0d69192_0
|
141 |
+
- tk=8.6.12=h1ccaba5_0
|
142 |
+
- tmap=1.0.6=py38h2bc3f7f_0
|
143 |
+
- torchaudio=2.0.2=py38_cu117
|
144 |
+
- torchtriton=2.0.0=py38
|
145 |
+
- torchvision=0.15.2=py38_cu117
|
146 |
+
- tqdm=4.65.0=py38hb070fc8_0
|
147 |
+
- traitlets=5.9.0=pyhd8ed1ab_0
|
148 |
+
- typing_extensions=4.6.3=py38h06a4308_0
|
149 |
+
- urllib3=1.26.16=py38h06a4308_0
|
150 |
+
- wcwidth=0.2.6=pyhd8ed1ab_0
|
151 |
+
- wheel=0.38.4=py38h06a4308_0
|
152 |
+
- xz=5.4.2=h5eee18b_0
|
153 |
+
- zeromq=4.3.4=h9c3ff4c_1
|
154 |
+
- zlib=1.2.13=h5eee18b_0
|
155 |
+
- zstd=1.5.5=hc292b87_0
|
156 |
+
- pip:
|
157 |
+
- absl-py==1.4.0
|
158 |
+
- accelerate==0.20.3
|
159 |
+
- aiofiles==23.2.1
|
160 |
+
- aiohttp==3.8.4
|
161 |
+
- aiosignal==1.3.1
|
162 |
+
- aliyun-python-sdk-core==2.13.36
|
163 |
+
- aliyun-python-sdk-kms==2.16.1
|
164 |
+
- altair==4.2.2
|
165 |
+
- annotated-types==0.6.0
|
166 |
+
- antlr4-python3-runtime==4.9.3
|
167 |
+
- anyio==3.7.1
|
168 |
+
- argon2-cffi==23.1.0
|
169 |
+
- argon2-cffi-bindings==21.2.0
|
170 |
+
- arrow==1.2.3
|
171 |
+
- async-lru==2.0.4
|
172 |
+
- async-timeout==4.0.2
|
173 |
+
- attrs==23.1.0
|
174 |
+
- autocommand==2.2.2
|
175 |
+
- babel==2.13.0
|
176 |
+
- backoff==2.2.1
|
177 |
+
- backports-zoneinfo==0.2.1
|
178 |
+
- beautifulsoup4==4.12.2
|
179 |
+
- bigmodelvis==0.0.1
|
180 |
+
- binaryornot==0.4.4
|
181 |
+
- bleach==6.0.0
|
182 |
+
- blessed==1.20.0
|
183 |
+
- blinker==1.6.2
|
184 |
+
- blis==0.7.9
|
185 |
+
- braceexpand==0.1.7
|
186 |
+
- cachetools==5.3.1
|
187 |
+
- catalogue==2.0.8
|
188 |
+
- cfgv==3.3.1
|
189 |
+
- chardet==5.2.0
|
190 |
+
- cheroot==10.0.0
|
191 |
+
- cherrypy==18.8.0
|
192 |
+
- click==8.1.4
|
193 |
+
- cloudpathlib==0.16.0
|
194 |
+
- cmake==3.27.7
|
195 |
+
- colorama==0.4.6
|
196 |
+
- colour==0.1.5
|
197 |
+
- comm==0.1.4
|
198 |
+
- confection==0.1.0
|
199 |
+
- configargparse==1.7
|
200 |
+
- contexttimer==0.3.3
|
201 |
+
- contourpy==1.1.0
|
202 |
+
- cookiecutter==2.4.0
|
203 |
+
- crcmod==1.7
|
204 |
+
- croniter==1.4.1
|
205 |
+
- ctranslate2==3.20.0
|
206 |
+
- cymem==2.0.7
|
207 |
+
- datasets==2.13.1
|
208 |
+
- dateutils==0.6.12
|
209 |
+
- decord==0.6.0
|
210 |
+
- deepdiff==6.3.1
|
211 |
+
- deepspeed==0.10.1+ff7d5275
|
212 |
+
- defusedxml==0.7.1
|
213 |
+
- delta-center-client==0.0.4
|
214 |
+
- dill==0.3.6
|
215 |
+
- diskcache==5.6.3
|
216 |
+
- distlib==0.3.6
|
217 |
+
- distro==1.8.0
|
218 |
+
- dnspython==2.4.2
|
219 |
+
- docker-pycreds==0.4.0
|
220 |
+
- einops==0.6.1
|
221 |
+
- evaluate==0.4.1
|
222 |
+
- exceptiongroup==1.1.2
|
223 |
+
- faerun==0.3.20
|
224 |
+
- fairscale==0.4.4
|
225 |
+
- fastapi==0.100.0
|
226 |
+
- fastjsonschema==2.18.1
|
227 |
+
- fasttext-wheel==0.9.2
|
228 |
+
- ffmpy==0.3.1
|
229 |
+
- filelock==3.12.2
|
230 |
+
- flash-attn==2.3.3
|
231 |
+
- flask==3.0.0
|
232 |
+
- fonttools==4.40.0
|
233 |
+
- fqdn==1.5.1
|
234 |
+
- frozenlist==1.3.3
|
235 |
+
- fsspec==2023.6.0
|
236 |
+
- ftfy==6.1.1
|
237 |
+
- future==0.18.3
|
238 |
+
- gdown==4.7.1
|
239 |
+
- gitdb==4.0.10
|
240 |
+
- gitpython==3.1.37
|
241 |
+
- google-auth==2.23.2
|
242 |
+
- google-auth-oauthlib==1.0.0
|
243 |
+
- gpustat==1.1.1
|
244 |
+
- gradio-client==0.7.0
|
245 |
+
- grpcio==1.59.0
|
246 |
+
- h11==0.14.0
|
247 |
+
- hjson==3.1.0
|
248 |
+
- httpcore==1.0.2
|
249 |
+
- httpx==0.25.1
|
250 |
+
- huggingface-hub==0.16.4
|
251 |
+
- identify==2.5.24
|
252 |
+
- imageio==2.31.1
|
253 |
+
- importlib-metadata==6.8.0
|
254 |
+
- importlib-resources==6.0.0
|
255 |
+
- inflect==7.0.0
|
256 |
+
- inquirer==3.1.3
|
257 |
+
- iopath==0.1.10
|
258 |
+
- ipython==8.12.2
|
259 |
+
- ipython-genutils==0.2.0
|
260 |
+
- ipywidgets==8.1.1
|
261 |
+
- isoduration==20.11.0
|
262 |
+
- itsdangerous==2.1.2
|
263 |
+
- jaraco-collections==4.3.0
|
264 |
+
- jaraco-context==4.3.0
|
265 |
+
- jaraco-functools==3.8.0
|
266 |
+
- jaraco-text==3.12.0
|
267 |
+
- jmespath==0.10.0
|
268 |
+
- json5==0.9.14
|
269 |
+
- jsonpointer==2.4
|
270 |
+
- jsonschema==4.18.0
|
271 |
+
- jsonschema-specifications==2023.6.1
|
272 |
+
- jupyter==1.0.0
|
273 |
+
- jupyter-client==8.4.0
|
274 |
+
- jupyter-console==6.6.3
|
275 |
+
- jupyter-events==0.8.0
|
276 |
+
- jupyter-lsp==2.2.0
|
277 |
+
- jupyter-server==2.8.0
|
278 |
+
- jupyter-server-terminals==0.4.4
|
279 |
+
- jupyterlab==4.0.7
|
280 |
+
- jupyterlab-pygments==0.2.2
|
281 |
+
- jupyterlab-server==2.25.0
|
282 |
+
- jupyterlab-widgets==3.0.9
|
283 |
+
- kaggle==1.5.15
|
284 |
+
- kiwisolver==1.4.4
|
285 |
+
- langcodes==3.3.0
|
286 |
+
- lazy-loader==0.3
|
287 |
+
- levenshtein==0.23.0
|
288 |
+
- lightning==2.1.2
|
289 |
+
- lightning-cloud==0.5.37
|
290 |
+
- lightning-utilities==0.9.0
|
291 |
+
- lit==17.0.6
|
292 |
+
- littleutils==0.2.2
|
293 |
+
- lmppl==0.3.1
|
294 |
+
- lxml==4.9.3
|
295 |
+
- markdown==3.5
|
296 |
+
- markdown-it-py==3.0.0
|
297 |
+
- matplotlib==3.2.2
|
298 |
+
- mdurl==0.1.2
|
299 |
+
- mistune==3.0.2
|
300 |
+
- more-itertools==9.1.0
|
301 |
+
- multidict==6.0.4
|
302 |
+
- multiprocess==0.70.14
|
303 |
+
- murmurhash==1.0.9
|
304 |
+
- nbclient==0.8.0
|
305 |
+
- nbconvert==7.9.2
|
306 |
+
- nbformat==5.9.2
|
307 |
+
- ninja==1.11.1
|
308 |
+
- nltk==3.8.1
|
309 |
+
- nodeenv==1.8.0
|
310 |
+
- notebook==7.0.6
|
311 |
+
- notebook-shim==0.2.3
|
312 |
+
- numpy==1.24.4
|
313 |
+
- nvidia-cublas-cu11==11.10.3.66
|
314 |
+
- nvidia-cublas-cu12==12.1.3.1
|
315 |
+
- nvidia-cuda-cupti-cu11==11.7.101
|
316 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
317 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
318 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
319 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
320 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
321 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
322 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
323 |
+
- nvidia-cufft-cu11==10.9.0.58
|
324 |
+
- nvidia-cufft-cu12==11.0.2.54
|
325 |
+
- nvidia-curand-cu11==10.2.10.91
|
326 |
+
- nvidia-curand-cu12==10.3.2.106
|
327 |
+
- nvidia-cusolver-cu11==11.4.0.1
|
328 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
329 |
+
- nvidia-cusparse-cu11==11.7.4.91
|
330 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
331 |
+
- nvidia-ml-py==12.535.77
|
332 |
+
- nvidia-nccl-cu11==2.14.3
|
333 |
+
- nvidia-nccl-cu12==2.18.1
|
334 |
+
- nvidia-nvjitlink-cu12==12.3.101
|
335 |
+
- nvidia-nvtx-cu11==11.7.91
|
336 |
+
- nvidia-nvtx-cu12==12.1.105
|
337 |
+
- oauthlib==3.2.2
|
338 |
+
- ogb==1.3.6
|
339 |
+
- omegaconf==2.3.0
|
340 |
+
- openai==1.2.4
|
341 |
+
- opencv-python-headless==4.5.5.64
|
342 |
+
- opendatasets==0.1.22
|
343 |
+
- opendelta==0.3.2
|
344 |
+
- opennmt-py==3.4.1
|
345 |
+
- ordered-set==4.1.0
|
346 |
+
- orjson==3.9.10
|
347 |
+
- oss2==2.15.0
|
348 |
+
- outdated==0.2.2
|
349 |
+
- overrides==7.4.0
|
350 |
+
- packaging==23.1
|
351 |
+
- pandas==2.0.3
|
352 |
+
- pandocfilters==1.5.0
|
353 |
+
- paragraph2actions==1.5.0
|
354 |
+
- pathtools==0.1.2
|
355 |
+
- pathy==0.10.2
|
356 |
+
- peft==0.3.0
|
357 |
+
- pip==23.3.1
|
358 |
+
- pkgutil-resolve-name==1.3.10
|
359 |
+
- platformdirs==3.8.1
|
360 |
+
- plotly==5.15.0
|
361 |
+
- portalocker==2.7.0
|
362 |
+
- portend==3.2.0
|
363 |
+
- pre-commit==3.3.3
|
364 |
+
- preshed==3.0.8
|
365 |
+
- prometheus-client==0.17.1
|
366 |
+
- promise==2.3
|
367 |
+
- protobuf==3.19.6
|
368 |
+
- psutil==5.9.5
|
369 |
+
- pubchempy==1.0.4
|
370 |
+
- py-cpuinfo==9.0.0
|
371 |
+
- pyahocorasick==2.0.0
|
372 |
+
- pyarrow==12.0.1
|
373 |
+
- pyasn1==0.5.0
|
374 |
+
- pyasn1-modules==0.3.0
|
375 |
+
- pybind11==2.11.1
|
376 |
+
- pycocoevalcap==1.2
|
377 |
+
- pycocotools==2.0.6
|
378 |
+
- pycryptodome==3.18.0
|
379 |
+
- pydantic==1.10.11
|
380 |
+
- pydantic-core==2.14.3
|
381 |
+
- pydeck==0.8.1b0
|
382 |
+
- pydub==0.25.1
|
383 |
+
- pyjwt==2.7.0
|
384 |
+
- pymongo==4.6.0
|
385 |
+
- pympler==1.0.1
|
386 |
+
- pyonmttok==1.37.1
|
387 |
+
- python-editor==1.0.4
|
388 |
+
- python-json-logger==2.0.7
|
389 |
+
- python-levenshtein==0.23.0
|
390 |
+
- python-magic==0.4.27
|
391 |
+
- python-multipart==0.0.6
|
392 |
+
- python-slugify==8.0.1
|
393 |
+
- pytorch-lightning==2.0.0
|
394 |
+
- pytz==2023.3
|
395 |
+
- pytz-deprecation-shim==0.1.0.post0
|
396 |
+
- pywavelets==1.4.1
|
397 |
+
- pyyaml==6.0.1
|
398 |
+
- pyzmq==25.1.1
|
399 |
+
- qtconsole==5.4.4
|
400 |
+
- qtpy==2.4.0
|
401 |
+
- rapidfuzz==3.4.0
|
402 |
+
- rdkit==2023.3.3
|
403 |
+
- readchar==4.0.5
|
404 |
+
- referencing==0.29.1
|
405 |
+
- regex==2023.6.3
|
406 |
+
- requests==2.31.0
|
407 |
+
- requests-oauthlib==1.3.1
|
408 |
+
- responses==0.18.0
|
409 |
+
- rfc3339-validator==0.1.4
|
410 |
+
- rfc3986-validator==0.1.1
|
411 |
+
- rich==13.4.2
|
412 |
+
- rouge-score==0.1.2
|
413 |
+
- rpds-py==0.8.10
|
414 |
+
- rsa==4.9
|
415 |
+
- rxn-onmt-utils==1.1.0
|
416 |
+
- rxn-opennmt-py==1.1.5
|
417 |
+
- rxn-utils==1.6.0
|
418 |
+
- sacrebleu==2.3.1
|
419 |
+
- safetensors==0.3.1
|
420 |
+
- salesforce-lavis==1.0.0
|
421 |
+
- scikit-image==0.20.0
|
422 |
+
- scikit-learn==0.23.1
|
423 |
+
- scipy==1.4.1
|
424 |
+
- semantic-version==2.10.0
|
425 |
+
- send2trash==1.8.2
|
426 |
+
- sentencepiece==0.1.99
|
427 |
+
- sentry-sdk==1.31.0
|
428 |
+
- setproctitle==1.3.3
|
429 |
+
- shellingham==1.5.4
|
430 |
+
- smart-open==6.3.0
|
431 |
+
- smmap==5.0.1
|
432 |
+
- sniffio==1.3.0
|
433 |
+
- soupsieve==2.4.1
|
434 |
+
- spacy==3.7.2
|
435 |
+
- spacy-legacy==3.0.12
|
436 |
+
- spacy-loggers==1.0.4
|
437 |
+
- srsly==2.4.6
|
438 |
+
- starlette==0.27.0
|
439 |
+
- starsessions==1.3.0
|
440 |
+
- streamlit==1.22.0
|
441 |
+
- tabulate==0.9.0
|
442 |
+
- tempora==5.5.0
|
443 |
+
- tenacity==8.2.2
|
444 |
+
- tensorboard==2.14.0
|
445 |
+
- tensorboard-data-server==0.7.1
|
446 |
+
- terminado==0.17.1
|
447 |
+
- text-unidecode==1.3
|
448 |
+
- textdistance==4.6.0
|
449 |
+
- thinc==8.1.10
|
450 |
+
- tifffile==2023.7.10
|
451 |
+
- timm==0.4.12
|
452 |
+
- tinycss2==1.2.1
|
453 |
+
- tokenizers==0.13.3
|
454 |
+
- toml==0.10.2
|
455 |
+
- tomli==2.0.1
|
456 |
+
- tomlkit==0.12.0
|
457 |
+
- toolz==0.12.0
|
458 |
+
- torch==2.0.1
|
459 |
+
- torchmetrics==1.0.0
|
460 |
+
- torchtext==0.4.0
|
461 |
+
- tornado==6.3.3
|
462 |
+
- transformers==4.33.3
|
463 |
+
- triton==2.0.0
|
464 |
+
- typer==0.9.0
|
465 |
+
- tzdata==2023.3
|
466 |
+
- tzlocal==4.3.1
|
467 |
+
- ujson==5.8.0
|
468 |
+
- uri-template==1.3.0
|
469 |
+
- uvicorn==0.22.0
|
470 |
+
- validators==0.20.0
|
471 |
+
- virtualenv==20.23.1
|
472 |
+
- waitress==2.1.2
|
473 |
+
- wandb==0.15.5
|
474 |
+
- wasabi==1.1.2
|
475 |
+
- watchdog==3.0.0
|
476 |
+
- weasel==0.3.4
|
477 |
+
- web-py==0.62
|
478 |
+
- webcolors==1.13
|
479 |
+
- webdataset==0.2.48
|
480 |
+
- webencodings==0.5.1
|
481 |
+
- websocket-client==1.6.1
|
482 |
+
- websockets==11.0.3
|
483 |
+
- werkzeug==3.0.0
|
484 |
+
- widgetsnbextension==4.0.9
|
485 |
+
- xxhash==3.2.0
|
486 |
+
- yacs==0.1.8
|
487 |
+
- yarl==1.9.2
|
488 |
+
- zc-lockfile==3.0.post1
|
489 |
+
- zipp==3.16.0
|
figures/frameworks.jpg
ADDED
Git LFS Details
|
gin_pretrained/graphcl_80.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc8f685ace701ad71e3d82330fa78add25c573287ebc2908d9f7fddf13bc745f
|
3 |
+
size 7454162
|
graph_gen.ipynb
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"from torch_geometric.data import Data\n",
|
11 |
+
"from ogb.utils import smiles2graph\n",
|
12 |
+
"import os\n",
|
13 |
+
"import json\n",
|
14 |
+
"from rdkit import RDLogger\n",
|
15 |
+
"from rdkit import Chem\n",
|
16 |
+
"RDLogger.DisableLog('rdApp.*')\n",
|
17 |
+
"from tqdm import tqdm\n",
|
18 |
+
"import multiprocessing\n",
|
19 |
+
"\n",
|
20 |
+
"def write_json(data, filename):\n",
|
21 |
+
" with open(filename, 'w') as f:\n",
|
22 |
+
" json.dump(data, f, indent=4, ensure_ascii=False)\n",
|
23 |
+
"\n",
|
24 |
+
"def read_json(filename):\n",
|
25 |
+
" with open(filename, 'r') as f:\n",
|
26 |
+
" data = json.load(f)\n",
|
27 |
+
" return data\n",
|
28 |
+
"\n",
|
29 |
+
"def smiles2data(smiles):\n",
|
30 |
+
" graph = smiles2graph(smiles)\n",
|
31 |
+
" x = torch.from_numpy(graph['node_feat'])\n",
|
32 |
+
" edge_index = torch.from_numpy(graph['edge_index'], )\n",
|
33 |
+
" edge_attr = torch.from_numpy(graph['edge_feat'])\n",
|
34 |
+
" data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)\n",
|
35 |
+
" return data\n"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": null,
|
41 |
+
"metadata": {},
|
42 |
+
"outputs": [],
|
43 |
+
"source": [
|
44 |
+
"# make pretrain graphs\n",
|
45 |
+
"root = 'data/pretrain_data/'\n",
|
46 |
+
"mol_property_list = read_json(f'{root}/Abstract_property.json')\n",
|
47 |
+
"target_file = f'{root}/mol_graph_map.pt'\n",
|
48 |
+
"\n",
|
49 |
+
"if not os.path.exists(target_file):\n",
|
50 |
+
" mol_graph_map = {}\n",
|
51 |
+
" for mol_dict in tqdm(mol_property_list):\n",
|
52 |
+
" smiles = mol_dict['canon_smiles']\n",
|
53 |
+
" graph = smiles2data(smiles)\n",
|
54 |
+
" mol_graph_map[smiles] = graph\n",
|
55 |
+
" torch.save(mol_graph_map, target_file)"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": null,
|
61 |
+
"metadata": {},
|
62 |
+
"outputs": [],
|
63 |
+
"source": [
|
64 |
+
"# make downstrem (action prediction) graphs\n",
|
65 |
+
"root = 'data/action_data'\n",
|
66 |
+
"target_file = f'{root}/mol_graph_map.pt'\n",
|
67 |
+
"\n",
|
68 |
+
"if not os.path.exists(target_file):\n",
|
69 |
+
" all_mols = set()\n",
|
70 |
+
" reaction_list = read_json(f'{root}/processed.json')\n",
|
71 |
+
" rxn_keys = ['REACTANT', 'PRODUCT', 'CATALYST', 'SOLVENT']\n",
|
72 |
+
"\n",
|
73 |
+
" for rxn in reaction_list:\n",
|
74 |
+
" for key in rxn_keys:\n",
|
75 |
+
" for mol in rxn[key]:\n",
|
76 |
+
" if mol in all_mols:\n",
|
77 |
+
" continue\n",
|
78 |
+
" all_mols.add(mol)\n",
|
79 |
+
" mol_graph_map={}\n",
|
80 |
+
"\n",
|
81 |
+
" for smiles in all_mols:\n",
|
82 |
+
" graph = smiles2data(smiles)\n",
|
83 |
+
" mol_graph_map[smiles] = graph\n",
|
84 |
+
" torch.save(mol_graph_map, target_file)"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "code",
|
89 |
+
"execution_count": null,
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [],
|
92 |
+
"source": [
|
93 |
+
"# make downstream (retrosynthesis) graphs\n",
|
94 |
+
"root = 'data/synthesis_data'\n",
|
95 |
+
"\n",
|
96 |
+
"for folder in [\n",
|
97 |
+
" 'USPTO_50K_PtoR',\n",
|
98 |
+
" 'USPTO_50K_PtoR_aug20',\n",
|
99 |
+
" 'USPTO-MIT_PtoR_aug5',\n",
|
100 |
+
" 'USPTO-MIT_RtoP_aug5_mixed',\n",
|
101 |
+
" 'USPTO-MIT_RtoP_aug5_separated',\n",
|
102 |
+
" 'USPTO_full_pretrain_aug5_masked_token',\n",
|
103 |
+
" ]:\n",
|
104 |
+
" mol_graphid_file = f'{root}/{folder}/mol_graphid_map.json'\n",
|
105 |
+
" target_file = f'{root}/{folder}/idx_graph_map.pt'\n",
|
106 |
+
" if not os.path.exists(mol_graphid_file):\n",
|
107 |
+
" canon_idx_map = {}\n",
|
108 |
+
" mol_idx_map = {}\n",
|
109 |
+
" mol_set = set()\n",
|
110 |
+
" for mode in ['train', 'val', 'test']:\n",
|
111 |
+
" for file in ['src', 'tgt']:\n",
|
112 |
+
" if 'pretrain' in folder:\n",
|
113 |
+
" if file=='src':\n",
|
114 |
+
" continue\n",
|
115 |
+
" else:\n",
|
116 |
+
" if file=='tgt':\n",
|
117 |
+
" continue\n",
|
118 |
+
" file_path = f'{root}/{folder}/{mode}/{file}-{mode}.txt'\n",
|
119 |
+
" with open(file_path) as f:\n",
|
120 |
+
" lines = f.readlines()\n",
|
121 |
+
" for line in lines:\n",
|
122 |
+
" line = line.strip().replace(' ', '')\n",
|
123 |
+
" line = line.replace('<separated>', '.')\n",
|
124 |
+
" for smi in line.split('.'):\n",
|
125 |
+
" mol_set.add(smi)\n",
|
126 |
+
" smi_list = list(mol_set)\n",
|
127 |
+
" pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())\n",
|
128 |
+
" canon_list = pool.map(func=Chem.CanonSmiles,iterable=smi_list)\n",
|
129 |
+
" for smi, canon in zip(smi_list, canon_list):\n",
|
130 |
+
" if canon not in canon_idx_map:\n",
|
131 |
+
" canon_idx_map[canon] = len(canon_idx_map)\n",
|
132 |
+
" mol_idx_map[smi] = canon_idx_map[canon]\n",
|
133 |
+
" write_json(mol_idx_map, mol_graphid_file)\n",
|
134 |
+
" else:\n",
|
135 |
+
" mol_idx_map = read_json(mol_graphid_file)\n",
|
136 |
+
"\n",
|
137 |
+
" cid_graph_map = {}\n",
|
138 |
+
" for smiles, graph_id in mol_idx_map.items():\n",
|
139 |
+
" if graph_id in cid_graph_map:\n",
|
140 |
+
" continue\n",
|
141 |
+
" graph = smiles2data(smiles)\n",
|
142 |
+
" cid_graph_map[graph_id] = graph\n",
|
143 |
+
" torch.save(cid_graph_map, target_file)"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "code",
|
148 |
+
"execution_count": 3,
|
149 |
+
"metadata": {},
|
150 |
+
"outputs": [],
|
151 |
+
"source": [
|
152 |
+
"# make downstream (retrosynthesis) graphs\n",
|
153 |
+
"root = 'data/ChEBI-20_data'\n",
|
154 |
+
"target_file = f'{root}/cid_graph_map.pt'\n",
|
155 |
+
"\n",
|
156 |
+
"cid_graph_map = {}\n",
|
157 |
+
"if not os.path.exists(target_file):\n",
|
158 |
+
" for mode in ['train', 'validation', 'test']:\n",
|
159 |
+
" with open(f'{root}/{mode}.txt') as f:\n",
|
160 |
+
" lines = f.readlines()\n",
|
161 |
+
" for line in lines[1:]:\n",
|
162 |
+
" cid, smiles, _ = line.strip().split('\\t', maxsplit=2)\n",
|
163 |
+
" graph = smiles2data(smiles)\n",
|
164 |
+
" cid_graph_map[cid] = graph\n",
|
165 |
+
" torch.save(cid_graph_map, target_file)"
|
166 |
+
]
|
167 |
+
}
|
168 |
+
],
|
169 |
+
"metadata": {
|
170 |
+
"kernelspec": {
|
171 |
+
"display_name": "pth20v3",
|
172 |
+
"language": "python",
|
173 |
+
"name": "python3"
|
174 |
+
},
|
175 |
+
"language_info": {
|
176 |
+
"codemirror_mode": {
|
177 |
+
"name": "ipython",
|
178 |
+
"version": 3
|
179 |
+
},
|
180 |
+
"file_extension": ".py",
|
181 |
+
"mimetype": "text/x-python",
|
182 |
+
"name": "python",
|
183 |
+
"nbconvert_exporter": "python",
|
184 |
+
"pygments_lexer": "ipython3",
|
185 |
+
"version": "3.8.17"
|
186 |
+
}
|
187 |
+
},
|
188 |
+
"nbformat": 4,
|
189 |
+
"nbformat_minor": 2
|
190 |
+
}
|
lora_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_model_name_or_path": null,
|
3 |
+
"bias": "none",
|
4 |
+
"fan_in_fan_out": false,
|
5 |
+
"inference_mode": false,
|
6 |
+
"init_lora_weights": true,
|
7 |
+
"lora_alpha": 16,
|
8 |
+
"lora_dropout": 0.1,
|
9 |
+
"target_modules": ["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"],
|
10 |
+
"peft_type": "LORA",
|
11 |
+
"r": 8,
|
12 |
+
"modules_to_save": null,
|
13 |
+
"task_type": "CAUSAL_LM"
|
14 |
+
}
|
main.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import warnings
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from pytorch_lightning import Trainer, strategies
|
7 |
+
import pytorch_lightning.callbacks as plc
|
8 |
+
from pytorch_lightning.loggers import CSVLogger
|
9 |
+
from pytorch_lightning.callbacks import TQDMProgressBar
|
10 |
+
from data_provider.pretrain_dm import PretrainDM
|
11 |
+
from data_provider.tune_dm import TuneDM
|
12 |
+
from model.opt_flash_attention import replace_opt_attn_with_flash_attn
|
13 |
+
from model.blip2_model import Blip2Model
|
14 |
+
from model.dist_funs import MyDeepSpeedStrategy
|
15 |
+
|
16 |
+
## for pyg bug
|
17 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
18 |
+
## for A5000 gpus
|
19 |
+
torch.set_float32_matmul_precision('medium') # can be medium (bfloat16), high (tensorfloat32), highest (float32)
|
20 |
+
|
21 |
+
try:
|
22 |
+
class MyDDPSpawnStrategy(strategies.DDPSpawnStrategy):
|
23 |
+
def load_model_state_dict(self, checkpoint):
|
24 |
+
assert self.lightning_module is not None
|
25 |
+
self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=False)
|
26 |
+
except:
|
27 |
+
pass
|
28 |
+
|
29 |
+
def main(args):
|
30 |
+
pl.seed_everything(args.seed)
|
31 |
+
# model
|
32 |
+
if args.init_checkpoint:
|
33 |
+
model = Blip2Model(args)
|
34 |
+
ckpt = torch.load(args.init_checkpoint, map_location='cpu')
|
35 |
+
model.load_state_dict(ckpt['state_dict'], strict=False)
|
36 |
+
print(f"loaded model from {args.init_checkpoint}")
|
37 |
+
else:
|
38 |
+
model = Blip2Model(args)
|
39 |
+
|
40 |
+
print('total params:', sum(p.numel() for p in model.parameters()))
|
41 |
+
|
42 |
+
if args.opt_model.find('galactica') >= 0 or args.opt_model.find('t5') >= 0:
|
43 |
+
tokenizer = model.blip2opt.opt_tokenizer
|
44 |
+
elif args.opt_model.find('llama') >= 0 or args.opt_model.find('vicuna') >= 0:
|
45 |
+
tokenizer = model.blip2opt.llm_tokenizer
|
46 |
+
else:
|
47 |
+
raise NotImplementedError
|
48 |
+
# data
|
49 |
+
if args.mode in {'pretrain', 'pretrain_eval'}:
|
50 |
+
dm = PretrainDM(
|
51 |
+
num_workers=args.num_workers,
|
52 |
+
batch_size=args.batch_size,
|
53 |
+
root=args.root,
|
54 |
+
text_max_len=args.text_max_len,
|
55 |
+
rxn_max_len=args.rxn_max_len,
|
56 |
+
smi_max_len=args.smi_max_len,
|
57 |
+
tokenizer=tokenizer,
|
58 |
+
args=args
|
59 |
+
)
|
60 |
+
elif args.mode in {'ft', 'eval'}:
|
61 |
+
dm = TuneDM(
|
62 |
+
num_workers=args.num_workers,
|
63 |
+
batch_size=args.batch_size,
|
64 |
+
root=args.root,
|
65 |
+
text_max_len=args.text_max_len,
|
66 |
+
rxn_max_len=args.rxn_max_len,
|
67 |
+
smi_max_len=args.smi_max_len,
|
68 |
+
tokenizer=tokenizer,
|
69 |
+
downstream_task=args.downstream_task,
|
70 |
+
args=args
|
71 |
+
)
|
72 |
+
|
73 |
+
callbacks = [TQDMProgressBar(refresh_rate=args.tqdm_interval)]
|
74 |
+
## fixme save only used parameters
|
75 |
+
# callbacks.append(plc.ModelCheckpoint(dirpath="all_checkpoints/"+args.filename+"/", every_n_epochs=10, save_top_k=-1))
|
76 |
+
callbacks.append(plc.ModelCheckpoint(dirpath="all_checkpoints/"+args.filename+"/",
|
77 |
+
filename='{epoch:02d}',
|
78 |
+
every_n_epochs=args.save_every_n_epochs,
|
79 |
+
save_last=True,
|
80 |
+
save_top_k=-1,
|
81 |
+
save_on_train_epoch_end=True))
|
82 |
+
if len(args.devices.split(',')) > 1:
|
83 |
+
if args.strategy_name == 'fsdp':
|
84 |
+
strategy = strategies.DDPFullyShardedNativeStrategy()
|
85 |
+
elif args.strategy_name == 'deepspeed':
|
86 |
+
strategy = strategies.DeepSpeedStrategy(stage=3)
|
87 |
+
elif args.strategy_name == 'mydeepspeed':
|
88 |
+
strategy = MyDeepSpeedStrategy(stage=2)
|
89 |
+
else:
|
90 |
+
strategy = MyDDPSpawnStrategy(find_unused_parameters=True)
|
91 |
+
else:
|
92 |
+
strategy = None
|
93 |
+
args.devices = eval(args.devices)
|
94 |
+
logger = CSVLogger(save_dir=f'./all_checkpoints/{args.filename}/')
|
95 |
+
reload_freq = 1 if args.mode == 'pretrain' else 0
|
96 |
+
trainer = Trainer(
|
97 |
+
accelerator=args.accelerator,
|
98 |
+
devices=args.devices,
|
99 |
+
precision=args.precision,
|
100 |
+
max_epochs=args.max_epochs,
|
101 |
+
accumulate_grad_batches=args.accumulate_grad_batches,
|
102 |
+
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
103 |
+
callbacks=callbacks,
|
104 |
+
strategy=strategy,
|
105 |
+
logger=logger,
|
106 |
+
reload_dataloaders_every_n_epochs=reload_freq
|
107 |
+
# limit_train_batches=100,
|
108 |
+
)
|
109 |
+
|
110 |
+
if args.mode in {'pretrain', 'ft'}:
|
111 |
+
trainer.fit(model, datamodule=dm, ckpt_path=args.ckpt_path)
|
112 |
+
elif args.mode in {'eval', 'pretrain_eval'}:
|
113 |
+
trainer.fit_loop.epoch_progress.current.completed = args.caption_eval_epoch - 1
|
114 |
+
trainer.validate(model, datamodule=dm)
|
115 |
+
# trainer.test(model, datamodule=dm)
|
116 |
+
else:
|
117 |
+
raise NotImplementedError()
|
118 |
+
|
119 |
+
def get_args():
|
120 |
+
parser = argparse.ArgumentParser()
|
121 |
+
parser.add_argument('--filename', type=str, default="main")
|
122 |
+
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
123 |
+
# MM settings
|
124 |
+
parser.add_argument('--mode', type=str, default='pretrain', choices=['pretrain', 'ft', 'eval', 'pretrain_eval'])
|
125 |
+
parser.add_argument('--strategy_name', type=str, default='mydeepspeed')
|
126 |
+
parser.add_argument('--iupac_prediction', action='store_true', default=False)
|
127 |
+
parser.add_argument('--ckpt_path', type=str, default=None)
|
128 |
+
# parser = Trainer.add_argparse_args(parser)
|
129 |
+
parser = Blip2Model.add_model_specific_args(parser) # add model args
|
130 |
+
parser = PretrainDM.add_model_specific_args(parser)
|
131 |
+
parser.add_argument('--accelerator', type=str, default='gpu')
|
132 |
+
parser.add_argument('--devices', type=str, default='0,1,2,3')
|
133 |
+
parser.add_argument('--precision', type=str, default='bf16-mixed')
|
134 |
+
parser.add_argument('--downstream_task', type=str, default='action', choices=['action', 'synthesis', 'caption', 'chebi'])
|
135 |
+
parser.add_argument('--max_epochs', type=int, default=10)
|
136 |
+
parser.add_argument('--enable_flash', action='store_true', default=False)
|
137 |
+
parser.add_argument('--disable_graph_cache', action='store_true', default=False)
|
138 |
+
parser.add_argument('--predict_rxn_condition', action='store_true', default=False)
|
139 |
+
parser.add_argument('--generate_restrict_tokens', action='store_true', default=False)
|
140 |
+
parser.add_argument('--train_restrict_tokens', action='store_true', default=False)
|
141 |
+
parser.add_argument('--smiles_type', type=str, default='default', choices=['default', 'canonical', 'restricted', 'unrestricted', 'r_smiles'])
|
142 |
+
parser.add_argument('--accumulate_grad_batches', type=int, default=1)
|
143 |
+
parser.add_argument('--tqdm_interval', type=int, default=50)
|
144 |
+
parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
|
145 |
+
args = parser.parse_args()
|
146 |
+
|
147 |
+
if args.enable_flash:
|
148 |
+
replace_opt_attn_with_flash_attn()
|
149 |
+
print("=========================================")
|
150 |
+
for k, v in sorted(vars(args).items()):
|
151 |
+
print(k, '=', v)
|
152 |
+
print("=========================================")
|
153 |
+
return args
|
154 |
+
|
155 |
+
if __name__ == '__main__':
|
156 |
+
main(get_args())
|
157 |
+
|
model/allowed_words.json
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"30": "(",
|
3 |
+
"31": ")",
|
4 |
+
"57": "C",
|
5 |
+
"39": "1",
|
6 |
+
"89": "c",
|
7 |
+
"69": "O",
|
8 |
+
"40": "2",
|
9 |
+
"51": "=",
|
10 |
+
"2275": "CC",
|
11 |
+
"1030": "cc",
|
12 |
+
"19": "[START_SMILES]",
|
13 |
+
"20": "[END_SMILES]",
|
14 |
+
"221": "\n",
|
15 |
+
"68": "N",
|
16 |
+
"24552": "ccc",
|
17 |
+
"36": ".",
|
18 |
+
"81": "[",
|
19 |
+
"83": "]",
|
20 |
+
"41": "3",
|
21 |
+
"4162": "OC",
|
22 |
+
"60": "F",
|
23 |
+
"19321": "cccc",
|
24 |
+
"100": "n",
|
25 |
+
"2356": "Cl",
|
26 |
+
"11863": "nc",
|
27 |
+
"62": "H",
|
28 |
+
"35": "-",
|
29 |
+
"8183": "Br",
|
30 |
+
"6597": "NC",
|
31 |
+
"43888": "CCC",
|
32 |
+
"54": "@",
|
33 |
+
"29332": "@@",
|
34 |
+
"4015": "CN",
|
35 |
+
"42": "4",
|
36 |
+
"38985": "Nc",
|
37 |
+
"3027": "CO",
|
38 |
+
"73": "S",
|
39 |
+
"25": "#",
|
40 |
+
"33": "+",
|
41 |
+
"27312": "Oc",
|
42 |
+
"16288": "cn",
|
43 |
+
"8095": "nn",
|
44 |
+
"56": "B",
|
45 |
+
"1228": "sc",
|
46 |
+
"37": "/",
|
47 |
+
"63": "I",
|
48 |
+
"105": "s",
|
49 |
+
"408": "oc",
|
50 |
+
"1912": "SC",
|
51 |
+
"5965": "Si",
|
52 |
+
"43": "5",
|
53 |
+
"46183": "Cn",
|
54 |
+
"98": "l",
|
55 |
+
"101": "o",
|
56 |
+
"7662": "NS",
|
57 |
+
"5869": "NN",
|
58 |
+
"8314": "cs",
|
59 |
+
"5396": "CI",
|
60 |
+
"82": "\\",
|
61 |
+
"9835": "Sc",
|
62 |
+
"3470": "CS",
|
63 |
+
"32712": "Fc",
|
64 |
+
"2304": "OS",
|
65 |
+
"3330": "NO",
|
66 |
+
"6882": "FC",
|
67 |
+
"70": "P",
|
68 |
+
"13136": "Sn",
|
69 |
+
"12702": "Mg",
|
70 |
+
"3529": "no",
|
71 |
+
"2812": "co",
|
72 |
+
"14530": "SCC",
|
73 |
+
"6342": "rc",
|
74 |
+
"35011": "BrN",
|
75 |
+
"9677": "NH",
|
76 |
+
"283": "on",
|
77 |
+
"20938": "onc",
|
78 |
+
"37190": "COS",
|
79 |
+
"44": "6",
|
80 |
+
"17952": "OB",
|
81 |
+
"11004": "Zn",
|
82 |
+
"28819": "OO",
|
83 |
+
"2085": "ns",
|
84 |
+
"3696": "CP",
|
85 |
+
"5097": "CF",
|
86 |
+
"978": "con",
|
87 |
+
"6017": "non",
|
88 |
+
"34244": "CNS",
|
89 |
+
"4232": "occ",
|
90 |
+
"10907": "CON",
|
91 |
+
"8072": "Cu",
|
92 |
+
"13346": "CB",
|
93 |
+
"45": "7",
|
94 |
+
"16378": "sn",
|
95 |
+
"1513": "ON",
|
96 |
+
"46": "8",
|
97 |
+
"4939": "OP",
|
98 |
+
"6321": "SN",
|
99 |
+
"26505": "conc",
|
100 |
+
"6913": "Se",
|
101 |
+
"2636": "SS",
|
102 |
+
"422": "se",
|
103 |
+
"47": "9",
|
104 |
+
"48321": "SSC",
|
105 |
+
"47306": "SCN",
|
106 |
+
"15780": "CNN",
|
107 |
+
"48968": "OCI",
|
108 |
+
"27": "%",
|
109 |
+
"38": "0",
|
110 |
+
"6389": "FS",
|
111 |
+
"4864": "On",
|
112 |
+
"27133": "SCO",
|
113 |
+
"2001": "IC",
|
114 |
+
"0": "<s>",
|
115 |
+
"1": "<pad>",
|
116 |
+
"2": "</s>",
|
117 |
+
"3": "<unk>"
|
118 |
+
}
|
model/blip2.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
import contextlib
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from lavis.common.dist_utils import download_cached_file
|
15 |
+
from lavis.common.utils import is_url
|
16 |
+
from lavis.models.base_model import BaseModel
|
17 |
+
from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel
|
18 |
+
from transformers import BertTokenizer
|
19 |
+
from model.gin_model import GNN
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
class Blip2Base(BaseModel):
|
24 |
+
@classmethod
|
25 |
+
def init_tokenizer(cls):
|
26 |
+
if True:
|
27 |
+
bert_name = 'allenai/scibert_scivocab_uncased'
|
28 |
+
else:
|
29 |
+
bert_name = 'bert_pretrained/'
|
30 |
+
tokenizer = BertTokenizer.from_pretrained(bert_name)
|
31 |
+
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
32 |
+
return tokenizer
|
33 |
+
|
34 |
+
def maybe_autocast(self, dtype=torch.float16):
|
35 |
+
# if on cpu, don't use autocast
|
36 |
+
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
37 |
+
enable_autocast = self.device != torch.device("cpu")
|
38 |
+
|
39 |
+
if enable_autocast:
|
40 |
+
return torch.cuda.amp.autocast(dtype=dtype)
|
41 |
+
else:
|
42 |
+
return contextlib.nullcontext()
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def init_Qformer(cls, model_name, num_query_token, graph_width, cross_attention_freq=2):
|
46 |
+
assert model_name == 'scibert'
|
47 |
+
print("bert load scibert")
|
48 |
+
if True:
|
49 |
+
bert_name = 'allenai/scibert_scivocab_uncased'
|
50 |
+
else:
|
51 |
+
bert_name = 'bert_pretrained/'
|
52 |
+
|
53 |
+
|
54 |
+
encoder_config = BertConfig.from_pretrained(bert_name)
|
55 |
+
encoder_config.encoder_width = graph_width
|
56 |
+
# insert cross-attention layer every other block
|
57 |
+
encoder_config.add_cross_attention = True
|
58 |
+
encoder_config.cross_attention_freq = cross_attention_freq
|
59 |
+
encoder_config.query_length = num_query_token
|
60 |
+
|
61 |
+
Qformer = BertLMHeadModel.from_pretrained(
|
62 |
+
bert_name, config=encoder_config
|
63 |
+
)
|
64 |
+
query_tokens = nn.Parameter(
|
65 |
+
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
66 |
+
)
|
67 |
+
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
68 |
+
return Qformer, query_tokens
|
69 |
+
|
70 |
+
|
71 |
+
@classmethod
|
72 |
+
def init_graph_encoder(
|
73 |
+
cls, gin_num_layers, gin_hidden_dim, gin_drop_ratio):
|
74 |
+
graph_encoder = GNN(
|
75 |
+
num_layer=gin_num_layers,
|
76 |
+
emb_dim=gin_hidden_dim,
|
77 |
+
gnn_type='gin',
|
78 |
+
drop_ratio=gin_drop_ratio,
|
79 |
+
JK='last',
|
80 |
+
)
|
81 |
+
ckpt = torch.load('gin_pretrained/graphcl_80.pth', map_location=torch.device('cpu'))
|
82 |
+
missing_keys, unexpected_keys = graph_encoder.load_state_dict(ckpt, strict=False)
|
83 |
+
if len(missing_keys) or len(unexpected_keys):
|
84 |
+
print(missing_keys)
|
85 |
+
print(unexpected_keys)
|
86 |
+
|
87 |
+
ln_graph = LayerNorm(graph_encoder.num_features)
|
88 |
+
|
89 |
+
return graph_encoder, ln_graph
|
90 |
+
|
91 |
+
def load_from_pretrained(self, url_or_filename):
|
92 |
+
if is_url(url_or_filename):
|
93 |
+
cached_file = download_cached_file(
|
94 |
+
url_or_filename, check_hash=False, progress=True
|
95 |
+
)
|
96 |
+
checkpoint = torch.load(cached_file, map_location="cpu")
|
97 |
+
elif os.path.isfile(url_or_filename):
|
98 |
+
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
99 |
+
else:
|
100 |
+
raise RuntimeError("checkpoint url or path is invalid")
|
101 |
+
|
102 |
+
state_dict = checkpoint["model"]
|
103 |
+
|
104 |
+
msg = self.load_state_dict(state_dict, strict=False)
|
105 |
+
|
106 |
+
# logging.info("Missing keys {}".format(msg.missing_keys))
|
107 |
+
logging.info("load checkpoint from %s" % url_or_filename)
|
108 |
+
|
109 |
+
return msg
|
110 |
+
|
111 |
+
|
112 |
+
def disabled_train(self, mode=True):
|
113 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
114 |
+
does not change anymore."""
|
115 |
+
return self
|
116 |
+
|
117 |
+
|
118 |
+
class LayerNorm(nn.LayerNorm):
|
119 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
120 |
+
|
121 |
+
def forward(self, x: torch.Tensor, mask=None):
|
122 |
+
orig_type = x.dtype
|
123 |
+
# ret = super().forward(x.type(torch.float32))
|
124 |
+
ret = super().forward(x)
|
125 |
+
return ret.type(orig_type)
|
126 |
+
|
model/blip2_llama.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.cuda.amp import autocast as autocast
|
11 |
+
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel
|
12 |
+
|
13 |
+
from lavis.models.blip2_models.blip2 import (
|
14 |
+
# Blip2Base,
|
15 |
+
disabled_train,
|
16 |
+
)
|
17 |
+
from model.blip2 import Blip2Base
|
18 |
+
from transformers import LlamaTokenizer
|
19 |
+
from model.modeling_llama import LlamaForCausalLM
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
llama_model_list = [
|
24 |
+
"decapoda-research/llama-13b-hf",
|
25 |
+
"decapoda-research/llama-7b-hf",
|
26 |
+
]
|
27 |
+
|
28 |
+
def mask_by_len(input, lens, fill_value=0):
|
29 |
+
'''
|
30 |
+
input: shape = [N, D]
|
31 |
+
lens: shape = [N]
|
32 |
+
'''
|
33 |
+
mask = torch.arange(input.shape[1], device=input.device).reshape(1, -1)
|
34 |
+
mask = mask < lens.reshape(-1, 1)
|
35 |
+
input[mask] = fill_value
|
36 |
+
return input
|
37 |
+
|
38 |
+
# @registry.register_model("blip2")
|
39 |
+
# @registry.register_model("blip2_feature_extractor")
|
40 |
+
class Blip2Llama(Blip2Base):
|
41 |
+
"""
|
42 |
+
BLIP2 first-stage model with Q-former and ViT.
|
43 |
+
Supported model types:
|
44 |
+
- pretrained: pretrained model with vit-g
|
45 |
+
- pretrain_vitL: pretrained model with vit-large
|
46 |
+
- coco: fintuned model on coco
|
47 |
+
Usage:
|
48 |
+
>>> from lavis.models import load_model
|
49 |
+
>>> model = load_model("blip2", "pretrain")
|
50 |
+
"""
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
bert_name,
|
54 |
+
gin_num_layers,
|
55 |
+
gin_hidden_dim,
|
56 |
+
gin_drop_ratio,
|
57 |
+
tune_gnn=False,
|
58 |
+
num_query_token=32,
|
59 |
+
cross_attention_freq=2,
|
60 |
+
lora_tuning=False,
|
61 |
+
peft_dir='',
|
62 |
+
llm_model="decapoda-research/llama-7b-hf",
|
63 |
+
prompt="",
|
64 |
+
args=None,
|
65 |
+
):
|
66 |
+
super().__init__()
|
67 |
+
self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio)
|
68 |
+
self.tune_gnn = tune_gnn
|
69 |
+
if not tune_gnn:
|
70 |
+
for name, param in self.graph_encoder.named_parameters():
|
71 |
+
param.requires_grad = False
|
72 |
+
self.graph_encoder = self.graph_encoder.eval()
|
73 |
+
self.graph_encoder.train = disabled_train
|
74 |
+
logging.info("freeze graph encoder")
|
75 |
+
|
76 |
+
self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq)
|
77 |
+
### remove the unused parameters
|
78 |
+
self.Qformer.cls = None
|
79 |
+
self.Qformer.bert.embeddings.word_embeddings = None
|
80 |
+
self.Qformer.bert.embeddings.position_embeddings = None
|
81 |
+
for layer in self.Qformer.bert.encoder.layer:
|
82 |
+
layer.output = None
|
83 |
+
layer.intermediate = None
|
84 |
+
|
85 |
+
## initialize opt model
|
86 |
+
self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_model, use_fast=False, padding_side='right')
|
87 |
+
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
88 |
+
self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
|
89 |
+
self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
|
90 |
+
self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
|
91 |
+
self.llm_model = LlamaForCausalLM.from_pretrained(llm_model, torch_dtype=torch.bfloat16)
|
92 |
+
# self.llm_model = LlamaForCausalLM.from_pretrained(llm_model)
|
93 |
+
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
|
94 |
+
|
95 |
+
self.lora_tuning = lora_tuning
|
96 |
+
if lora_tuning:
|
97 |
+
if peft_dir:
|
98 |
+
self.llm_model = PeftModel.from_pretrained(self.llm_model, peft_dir, is_trainable=True)
|
99 |
+
else:
|
100 |
+
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
|
101 |
+
self.llm_model = get_peft_model(self.llm_model, peft_config)
|
102 |
+
self.llm_model.print_trainable_parameters()
|
103 |
+
else:
|
104 |
+
for name, param in self.llm_model.named_parameters():
|
105 |
+
param.requires_grad = False
|
106 |
+
|
107 |
+
## fixme: this is different from the original BLIP2
|
108 |
+
self.eos_token_id = self.llm_tokenizer(
|
109 |
+
"\n", add_special_tokens=False
|
110 |
+
).input_ids[0]
|
111 |
+
self.pad_token_id = self.llm_tokenizer.pad_token_id
|
112 |
+
|
113 |
+
self.llm_proj = nn.Linear(
|
114 |
+
self.Qformer.config.hidden_size, self.llm_model.config.hidden_size
|
115 |
+
)
|
116 |
+
|
117 |
+
## fixme: no prompt yet
|
118 |
+
self.prompt = prompt
|
119 |
+
# prompt_tokens = self.opt_tokenizer(self.prompt, return_tensors="pt")
|
120 |
+
# self.prompt_length = prompt_tokens.attention_mask.sum(1)
|
121 |
+
|
122 |
+
def forward(self, batch):
|
123 |
+
graphs, text_tokens, prompt_lens = batch
|
124 |
+
graph_embeds, graph_masks = self.graph_encoder(graphs)
|
125 |
+
if not self.tune_gnn:
|
126 |
+
graph_embeds = graph_embeds.detach()
|
127 |
+
graph_embeds = self.ln_graph(graph_embeds, graph_masks)
|
128 |
+
device = graph_embeds.device
|
129 |
+
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
|
130 |
+
query_output = self.Qformer.bert(
|
131 |
+
query_embeds=query_tokens,
|
132 |
+
encoder_hidden_states=graph_embeds,
|
133 |
+
encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
|
134 |
+
return_dict=True,
|
135 |
+
)
|
136 |
+
inputs_llm = self.llm_proj(query_output.last_hidden_state)
|
137 |
+
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(device)
|
138 |
+
targets = text_tokens.input_ids.masked_fill(
|
139 |
+
text_tokens.input_ids == self.llm_tokenizer.pad_token_id, -100
|
140 |
+
)
|
141 |
+
if self.prompt:
|
142 |
+
targets = mask_by_len(targets, prompt_lens, -100) # do not apply loss to the prompt
|
143 |
+
# targets[:, : self.prompt_length] = -100 # do not apply loss to the prompt
|
144 |
+
|
145 |
+
empty_targets = (
|
146 |
+
torch.ones(atts_llm.size(), dtype=torch.long).to(device).fill_(-100)
|
147 |
+
)
|
148 |
+
targets = torch.cat([empty_targets, targets], dim=1)
|
149 |
+
# if self.lora_tuning:
|
150 |
+
# inputs_embeds = self.llm_model.model.get_decoder().embed_tokens(text_tokens.input_ids)
|
151 |
+
# else:
|
152 |
+
# inputs_embeds = self.llm_model.model.decoder.embed_tokens(text_tokens.input_ids)
|
153 |
+
inputs_embeds = self.llm_model.get_input_embeddings()(text_tokens.input_ids)
|
154 |
+
inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
|
155 |
+
attention_mask = torch.cat([atts_llm, text_tokens.attention_mask], dim=1)
|
156 |
+
|
157 |
+
outputs = self.llm_model(
|
158 |
+
inputs_embeds=inputs_embeds,
|
159 |
+
attention_mask=attention_mask,
|
160 |
+
return_dict=True,
|
161 |
+
labels=targets,
|
162 |
+
# use_cache=False,
|
163 |
+
)
|
164 |
+
loss = outputs.loss
|
165 |
+
return {"loss": loss}
|
166 |
+
|
167 |
+
@torch.no_grad()
|
168 |
+
def generate(
|
169 |
+
self,
|
170 |
+
samples,
|
171 |
+
do_sample=False,
|
172 |
+
num_beams=5,
|
173 |
+
max_length=128,
|
174 |
+
min_length=1,
|
175 |
+
top_p=0.9,
|
176 |
+
repetition_penalty=1.0,
|
177 |
+
length_penalty=1.0,
|
178 |
+
num_captions=1,
|
179 |
+
temperature=1,
|
180 |
+
):
|
181 |
+
"""
|
182 |
+
Args:
|
183 |
+
samples (dict): A dictionary containing the following keys:
|
184 |
+
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
|
185 |
+
num_beams (int): Number of beams for beam search. 1 means no beam search.
|
186 |
+
max_length (int): The maximum length of the sequence to be generated.
|
187 |
+
min_length (int): The minimum length of the sequence to be generated.
|
188 |
+
top_p (float): The cumulative probability for nucleus sampling.
|
189 |
+
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
|
190 |
+
num_captions (int): Number of captions to be generated for each image.
|
191 |
+
Returns:
|
192 |
+
captions (list): A list of strings of length batch_size * num_captions.
|
193 |
+
"""
|
194 |
+
graphs = samples['graphs']
|
195 |
+
prompt_tokens = samples['prompt_tokens']
|
196 |
+
# prompt_lens = samples['prompt_lens']
|
197 |
+
with self.maybe_autocast():
|
198 |
+
graph_embeds, graph_masks = self.graph_encoder(graphs)
|
199 |
+
graph_embeds = self.ln_graph(graph_embeds)
|
200 |
+
|
201 |
+
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
|
202 |
+
query_output = self.Qformer.bert(
|
203 |
+
query_embeds=query_tokens,
|
204 |
+
encoder_hidden_states=graph_embeds,
|
205 |
+
encoder_attention_mask=graph_masks,
|
206 |
+
return_dict=True,
|
207 |
+
)
|
208 |
+
|
209 |
+
device = graph_embeds.device
|
210 |
+
inputs_llm = self.llm_proj(query_output.last_hidden_state)
|
211 |
+
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long, device=device)
|
212 |
+
|
213 |
+
attention_mask = torch.cat([atts_llm, prompt_tokens.attention_mask], dim=1)
|
214 |
+
|
215 |
+
if False:
|
216 |
+
if do_sample:
|
217 |
+
query_embeds = inputs_llm.repeat_interleave(num_captions, dim=0)
|
218 |
+
num_beams = 1
|
219 |
+
else:
|
220 |
+
query_embeds = inputs_llm.repeat_interleave(num_beams, dim=0)
|
221 |
+
|
222 |
+
outputs = self.llm_model.generate(
|
223 |
+
input_ids=prompt_tokens.input_ids,
|
224 |
+
query_embeds=query_embeds,
|
225 |
+
attention_mask=attention_mask,
|
226 |
+
do_sample=do_sample,
|
227 |
+
top_p=top_p,
|
228 |
+
temperature=temperature,
|
229 |
+
num_beams=num_beams,
|
230 |
+
max_new_tokens=max_length,
|
231 |
+
min_length=min_length,
|
232 |
+
eos_token_id=self.eos_token_id,
|
233 |
+
repetition_penalty=repetition_penalty,
|
234 |
+
length_penalty=length_penalty,
|
235 |
+
num_return_sequences=num_captions,
|
236 |
+
)
|
237 |
+
|
238 |
+
prompt_length = prompt_tokens.input_ids.shape[1]
|
239 |
+
output_text = self.opt_tokenizer.batch_decode(
|
240 |
+
outputs[:, prompt_length:], skip_special_tokens=True
|
241 |
+
)
|
242 |
+
else:
|
243 |
+
inputs_embeds = self.llm_model.get_input_embeddings()(prompt_tokens.input_ids)
|
244 |
+
inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
|
245 |
+
attention_mask = torch.cat([atts_llm, prompt_tokens.attention_mask], dim=1)
|
246 |
+
|
247 |
+
outputs = self.llm_model.generate(
|
248 |
+
inputs_embeds=inputs_embeds,
|
249 |
+
attention_mask=attention_mask,
|
250 |
+
do_sample=do_sample,
|
251 |
+
top_p=top_p,
|
252 |
+
temperature=temperature,
|
253 |
+
num_beams=num_beams,
|
254 |
+
max_length=max_length,
|
255 |
+
min_length=min_length,
|
256 |
+
pad_token_id=self.pad_token_id,
|
257 |
+
eos_token_id=self.eos_token_id,
|
258 |
+
repetition_penalty=repetition_penalty,
|
259 |
+
length_penalty=length_penalty,
|
260 |
+
num_return_sequences=num_captions,
|
261 |
+
# use_cache=False,
|
262 |
+
)
|
263 |
+
# outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
|
264 |
+
output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
265 |
+
output_text = [text.strip() for text in output_text]
|
266 |
+
return output_text
|
model/blip2_model.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, Dict
|
3 |
+
import torch
|
4 |
+
from model.blip2_opt import Blip2OPT
|
5 |
+
from model.blip2_llama import Blip2Llama
|
6 |
+
from model.blip2_t5 import Blip2T5
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
from torch import optim
|
9 |
+
from lavis.common.optims import LinearWarmupCosineLRScheduler, LinearWarmupStepLRScheduler
|
10 |
+
import json
|
11 |
+
from model.opt_flash_attention import replace_opt_attn_with_flash_attn, replace_opt_attn_with_original_attn
|
12 |
+
import torch.distributed as dist
|
13 |
+
from peft import LoraConfig, TaskType
|
14 |
+
from model.help_funcs import caption_evaluate, AttrDict
|
15 |
+
from transformers import Adafactor
|
16 |
+
from torch_ema import ExponentialMovingAverage
|
17 |
+
|
18 |
+
def load_ignore_unexpected(model, state_dict):
|
19 |
+
keys = set(model.state_dict().keys())
|
20 |
+
state_dict = {k: v for k, v in state_dict.items() if k in keys}
|
21 |
+
|
22 |
+
## try to print keys that are not included
|
23 |
+
model.load_state_dict(state_dict, strict=True)
|
24 |
+
|
25 |
+
|
26 |
+
# def load_ignore_mismatch(model, state_dict):
|
27 |
+
# keys = set(model.state_dict().keys())
|
28 |
+
# extra_keys = set()
|
29 |
+
# for key in state_dict:
|
30 |
+
# if key not in keys:
|
31 |
+
# extra_keys.add(key)
|
32 |
+
# missing_keys = set()
|
33 |
+
# for key in keys:
|
34 |
+
# if key not in state_dict:
|
35 |
+
# missing_keys.add(key)
|
36 |
+
# ## try to print keys that are not included
|
37 |
+
# model.load_state_dict(state_dict, strict=False)
|
38 |
+
|
39 |
+
|
40 |
+
def get_module_state_dict(state_dict, module_name):
|
41 |
+
module_state_dict = {}
|
42 |
+
for key, value in state_dict.items():
|
43 |
+
if key.startswith(module_name):
|
44 |
+
key = key[len(module_name) + 1:]
|
45 |
+
if key == '':
|
46 |
+
return value
|
47 |
+
module_state_dict[key] = value
|
48 |
+
return module_state_dict
|
49 |
+
# peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
|
50 |
+
class Blip2Model(pl.LightningModule):
|
51 |
+
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
52 |
+
if self.llm_tune != 'full':
|
53 |
+
to_be_removed = []
|
54 |
+
for key in checkpoint['state_dict']:
|
55 |
+
if key.startswith('blip2opt.opt_model') or key.startswith('blip2opt.llm_model'):
|
56 |
+
to_be_removed.append(key)
|
57 |
+
for key in to_be_removed:
|
58 |
+
checkpoint['state_dict'].pop(key)
|
59 |
+
if isinstance(self.args.save_every_n_epochs, int) and self.args.save_every_n_epochs > 0:
|
60 |
+
if self.llm_tune == 'lora' and (self.current_epoch + 1) % self.args.save_every_n_epochs == 0:
|
61 |
+
if self.local_rank == 0: # manually fix a bug in peft module
|
62 |
+
if self.args.peft_config:
|
63 |
+
peft_config = LoraConfig(**LoraConfig.from_json_file(self.args.peft_config))
|
64 |
+
else:
|
65 |
+
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=self.args.lora_r, lora_alpha=self.args.lora_alpha, lora_dropout=self.args.lora_dropout)
|
66 |
+
if hasattr(self.blip2opt, 'opt_model'):
|
67 |
+
self.blip2opt.opt_model.peft_config['default'] = peft_config
|
68 |
+
self.blip2opt.opt_model.save_pretrained(os.path.join(self.logger.save_dir, f'lora_epoch_{self.current_epoch}'))
|
69 |
+
elif hasattr(self.blip2opt, 'llm_model'):
|
70 |
+
self.blip2opt.llm_model.peft_config['default'] = peft_config
|
71 |
+
self.blip2opt.llm_model.save_pretrained(os.path.join(self.logger.save_dir, f'lora_epoch_{self.current_epoch}'))
|
72 |
+
return super().on_save_checkpoint(checkpoint)
|
73 |
+
|
74 |
+
def __init__(self, args):
|
75 |
+
super().__init__()
|
76 |
+
if isinstance(args, dict):
|
77 |
+
args = AttrDict(**args)
|
78 |
+
|
79 |
+
self.args = args
|
80 |
+
if not hasattr(args, 'do_sample'):
|
81 |
+
args.do_sample = False
|
82 |
+
self.caption_eval_epoch = args.caption_eval_epoch
|
83 |
+
self.do_sample = args.do_sample
|
84 |
+
self.num_beams = args.num_beams
|
85 |
+
self.max_inference_len = args.max_inference_len
|
86 |
+
self.min_inference_len = args.min_inference_len
|
87 |
+
self.num_generate_captions = args.num_generate_captions
|
88 |
+
self.reaction_weight = args.reaction_weight
|
89 |
+
self.llm_tune = args.llm_tune
|
90 |
+
self.enable_flash = args.enable_flash
|
91 |
+
if args.opt_model.find('galactica') >= 0:
|
92 |
+
self.blip2opt = Blip2OPT(args.bert_name, args.gin_num_layers, args.gin_hidden_dim, args.drop_ratio, args.tune_gnn, not args.not_tune_qformer, args.num_query_token, args.cross_attention_freq, args.llm_tune, args.peft_dir, args.opt_model, args.prompt, args)
|
93 |
+
elif args.opt_model.find('llama') >= 0 or args.opt_model.find('vicuna') >= 0:
|
94 |
+
self.blip2opt = Blip2Llama(args.bert_name, args.gin_num_layers, args.gin_hidden_dim, args.drop_ratio, args.tune_gnn, args.num_query_token, args.cross_attention_freq, args.llm_tune, args.peft_dir, args.opt_model, args.prompt, args)
|
95 |
+
elif args.opt_model.find('t5') >= 0:
|
96 |
+
self.blip2opt = Blip2T5(args.bert_name, args.gin_num_layers, args.gin_hidden_dim, args.drop_ratio, args.tune_gnn, args.num_query_token, args.cross_attention_freq, args.llm_tune, args.peft_dir, args.opt_model, args.prompt, args)
|
97 |
+
else:
|
98 |
+
raise NotImplementedError()
|
99 |
+
self.tokenizer = self.blip2opt.init_tokenizer()
|
100 |
+
self.mode = args.mode
|
101 |
+
self.downstream_task = args.downstream_task
|
102 |
+
self.save_hyperparameters(args)
|
103 |
+
self.save_ema_checkpoint = args.save_ema_checkpoint
|
104 |
+
if self.save_ema_checkpoint:
|
105 |
+
self.ema = ExponentialMovingAverage(self.parameters(), 0.99)
|
106 |
+
self.save_on_steps = args.save_on_steps
|
107 |
+
|
108 |
+
def load_from_stage1_checkpoint(self, path):
|
109 |
+
ckpt = torch.load(path, map_location='cpu')
|
110 |
+
state_dict = ckpt['state_dict']
|
111 |
+
graph_encoder_dict = get_module_state_dict(state_dict, 'blip2qformer.graph_encoder')
|
112 |
+
qformer_dict = get_module_state_dict(state_dict, 'blip2qformer.Qformer')
|
113 |
+
ln_graph_dict = get_module_state_dict(state_dict, 'blip2qformer.ln_graph')
|
114 |
+
qs_weight = get_module_state_dict(state_dict, 'blip2qformer.query_tokens')
|
115 |
+
load_ignore_unexpected(self.blip2opt.Qformer, qformer_dict)
|
116 |
+
self.blip2opt.graph_encoder.load_state_dict(graph_encoder_dict)
|
117 |
+
self.blip2opt.ln_graph.load_state_dict(ln_graph_dict)
|
118 |
+
self.blip2opt.query_tokens.data.copy_(qs_weight)
|
119 |
+
return self
|
120 |
+
|
121 |
+
# def load_from_stage1_checkpoint(self, path):
|
122 |
+
# ckpt = torch.load(path, map_location='cpu')
|
123 |
+
# state_dict = ckpt['state_dict']
|
124 |
+
# state_dict = {k[13:]: v for k,v in state_dict.items()}
|
125 |
+
# load_ignore_mismatch(self.blip2opt, state_dict)
|
126 |
+
# return self
|
127 |
+
|
128 |
+
def configure_optimizers(self):
|
129 |
+
if self.args.optimizer == 'adafactor':
|
130 |
+
print('Using adafactor optimizer')
|
131 |
+
optimizer = Adafactor(
|
132 |
+
self.parameters(),
|
133 |
+
lr=1e-3,
|
134 |
+
relative_step=False,
|
135 |
+
scale_parameter=False,
|
136 |
+
warmup_init=False
|
137 |
+
)
|
138 |
+
self.scheduler = None
|
139 |
+
else:
|
140 |
+
self.trainer.fit_loop.setup_data()
|
141 |
+
# self.trainer.reset_train_dataloader()
|
142 |
+
warmup_steps = min(len(self.trainer.train_dataloader), self.args.warmup_steps)
|
143 |
+
optimizer = optim.AdamW(self.parameters(), lr=self.args.init_lr, weight_decay=self.args.weight_decay)
|
144 |
+
if self.args.scheduler == 'linear_warmup_cosine_lr':
|
145 |
+
self.scheduler = LinearWarmupCosineLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, warmup_steps, self.args.warmup_lr)
|
146 |
+
elif self.args.scheduler == 'linear_warmup_step_lr':
|
147 |
+
self.scheduler = LinearWarmupStepLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, self.args.lr_decay_rate, self.args.warmup_lr, warmup_steps)
|
148 |
+
elif self.args.scheduler == 'None':
|
149 |
+
self.scheduler = None
|
150 |
+
else:
|
151 |
+
raise NotImplementedError()
|
152 |
+
return optimizer
|
153 |
+
|
154 |
+
def test_epoch_end(self, outputs):
|
155 |
+
print('test epoch end')
|
156 |
+
list_ids, list_predictions, list_targets = zip(*outputs)
|
157 |
+
predictions = [i for ii in list_predictions for i in ii]
|
158 |
+
targets = [i for ii in list_targets for i in ii]
|
159 |
+
|
160 |
+
all_ids = [None for _ in range(self.trainer.world_size)]
|
161 |
+
all_predictions = [None for _ in range(self.trainer.world_size)]
|
162 |
+
all_targets = [None for _ in range(self.trainer.world_size)]
|
163 |
+
|
164 |
+
dist.all_gather_object(all_ids, list_ids)
|
165 |
+
dist.all_gather_object(all_predictions, predictions)
|
166 |
+
dist.all_gather_object(all_targets, targets)
|
167 |
+
print(len(all_ids), len(all_predictions), len(all_targets))
|
168 |
+
if self.global_rank == 0:
|
169 |
+
print(f'saveing predictions to {self.logger.log_dir}')
|
170 |
+
|
171 |
+
all_predictions = [i for ii in all_predictions for i in ii]
|
172 |
+
all_targets = [i for ii in all_targets for i in ii]
|
173 |
+
self.save_predictions(all_ids, all_predictions, all_targets)
|
174 |
+
## fixme: I am not sure if the max length is the same as previous experiments
|
175 |
+
bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score = \
|
176 |
+
caption_evaluate(all_predictions, all_targets, self.tokenizer, self.max_inference_len * 2)
|
177 |
+
self.log("bleu2", bleu2, sync_dist=False)
|
178 |
+
self.log("bleu4", bleu4, sync_dist=False)
|
179 |
+
self.log("rouge_1", rouge_1, sync_dist=False)
|
180 |
+
self.log("rouge_2", rouge_2, sync_dist=False)
|
181 |
+
self.log("rouge_l", rouge_l, sync_dist=False)
|
182 |
+
self.log("meteor_score", meteor_score, sync_dist=False)
|
183 |
+
|
184 |
+
def save_predictions(self, rxn_ids, predictions, targets):
|
185 |
+
assert False
|
186 |
+
assert len(rxn_ids) == len(targets)
|
187 |
+
assert len(predictions) == len(targets)
|
188 |
+
with open(os.path.join(self.logger.log_dir, 'predictions.txt'), 'w', encoding='utf8') as f:
|
189 |
+
for i, p, t in zip(rxn_ids, predictions, targets):
|
190 |
+
line = {'index': i, 'prediction': p, 'target': t}
|
191 |
+
f.write(json.dumps(line, ensure_ascii=False) + '\n')
|
192 |
+
|
193 |
+
@torch.no_grad()
|
194 |
+
def test_step(self, batch, batch_idx):
|
195 |
+
assert False
|
196 |
+
|
197 |
+
def gather_dict_results(self, dict_list):
|
198 |
+
list_of_dict_list = [None for _ in range(self.trainer.world_size)]
|
199 |
+
dist.all_gather_object(list_of_dict_list, dict_list)
|
200 |
+
dict_list = [i for ii in list_of_dict_list for i in ii] ## dict list, each dict has values that are lists of predictions, etc.
|
201 |
+
keys = dict_list[0].keys()
|
202 |
+
gathered_dict = {} # each value is a list of predictions, etc.
|
203 |
+
for key in keys:
|
204 |
+
gathered_dict[key] = [i for d in dict_list for i in d[key]]
|
205 |
+
if self.num_generate_captions>1:
|
206 |
+
M = self.num_generate_captions
|
207 |
+
N = len(gathered_dict['index'])
|
208 |
+
assert len(gathered_dict['predictions'])==N*M
|
209 |
+
gathered_dict['predictions'] = [
|
210 |
+
gathered_dict['predictions'][i * M:(i + 1) * M]
|
211 |
+
for i in range(N)
|
212 |
+
]
|
213 |
+
dict_list = []
|
214 |
+
for i in range(len(gathered_dict['predictions'])):
|
215 |
+
d = {k:gathered_dict[k][i] for k in keys}
|
216 |
+
dict_list.append(d)
|
217 |
+
return dict_list
|
218 |
+
|
219 |
+
def save_results(self, dict_list, log_prefix=""):
|
220 |
+
if log_prefix:
|
221 |
+
name = f'{log_prefix}_predictions.txt'
|
222 |
+
else:
|
223 |
+
name = 'predictions.txt'
|
224 |
+
with open(os.path.join(self.logger.log_dir, name), 'w', encoding='utf8') as f:
|
225 |
+
for i in range(len(dict_list)):
|
226 |
+
f.write(json.dumps(dict_list[i], ensure_ascii=True) + '\n')
|
227 |
+
|
228 |
+
def on_validation_epoch_start(self):
|
229 |
+
if self.enable_flash:
|
230 |
+
replace_opt_attn_with_original_attn()
|
231 |
+
self.saved_dict_list = []
|
232 |
+
|
233 |
+
def on_validation_epoch_end(self):
|
234 |
+
if self.enable_flash:
|
235 |
+
replace_opt_attn_with_flash_attn()
|
236 |
+
if (self.current_epoch+1) % self.caption_eval_epoch != 0:
|
237 |
+
return
|
238 |
+
result_list = self.gather_dict_results(self.saved_dict_list)
|
239 |
+
## empty cache
|
240 |
+
self.saved_dict_list = []
|
241 |
+
if self.global_rank == 0:
|
242 |
+
self.save_results(result_list, 'epoch_{}'.format(self.current_epoch))
|
243 |
+
if self.downstream_task == 'synthesis':
|
244 |
+
return
|
245 |
+
all_predictions = [i['predictions'] for i in result_list]
|
246 |
+
all_targets = [i['targets'] for i in result_list]
|
247 |
+
bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score = \
|
248 |
+
caption_evaluate(all_predictions, all_targets, self.tokenizer, self.max_inference_len * 2)
|
249 |
+
self.log("bleu2", bleu2, sync_dist=False)
|
250 |
+
self.log("bleu4", bleu4, sync_dist=False)
|
251 |
+
self.log("rouge_1", rouge_1, sync_dist=False)
|
252 |
+
self.log("rouge_2", rouge_2, sync_dist=False)
|
253 |
+
self.log("rouge_l", rouge_l, sync_dist=False)
|
254 |
+
self.log("meteor_score", meteor_score, sync_dist=False)
|
255 |
+
|
256 |
+
@torch.no_grad()
|
257 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=1):
|
258 |
+
if dataloader_idx == 0:
|
259 |
+
return
|
260 |
+
elif dataloader_idx == 1:
|
261 |
+
if (self.current_epoch+1) % self.caption_eval_epoch != 0:
|
262 |
+
return
|
263 |
+
rxn_ids, graphs, prompt_tokens, texts, inputs = batch
|
264 |
+
###============== Captioning Results ===================###
|
265 |
+
samples = {'graphs': graphs, 'prompt_tokens': prompt_tokens}
|
266 |
+
if self.mode in {'ft', 'eval', 'pretrain_eval'}:
|
267 |
+
predictions = self.blip2opt.generate(
|
268 |
+
samples,
|
269 |
+
do_sample=self.do_sample,
|
270 |
+
num_beams=self.num_beams,
|
271 |
+
max_length=self.max_inference_len,
|
272 |
+
min_length=self.min_inference_len,
|
273 |
+
num_captions=self.num_generate_captions,
|
274 |
+
use_graph=not self.args.disable_graphs
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
raise NotImplementedError()
|
278 |
+
self.saved_dict_list.append({
|
279 |
+
'index': rxn_ids,
|
280 |
+
'input': inputs,
|
281 |
+
'predictions': predictions,
|
282 |
+
'targets': texts
|
283 |
+
})
|
284 |
+
else:
|
285 |
+
raise NotImplementedError
|
286 |
+
|
287 |
+
def on_train_start(self):
|
288 |
+
if hasattr(self, 'ema'):
|
289 |
+
self.ema.to(self.device)
|
290 |
+
|
291 |
+
def on_before_zero_grad(self, *args, **kwargs):
|
292 |
+
if self.save_ema_checkpoint:
|
293 |
+
if self.trainer.global_step % 100 == 0:
|
294 |
+
self.ema.update(self.parameters())
|
295 |
+
if self.trainer.global_step in self.save_on_steps:
|
296 |
+
checkpoint_path = os.path.join(f"all_checkpoints/{self.args.filename}/", f'step{self.trainer.global_step}.ckpt')
|
297 |
+
self.trainer.save_checkpoint(checkpoint_path)
|
298 |
+
|
299 |
+
def on_train_epoch_end(self):
|
300 |
+
save_every_n_epochs = self.args.save_every_n_epochs if self.args.save_every_n_epochs > 0 else self.args.max_epochs
|
301 |
+
if (self.current_epoch + 1) % save_every_n_epochs != 0:
|
302 |
+
return
|
303 |
+
if self.save_ema_checkpoint:
|
304 |
+
with self.ema.average_parameters():
|
305 |
+
checkpoint_path = os.path.join(f"all_checkpoints/{self.args.filename}/", f'ema_epoch{self.current_epoch}.ckpt')
|
306 |
+
self.trainer.save_checkpoint(checkpoint_path)
|
307 |
+
|
308 |
+
def training_step(self, batch, batch_idx):
|
309 |
+
if self.scheduler:
|
310 |
+
self.scheduler.step(self.trainer.current_epoch, self.trainer.global_step)
|
311 |
+
|
312 |
+
batch_size = batch[-1].input_ids.size(0)
|
313 |
+
###============== Overall Loss ===================###
|
314 |
+
if self.mode == 'ft':
|
315 |
+
loss = self.blip2opt.forward_action(batch, use_gragh=not self.args.disable_graphs)
|
316 |
+
elif self.mode == 'pretrain':
|
317 |
+
loss = self.blip2opt.forward_abstract(batch, use_gragh=not self.args.disable_graphs)
|
318 |
+
else:
|
319 |
+
raise NotImplementedError()
|
320 |
+
self.log("molecule loss", float(loss['loss']), batch_size=batch_size, sync_dist=True, prog_bar=True)
|
321 |
+
self.log("lr", self.trainer.optimizers[0].param_groups[0]['lr'], batch_size=batch_size, sync_dist=True, prog_bar=True)
|
322 |
+
return loss['loss']
|
323 |
+
|
324 |
+
@staticmethod
|
325 |
+
def add_model_specific_args(parent_parser):
|
326 |
+
parser = parent_parser.add_argument_group("GINSimclr")
|
327 |
+
# train mode
|
328 |
+
# GIN
|
329 |
+
parser.add_argument('--gin_hidden_dim', type=int, default=300)
|
330 |
+
parser.add_argument('--gin_num_layers', type=int, default=5)
|
331 |
+
parser.add_argument('--drop_ratio', type=float, default=0.0)
|
332 |
+
parser.add_argument('--tune_gnn', action='store_true', default=False)
|
333 |
+
parser.add_argument('--not_tune_qformer', action='store_true', default=False)
|
334 |
+
parser.add_argument('--disable_graphs', action='store_true', default=False)
|
335 |
+
# Bert
|
336 |
+
parser.add_argument('--bert_hidden_dim', type=int, default=2048, help='')
|
337 |
+
parser.add_argument('--bert_name', type=str, default='scibert')
|
338 |
+
parser.add_argument('--cross_attention_freq', type=int, default=2)
|
339 |
+
parser.add_argument('--num_query_token', type=int, default=8)
|
340 |
+
# OPT
|
341 |
+
parser.add_argument('--opt_model', type=str, default="facebook/galactica-1.3b")
|
342 |
+
# parser.add_argument('--prompt', type=str, default='a molecule of ')
|
343 |
+
parser.add_argument('--num_beams', type=int, default=5)
|
344 |
+
parser.add_argument('--do_sample', action='store_true', default=False)
|
345 |
+
parser.add_argument('--max_inference_len', type=int, default=512)
|
346 |
+
parser.add_argument('--min_inference_len', type=int, default=8)
|
347 |
+
parser.add_argument('--llm_tune', type=str, default='freeze')
|
348 |
+
parser.add_argument('--peft_config', type=str, default=None)
|
349 |
+
parser.add_argument('--peft_dir', type=str, default='')
|
350 |
+
|
351 |
+
parser.add_argument('--save_every_n_epochs', type=int, default=0)
|
352 |
+
## quantization
|
353 |
+
parser.add_argument('--load_in_8bit', action='store_true', default=False)
|
354 |
+
|
355 |
+
## lora config
|
356 |
+
parser.add_argument('--lora_r', type=int, default=8)
|
357 |
+
parser.add_argument('--lora_alpha', type=int, default=32)
|
358 |
+
parser.add_argument('--lora_dropout', type=int, default=0.1)
|
359 |
+
|
360 |
+
# optimization
|
361 |
+
parser.add_argument('--reaction_weight', type=float, default=1.0)
|
362 |
+
parser.add_argument('--weight_decay', type=float, default=0.05, help='optimizer weight decay')
|
363 |
+
parser.add_argument('--init_lr', type=float, default=1e-4, help='optimizer init learning rate')
|
364 |
+
parser.add_argument('--min_lr', type=float, default=1e-5, help='optimizer min learning rate')
|
365 |
+
parser.add_argument('--warmup_lr', type=float, default=1e-6, help='optimizer warmup learning rate')
|
366 |
+
parser.add_argument('--warmup_steps', type=int, default=1000, help='optimizer warmup steps')
|
367 |
+
parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='optimizer lr decay rate')
|
368 |
+
parser.add_argument('--scheduler', type=str, default='linear_warmup_cosine_lr', help='type of scheduler') # or linear_warmup_step_lr
|
369 |
+
parser.add_argument('--optimizer', type=str, default='adamw', help='type of scheduler')
|
370 |
+
parser.add_argument('--init_checkpoint', type=str, default='')
|
371 |
+
parser.add_argument('--caption_eval_epoch', type=int, default=10)
|
372 |
+
parser.add_argument('--num_generate_captions', type=int, default=1)
|
373 |
+
|
374 |
+
# OPT Config
|
375 |
+
parser.add_argument('--optconfig_attention_dropout', type=float, default=0.0)
|
376 |
+
parser.add_argument('--optconfig_dropout', type=float, default=0.0)
|
377 |
+
|
378 |
+
# others
|
379 |
+
parser.add_argument('--save_ema_checkpoint', action='store_true', default=False)
|
380 |
+
parser.add_argument('--save_on_steps', default=[], nargs='+', type=int)
|
381 |
+
return parent_parser
|
model/blip2_opt.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.cuda.amp import autocast as autocast
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from torch.nn import CrossEntropyLoss
|
13 |
+
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel
|
14 |
+
from ogb.utils import smiles2graph
|
15 |
+
from torch_geometric.loader.dataloader import Collater
|
16 |
+
from torch_geometric.data import Data
|
17 |
+
import numpy as np
|
18 |
+
from lavis.models.blip2_models.blip2 import (
|
19 |
+
# Blip2Base,
|
20 |
+
disabled_train,
|
21 |
+
)
|
22 |
+
from model.blip2 import Blip2Base
|
23 |
+
from model.help_funcs import get_not_allowed_tokens_ids
|
24 |
+
from transformers import AutoTokenizer
|
25 |
+
from transformers import OPTForCausalLM, OPTConfig
|
26 |
+
# from opendelta import LoraModel
|
27 |
+
# from opendelta.delta_models.lora import LoraConfig
|
28 |
+
# from opendelta.delta_configs
|
29 |
+
|
30 |
+
opt_model_list = [
|
31 |
+
"facebook/galactica-125m",
|
32 |
+
"facebook/galactica-1.3b",
|
33 |
+
"facebook/galactica-6.7b",
|
34 |
+
"facebook/galactica-30b",
|
35 |
+
]
|
36 |
+
|
37 |
+
def mask_by_len(input, lens, fill_value=0):
|
38 |
+
'''
|
39 |
+
input: shape = [N, D]
|
40 |
+
lens: shape = [N]
|
41 |
+
'''
|
42 |
+
mask = torch.arange(input.shape[1], device=input.device).reshape(1, -1)
|
43 |
+
mask = mask < lens.reshape(-1, 1)
|
44 |
+
input[mask] = fill_value
|
45 |
+
return input
|
46 |
+
|
47 |
+
|
48 |
+
def smiles2data(smiles):
|
49 |
+
graph = smiles2graph(smiles)
|
50 |
+
x = torch.from_numpy(graph['node_feat'])
|
51 |
+
edge_index = torch.from_numpy(graph['edge_index'], )
|
52 |
+
edge_attr = torch.from_numpy(graph['edge_feat'])
|
53 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
54 |
+
return data
|
55 |
+
|
56 |
+
import re
|
57 |
+
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
|
58 |
+
|
59 |
+
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
|
60 |
+
|
61 |
+
|
62 |
+
def _insert_split_marker(m: re.Match):
|
63 |
+
"""
|
64 |
+
Applies split marker based on a regex match of special tokens such as
|
65 |
+
[START_DNA].
|
66 |
+
|
67 |
+
Parameters
|
68 |
+
----------
|
69 |
+
n : str
|
70 |
+
Input text to split
|
71 |
+
|
72 |
+
Returns
|
73 |
+
----------
|
74 |
+
str - the text with the split token added
|
75 |
+
"""
|
76 |
+
start_token, _, sequence, end_token = m.groups()
|
77 |
+
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
|
78 |
+
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
|
79 |
+
|
80 |
+
def escape_custom_split_sequence(text):
|
81 |
+
"""
|
82 |
+
Applies custom splitting to the text for GALILEO's tokenization
|
83 |
+
|
84 |
+
Parameters
|
85 |
+
----------
|
86 |
+
text : str
|
87 |
+
Input text to split
|
88 |
+
|
89 |
+
Returns
|
90 |
+
----------
|
91 |
+
str - the text with the split token added
|
92 |
+
"""
|
93 |
+
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
|
94 |
+
|
95 |
+
def smiles_handler(text, mol_ph):
|
96 |
+
smiles_list = []
|
97 |
+
for match in CUSTOM_SEQ_RE.finditer(text):
|
98 |
+
smiles = match.group(3)
|
99 |
+
smiles_list.append(smiles)
|
100 |
+
|
101 |
+
text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text)
|
102 |
+
text = escape_custom_split_sequence(text)
|
103 |
+
return text, smiles_list
|
104 |
+
|
105 |
+
|
106 |
+
class Blip2OPT(Blip2Base):
|
107 |
+
"""
|
108 |
+
BLIP2 first-stage model with Q-former and ViT.
|
109 |
+
Supported model types:
|
110 |
+
- pretrained: pretrained model with vit-g
|
111 |
+
- pretrain_vitL: pretrained model with vit-large
|
112 |
+
- coco: fintuned model on coco
|
113 |
+
Usage:
|
114 |
+
>>> from lavis.models import load_model
|
115 |
+
>>> model = load_model("blip2", "pretrain")
|
116 |
+
"""
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
bert_name,
|
120 |
+
gin_num_layers,
|
121 |
+
gin_hidden_dim,
|
122 |
+
gin_drop_ratio,
|
123 |
+
tune_gnn=False,
|
124 |
+
tune_qformer=False,
|
125 |
+
num_query_token=32,
|
126 |
+
cross_attention_freq=2,
|
127 |
+
llm_tune='freeze',
|
128 |
+
peft_dir='',
|
129 |
+
opt_model="facebook/galactica-1.3b",
|
130 |
+
prompt="",
|
131 |
+
args=None,
|
132 |
+
):
|
133 |
+
super().__init__()
|
134 |
+
self.args = args
|
135 |
+
|
136 |
+
self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio)
|
137 |
+
self.tune_gnn = tune_gnn
|
138 |
+
self.tune_qformer = tune_qformer
|
139 |
+
if not tune_gnn:
|
140 |
+
for name, param in self.graph_encoder.named_parameters():
|
141 |
+
param.requires_grad = False
|
142 |
+
self.graph_encoder = self.graph_encoder.eval()
|
143 |
+
self.graph_encoder.train = disabled_train
|
144 |
+
logging.info("freeze graph encoder")
|
145 |
+
else:
|
146 |
+
logging.info("tune graph encoder")
|
147 |
+
|
148 |
+
self.num_query_token = num_query_token
|
149 |
+
self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq)
|
150 |
+
if not tune_qformer:
|
151 |
+
for name, param in self.Qformer.named_parameters():
|
152 |
+
param.requires_grad = False
|
153 |
+
self.Qformer = self.Qformer.eval()
|
154 |
+
self.Qformer.train = disabled_train
|
155 |
+
self.query_tokens.requires_grad = False
|
156 |
+
logging.info("freeze qformer encoder")
|
157 |
+
else:
|
158 |
+
logging.info("tune qformer encoder")
|
159 |
+
### remove the unused parameters
|
160 |
+
self.Qformer.cls = None
|
161 |
+
self.Qformer.bert.embeddings.word_embeddings = None
|
162 |
+
self.Qformer.bert.embeddings.position_embeddings = None
|
163 |
+
for layer in self.Qformer.bert.encoder.layer:
|
164 |
+
layer.output = None
|
165 |
+
layer.intermediate = None
|
166 |
+
|
167 |
+
opt_config_params = {k[len("optconfig_"):]: v for k, v in vars(args).items() if k.startswith("optconfig_")}
|
168 |
+
config = OPTConfig.from_pretrained(opt_model, **opt_config_params)
|
169 |
+
## initialize opt model
|
170 |
+
self.opt_tokenizer = AutoTokenizer.from_pretrained(opt_model, use_fast=False, padding_side='right')
|
171 |
+
self.opt_tokenizer.add_special_tokens({'pad_token': '<pad>'})
|
172 |
+
self.opt_tokenizer.add_tokens('<mol>') # molecule placeholder
|
173 |
+
self.mol_token = '<mol>'
|
174 |
+
self.opt_tokenizer.mol_token_id = self.opt_tokenizer("<mol>", add_special_tokens=False).input_ids[0]
|
175 |
+
|
176 |
+
self.collater = Collater([], [])
|
177 |
+
|
178 |
+
if opt_model == 'facebook/galactica-125m':
|
179 |
+
self.opt_model = OPTForCausalLM.from_pretrained(opt_model, config=config)
|
180 |
+
else:
|
181 |
+
if torch.cuda.is_bf16_supported():
|
182 |
+
self.opt_model = OPTForCausalLM.from_pretrained(opt_model, torch_dtype=torch.bfloat16, config=config)
|
183 |
+
else:
|
184 |
+
self.opt_model = OPTForCausalLM.from_pretrained(opt_model, torch_dtype=torch.float16, config=config)
|
185 |
+
self.opt_model.resize_token_embeddings(len(self.opt_tokenizer)) ## this will cause bug when full fine-tuning the opt model
|
186 |
+
|
187 |
+
self.llm_tune = llm_tune
|
188 |
+
if llm_tune == 'lora':
|
189 |
+
if peft_dir:
|
190 |
+
self.opt_model = PeftModel.from_pretrained(self.opt_model, peft_dir, is_trainable=True)
|
191 |
+
else:
|
192 |
+
if self.args.peft_config:
|
193 |
+
peft_config = LoraConfig(**LoraConfig.from_json_file(self.args.peft_config))
|
194 |
+
else:
|
195 |
+
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout)
|
196 |
+
self.peft_config = peft_config
|
197 |
+
self.opt_model = get_peft_model(self.opt_model, peft_config)
|
198 |
+
self.opt_model.print_trainable_parameters()
|
199 |
+
elif llm_tune == 'freeze':
|
200 |
+
for name, param in self.opt_model.named_parameters():
|
201 |
+
param.requires_grad = False
|
202 |
+
elif llm_tune == 'full':
|
203 |
+
pass
|
204 |
+
else:
|
205 |
+
raise NotImplementedError()
|
206 |
+
|
207 |
+
## fixme: this is different from the original BLIP2
|
208 |
+
if args.mode=='pretrain_eval':
|
209 |
+
self.eos_token_id = self.opt_tokenizer(
|
210 |
+
"[START_SMILES]\n", add_special_tokens=False
|
211 |
+
).input_ids
|
212 |
+
else:
|
213 |
+
self.eos_token_id = self.opt_tokenizer(
|
214 |
+
"\n", add_special_tokens=False
|
215 |
+
).input_ids[0]
|
216 |
+
|
217 |
+
self.opt_proj = nn.Linear(
|
218 |
+
self.Qformer.config.hidden_size, self.opt_model.config.hidden_size
|
219 |
+
)
|
220 |
+
|
221 |
+
## fixme: no prompt yet
|
222 |
+
self.prompt = prompt
|
223 |
+
self.rxn_batch_size = args.rxn_batch_size
|
224 |
+
self.generate_restrict_tokens = args.generate_restrict_tokens
|
225 |
+
self.train_restrict_tokens = args.train_restrict_tokens
|
226 |
+
if self.generate_restrict_tokens or self.train_restrict_tokens:
|
227 |
+
self.bad_words_ids = get_not_allowed_tokens_ids(opt_model)
|
228 |
+
# prompt_tokens = self.opt_tokenizer(self.prompt, return_tensors="pt")
|
229 |
+
# self.prompt_length = prompt_tokens.attention_mask.sum(1)
|
230 |
+
|
231 |
+
def opt_forward_v2(
|
232 |
+
self,
|
233 |
+
inputs_embeds,
|
234 |
+
attention_mask,
|
235 |
+
labels,
|
236 |
+
bad_word_ids=None,
|
237 |
+
):
|
238 |
+
output = self.opt_model(
|
239 |
+
inputs_embeds=inputs_embeds,
|
240 |
+
attention_mask=attention_mask,
|
241 |
+
return_dict=True,
|
242 |
+
labels=labels,
|
243 |
+
)
|
244 |
+
logits = output.logits
|
245 |
+
labels = labels.to(logits.device)
|
246 |
+
# Shift so that tokens < n predict n
|
247 |
+
|
248 |
+
if bad_word_ids:
|
249 |
+
bad_word_ids = torch.tensor(bad_word_ids, device=logits.device, dtype=torch.long)
|
250 |
+
bad_word_ids = bad_word_ids.squeeze()
|
251 |
+
logits[:, :, bad_word_ids] = -100
|
252 |
+
|
253 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
254 |
+
shift_labels = labels[..., 1:].contiguous()
|
255 |
+
shift_logits = shift_logits.view(-1, self.opt_model.config.vocab_size)
|
256 |
+
loss_fct = CrossEntropyLoss()
|
257 |
+
loss = loss_fct(shift_logits, shift_labels.view(-1))
|
258 |
+
return loss
|
259 |
+
|
260 |
+
def forward_action(self, batch, use_gragh=True):
|
261 |
+
# batch unpack
|
262 |
+
rxn_ids, graphs, text_tokens = batch
|
263 |
+
if use_gragh:
|
264 |
+
graph_embeds, graph_masks = self.graph_encoder(graphs)
|
265 |
+
if not self.tune_gnn:
|
266 |
+
graph_embeds = graph_embeds.detach()
|
267 |
+
|
268 |
+
# graph embedding calculation
|
269 |
+
graph_embeds = self.ln_graph(graph_embeds, graph_masks)
|
270 |
+
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
|
271 |
+
query_output = self.Qformer.bert(
|
272 |
+
query_embeds=query_tokens,
|
273 |
+
encoder_hidden_states=graph_embeds,
|
274 |
+
encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
|
275 |
+
return_dict=True,
|
276 |
+
)
|
277 |
+
mol_tokens = self.opt_proj(query_output.last_hidden_state) # graph_num x num_query_token x D
|
278 |
+
else:
|
279 |
+
del graphs
|
280 |
+
|
281 |
+
pad_mask = text_tokens.input_ids == self.opt_tokenizer.pad_token_id
|
282 |
+
targets = text_tokens.input_ids.masked_fill(pad_mask, -100)
|
283 |
+
targets = targets.masked_fill(text_tokens.is_mol_token, -100)
|
284 |
+
targets = targets.masked_fill(text_tokens.token_type_ids == 0, -100)
|
285 |
+
|
286 |
+
inputs_embeds = self.opt_model.get_input_embeddings()(text_tokens.input_ids)
|
287 |
+
if use_gragh:
|
288 |
+
inputs_embeds[text_tokens.is_mol_token] = mol_tokens.flatten(0, 1) # graph_num x emb_dim
|
289 |
+
|
290 |
+
if self.train_restrict_tokens:
|
291 |
+
loss = self.opt_forward_v2(
|
292 |
+
inputs_embeds=inputs_embeds,
|
293 |
+
attention_mask=text_tokens.attention_mask,
|
294 |
+
labels=targets,
|
295 |
+
bad_word_ids=self.bad_words_ids,
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
outputs = self.opt_model(
|
299 |
+
inputs_embeds=inputs_embeds,
|
300 |
+
attention_mask=text_tokens.attention_mask,
|
301 |
+
return_dict=True,
|
302 |
+
labels=targets,
|
303 |
+
)
|
304 |
+
loss = outputs.loss
|
305 |
+
return {"loss": loss}
|
306 |
+
|
307 |
+
def forward_abstract(self, batch, use_gragh=True):
|
308 |
+
# batch unpack
|
309 |
+
graphs, text_tokens = batch
|
310 |
+
if use_gragh:
|
311 |
+
graph_embeds, graph_masks = self.graph_encoder(graphs)
|
312 |
+
if not self.tune_gnn:
|
313 |
+
graph_embeds = graph_embeds.detach()
|
314 |
+
|
315 |
+
# graph embedding calculation
|
316 |
+
graph_embeds = self.ln_graph(graph_embeds, graph_masks)
|
317 |
+
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
|
318 |
+
query_output = self.Qformer.bert(
|
319 |
+
query_embeds=query_tokens,
|
320 |
+
encoder_hidden_states=graph_embeds,
|
321 |
+
encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
|
322 |
+
return_dict=True,
|
323 |
+
)
|
324 |
+
mol_tokens = self.opt_proj(query_output.last_hidden_state) # graph_num x num_query_token x D
|
325 |
+
else:
|
326 |
+
del graphs
|
327 |
+
|
328 |
+
pad_mask = text_tokens.input_ids == self.opt_tokenizer.pad_token_id
|
329 |
+
targets = text_tokens.input_ids.masked_fill(pad_mask, -100)
|
330 |
+
targets = targets.masked_fill(text_tokens.is_mol_token, -100)
|
331 |
+
|
332 |
+
inputs_embeds = self.opt_model.get_input_embeddings()(text_tokens.input_ids)
|
333 |
+
if use_gragh:
|
334 |
+
inputs_embeds[text_tokens.is_mol_token] = mol_tokens.flatten(0, 1) # graph_num x emb_dim
|
335 |
+
|
336 |
+
outputs = self.opt_model(
|
337 |
+
inputs_embeds=inputs_embeds,
|
338 |
+
attention_mask=text_tokens.attention_mask,
|
339 |
+
return_dict=True,
|
340 |
+
labels=targets,
|
341 |
+
)
|
342 |
+
loss = outputs.loss
|
343 |
+
return {"loss": loss}
|
344 |
+
|
345 |
+
@torch.no_grad()
|
346 |
+
def generate(
|
347 |
+
self,
|
348 |
+
samples,
|
349 |
+
do_sample=False,
|
350 |
+
num_beams=5,
|
351 |
+
max_length=128,
|
352 |
+
min_length=1,
|
353 |
+
top_p=0.9,
|
354 |
+
repetition_penalty=1.0,
|
355 |
+
length_penalty=1.0,
|
356 |
+
num_captions=1,
|
357 |
+
temperature=1,
|
358 |
+
use_graph=True,
|
359 |
+
):
|
360 |
+
"""
|
361 |
+
Args:
|
362 |
+
samples (dict): A dictionary containing the following keys:
|
363 |
+
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
|
364 |
+
num_beams (int): Number of beams for beam search. 1 means no beam search.
|
365 |
+
max_length (int): The maximum length of the sequence to be generated.
|
366 |
+
min_length (int): The minimum length of the sequence to be generated.
|
367 |
+
top_p (float): The cumulative probability for nucleus sampling.
|
368 |
+
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
|
369 |
+
num_captions (int): Number of captions to be generated for each image.
|
370 |
+
Returns:
|
371 |
+
captions (list): A list of strings of length batch_size * num_captions.
|
372 |
+
"""
|
373 |
+
graphs = samples['graphs']
|
374 |
+
prompt_tokens = samples['prompt_tokens']
|
375 |
+
# prompt_lens = samples['prompt_lens']
|
376 |
+
# with self.maybe_autocast():
|
377 |
+
if use_graph:
|
378 |
+
graph_embeds, graph_masks = self.graph_encoder(graphs)
|
379 |
+
graph_embeds = self.ln_graph(graph_embeds)
|
380 |
+
|
381 |
+
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
|
382 |
+
query_output = self.Qformer.bert(
|
383 |
+
query_embeds=query_tokens,
|
384 |
+
encoder_hidden_states=graph_embeds,
|
385 |
+
encoder_attention_mask=graph_masks,
|
386 |
+
return_dict=True,
|
387 |
+
)
|
388 |
+
mol_tokens = self.opt_proj(query_output.last_hidden_state)
|
389 |
+
|
390 |
+
prompt_embeds = self.opt_model.get_input_embeddings()(prompt_tokens.input_ids)
|
391 |
+
if use_graph:
|
392 |
+
prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(dtype=prompt_embeds.dtype)
|
393 |
+
extra_params = {}
|
394 |
+
if self.generate_restrict_tokens:
|
395 |
+
extra_params['bad_words_ids'] = self.bad_words_ids
|
396 |
+
|
397 |
+
outputs = self.opt_model.generate(
|
398 |
+
inputs_embeds=prompt_embeds,
|
399 |
+
attention_mask=prompt_tokens.attention_mask,
|
400 |
+
do_sample=do_sample,
|
401 |
+
top_p=top_p,
|
402 |
+
temperature=temperature,
|
403 |
+
num_beams=num_beams,
|
404 |
+
max_length=max_length,
|
405 |
+
min_length=min_length,
|
406 |
+
# pad_token_id=self.pad_token_id,
|
407 |
+
eos_token_id=self.eos_token_id,
|
408 |
+
repetition_penalty=repetition_penalty,
|
409 |
+
length_penalty=length_penalty,
|
410 |
+
num_return_sequences=num_captions,
|
411 |
+
# use_cache=False,
|
412 |
+
**extra_params
|
413 |
+
)
|
414 |
+
output_text = self.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
415 |
+
|
416 |
+
output_text = [text.strip() for text in output_text]
|
417 |
+
return output_text
|
model/blip2_t5.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.cuda.amp import autocast as autocast
|
11 |
+
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
|
12 |
+
from lavis.models.blip2_models.blip2 import disabled_train
|
13 |
+
from model.blip2 import Blip2Base
|
14 |
+
# from model.smiles_t5_captioning
|
15 |
+
from lavis.models.blip2_models.modeling_t5 import T5ForConditionalGeneration
|
16 |
+
from transformers import AutoTokenizer, T5TokenizerFast
|
17 |
+
#, T5ForConditionalGeneration
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
class Blip2T5(Blip2Base):
|
23 |
+
"""
|
24 |
+
BLIP2 first-stage model with Q-former and ViT.
|
25 |
+
Supported model types:
|
26 |
+
- pretrained: pretrained model with vit-g
|
27 |
+
- pretrain_vitL: pretrained model with vit-large
|
28 |
+
- coco: fintuned model on coco
|
29 |
+
Usage:
|
30 |
+
>>> from lavis.models import load_model
|
31 |
+
>>> model = load_model("blip2", "pretrain")
|
32 |
+
"""
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
bert_name,
|
36 |
+
gin_num_layers,
|
37 |
+
gin_hidden_dim,
|
38 |
+
gin_drop_ratio,
|
39 |
+
tune_gnn=False,
|
40 |
+
num_query_token=32,
|
41 |
+
cross_attention_freq=2,
|
42 |
+
llm_tune='freeze',
|
43 |
+
peft_dir='',
|
44 |
+
opt_model="facebook/galactica-1.3b",
|
45 |
+
prompt="",
|
46 |
+
args=None,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.args = args
|
50 |
+
|
51 |
+
self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio)
|
52 |
+
self.tune_gnn = tune_gnn
|
53 |
+
if not tune_gnn:
|
54 |
+
for name, param in self.graph_encoder.named_parameters():
|
55 |
+
param.requires_grad = False
|
56 |
+
self.graph_encoder = self.graph_encoder.eval()
|
57 |
+
self.graph_encoder.train = disabled_train
|
58 |
+
logging.info("freeze graph encoder")
|
59 |
+
|
60 |
+
self.num_query_token = num_query_token
|
61 |
+
self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq)
|
62 |
+
### remove the unused parameters
|
63 |
+
self.Qformer.cls = None
|
64 |
+
self.Qformer.bert.embeddings.word_embeddings = None
|
65 |
+
self.Qformer.bert.embeddings.position_embeddings = None
|
66 |
+
for layer in self.Qformer.bert.encoder.layer:
|
67 |
+
layer.output = None
|
68 |
+
layer.intermediate = None
|
69 |
+
|
70 |
+
# assert opt_model == 'laituan245/molt5-large'
|
71 |
+
## initialize opt model
|
72 |
+
# self.opt_tokenizer = AutoTokenizer.from_pretrained(opt_model)
|
73 |
+
self.opt_tokenizer = T5TokenizerFast.from_pretrained(opt_model)
|
74 |
+
self.opt_tokenizer.add_tokens('<mol>') # molecule placeholder
|
75 |
+
self.mol_token = '<mol>'
|
76 |
+
self.opt_tokenizer.mol_token_id = self.opt_tokenizer("<mol>", add_special_tokens=False).input_ids[0]
|
77 |
+
|
78 |
+
self.opt_model = T5ForConditionalGeneration.from_pretrained(opt_model, torch_dtype=torch.float32)
|
79 |
+
self.opt_model.resize_token_embeddings(len(self.opt_tokenizer)) ## this will cause bug when full fine-tuning the opt model
|
80 |
+
|
81 |
+
self.llm_tune = llm_tune
|
82 |
+
if llm_tune == 'lora':
|
83 |
+
if peft_dir:
|
84 |
+
self.opt_model = PeftModel.from_pretrained(self.opt_model, peft_dir, is_trainable=True)
|
85 |
+
else:
|
86 |
+
if self.args.peft_config:
|
87 |
+
peft_config = LoraConfig(**LoraConfig.from_json_file(self.args.peft_config))
|
88 |
+
else:
|
89 |
+
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout)
|
90 |
+
self.peft_config = peft_config
|
91 |
+
self.opt_model = get_peft_model(self.opt_model, peft_config)
|
92 |
+
self.opt_model.print_trainable_parameters()
|
93 |
+
elif llm_tune == 'freeze':
|
94 |
+
for name, param in self.opt_model.named_parameters():
|
95 |
+
param.requires_grad = False
|
96 |
+
elif llm_tune == 'full':
|
97 |
+
pass
|
98 |
+
else:
|
99 |
+
raise NotImplementedError()
|
100 |
+
|
101 |
+
## fixme: this is different from the original BLIP2
|
102 |
+
# self.eos_token_id = self.opt_tokenizer(
|
103 |
+
# "\n", add_special_tokens=False
|
104 |
+
# ).input_ids[0]
|
105 |
+
self.eos_token_id = self.opt_tokenizer(
|
106 |
+
"</s>", add_special_tokens=False
|
107 |
+
).input_ids[0]
|
108 |
+
|
109 |
+
self.opt_proj = nn.Linear(
|
110 |
+
self.Qformer.config.hidden_size, self.opt_model.config.hidden_size
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
def forward(self, batch):
|
115 |
+
graphs, prompt_tokens, text_tokens = batch
|
116 |
+
|
117 |
+
graph_embeds, graph_masks = self.graph_encoder(graphs)
|
118 |
+
if not self.tune_gnn:
|
119 |
+
graph_embeds = graph_embeds.detach()
|
120 |
+
graph_embeds = self.ln_graph(graph_embeds, graph_masks)
|
121 |
+
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
|
122 |
+
query_output = self.Qformer.bert(
|
123 |
+
query_embeds=query_tokens,
|
124 |
+
encoder_hidden_states=graph_embeds,
|
125 |
+
encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
|
126 |
+
return_dict=True,
|
127 |
+
)
|
128 |
+
mol_tokens = self.opt_proj(query_output.last_hidden_state)
|
129 |
+
|
130 |
+
targets = text_tokens.input_ids.masked_fill(
|
131 |
+
text_tokens.input_ids == self.opt_tokenizer.pad_token_id, -100
|
132 |
+
)
|
133 |
+
with self.maybe_autocast(torch.float32):
|
134 |
+
prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
|
135 |
+
prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32)
|
136 |
+
outputs = self.opt_model(
|
137 |
+
inputs_embeds=prompt_embeds,
|
138 |
+
attention_mask=prompt_tokens.attention_mask,
|
139 |
+
decoder_attention_mask=text_tokens.attention_mask,
|
140 |
+
return_dict=True,
|
141 |
+
labels=targets,
|
142 |
+
)
|
143 |
+
loss = outputs.loss
|
144 |
+
return {"loss": loss}
|
145 |
+
|
146 |
+
def forward_action(self, batch, use_gragh=True):
|
147 |
+
rxn_ids, graphs, prompt_tokens, text_tokens = batch
|
148 |
+
if use_gragh:
|
149 |
+
graph_embeds, graph_masks = self.graph_encoder(graphs)
|
150 |
+
if not self.tune_gnn:
|
151 |
+
graph_embeds = graph_embeds.detach()
|
152 |
+
graph_embeds = self.ln_graph(graph_embeds, graph_masks)
|
153 |
+
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
|
154 |
+
query_output = self.Qformer.bert(
|
155 |
+
query_embeds=query_tokens,
|
156 |
+
encoder_hidden_states=graph_embeds,
|
157 |
+
encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
|
158 |
+
return_dict=True,
|
159 |
+
)
|
160 |
+
mol_tokens = self.opt_proj(query_output.last_hidden_state)
|
161 |
+
else:
|
162 |
+
del graphs
|
163 |
+
|
164 |
+
targets = text_tokens.input_ids.masked_fill(
|
165 |
+
text_tokens.input_ids == self.opt_tokenizer.pad_token_id, -100
|
166 |
+
)
|
167 |
+
with self.maybe_autocast(torch.float32):
|
168 |
+
prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
|
169 |
+
if use_gragh:
|
170 |
+
prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32)
|
171 |
+
outputs = self.opt_model(
|
172 |
+
inputs_embeds=prompt_embeds,
|
173 |
+
attention_mask=prompt_tokens.attention_mask,
|
174 |
+
decoder_attention_mask=text_tokens.attention_mask,
|
175 |
+
return_dict=True,
|
176 |
+
labels=targets,
|
177 |
+
)
|
178 |
+
loss = outputs.loss
|
179 |
+
return {"loss": loss}
|
180 |
+
|
181 |
+
|
182 |
+
@torch.no_grad()
|
183 |
+
def generate(
|
184 |
+
self,
|
185 |
+
samples,
|
186 |
+
do_sample=False,
|
187 |
+
num_beams=5,
|
188 |
+
max_length=128,
|
189 |
+
min_length=1,
|
190 |
+
top_p=0.9,
|
191 |
+
repetition_penalty=1.0,
|
192 |
+
length_penalty=1.0,
|
193 |
+
num_captions=1,
|
194 |
+
temperature=1,
|
195 |
+
):
|
196 |
+
"""
|
197 |
+
Args:
|
198 |
+
samples (dict): A dictionary containing the following keys:
|
199 |
+
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
|
200 |
+
num_beams (int): Number of beams for beam search. 1 means no beam search.
|
201 |
+
max_length (int): The maximum length of the sequence to be generated.
|
202 |
+
min_length (int): The minimum length of the sequence to be generated.
|
203 |
+
top_p (float): The cumulative probability for nucleus sampling.
|
204 |
+
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
|
205 |
+
num_captions (int): Number of captions to be generated for each image.
|
206 |
+
Returns:
|
207 |
+
captions (list): A list of strings of length batch_size * num_captions.
|
208 |
+
"""
|
209 |
+
graphs = samples['graphs']
|
210 |
+
prompt_tokens = samples['prompt_tokens']
|
211 |
+
graph_embeds, graph_masks = self.graph_encoder(graphs)
|
212 |
+
graph_embeds = self.ln_graph(graph_embeds)
|
213 |
+
|
214 |
+
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
|
215 |
+
query_output = self.Qformer.bert(
|
216 |
+
query_embeds=query_tokens,
|
217 |
+
encoder_hidden_states=graph_embeds,
|
218 |
+
encoder_attention_mask=graph_masks,
|
219 |
+
return_dict=True,
|
220 |
+
)
|
221 |
+
mol_tokens = self.opt_proj(query_output.last_hidden_state)
|
222 |
+
with self.maybe_autocast(torch.float32):
|
223 |
+
prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
|
224 |
+
prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32)
|
225 |
+
# prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
|
226 |
+
# prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1)
|
227 |
+
|
228 |
+
outputs = self.opt_model.generate(
|
229 |
+
inputs_embeds=prompt_embeds,
|
230 |
+
attention_mask=prompt_tokens.attention_mask,
|
231 |
+
do_sample=do_sample,
|
232 |
+
top_p=top_p,
|
233 |
+
temperature=temperature,
|
234 |
+
num_beams=num_beams,
|
235 |
+
max_length=max_length,
|
236 |
+
min_length=min_length,
|
237 |
+
# pad_token_id=self.pad_token_id,
|
238 |
+
eos_token_id=self.eos_token_id,
|
239 |
+
repetition_penalty=repetition_penalty,
|
240 |
+
length_penalty=length_penalty,
|
241 |
+
num_return_sequences=num_captions,
|
242 |
+
# use_cache=False,
|
243 |
+
)
|
244 |
+
output_text = self.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
245 |
+
|
246 |
+
output_text = [text.strip() for text in output_text]
|
247 |
+
return output_text
|
248 |
+
|
249 |
+
@torch.no_grad()
|
250 |
+
def generate_action(
|
251 |
+
self,
|
252 |
+
samples,
|
253 |
+
do_sample=False,
|
254 |
+
num_beams=5,
|
255 |
+
max_length=128,
|
256 |
+
min_length=1,
|
257 |
+
top_p=0.9,
|
258 |
+
repetition_penalty=1.0,
|
259 |
+
length_penalty=1.0,
|
260 |
+
num_captions=1,
|
261 |
+
temperature=1,
|
262 |
+
use_graph=True
|
263 |
+
):
|
264 |
+
graphs = samples['graphs']
|
265 |
+
prompt_tokens = samples['prompt_tokens']
|
266 |
+
if use_graph:
|
267 |
+
graph_embeds, graph_masks = self.graph_encoder(graphs)
|
268 |
+
graph_embeds = self.ln_graph(graph_embeds)
|
269 |
+
|
270 |
+
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
|
271 |
+
query_output = self.Qformer.bert(
|
272 |
+
query_embeds=query_tokens,
|
273 |
+
encoder_hidden_states=graph_embeds,
|
274 |
+
encoder_attention_mask=graph_masks,
|
275 |
+
return_dict=True,
|
276 |
+
)
|
277 |
+
mol_tokens = self.opt_proj(query_output.last_hidden_state)
|
278 |
+
|
279 |
+
with self.maybe_autocast(torch.float32):
|
280 |
+
prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
|
281 |
+
if use_graph:
|
282 |
+
prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32)
|
283 |
+
# prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
|
284 |
+
# prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1)
|
285 |
+
|
286 |
+
outputs = self.opt_model.generate(
|
287 |
+
inputs_embeds=prompt_embeds,
|
288 |
+
attention_mask=prompt_tokens.attention_mask,
|
289 |
+
do_sample=do_sample,
|
290 |
+
top_p=top_p,
|
291 |
+
temperature=temperature,
|
292 |
+
num_beams=num_beams,
|
293 |
+
max_length=max_length,
|
294 |
+
min_length=min_length,
|
295 |
+
# pad_token_id=self.pad_token_id,
|
296 |
+
eos_token_id=self.eos_token_id,
|
297 |
+
repetition_penalty=repetition_penalty,
|
298 |
+
length_penalty=length_penalty,
|
299 |
+
num_return_sequences=num_captions,
|
300 |
+
# use_cache=False,
|
301 |
+
)
|
302 |
+
output_text = self.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
303 |
+
|
304 |
+
output_text = [text.strip() for text in output_text]
|
305 |
+
return output_text
|
model/blip2qformer.py
ADDED
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.cuda.amp import autocast as autocast
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
# from lavis.common.registry import registry
|
16 |
+
# from lavis.models.base_model import all_gather_with_grad, concat_all_gather
|
17 |
+
from lavis.models.blip2_models.blip2 import (
|
18 |
+
disabled_train,
|
19 |
+
)
|
20 |
+
from lavis.models.blip_models.blip_outputs import BlipOutput
|
21 |
+
from lavis.common.dist_utils import is_dist_avail_and_initialized
|
22 |
+
from model.blip2 import Blip2Base
|
23 |
+
from pytorch_lightning.utilities import distributed
|
24 |
+
|
25 |
+
@torch.no_grad()
|
26 |
+
def concat_all_gather(tensor):
|
27 |
+
"""
|
28 |
+
Performs all_gather operation on the provided tensors.
|
29 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
30 |
+
"""
|
31 |
+
# if use distributed training
|
32 |
+
if not is_dist_avail_and_initialized():
|
33 |
+
return tensor
|
34 |
+
|
35 |
+
tensors_gather = [
|
36 |
+
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
37 |
+
]
|
38 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
39 |
+
|
40 |
+
output = torch.cat(tensors_gather, dim=0)
|
41 |
+
print('running here')
|
42 |
+
return output
|
43 |
+
|
44 |
+
@torch.no_grad()
|
45 |
+
def pl_concat_all_gather(tensor):
|
46 |
+
"""
|
47 |
+
Performs all_gather operation on the provided tensors.
|
48 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
49 |
+
"""
|
50 |
+
# if use distributed training
|
51 |
+
if not is_dist_avail_and_initialized():
|
52 |
+
return tensor
|
53 |
+
|
54 |
+
tensors_gather = distributed.gather_all_tensors(tensor)
|
55 |
+
output = torch.cat(tensors_gather, dim=0)
|
56 |
+
return output
|
57 |
+
|
58 |
+
|
59 |
+
# @registry.register_model("blip2")
|
60 |
+
# @registry.register_model("blip2_feature_extractor")
|
61 |
+
class Blip2Qformer(Blip2Base):
|
62 |
+
"""
|
63 |
+
BLIP2 first-stage model with Q-former and ViT.
|
64 |
+
Supported model types:
|
65 |
+
- pretrained: pretrained model with vit-g
|
66 |
+
- pretrain_vitL: pretrained model with vit-large
|
67 |
+
- coco: fintuned model on coco
|
68 |
+
Usage:
|
69 |
+
>>> from lavis.models import load_model
|
70 |
+
>>> model = load_model("blip2", "pretrain")
|
71 |
+
"""
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
gtm,
|
75 |
+
lm,
|
76 |
+
bert_name,
|
77 |
+
temperature,
|
78 |
+
gin_num_layers,
|
79 |
+
gin_hidden_dim,
|
80 |
+
gin_drop_ratio,
|
81 |
+
tune_gnn=False,
|
82 |
+
num_query_token=32,
|
83 |
+
cross_attention_freq=2,
|
84 |
+
embed_dim=256,
|
85 |
+
):
|
86 |
+
super().__init__()
|
87 |
+
self.gtm = gtm
|
88 |
+
self.lm = lm
|
89 |
+
|
90 |
+
self.tokenizer = self.init_tokenizer()
|
91 |
+
|
92 |
+
self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio)
|
93 |
+
self.tune_gnn = tune_gnn
|
94 |
+
if not tune_gnn:
|
95 |
+
for name, param in self.graph_encoder.named_parameters():
|
96 |
+
param.requires_grad = False
|
97 |
+
self.graph_encoder = self.graph_encoder.eval()
|
98 |
+
self.graph_encoder.train = disabled_train
|
99 |
+
logging.info("freeze graph encoder")
|
100 |
+
|
101 |
+
self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq)
|
102 |
+
self.Qformer.resize_token_embeddings(len(self.tokenizer))
|
103 |
+
state_dict = self.Qformer.state_dict()
|
104 |
+
for name, param in self.Qformer.named_parameters():
|
105 |
+
if "_query" in name:
|
106 |
+
key_orig = name.replace("_query", "")
|
107 |
+
param.data.copy_(state_dict[key_orig])
|
108 |
+
|
109 |
+
self.graph_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
|
110 |
+
self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
|
111 |
+
|
112 |
+
self.gtm_head = nn.Linear(self.Qformer.config.hidden_size, 2)
|
113 |
+
|
114 |
+
self.temperature = temperature
|
115 |
+
|
116 |
+
|
117 |
+
def contrast(self, features_graph, features_text, return_sim=False):
|
118 |
+
'''
|
119 |
+
features_graph: shape = [B, num_qs, D]
|
120 |
+
features_text: shape = [B, D]
|
121 |
+
'''
|
122 |
+
batch_size = features_graph.size(0)
|
123 |
+
|
124 |
+
# normalized features
|
125 |
+
features_graph = F.normalize(features_graph, dim=-1)
|
126 |
+
features_text = F.normalize(features_text, dim=-1)
|
127 |
+
|
128 |
+
# cosine similarity as logits
|
129 |
+
sim_q2t = (features_graph.unsqueeze(1) @ features_text.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B, D, 1]; output shape = [B, B, num_qs]
|
130 |
+
sim_g2t, _ = sim_q2t.max(-1) # shape = [B, B]
|
131 |
+
|
132 |
+
logits_per_graph = sim_g2t / self.temperature
|
133 |
+
logits_per_text = logits_per_graph.t()
|
134 |
+
|
135 |
+
labels = torch.arange(batch_size, dtype=torch.long, device=self.device) # 大小为B
|
136 |
+
loss_graph = F.cross_entropy(logits_per_graph, labels)
|
137 |
+
loss_text = F.cross_entropy(logits_per_text, labels)
|
138 |
+
loss = (loss_graph + loss_text) / 2
|
139 |
+
|
140 |
+
if return_sim:
|
141 |
+
return logits_per_graph, logits_per_text, loss
|
142 |
+
else:
|
143 |
+
return loss
|
144 |
+
|
145 |
+
def contrast_global(self, features_graph, features_text, features_graph_all, features_text_all, return_sim=False):
|
146 |
+
'''
|
147 |
+
features_graph: shape = [B, num_qs, D]
|
148 |
+
features_text: shape = [B, D]
|
149 |
+
features_text_all: shape = [B * num_gpus, D]
|
150 |
+
features_graph_all: shape = [B * num_gpus, num_qs, D]
|
151 |
+
'''
|
152 |
+
bs = features_graph.size(0)
|
153 |
+
|
154 |
+
# cosine similarity as logits
|
155 |
+
sim_q2t = (features_graph.unsqueeze(1) @ features_text_all.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B * num_gpus, D, 1]; output shape = [B, B * num_gpus, num_qs]
|
156 |
+
sim_g2t, _ = sim_q2t.max(-1) # shape = [B, B * num_gpus]
|
157 |
+
|
158 |
+
logits_per_graph = sim_g2t / self.temperature
|
159 |
+
|
160 |
+
|
161 |
+
sim_t2q = (features_text.unsqueeze(1).unsqueeze(1) @ features_graph_all.permute(0, 2, 1)).squeeze() # shape = [B, 1, 1, D]; [B*num_gpus, D, num_qs]; output shape = [B, B*num_gpus, 1, num_qs]
|
162 |
+
sim_t2g, _ = sim_t2q.max(-1)
|
163 |
+
logits_per_text = sim_t2g / self.temperature
|
164 |
+
|
165 |
+
# labels = torch.arange(bs, dtype=torch.long, device=self.device)
|
166 |
+
rank = dist.get_rank()
|
167 |
+
labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(self.device)
|
168 |
+
|
169 |
+
loss_graph = F.cross_entropy(logits_per_graph, labels)
|
170 |
+
loss_text = F.cross_entropy(logits_per_text, labels)
|
171 |
+
loss = (loss_graph + loss_text) / 2
|
172 |
+
|
173 |
+
if return_sim:
|
174 |
+
return logits_per_graph[:, rank*bs:rank*bs+bs], logits_per_text[:, rank*bs:rank*bs+bs], loss
|
175 |
+
else:
|
176 |
+
return loss
|
177 |
+
|
178 |
+
def forward_old(self, batch):
|
179 |
+
## v1: not gather results from all gpus
|
180 |
+
###============== Image-text Contrastive ===================###
|
181 |
+
graph, text, mask = batch
|
182 |
+
batch_node, batch_mask = self.graph_encoder(graph)
|
183 |
+
batch_node = batch_node.detach()
|
184 |
+
batch_size = batch_node.shape[0]
|
185 |
+
|
186 |
+
batch_node = self.ln_graph(batch_node, batch_mask)
|
187 |
+
query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1)
|
188 |
+
query_output = self.Qformer.bert(
|
189 |
+
query_embeds=query_tokens,
|
190 |
+
encoder_hidden_states=batch_node,
|
191 |
+
encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct
|
192 |
+
use_cache=True,
|
193 |
+
return_dict=True,
|
194 |
+
)
|
195 |
+
graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D]
|
196 |
+
text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D]
|
197 |
+
text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :])
|
198 |
+
sim_g2t, sim_t2g, loss_gtc = self.contrast(graph_feats, text_feats, return_sim=True)
|
199 |
+
|
200 |
+
|
201 |
+
###============== Image-text Matching ===================###
|
202 |
+
loss_gtm = 0
|
203 |
+
if self.gtm:
|
204 |
+
g_emb = batch_node
|
205 |
+
g_mask = batch_mask
|
206 |
+
text_ids = text.clone()
|
207 |
+
with torch.no_grad():
|
208 |
+
weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4
|
209 |
+
weights_t2g.fill_diagonal_(0)
|
210 |
+
weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4
|
211 |
+
weights_g2t.fill_diagonal_(0)
|
212 |
+
|
213 |
+
# select a negative graph for each text
|
214 |
+
graph_embeds_neg = []
|
215 |
+
graph_mask_neg = []
|
216 |
+
for b in range(batch_size):
|
217 |
+
neg_idx = torch.multinomial(weights_t2g[b], 1).item()
|
218 |
+
graph_embeds_neg.append(g_emb[neg_idx])
|
219 |
+
graph_mask_neg.append(g_mask[neg_idx])
|
220 |
+
|
221 |
+
graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0)
|
222 |
+
graph_mask_neg = torch.stack(graph_mask_neg, dim=0)
|
223 |
+
|
224 |
+
# select a negative text for each image
|
225 |
+
text_ids_neg = []
|
226 |
+
text_atts_neg = []
|
227 |
+
for b in range(batch_size):
|
228 |
+
neg_idx = torch.multinomial(weights_g2t[b], 1).item()
|
229 |
+
text_ids_neg.append(text_ids[neg_idx])
|
230 |
+
text_atts_neg.append(mask[neg_idx])
|
231 |
+
|
232 |
+
text_ids_neg = torch.stack(text_ids_neg, dim=0)
|
233 |
+
text_atts_neg = torch.stack(text_atts_neg, dim=0)
|
234 |
+
|
235 |
+
text_ids_all = torch.cat(
|
236 |
+
[text_ids, text_ids, text_ids_neg], dim=0
|
237 |
+
) # pos, pos, neg
|
238 |
+
text_atts_all = torch.cat(
|
239 |
+
[mask, mask, text_atts_neg],
|
240 |
+
dim=0,
|
241 |
+
)
|
242 |
+
|
243 |
+
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
|
244 |
+
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text.device)
|
245 |
+
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
|
246 |
+
|
247 |
+
graph_embeds_all = torch.cat([g_emb, graph_embeds_neg, g_emb], dim=0) # pos, neg, pos
|
248 |
+
graph_atts_all = torch.cat([g_mask, graph_mask_neg, g_mask], dim=0)
|
249 |
+
|
250 |
+
output_itm = self.Qformer.bert(
|
251 |
+
text_ids_all,
|
252 |
+
query_embeds=query_tokens_itm,
|
253 |
+
attention_mask=attention_mask_all,
|
254 |
+
encoder_hidden_states=graph_embeds_all,
|
255 |
+
encoder_attention_mask=graph_atts_all,
|
256 |
+
return_dict=True,
|
257 |
+
)
|
258 |
+
|
259 |
+
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only
|
260 |
+
vl_output = self.gtm_head(vl_embeddings)
|
261 |
+
logits = vl_output.mean(dim=1)
|
262 |
+
|
263 |
+
itm_labels = torch.cat(
|
264 |
+
[torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)],
|
265 |
+
dim=0,
|
266 |
+
).to(text.device)
|
267 |
+
loss_gtm = F.cross_entropy(logits, itm_labels)
|
268 |
+
|
269 |
+
##================= Image Captioning ========================##
|
270 |
+
loss_lm = 0
|
271 |
+
if self.lm:
|
272 |
+
decoder_input_ids = text.clone()
|
273 |
+
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
|
274 |
+
labels = decoder_input_ids.masked_fill(
|
275 |
+
decoder_input_ids == self.tokenizer.pad_token_id, -100
|
276 |
+
)
|
277 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text.device)
|
278 |
+
|
279 |
+
attention_mask = torch.cat([query_atts, mask], dim=1)
|
280 |
+
lm_output = self.Qformer(
|
281 |
+
decoder_input_ids,
|
282 |
+
attention_mask=attention_mask,
|
283 |
+
past_key_values=query_output.past_key_values,
|
284 |
+
return_dict=True,
|
285 |
+
labels=labels,
|
286 |
+
)
|
287 |
+
|
288 |
+
loss_lm = lm_output.loss
|
289 |
+
|
290 |
+
return BlipOutput(
|
291 |
+
loss=loss_gtc + loss_gtm + loss_lm,
|
292 |
+
loss_itc=loss_gtc,
|
293 |
+
loss_itm=loss_gtm,
|
294 |
+
loss_lm=loss_lm,
|
295 |
+
)
|
296 |
+
|
297 |
+
|
298 |
+
def forward(self, batch):
|
299 |
+
## v2: gather results from all gpus
|
300 |
+
###============== Image-text Contrastive ===================###
|
301 |
+
graph, text, mask = batch
|
302 |
+
batch_node, batch_mask = self.graph_encoder(graph)
|
303 |
+
if not self.tune_gnn:
|
304 |
+
batch_node = batch_node.detach()
|
305 |
+
batch_size = batch_node.shape[0]
|
306 |
+
|
307 |
+
batch_node = self.ln_graph(batch_node, batch_mask)
|
308 |
+
query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1)
|
309 |
+
query_output = self.Qformer.bert(
|
310 |
+
query_embeds=query_tokens,
|
311 |
+
encoder_hidden_states=batch_node,
|
312 |
+
encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct
|
313 |
+
use_cache=True,
|
314 |
+
return_dict=True,
|
315 |
+
)
|
316 |
+
graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D]
|
317 |
+
text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D]
|
318 |
+
text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :])
|
319 |
+
|
320 |
+
text_feats, graph_feats = F.normalize(text_feats, p=2, dim=-1), F.normalize(graph_feats, p=2, dim=-1)
|
321 |
+
text_feats_all, graph_feats_all = pl_concat_all_gather(text_feats), pl_concat_all_gather(graph_feats) # shape = [B * num_gpus, D]
|
322 |
+
sim_g2t, sim_t2g, loss_gtc = self.contrast_global(graph_feats, text_feats, graph_feats_all, text_feats_all, return_sim=True)
|
323 |
+
|
324 |
+
|
325 |
+
###============== Image-text Matching ===================###
|
326 |
+
loss_gtm = 0
|
327 |
+
if self.gtm:
|
328 |
+
## not aggregate global tensor because of their different shapes
|
329 |
+
g_emb_world = batch_node
|
330 |
+
g_mask_world = batch_mask
|
331 |
+
text_ids_world = text
|
332 |
+
text_mask_world = mask
|
333 |
+
with torch.no_grad():
|
334 |
+
weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4
|
335 |
+
weights_t2g.fill_diagonal_(0)
|
336 |
+
weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4
|
337 |
+
weights_g2t.fill_diagonal_(0)
|
338 |
+
|
339 |
+
# select a negative graph for each text
|
340 |
+
graph_embeds_neg = []
|
341 |
+
graph_mask_neg = []
|
342 |
+
for b in range(batch_size):
|
343 |
+
neg_idx = torch.multinomial(weights_t2g[b], 1).item()
|
344 |
+
graph_embeds_neg.append(g_emb_world[neg_idx])
|
345 |
+
graph_mask_neg.append(g_mask_world[neg_idx])
|
346 |
+
|
347 |
+
graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0)
|
348 |
+
graph_mask_neg = torch.stack(graph_mask_neg, dim=0)
|
349 |
+
|
350 |
+
# select a negative text for each image
|
351 |
+
text_ids_neg = []
|
352 |
+
text_atts_neg = []
|
353 |
+
for b in range(batch_size):
|
354 |
+
neg_idx = torch.multinomial(weights_g2t[b], 1).item()
|
355 |
+
text_ids_neg.append(text_ids_world[neg_idx])
|
356 |
+
text_atts_neg.append(text_mask_world[neg_idx])
|
357 |
+
|
358 |
+
text_ids_neg = torch.stack(text_ids_neg, dim=0)
|
359 |
+
text_atts_neg = torch.stack(text_atts_neg, dim=0)
|
360 |
+
|
361 |
+
text_ids_all = torch.cat(
|
362 |
+
[text, text, text_ids_neg], dim=0
|
363 |
+
) # pos, pos, neg
|
364 |
+
text_atts_all = torch.cat(
|
365 |
+
[mask, mask, text_atts_neg],
|
366 |
+
dim=0,
|
367 |
+
)
|
368 |
+
|
369 |
+
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
|
370 |
+
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text.device)
|
371 |
+
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
|
372 |
+
|
373 |
+
graph_embeds_all = torch.cat([batch_node, graph_embeds_neg, batch_node], dim=0) # pos, neg, pos
|
374 |
+
graph_atts_all = torch.cat([batch_mask, graph_mask_neg, batch_mask], dim=0)
|
375 |
+
|
376 |
+
output_itm = self.Qformer.bert(
|
377 |
+
text_ids_all,
|
378 |
+
query_embeds=query_tokens_itm,
|
379 |
+
attention_mask=attention_mask_all,
|
380 |
+
encoder_hidden_states=graph_embeds_all,
|
381 |
+
encoder_attention_mask=graph_atts_all,
|
382 |
+
return_dict=True,
|
383 |
+
)
|
384 |
+
|
385 |
+
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only
|
386 |
+
vl_output = self.gtm_head(vl_embeddings)
|
387 |
+
logits = vl_output.mean(dim=1)
|
388 |
+
|
389 |
+
itm_labels = torch.cat(
|
390 |
+
[torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)],
|
391 |
+
dim=0,
|
392 |
+
).to(text.device)
|
393 |
+
loss_gtm = F.cross_entropy(logits, itm_labels)
|
394 |
+
|
395 |
+
##================= Image Captioning ========================##
|
396 |
+
loss_lm = 0
|
397 |
+
if self.lm:
|
398 |
+
decoder_input_ids = text.clone()
|
399 |
+
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
|
400 |
+
labels = decoder_input_ids.masked_fill(
|
401 |
+
decoder_input_ids == self.tokenizer.pad_token_id, -100
|
402 |
+
)
|
403 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text.device)
|
404 |
+
|
405 |
+
attention_mask = torch.cat([query_atts, mask], dim=1)
|
406 |
+
lm_output = self.Qformer(
|
407 |
+
decoder_input_ids,
|
408 |
+
attention_mask=attention_mask,
|
409 |
+
past_key_values=query_output.past_key_values,
|
410 |
+
return_dict=True,
|
411 |
+
labels=labels,
|
412 |
+
)
|
413 |
+
|
414 |
+
loss_lm = lm_output.loss
|
415 |
+
|
416 |
+
return BlipOutput(
|
417 |
+
loss=loss_gtc + loss_gtm + loss_lm,
|
418 |
+
loss_itc=loss_gtc,
|
419 |
+
loss_itm=loss_gtm,
|
420 |
+
loss_lm=loss_lm,
|
421 |
+
)
|
422 |
+
|
423 |
+
def forward_v3(self, batch):
|
424 |
+
## v3: use smiles instruction
|
425 |
+
###============== Image-text Contrastive ===================###
|
426 |
+
graphs, text_tokens, prompt_tokens = batch
|
427 |
+
graph_embeds, graph_masks = self.graph_encoder(graphs)
|
428 |
+
if not self.tune_gnn:
|
429 |
+
graph_embeds = graph_embeds.detach()
|
430 |
+
graph_embeds = self.ln_graph(graph_embeds, graph_masks)
|
431 |
+
|
432 |
+
device = text_tokens.input_ids.device
|
433 |
+
batch_size = graph_embeds.shape[0]
|
434 |
+
|
435 |
+
##
|
436 |
+
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
|
437 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=device)
|
438 |
+
attention_mask_gtc = torch.cat([query_atts, prompt_tokens.attention_mask], dim=1)
|
439 |
+
query_output = self.Qformer.bert(
|
440 |
+
input_ids=prompt_tokens,
|
441 |
+
query_embeds=query_tokens,
|
442 |
+
attention_mask=attention_mask_gtc,
|
443 |
+
encoder_hidden_states=graph_embeds,
|
444 |
+
encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
|
445 |
+
use_cache=True,
|
446 |
+
return_dict=True,
|
447 |
+
)
|
448 |
+
|
449 |
+
query_output = query_output.last_hidden_state[:, : query_tokens.size(1), :] # keep query tokens only
|
450 |
+
graph_feats = self.graph_proj(query_output) # shape = [B, num_q, D]
|
451 |
+
text_output = self.Qformer.bert(text_tokens.input_ids, attention_mask=text_tokens.attention_mask, return_dict=True) # shape = [B, n_max, D]
|
452 |
+
text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :])
|
453 |
+
|
454 |
+
text_feats, graph_feats = F.normalize(text_feats, p=2, dim=-1), F.normalize(graph_feats, p=2, dim=-1)
|
455 |
+
text_feats_all, graph_feats_all = pl_concat_all_gather(text_feats), pl_concat_all_gather(graph_feats) # shape = [B * num_gpus, D]
|
456 |
+
sim_g2t, sim_t2g, loss_gtc = self.contrast_global(graph_feats, text_feats, graph_feats_all, text_feats_all, return_sim=True)
|
457 |
+
|
458 |
+
|
459 |
+
###============== Image-text Matching ===================###
|
460 |
+
loss_gtm = 0
|
461 |
+
if self.gtm:
|
462 |
+
## not aggregate global tensor because of their different shapes
|
463 |
+
g_emb_world = graph_embeds
|
464 |
+
g_mask_world = graph_masks
|
465 |
+
text_ids_world = text_tokens.input_ids
|
466 |
+
text_mask_world = text_tokens.attention_mask
|
467 |
+
with torch.no_grad():
|
468 |
+
weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4
|
469 |
+
weights_t2g.fill_diagonal_(0)
|
470 |
+
weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4
|
471 |
+
weights_g2t.fill_diagonal_(0)
|
472 |
+
|
473 |
+
# select a negative graph for each text
|
474 |
+
graph_embeds_neg = []
|
475 |
+
graph_mask_neg = []
|
476 |
+
for b in range(batch_size):
|
477 |
+
neg_idx = torch.multinomial(weights_t2g[b], 1).item()
|
478 |
+
graph_embeds_neg.append(g_emb_world[neg_idx])
|
479 |
+
graph_mask_neg.append(g_mask_world[neg_idx])
|
480 |
+
|
481 |
+
graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0)
|
482 |
+
graph_mask_neg = torch.stack(graph_mask_neg, dim=0)
|
483 |
+
|
484 |
+
# select a negative text for each image
|
485 |
+
text_ids_neg = []
|
486 |
+
text_atts_neg = []
|
487 |
+
for b in range(batch_size):
|
488 |
+
neg_idx = torch.multinomial(weights_g2t[b], 1).item()
|
489 |
+
text_ids_neg.append(text_ids_world[neg_idx])
|
490 |
+
text_atts_neg.append(text_mask_world[neg_idx])
|
491 |
+
|
492 |
+
text_ids_neg = torch.stack(text_ids_neg, dim=0)
|
493 |
+
text_atts_neg = torch.stack(text_atts_neg, dim=0)
|
494 |
+
|
495 |
+
text_ids_all = torch.cat(
|
496 |
+
[text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
|
497 |
+
) # pos, pos, neg
|
498 |
+
text_atts_all = torch.cat(
|
499 |
+
[text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
|
500 |
+
dim=0,
|
501 |
+
)
|
502 |
+
|
503 |
+
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
|
504 |
+
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text_tokens.input_ids.device)
|
505 |
+
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
|
506 |
+
|
507 |
+
graph_embeds_all = torch.cat([graph_embeds, graph_embeds_neg, graph_embeds], dim=0) # pos, neg, pos
|
508 |
+
graph_atts_all = torch.cat([graph_masks, graph_mask_neg, graph_masks], dim=0)
|
509 |
+
|
510 |
+
output_itm = self.Qformer.bert(
|
511 |
+
text_ids_all,
|
512 |
+
query_embeds=query_tokens_itm,
|
513 |
+
attention_mask=attention_mask_all,
|
514 |
+
encoder_hidden_states=graph_embeds_all,
|
515 |
+
encoder_attention_mask=graph_atts_all,
|
516 |
+
return_dict=True,
|
517 |
+
)
|
518 |
+
|
519 |
+
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only
|
520 |
+
vl_output = self.gtm_head(vl_embeddings)
|
521 |
+
logits = vl_output.mean(dim=1)
|
522 |
+
|
523 |
+
itm_labels = torch.cat(
|
524 |
+
[torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)],
|
525 |
+
dim=0,
|
526 |
+
).to(text_tokens.input_ids.device)
|
527 |
+
loss_gtm = F.cross_entropy(logits, itm_labels)
|
528 |
+
|
529 |
+
##================= Image Captioning ========================##
|
530 |
+
loss_lm = 0
|
531 |
+
if self.lm:
|
532 |
+
decoder_input_ids = text_tokens.input_ids.clone()
|
533 |
+
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
|
534 |
+
labels = decoder_input_ids.masked_fill(
|
535 |
+
decoder_input_ids == self.tokenizer.pad_token_id, -100
|
536 |
+
)
|
537 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text_tokens.input_ids.device)
|
538 |
+
|
539 |
+
attention_mask = torch.cat([query_atts, prompt_tokens.attention_mask, text_tokens.attention_mask], dim=1)
|
540 |
+
lm_output = self.Qformer(
|
541 |
+
decoder_input_ids,
|
542 |
+
attention_mask=attention_mask,
|
543 |
+
past_key_values=query_output.past_key_values,
|
544 |
+
return_dict=True,
|
545 |
+
labels=labels,
|
546 |
+
)
|
547 |
+
|
548 |
+
loss_lm = lm_output.loss
|
549 |
+
|
550 |
+
return BlipOutput(
|
551 |
+
loss=loss_gtc + loss_gtm + loss_lm,
|
552 |
+
loss_itc=loss_gtc,
|
553 |
+
loss_itm=loss_gtm,
|
554 |
+
loss_lm=loss_lm,
|
555 |
+
)
|
556 |
+
|
557 |
+
def graph_forward(self, graph):
|
558 |
+
batch_node, batch_mask = self.graph_encoder(graph)
|
559 |
+
batch_node = self.ln_graph(batch_node, batch_mask)
|
560 |
+
query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1)
|
561 |
+
query_output = self.Qformer.bert(
|
562 |
+
query_embeds=query_tokens,
|
563 |
+
encoder_hidden_states=batch_node,
|
564 |
+
encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct
|
565 |
+
use_cache=False,
|
566 |
+
return_dict=True,
|
567 |
+
)
|
568 |
+
graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D]
|
569 |
+
graph_feats = F.normalize(graph_feats, p=2, dim=-1)
|
570 |
+
return graph_feats, batch_node, batch_mask
|
571 |
+
|
572 |
+
def text_forward(self, text, mask):
|
573 |
+
text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D]
|
574 |
+
text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :] )
|
575 |
+
text_feats = F.normalize(text_feats, dim=-1, p=2)
|
576 |
+
return text_feats
|
577 |
+
|
578 |
+
def compute_gtm(self, batch_node, batch_mask, text_ids, text_atts):
|
579 |
+
'''
|
580 |
+
batch_node shape = [B, N, D]
|
581 |
+
batch_mask shape = [B, N]
|
582 |
+
text_ids shape = [B, N]
|
583 |
+
text_atts shape = [B, N]
|
584 |
+
'''
|
585 |
+
query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1) # shape = [B, Nq, D]
|
586 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
|
587 |
+
batch_node.device
|
588 |
+
) # shape = [B, Nq]
|
589 |
+
attention_mask = torch.cat([query_atts, text_atts], dim=1) # shape = [B, Nq + N]
|
590 |
+
output_gtm = self.Qformer.bert(
|
591 |
+
text_ids,
|
592 |
+
query_embeds=query_tokens,
|
593 |
+
attention_mask=attention_mask,
|
594 |
+
encoder_hidden_states=batch_node,
|
595 |
+
encoder_attention_mask=batch_mask,
|
596 |
+
return_dict=True,
|
597 |
+
)
|
598 |
+
gl_embeddings = output_gtm.last_hidden_state[:, : query_tokens.size(1), :] # shape = [B, Nq, D]
|
599 |
+
gtm_logit = self.gtm_head(gl_embeddings).mean(dim=1) # shape = [B, Nq, 2]
|
600 |
+
# gtm_logit = F.softmax(gtm_logit, dim=-1)[:, 1] # select the axis of the positive class
|
601 |
+
gtm_logit = gtm_logit[:, 1] # select the axis of the positive class
|
602 |
+
return gtm_logit
|
603 |
+
|
model/dist_funs.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union, Dict
|
3 |
+
from pytorch_lightning import strategies
|
4 |
+
from lightning_fabric.utilities.types import _PATH
|
5 |
+
from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict
|
6 |
+
|
7 |
+
|
8 |
+
'''
|
9 |
+
overwrite the function in deepspeed
|
10 |
+
'''
|
11 |
+
|
12 |
+
### start overwrite ###
|
13 |
+
def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False):
|
14 |
+
sd = self.module.state_dict(destination, prefix, keep_vars)
|
15 |
+
# Remove frozen parameter weights from state_dict if specified
|
16 |
+
if exclude_frozen_parameters:
|
17 |
+
to_be_removed = []
|
18 |
+
for n in sd:
|
19 |
+
try:
|
20 |
+
if not self.module.get_parameter(n).requires_grad:
|
21 |
+
to_be_removed.append(n)
|
22 |
+
except AttributeError:
|
23 |
+
to_be_removed.append(n)
|
24 |
+
for key in to_be_removed:
|
25 |
+
sd.pop(key)
|
26 |
+
if self.random_ltd_enabled():
|
27 |
+
sd = remove_random_ltd_state_dict(sd)
|
28 |
+
return sd
|
29 |
+
from deepspeed import DeepSpeedEngine
|
30 |
+
DeepSpeedEngine.module_state_dict = module_state_dict
|
31 |
+
### end overwrite ###
|
32 |
+
|
33 |
+
class MyDeepSpeedStrategy(strategies.DeepSpeedStrategy):
|
34 |
+
def save_checkpoint_v1(
|
35 |
+
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
|
36 |
+
):
|
37 |
+
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
checkpoint: dict containing model and trainer state
|
41 |
+
filepath: write-target file's path
|
42 |
+
storage_options: parameter for how to save to st
|
43 |
+
orage, passed to ``CheckpointIO`` plugin
|
44 |
+
"""
|
45 |
+
if self.is_global_zero:
|
46 |
+
self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
|
47 |
+
|
48 |
+
def load_model_state_dict(self, checkpoint):
|
49 |
+
assert self.lightning_module is not None
|
50 |
+
self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=False)
|
51 |
+
def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None:
|
52 |
+
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
checkpoint: The checkpoint state dictionary
|
56 |
+
filepath: write-target file's path
|
57 |
+
storage_options: not used for ``DeepSpeedStrategy`` as ``CheckpointIO`` is not used
|
58 |
+
|
59 |
+
Raises:
|
60 |
+
TypeError:
|
61 |
+
If ``storage_options`` arg is passed in
|
62 |
+
"""
|
63 |
+
# broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath
|
64 |
+
filepath = self.broadcast(filepath)
|
65 |
+
if storage_options is not None:
|
66 |
+
raise TypeError(
|
67 |
+
"`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
|
68 |
+
f" is not supported for `{self.__class__.__name__}` as `CheckpointIO` is not used."
|
69 |
+
)
|
70 |
+
|
71 |
+
if self.zero_stage_3 and self._multi_device and self.is_global_zero:
|
72 |
+
print(
|
73 |
+
"Warning: When saving the DeepSpeed Stage 3 checkpoint, "
|
74 |
+
"each worker will save a shard of the checkpoint within a directory. "
|
75 |
+
"If a single file is required after training, "
|
76 |
+
"see https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#"
|
77 |
+
"deepspeed-zero-stage-3-single-file for instructions."
|
78 |
+
)
|
79 |
+
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
|
80 |
+
# dump states as a checkpoint dictionary object
|
81 |
+
_exclude_keys = ["state_dict", "optimizer_states"]
|
82 |
+
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
|
83 |
+
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint", exclude_frozen_parameters=True)
|
model/gin_model.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch_geometric.nn import MessagePassing
|
3 |
+
from torch_geometric.utils import add_self_loops, degree, softmax, to_dense_batch
|
4 |
+
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
|
5 |
+
import torch.nn.functional as F
|
6 |
+
# from torch_scatter import scatter_add
|
7 |
+
from torch_geometric.nn.inits import glorot, zeros
|
8 |
+
|
9 |
+
num_atom_type = 120 #including the extra mask tokens
|
10 |
+
num_chirality_tag = 3
|
11 |
+
|
12 |
+
num_bond_type = 6 #including aromatic and self-loop edge, and extra masked tokens
|
13 |
+
num_bond_direction = 3
|
14 |
+
|
15 |
+
class GINConv(MessagePassing):
|
16 |
+
"""
|
17 |
+
Extension of GIN aggregation to incorporate edge information by concatenation.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
emb_dim (int): dimensionality of embeddings for nodes and edges.
|
21 |
+
embed_input (bool): whether to embed input or not.
|
22 |
+
|
23 |
+
|
24 |
+
See https://arxiv.org/abs/1810.00826
|
25 |
+
"""
|
26 |
+
def __init__(self, emb_dim, aggr = "add"):
|
27 |
+
super(GINConv, self).__init__(aggr = "add")
|
28 |
+
#multi-layer perceptron
|
29 |
+
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
|
30 |
+
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
|
31 |
+
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
|
32 |
+
|
33 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
|
34 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
|
35 |
+
self.aggr = aggr
|
36 |
+
|
37 |
+
def forward(self, x, edge_index, edge_attr):
|
38 |
+
#add self loops in the edge space
|
39 |
+
# print('--------------------')
|
40 |
+
# print('x:', x.shape)
|
41 |
+
# print('edge_index:',edge_index.shape)
|
42 |
+
edge_index, edge_attr = add_self_loops(edge_index, edge_attr, fill_value=0, num_nodes = x.size(0))
|
43 |
+
|
44 |
+
|
45 |
+
#add features corresponding to self-loop edges.
|
46 |
+
# self_loop_attr = torch.zeros(x.size(0), 2)
|
47 |
+
# self_loop_attr[:,0] = 4 #bond type for self-loop edge
|
48 |
+
# self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
|
49 |
+
# print('edge_attr:',edge_attr.shape)
|
50 |
+
# print('self_loop_attr:',self_loop_attr.shape)
|
51 |
+
# print('--------------------')
|
52 |
+
# edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
|
53 |
+
|
54 |
+
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
|
55 |
+
|
56 |
+
return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
|
57 |
+
|
58 |
+
def message(self, x_j, edge_attr):
|
59 |
+
return x_j + edge_attr
|
60 |
+
|
61 |
+
def update(self, aggr_out):
|
62 |
+
return self.mlp(aggr_out)
|
63 |
+
|
64 |
+
|
65 |
+
class GCNConv(MessagePassing):
|
66 |
+
|
67 |
+
def __init__(self, emb_dim, aggr = "add"):
|
68 |
+
super(GCNConv, self).__init__()
|
69 |
+
|
70 |
+
self.emb_dim = emb_dim
|
71 |
+
self.linear = torch.nn.Linear(emb_dim, emb_dim)
|
72 |
+
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
|
73 |
+
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
|
74 |
+
|
75 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
|
76 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
|
77 |
+
|
78 |
+
self.aggr = aggr
|
79 |
+
|
80 |
+
def norm(self, edge_index, num_nodes, dtype):
|
81 |
+
### assuming that self-loops have been already added in edge_index
|
82 |
+
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
|
83 |
+
device=edge_index.device)
|
84 |
+
row, col = edge_index
|
85 |
+
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
|
86 |
+
deg_inv_sqrt = deg.pow(-0.5)
|
87 |
+
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
|
88 |
+
|
89 |
+
return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
|
90 |
+
|
91 |
+
|
92 |
+
def forward(self, x, edge_index, edge_attr):
|
93 |
+
#add self loops in the edge space
|
94 |
+
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
|
95 |
+
|
96 |
+
#add features corresponding to self-loop edges.
|
97 |
+
self_loop_attr = torch.zeros(x.size(0), 2)
|
98 |
+
self_loop_attr[:,0] = 4 #bond type for self-loop edge
|
99 |
+
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
|
100 |
+
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
|
101 |
+
|
102 |
+
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
|
103 |
+
|
104 |
+
norm = self.norm(edge_index, x.size(0), x.dtype)
|
105 |
+
|
106 |
+
x = self.linear(x)
|
107 |
+
|
108 |
+
return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings, norm = norm)
|
109 |
+
|
110 |
+
def message(self, x_j, edge_attr, norm):
|
111 |
+
return norm.view(-1, 1) * (x_j + edge_attr)
|
112 |
+
|
113 |
+
|
114 |
+
class GATConv(MessagePassing):
|
115 |
+
def __init__(self, emb_dim, heads=2, negative_slope=0.2, aggr = "add"):
|
116 |
+
super(GATConv, self).__init__()
|
117 |
+
|
118 |
+
self.aggr = aggr
|
119 |
+
|
120 |
+
self.emb_dim = emb_dim
|
121 |
+
self.heads = heads
|
122 |
+
self.negative_slope = negative_slope
|
123 |
+
|
124 |
+
self.weight_linear = torch.nn.Linear(emb_dim, heads * emb_dim)
|
125 |
+
self.att = torch.nn.Parameter(torch.Tensor(1, heads, 2 * emb_dim))
|
126 |
+
|
127 |
+
self.bias = torch.nn.Parameter(torch.Tensor(emb_dim))
|
128 |
+
|
129 |
+
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, heads * emb_dim)
|
130 |
+
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, heads * emb_dim)
|
131 |
+
|
132 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
|
133 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
|
134 |
+
|
135 |
+
self.reset_parameters()
|
136 |
+
|
137 |
+
def reset_parameters(self):
|
138 |
+
glorot(self.att)
|
139 |
+
zeros(self.bias)
|
140 |
+
|
141 |
+
def forward(self, x, edge_index, edge_attr):
|
142 |
+
|
143 |
+
#add self loops in the edge space
|
144 |
+
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
|
145 |
+
|
146 |
+
#add features corresponding to self-loop edges.
|
147 |
+
self_loop_attr = torch.zeros(x.size(0), 2)
|
148 |
+
self_loop_attr[:,0] = 4 #bond type for self-loop edge
|
149 |
+
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
|
150 |
+
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
|
151 |
+
|
152 |
+
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
|
153 |
+
|
154 |
+
x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)
|
155 |
+
return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings)
|
156 |
+
|
157 |
+
def message(self, edge_index, x_i, x_j, edge_attr):
|
158 |
+
edge_attr = edge_attr.view(-1, self.heads, self.emb_dim)
|
159 |
+
x_j += edge_attr
|
160 |
+
|
161 |
+
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
|
162 |
+
|
163 |
+
alpha = F.leaky_relu(alpha, self.negative_slope)
|
164 |
+
alpha = softmax(alpha, edge_index[0])
|
165 |
+
|
166 |
+
return x_j * alpha.view(-1, self.heads, 1)
|
167 |
+
|
168 |
+
def update(self, aggr_out):
|
169 |
+
aggr_out = aggr_out.mean(dim=1)
|
170 |
+
aggr_out = aggr_out + self.bias
|
171 |
+
|
172 |
+
return aggr_out
|
173 |
+
|
174 |
+
|
175 |
+
class GraphSAGEConv(MessagePassing):
|
176 |
+
def __init__(self, emb_dim, aggr = "mean"):
|
177 |
+
super(GraphSAGEConv, self).__init__()
|
178 |
+
|
179 |
+
self.emb_dim = emb_dim
|
180 |
+
self.linear = torch.nn.Linear(emb_dim, emb_dim)
|
181 |
+
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
|
182 |
+
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
|
183 |
+
|
184 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
|
185 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
|
186 |
+
|
187 |
+
self.aggr = aggr
|
188 |
+
|
189 |
+
def forward(self, x, edge_index, edge_attr):
|
190 |
+
#add self loops in the edge space
|
191 |
+
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
|
192 |
+
|
193 |
+
#add features corresponding to self-loop edges.
|
194 |
+
self_loop_attr = torch.zeros(x.size(0), 2)
|
195 |
+
self_loop_attr[:,0] = 4 #bond type for self-loop edge
|
196 |
+
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
|
197 |
+
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
|
198 |
+
|
199 |
+
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
|
200 |
+
|
201 |
+
x = self.linear(x)
|
202 |
+
|
203 |
+
return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings)
|
204 |
+
|
205 |
+
def message(self, x_j, edge_attr):
|
206 |
+
return x_j + edge_attr
|
207 |
+
|
208 |
+
def update(self, aggr_out):
|
209 |
+
return F.normalize(aggr_out, p = 2, dim = -1)
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
class GNN(torch.nn.Module):
|
214 |
+
"""
|
215 |
+
|
216 |
+
|
217 |
+
Args:
|
218 |
+
num_layer (int): the number of GNN layers
|
219 |
+
emb_dim (int): dimensionality of embeddings
|
220 |
+
JK (str): last, concat, max or sum.
|
221 |
+
max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
|
222 |
+
drop_ratio (float): dropout rate
|
223 |
+
gnn_type: gin, gcn, graphsage, gat
|
224 |
+
|
225 |
+
Output:
|
226 |
+
node representations
|
227 |
+
|
228 |
+
"""
|
229 |
+
def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, gnn_type = "gin"):
|
230 |
+
super(GNN, self).__init__()
|
231 |
+
self.num_layer = num_layer
|
232 |
+
self.drop_ratio = drop_ratio
|
233 |
+
self.JK = JK
|
234 |
+
|
235 |
+
if self.num_layer < 2:
|
236 |
+
raise ValueError("Number of GNN layers must be greater than 1.")
|
237 |
+
|
238 |
+
self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
|
239 |
+
self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)
|
240 |
+
|
241 |
+
torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
|
242 |
+
torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)
|
243 |
+
|
244 |
+
###List of MLPs
|
245 |
+
self.gnns = torch.nn.ModuleList()
|
246 |
+
for layer in range(num_layer):
|
247 |
+
if gnn_type == "gin":
|
248 |
+
self.gnns.append(GINConv(emb_dim, aggr = "add"))
|
249 |
+
elif gnn_type == "gcn":
|
250 |
+
self.gnns.append(GCNConv(emb_dim))
|
251 |
+
elif gnn_type == "gat":
|
252 |
+
self.gnns.append(GATConv(emb_dim))
|
253 |
+
elif gnn_type == "graphsage":
|
254 |
+
self.gnns.append(GraphSAGEConv(emb_dim))
|
255 |
+
|
256 |
+
self.pool = global_mean_pool
|
257 |
+
|
258 |
+
###List of batchnorms
|
259 |
+
self.batch_norms = torch.nn.ModuleList()
|
260 |
+
for layer in range(num_layer):
|
261 |
+
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
|
262 |
+
self.num_features = emb_dim
|
263 |
+
self.cat_grep = True
|
264 |
+
|
265 |
+
#def forward(self, x, edge_index, edge_attr):
|
266 |
+
def forward(self, *argv):
|
267 |
+
if len(argv) == 3:
|
268 |
+
x, edge_index, edge_attr = argv[0], argv[1], argv[2]
|
269 |
+
elif len(argv) == 1:
|
270 |
+
data = argv[0]
|
271 |
+
x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
|
272 |
+
else:
|
273 |
+
raise ValueError("unmatched number of arguments.")
|
274 |
+
|
275 |
+
x = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])
|
276 |
+
|
277 |
+
h_list = [x]
|
278 |
+
for layer in range(self.num_layer):
|
279 |
+
h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
|
280 |
+
h = self.batch_norms[layer](h)
|
281 |
+
#h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
|
282 |
+
if layer == self.num_layer - 1:
|
283 |
+
#remove relu for the last layer
|
284 |
+
h = F.dropout(h, self.drop_ratio, training = self.training)
|
285 |
+
else:
|
286 |
+
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
|
287 |
+
h_list.append(h)
|
288 |
+
|
289 |
+
### Different implementations of Jk-concat
|
290 |
+
if self.JK == "concat":
|
291 |
+
node_representation = torch.cat(h_list, dim = 1)
|
292 |
+
elif self.JK == "last":
|
293 |
+
node_representation = h_list[-1]
|
294 |
+
elif self.JK == "max":
|
295 |
+
h_list = [h.unsqueeze_(0) for h in h_list]
|
296 |
+
node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
|
297 |
+
elif self.JK == "sum":
|
298 |
+
h_list = [h.unsqueeze_(0) for h in h_list]
|
299 |
+
node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]
|
300 |
+
|
301 |
+
|
302 |
+
h_graph = self.pool(node_representation, batch) # shape = [B, D]
|
303 |
+
batch_node, batch_mask = to_dense_batch(node_representation, batch) # shape = [B, n_max, D],
|
304 |
+
batch_mask = batch_mask.bool()
|
305 |
+
|
306 |
+
if self.cat_grep:
|
307 |
+
batch_node = torch.cat((h_graph.unsqueeze(1), batch_node), dim=1) # shape = [B, n_max+1, D]
|
308 |
+
batch_mask = torch.cat([torch.ones((batch_mask.shape[0], 1), dtype=torch.bool, device=batch.device), batch_mask], dim=1)
|
309 |
+
return batch_node, batch_mask
|
310 |
+
else:
|
311 |
+
return batch_node, batch_mask, h_graph
|
312 |
+
|
313 |
+
|
314 |
+
class GNN_graphpred(torch.nn.Module):
|
315 |
+
"""
|
316 |
+
Extension of GIN to incorporate edge information by concatenation.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
num_layer (int): the number of GNN layers
|
320 |
+
emb_dim (int): dimensionality of embeddings
|
321 |
+
num_tasks (int): number of tasks in multi-task learning scenario
|
322 |
+
drop_ratio (float): dropout rate
|
323 |
+
JK (str): last, concat, max or sum.
|
324 |
+
graph_pooling (str): sum, mean, max, attention, set2set
|
325 |
+
gnn_type: gin, gcn, graphsage, gat
|
326 |
+
|
327 |
+
See https://arxiv.org/abs/1810.00826
|
328 |
+
JK-net: https://arxiv.org/abs/1806.03536
|
329 |
+
"""
|
330 |
+
def __init__(self, num_layer, emb_dim, num_tasks, JK = "last", drop_ratio = 0, graph_pooling = "mean", gnn_type = "gin"):
|
331 |
+
super(GNN_graphpred, self).__init__()
|
332 |
+
self.num_layer = num_layer
|
333 |
+
self.drop_ratio = drop_ratio
|
334 |
+
self.JK = JK
|
335 |
+
self.emb_dim = emb_dim
|
336 |
+
self.num_tasks = num_tasks
|
337 |
+
|
338 |
+
if self.num_layer < 2:
|
339 |
+
raise ValueError("Number of GNN layers must be greater than 1.")
|
340 |
+
|
341 |
+
self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type = gnn_type)
|
342 |
+
|
343 |
+
#Different kind of graph pooling
|
344 |
+
if graph_pooling == "sum":
|
345 |
+
self.pool = global_add_pool
|
346 |
+
elif graph_pooling == "mean":
|
347 |
+
self.pool = global_mean_pool
|
348 |
+
elif graph_pooling == "max":
|
349 |
+
self.pool = global_max_pool
|
350 |
+
elif graph_pooling == "attention":
|
351 |
+
if self.JK == "concat":
|
352 |
+
self.pool = GlobalAttention(gate_nn = torch.nn.Linear((self.num_layer + 1) * emb_dim, 1))
|
353 |
+
else:
|
354 |
+
self.pool = GlobalAttention(gate_nn = torch.nn.Linear(emb_dim, 1))
|
355 |
+
elif graph_pooling[:-1] == "set2set":
|
356 |
+
set2set_iter = int(graph_pooling[-1])
|
357 |
+
if self.JK == "concat":
|
358 |
+
self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter)
|
359 |
+
else:
|
360 |
+
self.pool = Set2Set(emb_dim, set2set_iter)
|
361 |
+
else:
|
362 |
+
raise ValueError("Invalid graph pooling type.")
|
363 |
+
|
364 |
+
#For graph-level binary classification
|
365 |
+
if graph_pooling[:-1] == "set2set":
|
366 |
+
self.mult = 2
|
367 |
+
else:
|
368 |
+
self.mult = 1
|
369 |
+
|
370 |
+
if self.JK == "concat":
|
371 |
+
self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_tasks)
|
372 |
+
else:
|
373 |
+
self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_tasks)
|
374 |
+
|
375 |
+
def from_pretrained(self, model_file):
|
376 |
+
#self.gnn = GNN(self.num_layer, self.emb_dim, JK = self.JK, drop_ratio = self.drop_ratio)
|
377 |
+
missing_keys, unexpected_keys = self.gnn.load_state_dict(torch.load(model_file))
|
378 |
+
print(missing_keys)
|
379 |
+
print(unexpected_keys)
|
380 |
+
|
381 |
+
def forward(self, *argv):
|
382 |
+
if len(argv) == 4:
|
383 |
+
x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
|
384 |
+
elif len(argv) == 1:
|
385 |
+
data = argv[0]
|
386 |
+
x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
|
387 |
+
else:
|
388 |
+
raise ValueError("unmatched number of arguments.")
|
389 |
+
|
390 |
+
node_representation = self.gnn(x, edge_index, edge_attr)
|
391 |
+
|
392 |
+
return self.graph_pred_linear(self.pool(node_representation, batch))
|
393 |
+
|
394 |
+
|
395 |
+
if __name__ == "__main__":
|
396 |
+
pass
|
397 |
+
|
model/help_funcs.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from nltk.translate.bleu_score import corpus_bleu
|
2 |
+
from nltk.translate.meteor_score import meteor_score
|
3 |
+
from rouge_score import rouge_scorer
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
def caption_evaluate(predictions, targets, tokenizer, text_trunc_length):
|
10 |
+
meteor_scores = []
|
11 |
+
references = []
|
12 |
+
hypotheses = []
|
13 |
+
for gt, out in tqdm(zip(targets, predictions)):
|
14 |
+
gt_tokens = tokenizer.tokenize(gt, truncation=True, max_length=text_trunc_length,
|
15 |
+
padding='max_length')
|
16 |
+
gt_tokens = list(filter(('[PAD]').__ne__, gt_tokens))
|
17 |
+
gt_tokens = list(filter(('[CLS]').__ne__, gt_tokens))
|
18 |
+
gt_tokens = list(filter(('[SEP]').__ne__, gt_tokens))
|
19 |
+
|
20 |
+
out_tokens = tokenizer.tokenize(out, truncation=True, max_length=text_trunc_length,
|
21 |
+
padding='max_length')
|
22 |
+
out_tokens = list(filter(('[PAD]').__ne__, out_tokens))
|
23 |
+
out_tokens = list(filter(('[CLS]').__ne__, out_tokens))
|
24 |
+
out_tokens = list(filter(('[SEP]').__ne__, out_tokens))
|
25 |
+
|
26 |
+
references.append([gt_tokens])
|
27 |
+
hypotheses.append(out_tokens)
|
28 |
+
|
29 |
+
mscore = meteor_score([gt_tokens], out_tokens)
|
30 |
+
meteor_scores.append(mscore)
|
31 |
+
|
32 |
+
bleu2 = corpus_bleu(references, hypotheses, weights=(.5,.5))
|
33 |
+
bleu4 = corpus_bleu(references, hypotheses, weights=(.25,.25,.25,.25))
|
34 |
+
bleu2 *= 100
|
35 |
+
bleu4 *= 100
|
36 |
+
|
37 |
+
print('BLEU-2 score:', bleu2)
|
38 |
+
print('BLEU-4 score:', bleu4)
|
39 |
+
_meteor_score = np.mean(meteor_scores)
|
40 |
+
_meteor_score *= 100
|
41 |
+
print('Average Meteor score:', _meteor_score)
|
42 |
+
|
43 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'])
|
44 |
+
|
45 |
+
rouge_scores = []
|
46 |
+
|
47 |
+
references = []
|
48 |
+
hypotheses = []
|
49 |
+
|
50 |
+
for gt, out in tqdm(zip(targets, predictions)):
|
51 |
+
rs = scorer.score(out, gt)
|
52 |
+
rouge_scores.append(rs)
|
53 |
+
|
54 |
+
print('ROUGE score:')
|
55 |
+
rouge_1 = np.mean([rs['rouge1'].fmeasure for rs in rouge_scores]) * 100
|
56 |
+
rouge_2 = np.mean([rs['rouge2'].fmeasure for rs in rouge_scores]) * 100
|
57 |
+
rouge_l = np.mean([rs['rougeL'].fmeasure for rs in rouge_scores]) * 100
|
58 |
+
print('rouge1:', rouge_1)
|
59 |
+
print('rouge2:', rouge_2)
|
60 |
+
print('rougeL:', rouge_l)
|
61 |
+
return bleu2, bleu4, rouge_1, rouge_2, rouge_l, _meteor_score
|
62 |
+
|
63 |
+
|
64 |
+
class AttrDict(dict):
|
65 |
+
def __init__(self, *args, **kwargs):
|
66 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
67 |
+
self.__dict__ = self
|
68 |
+
|
69 |
+
def get_tokens_as_list(tokenizer, word_list):
|
70 |
+
"Converts a sequence of words into a list of tokens"
|
71 |
+
"Source: https://huggingface.co/docs/transformers/internal/generation_utils"
|
72 |
+
tokens_list = []
|
73 |
+
for word in word_list:
|
74 |
+
tokenized_word = tokenizer([word], add_special_tokens=False).input_ids[0]
|
75 |
+
tokens_list.extend(tokenized_word)
|
76 |
+
return tokens_list
|
77 |
+
|
78 |
+
def get_not_allowed_tokens_ids(tokenizer_name, allowed_words_file='model/allowed_words.json'):
|
79 |
+
tokenizer_with_prefix_space = AutoTokenizer.from_pretrained(tokenizer_name, add_prefix_space=True)
|
80 |
+
with open(allowed_words_file, 'r') as f:
|
81 |
+
allowed_words = json.load(f)
|
82 |
+
allowed_words = list(allowed_words.values())
|
83 |
+
allowed_tokens_ids = get_tokens_as_list(tokenizer_with_prefix_space, allowed_words)
|
84 |
+
full_token_space = list(range(tokenizer_with_prefix_space.vocab_size))
|
85 |
+
not_allowed_tokens_ids = [[token_id] for token_id in full_token_space if token_id not in allowed_tokens_ids]
|
86 |
+
return not_allowed_tokens_ids
|
model/modeling_llama.py
ADDED
@@ -0,0 +1,888 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
""" PyTorch LLaMA model."""
|
21 |
+
import math
|
22 |
+
from typing import List, Optional, Tuple, Union
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.utils.checkpoint
|
26 |
+
from torch import nn
|
27 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
28 |
+
|
29 |
+
from transformers.activations import ACT2FN
|
30 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
31 |
+
from transformers.modeling_utils import PreTrainedModel
|
32 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
33 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__)
|
37 |
+
|
38 |
+
_CONFIG_FOR_DOC = "LlamaConfig"
|
39 |
+
|
40 |
+
|
41 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
42 |
+
def _make_causal_mask(
|
43 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Make causal mask used for bi-directional self-attention.
|
47 |
+
"""
|
48 |
+
bsz, tgt_len = input_ids_shape
|
49 |
+
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
|
50 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
51 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
52 |
+
mask = mask.to(dtype)
|
53 |
+
|
54 |
+
if past_key_values_length > 0:
|
55 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
56 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
57 |
+
|
58 |
+
|
59 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
60 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
61 |
+
"""
|
62 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
63 |
+
"""
|
64 |
+
bsz, src_len = mask.size()
|
65 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
66 |
+
|
67 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
68 |
+
|
69 |
+
inverted_mask = 1.0 - expanded_mask
|
70 |
+
|
71 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
72 |
+
|
73 |
+
|
74 |
+
class LlamaRMSNorm(nn.Module):
|
75 |
+
def __init__(self, hidden_size, eps=1e-6):
|
76 |
+
"""
|
77 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
78 |
+
"""
|
79 |
+
super().__init__()
|
80 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
81 |
+
self.variance_epsilon = eps
|
82 |
+
|
83 |
+
def forward(self, hidden_states):
|
84 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
85 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
86 |
+
|
87 |
+
# convert into half-precision if necessary
|
88 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
89 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
90 |
+
|
91 |
+
return self.weight * hidden_states
|
92 |
+
|
93 |
+
|
94 |
+
class LlamaRotaryEmbedding(torch.nn.Module):
|
95 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
96 |
+
super().__init__()
|
97 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
98 |
+
self.register_buffer("inv_freq", inv_freq)
|
99 |
+
|
100 |
+
# Build here to make `torch.jit.trace` work.
|
101 |
+
self.max_seq_len_cached = max_position_embeddings
|
102 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
103 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
104 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
105 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
106 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
107 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
108 |
+
|
109 |
+
def forward(self, x, seq_len=None):
|
110 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
111 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
112 |
+
if seq_len > self.max_seq_len_cached:
|
113 |
+
self.max_seq_len_cached = seq_len
|
114 |
+
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
115 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
116 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
117 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
118 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
119 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
120 |
+
return (
|
121 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
122 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
def rotate_half(x):
|
127 |
+
"""Rotates half the hidden dims of the input."""
|
128 |
+
x1 = x[..., : x.shape[-1] // 2]
|
129 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
130 |
+
return torch.cat((-x2, x1), dim=-1)
|
131 |
+
|
132 |
+
|
133 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
134 |
+
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
135 |
+
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
136 |
+
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
137 |
+
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
138 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
139 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
140 |
+
return q_embed, k_embed
|
141 |
+
|
142 |
+
|
143 |
+
class LlamaMLP(nn.Module):
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
hidden_size: int,
|
147 |
+
intermediate_size: int,
|
148 |
+
hidden_act: str,
|
149 |
+
):
|
150 |
+
super().__init__()
|
151 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
152 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
153 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
154 |
+
self.act_fn = ACT2FN[hidden_act]
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
158 |
+
|
159 |
+
|
160 |
+
class LlamaAttention(nn.Module):
|
161 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
162 |
+
|
163 |
+
def __init__(self, config: LlamaConfig):
|
164 |
+
super().__init__()
|
165 |
+
self.config = config
|
166 |
+
self.hidden_size = config.hidden_size
|
167 |
+
self.num_heads = config.num_attention_heads
|
168 |
+
self.head_dim = self.hidden_size // self.num_heads
|
169 |
+
self.max_position_embeddings = config.max_position_embeddings
|
170 |
+
|
171 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
172 |
+
raise ValueError(
|
173 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
174 |
+
f" and `num_heads`: {self.num_heads})."
|
175 |
+
)
|
176 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
177 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
178 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
179 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
180 |
+
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
181 |
+
|
182 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
183 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
184 |
+
|
185 |
+
def forward(
|
186 |
+
self,
|
187 |
+
hidden_states: torch.Tensor,
|
188 |
+
attention_mask: Optional[torch.Tensor] = None,
|
189 |
+
position_ids: Optional[torch.LongTensor] = None,
|
190 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
191 |
+
output_attentions: bool = False,
|
192 |
+
use_cache: bool = False,
|
193 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
194 |
+
bsz, q_len, _ = hidden_states.size()
|
195 |
+
|
196 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
197 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
198 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
199 |
+
|
200 |
+
kv_seq_len = key_states.shape[-2]
|
201 |
+
if past_key_value is not None:
|
202 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
203 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
204 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
205 |
+
# [bsz, nh, t, hd]
|
206 |
+
|
207 |
+
if past_key_value is not None:
|
208 |
+
# reuse k, v, self_attention
|
209 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
210 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
211 |
+
|
212 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
213 |
+
|
214 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
215 |
+
|
216 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
217 |
+
raise ValueError(
|
218 |
+
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
219 |
+
f" {attn_weights.size()}"
|
220 |
+
)
|
221 |
+
|
222 |
+
if attention_mask is not None:
|
223 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
224 |
+
raise ValueError(
|
225 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
226 |
+
)
|
227 |
+
attn_weights = attn_weights + attention_mask
|
228 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
229 |
+
|
230 |
+
# upcast attention to fp32
|
231 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
232 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
233 |
+
|
234 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
235 |
+
raise ValueError(
|
236 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
237 |
+
f" {attn_output.size()}"
|
238 |
+
)
|
239 |
+
|
240 |
+
attn_output = attn_output.transpose(1, 2)
|
241 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
242 |
+
|
243 |
+
attn_output = self.o_proj(attn_output)
|
244 |
+
|
245 |
+
if not output_attentions:
|
246 |
+
attn_weights = None
|
247 |
+
|
248 |
+
return attn_output, attn_weights, past_key_value
|
249 |
+
|
250 |
+
|
251 |
+
class LlamaDecoderLayer(nn.Module):
|
252 |
+
def __init__(self, config: LlamaConfig):
|
253 |
+
super().__init__()
|
254 |
+
self.hidden_size = config.hidden_size
|
255 |
+
self.self_attn = LlamaAttention(config=config)
|
256 |
+
self.mlp = LlamaMLP(
|
257 |
+
hidden_size=self.hidden_size,
|
258 |
+
intermediate_size=config.intermediate_size,
|
259 |
+
hidden_act=config.hidden_act,
|
260 |
+
)
|
261 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
262 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
263 |
+
|
264 |
+
def forward(
|
265 |
+
self,
|
266 |
+
hidden_states: torch.Tensor,
|
267 |
+
attention_mask: Optional[torch.Tensor] = None,
|
268 |
+
position_ids: Optional[torch.LongTensor] = None,
|
269 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
270 |
+
output_attentions: Optional[bool] = False,
|
271 |
+
use_cache: Optional[bool] = False,
|
272 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
273 |
+
"""
|
274 |
+
Args:
|
275 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
276 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
277 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
278 |
+
output_attentions (`bool`, *optional*):
|
279 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
280 |
+
returned tensors for more detail.
|
281 |
+
use_cache (`bool`, *optional*):
|
282 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
283 |
+
(see `past_key_values`).
|
284 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
285 |
+
"""
|
286 |
+
|
287 |
+
residual = hidden_states
|
288 |
+
|
289 |
+
hidden_states = self.input_layernorm(hidden_states)
|
290 |
+
|
291 |
+
# Self Attention
|
292 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
293 |
+
hidden_states=hidden_states,
|
294 |
+
attention_mask=attention_mask,
|
295 |
+
position_ids=position_ids,
|
296 |
+
past_key_value=past_key_value,
|
297 |
+
output_attentions=output_attentions,
|
298 |
+
use_cache=use_cache,
|
299 |
+
)
|
300 |
+
hidden_states = residual + hidden_states
|
301 |
+
|
302 |
+
# Fully Connected
|
303 |
+
residual = hidden_states
|
304 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
305 |
+
hidden_states = self.mlp(hidden_states)
|
306 |
+
hidden_states = residual + hidden_states
|
307 |
+
|
308 |
+
outputs = (hidden_states,)
|
309 |
+
|
310 |
+
if output_attentions:
|
311 |
+
outputs += (self_attn_weights,)
|
312 |
+
|
313 |
+
if use_cache:
|
314 |
+
outputs += (present_key_value,)
|
315 |
+
|
316 |
+
return outputs
|
317 |
+
|
318 |
+
|
319 |
+
LLAMA_START_DOCSTRING = r"""
|
320 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
321 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
322 |
+
etc.)
|
323 |
+
|
324 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
325 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
326 |
+
and behavior.
|
327 |
+
|
328 |
+
Parameters:
|
329 |
+
config ([`LlamaConfig`]):
|
330 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
331 |
+
load the weights associated with the model, only the configuration. Check out the
|
332 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
333 |
+
"""
|
334 |
+
|
335 |
+
|
336 |
+
@add_start_docstrings(
|
337 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
338 |
+
LLAMA_START_DOCSTRING,
|
339 |
+
)
|
340 |
+
class LlamaPreTrainedModel(PreTrainedModel):
|
341 |
+
config_class = LlamaConfig
|
342 |
+
base_model_prefix = "model"
|
343 |
+
supports_gradient_checkpointing = True
|
344 |
+
_no_split_modules = ["LlamaDecoderLayer"]
|
345 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
346 |
+
|
347 |
+
def _init_weights(self, module):
|
348 |
+
std = self.config.initializer_range
|
349 |
+
if isinstance(module, nn.Linear):
|
350 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
351 |
+
if module.bias is not None:
|
352 |
+
module.bias.data.zero_()
|
353 |
+
elif isinstance(module, nn.Embedding):
|
354 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
355 |
+
if module.padding_idx is not None:
|
356 |
+
module.weight.data[module.padding_idx].zero_()
|
357 |
+
|
358 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
359 |
+
if isinstance(module, LlamaModel):
|
360 |
+
module.gradient_checkpointing = value
|
361 |
+
|
362 |
+
|
363 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
364 |
+
Args:
|
365 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
366 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
367 |
+
it.
|
368 |
+
|
369 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
370 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
371 |
+
|
372 |
+
[What are input IDs?](../glossary#input-ids)
|
373 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
374 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
375 |
+
|
376 |
+
- 1 for tokens that are **not masked**,
|
377 |
+
- 0 for tokens that are **masked**.
|
378 |
+
|
379 |
+
[What are attention masks?](../glossary#attention-mask)
|
380 |
+
|
381 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
382 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
383 |
+
|
384 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
385 |
+
`past_key_values`).
|
386 |
+
|
387 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
388 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
389 |
+
information on the default strategy.
|
390 |
+
|
391 |
+
- 1 indicates the head is **not masked**,
|
392 |
+
- 0 indicates the head is **masked**.
|
393 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
394 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
395 |
+
config.n_positions - 1]`.
|
396 |
+
|
397 |
+
[What are position IDs?](../glossary#position-ids)
|
398 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
399 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
400 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
401 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
402 |
+
|
403 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
404 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
405 |
+
|
406 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
407 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
408 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
409 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
410 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
411 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
412 |
+
model's internal embedding lookup matrix.
|
413 |
+
use_cache (`bool`, *optional*):
|
414 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
415 |
+
`past_key_values`).
|
416 |
+
output_attentions (`bool`, *optional*):
|
417 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
418 |
+
tensors for more detail.
|
419 |
+
output_hidden_states (`bool`, *optional*):
|
420 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
421 |
+
more detail.
|
422 |
+
return_dict (`bool`, *optional*):
|
423 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
424 |
+
"""
|
425 |
+
|
426 |
+
|
427 |
+
@add_start_docstrings(
|
428 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
429 |
+
LLAMA_START_DOCSTRING,
|
430 |
+
)
|
431 |
+
class LlamaModel(LlamaPreTrainedModel):
|
432 |
+
"""
|
433 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
434 |
+
|
435 |
+
Args:
|
436 |
+
config: LlamaConfig
|
437 |
+
"""
|
438 |
+
|
439 |
+
def __init__(self, config: LlamaConfig):
|
440 |
+
super().__init__(config)
|
441 |
+
self.padding_idx = config.pad_token_id
|
442 |
+
self.vocab_size = config.vocab_size
|
443 |
+
|
444 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
445 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
446 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
447 |
+
|
448 |
+
self.gradient_checkpointing = False
|
449 |
+
# Initialize weights and apply final processing
|
450 |
+
self.post_init()
|
451 |
+
|
452 |
+
def get_input_embeddings(self):
|
453 |
+
return self.embed_tokens
|
454 |
+
|
455 |
+
def set_input_embeddings(self, value):
|
456 |
+
self.embed_tokens = value
|
457 |
+
|
458 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
459 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
460 |
+
# create causal mask
|
461 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
462 |
+
combined_attention_mask = None
|
463 |
+
if input_shape[-1] > 1:
|
464 |
+
combined_attention_mask = _make_causal_mask(
|
465 |
+
input_shape,
|
466 |
+
inputs_embeds.dtype,
|
467 |
+
device=inputs_embeds.device,
|
468 |
+
past_key_values_length=past_key_values_length,
|
469 |
+
)
|
470 |
+
|
471 |
+
if attention_mask is not None:
|
472 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
473 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
474 |
+
inputs_embeds.device
|
475 |
+
)
|
476 |
+
combined_attention_mask = (
|
477 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
478 |
+
)
|
479 |
+
|
480 |
+
return combined_attention_mask
|
481 |
+
|
482 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
483 |
+
def forward(
|
484 |
+
self,
|
485 |
+
input_ids: torch.LongTensor = None,
|
486 |
+
attention_mask: Optional[torch.Tensor] = None,
|
487 |
+
position_ids: Optional[torch.LongTensor] = None,
|
488 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
489 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
490 |
+
use_cache: Optional[bool] = None,
|
491 |
+
output_attentions: Optional[bool] = None,
|
492 |
+
output_hidden_states: Optional[bool] = None,
|
493 |
+
return_dict: Optional[bool] = None,
|
494 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
495 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
496 |
+
output_hidden_states = (
|
497 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
498 |
+
)
|
499 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
500 |
+
|
501 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
502 |
+
|
503 |
+
# retrieve input_ids and inputs_embeds
|
504 |
+
if input_ids is not None and inputs_embeds is not None:
|
505 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
506 |
+
elif input_ids is not None:
|
507 |
+
batch_size, seq_length = input_ids.shape
|
508 |
+
elif inputs_embeds is not None:
|
509 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
510 |
+
else:
|
511 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
512 |
+
|
513 |
+
seq_length_with_past = seq_length
|
514 |
+
past_key_values_length = 0
|
515 |
+
|
516 |
+
if past_key_values is not None:
|
517 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
518 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
519 |
+
|
520 |
+
if position_ids is None:
|
521 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
522 |
+
position_ids = torch.arange(
|
523 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
524 |
+
)
|
525 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
526 |
+
else:
|
527 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
528 |
+
|
529 |
+
if inputs_embeds is None:
|
530 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
531 |
+
# embed positions
|
532 |
+
if attention_mask is None:
|
533 |
+
attention_mask = torch.ones(
|
534 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
535 |
+
)
|
536 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
537 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
538 |
+
)
|
539 |
+
|
540 |
+
hidden_states = inputs_embeds
|
541 |
+
|
542 |
+
if self.gradient_checkpointing and self.training:
|
543 |
+
if use_cache:
|
544 |
+
logger.warning_once(
|
545 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
546 |
+
)
|
547 |
+
use_cache = False
|
548 |
+
|
549 |
+
# decoder layers
|
550 |
+
all_hidden_states = () if output_hidden_states else None
|
551 |
+
all_self_attns = () if output_attentions else None
|
552 |
+
next_decoder_cache = () if use_cache else None
|
553 |
+
|
554 |
+
for idx, decoder_layer in enumerate(self.layers):
|
555 |
+
if output_hidden_states:
|
556 |
+
all_hidden_states += (hidden_states,)
|
557 |
+
|
558 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
559 |
+
|
560 |
+
if self.gradient_checkpointing and self.training:
|
561 |
+
|
562 |
+
def create_custom_forward(module):
|
563 |
+
def custom_forward(*inputs):
|
564 |
+
# None for past_key_value
|
565 |
+
return module(*inputs, output_attentions, None)
|
566 |
+
|
567 |
+
return custom_forward
|
568 |
+
|
569 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
570 |
+
create_custom_forward(decoder_layer),
|
571 |
+
hidden_states,
|
572 |
+
attention_mask,
|
573 |
+
position_ids,
|
574 |
+
None,
|
575 |
+
)
|
576 |
+
else:
|
577 |
+
layer_outputs = decoder_layer(
|
578 |
+
hidden_states,
|
579 |
+
attention_mask=attention_mask,
|
580 |
+
position_ids=position_ids,
|
581 |
+
past_key_value=past_key_value,
|
582 |
+
output_attentions=output_attentions,
|
583 |
+
use_cache=use_cache,
|
584 |
+
)
|
585 |
+
|
586 |
+
hidden_states = layer_outputs[0]
|
587 |
+
|
588 |
+
if use_cache:
|
589 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
590 |
+
|
591 |
+
if output_attentions:
|
592 |
+
all_self_attns += (layer_outputs[1],)
|
593 |
+
|
594 |
+
hidden_states = self.norm(hidden_states)
|
595 |
+
|
596 |
+
# add hidden states from the last decoder layer
|
597 |
+
if output_hidden_states:
|
598 |
+
all_hidden_states += (hidden_states,)
|
599 |
+
|
600 |
+
next_cache = next_decoder_cache if use_cache else None
|
601 |
+
if not return_dict:
|
602 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
603 |
+
return BaseModelOutputWithPast(
|
604 |
+
last_hidden_state=hidden_states,
|
605 |
+
past_key_values=next_cache,
|
606 |
+
hidden_states=all_hidden_states,
|
607 |
+
attentions=all_self_attns,
|
608 |
+
)
|
609 |
+
|
610 |
+
|
611 |
+
class LlamaForCausalLM(LlamaPreTrainedModel):
|
612 |
+
def __init__(self, config):
|
613 |
+
super().__init__(config)
|
614 |
+
self.model = LlamaModel(config)
|
615 |
+
|
616 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
617 |
+
|
618 |
+
# Initialize weights and apply final processing
|
619 |
+
self.post_init()
|
620 |
+
|
621 |
+
def get_input_embeddings(self):
|
622 |
+
return self.model.embed_tokens
|
623 |
+
|
624 |
+
def set_input_embeddings(self, value):
|
625 |
+
self.model.embed_tokens = value
|
626 |
+
|
627 |
+
def get_output_embeddings(self):
|
628 |
+
return self.lm_head
|
629 |
+
|
630 |
+
def set_output_embeddings(self, new_embeddings):
|
631 |
+
self.lm_head = new_embeddings
|
632 |
+
|
633 |
+
def set_decoder(self, decoder):
|
634 |
+
self.model = decoder
|
635 |
+
|
636 |
+
def get_decoder(self):
|
637 |
+
return self.model
|
638 |
+
|
639 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
640 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
641 |
+
def forward(
|
642 |
+
self,
|
643 |
+
input_ids: torch.LongTensor = None,
|
644 |
+
attention_mask: Optional[torch.Tensor] = None,
|
645 |
+
position_ids: Optional[torch.LongTensor] = None,
|
646 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
647 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
648 |
+
labels: Optional[torch.LongTensor] = None,
|
649 |
+
use_cache: Optional[bool] = None,
|
650 |
+
output_attentions: Optional[bool] = None,
|
651 |
+
output_hidden_states: Optional[bool] = None,
|
652 |
+
return_dict: Optional[bool] = None,
|
653 |
+
reduction: Optional[str] = "mean",
|
654 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
655 |
+
r"""
|
656 |
+
Args:
|
657 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
658 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
659 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
660 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
661 |
+
|
662 |
+
Returns:
|
663 |
+
|
664 |
+
Example:
|
665 |
+
|
666 |
+
```python
|
667 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
668 |
+
|
669 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
670 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
671 |
+
|
672 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
673 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
674 |
+
|
675 |
+
>>> # Generate
|
676 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
677 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
678 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
679 |
+
```"""
|
680 |
+
|
681 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
682 |
+
output_hidden_states = (
|
683 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
684 |
+
)
|
685 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
686 |
+
|
687 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
688 |
+
outputs = self.model(
|
689 |
+
input_ids=input_ids,
|
690 |
+
attention_mask=attention_mask,
|
691 |
+
position_ids=position_ids,
|
692 |
+
past_key_values=past_key_values,
|
693 |
+
inputs_embeds=inputs_embeds,
|
694 |
+
use_cache=use_cache,
|
695 |
+
output_attentions=output_attentions,
|
696 |
+
output_hidden_states=output_hidden_states,
|
697 |
+
return_dict=return_dict,
|
698 |
+
)
|
699 |
+
|
700 |
+
hidden_states = outputs[0]
|
701 |
+
logits = self.lm_head(hidden_states)
|
702 |
+
|
703 |
+
loss = None
|
704 |
+
if labels is not None:
|
705 |
+
# Shift so that tokens < n predict n
|
706 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
707 |
+
shift_labels = labels[..., 1:].contiguous()
|
708 |
+
# Flatten the tokens
|
709 |
+
loss_fct = CrossEntropyLoss(reduction=reduction)
|
710 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
711 |
+
shift_labels = shift_labels.view(-1)
|
712 |
+
# Enable model parallelism
|
713 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
714 |
+
loss = loss_fct(shift_logits, shift_labels)
|
715 |
+
if reduction == "none":
|
716 |
+
# loss = loss.view(logits.size(0), -1).sum(1)
|
717 |
+
loss = loss.view(logits.size(0), -1).mean(1)
|
718 |
+
|
719 |
+
if not return_dict:
|
720 |
+
output = (logits,) + outputs[1:]
|
721 |
+
return (loss,) + output if loss is not None else output
|
722 |
+
|
723 |
+
return CausalLMOutputWithPast(
|
724 |
+
loss=loss,
|
725 |
+
logits=logits,
|
726 |
+
past_key_values=outputs.past_key_values,
|
727 |
+
hidden_states=outputs.hidden_states,
|
728 |
+
attentions=outputs.attentions,
|
729 |
+
)
|
730 |
+
|
731 |
+
def prepare_inputs_for_generation(
|
732 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
733 |
+
):
|
734 |
+
if past_key_values:
|
735 |
+
input_ids = input_ids[:, -1:]
|
736 |
+
|
737 |
+
position_ids = kwargs.get("position_ids", None)
|
738 |
+
if attention_mask is not None and position_ids is None:
|
739 |
+
# create position_ids on the fly for batch generation
|
740 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
741 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
742 |
+
if past_key_values:
|
743 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
744 |
+
|
745 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
746 |
+
if inputs_embeds is not None and past_key_values is None:
|
747 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
748 |
+
else:
|
749 |
+
model_inputs = {"input_ids": input_ids}
|
750 |
+
|
751 |
+
model_inputs.update(
|
752 |
+
{
|
753 |
+
"position_ids": position_ids,
|
754 |
+
"past_key_values": past_key_values,
|
755 |
+
"use_cache": kwargs.get("use_cache"),
|
756 |
+
"attention_mask": attention_mask,
|
757 |
+
}
|
758 |
+
)
|
759 |
+
return model_inputs
|
760 |
+
|
761 |
+
@staticmethod
|
762 |
+
def _reorder_cache(past_key_values, beam_idx):
|
763 |
+
reordered_past = ()
|
764 |
+
for layer_past in past_key_values:
|
765 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
766 |
+
return reordered_past
|
767 |
+
|
768 |
+
|
769 |
+
@add_start_docstrings(
|
770 |
+
"""
|
771 |
+
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
772 |
+
|
773 |
+
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
774 |
+
(e.g. GPT-2) do.
|
775 |
+
|
776 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
777 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
778 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
779 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
780 |
+
each row of the batch).
|
781 |
+
""",
|
782 |
+
LLAMA_START_DOCSTRING,
|
783 |
+
)
|
784 |
+
class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
785 |
+
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
786 |
+
|
787 |
+
def __init__(self, config):
|
788 |
+
super().__init__(config)
|
789 |
+
self.num_labels = config.num_labels
|
790 |
+
self.model = LlamaModel(config)
|
791 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
792 |
+
|
793 |
+
# Initialize weights and apply final processing
|
794 |
+
self.post_init()
|
795 |
+
|
796 |
+
def get_input_embeddings(self):
|
797 |
+
return self.model.embed_tokens
|
798 |
+
|
799 |
+
def set_input_embeddings(self, value):
|
800 |
+
self.model.embed_tokens = value
|
801 |
+
|
802 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
803 |
+
def forward(
|
804 |
+
self,
|
805 |
+
input_ids: torch.LongTensor = None,
|
806 |
+
attention_mask: Optional[torch.Tensor] = None,
|
807 |
+
position_ids: Optional[torch.LongTensor] = None,
|
808 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
809 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
810 |
+
labels: Optional[torch.LongTensor] = None,
|
811 |
+
use_cache: Optional[bool] = None,
|
812 |
+
output_attentions: Optional[bool] = None,
|
813 |
+
output_hidden_states: Optional[bool] = None,
|
814 |
+
return_dict: Optional[bool] = None,
|
815 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
816 |
+
r"""
|
817 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
818 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
819 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
820 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
821 |
+
"""
|
822 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
823 |
+
|
824 |
+
transformer_outputs = self.model(
|
825 |
+
input_ids,
|
826 |
+
attention_mask=attention_mask,
|
827 |
+
position_ids=position_ids,
|
828 |
+
past_key_values=past_key_values,
|
829 |
+
inputs_embeds=inputs_embeds,
|
830 |
+
use_cache=use_cache,
|
831 |
+
output_attentions=output_attentions,
|
832 |
+
output_hidden_states=output_hidden_states,
|
833 |
+
return_dict=return_dict,
|
834 |
+
)
|
835 |
+
hidden_states = transformer_outputs[0]
|
836 |
+
logits = self.score(hidden_states)
|
837 |
+
|
838 |
+
if input_ids is not None:
|
839 |
+
batch_size = input_ids.shape[0]
|
840 |
+
else:
|
841 |
+
batch_size = inputs_embeds.shape[0]
|
842 |
+
|
843 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
844 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
845 |
+
if self.config.pad_token_id is None:
|
846 |
+
sequence_lengths = -1
|
847 |
+
else:
|
848 |
+
if input_ids is not None:
|
849 |
+
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
850 |
+
else:
|
851 |
+
sequence_lengths = -1
|
852 |
+
|
853 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
854 |
+
|
855 |
+
loss = None
|
856 |
+
if labels is not None:
|
857 |
+
labels = labels.to(logits.device)
|
858 |
+
if self.config.problem_type is None:
|
859 |
+
if self.num_labels == 1:
|
860 |
+
self.config.problem_type = "regression"
|
861 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
862 |
+
self.config.problem_type = "single_label_classification"
|
863 |
+
else:
|
864 |
+
self.config.problem_type = "multi_label_classification"
|
865 |
+
|
866 |
+
if self.config.problem_type == "regression":
|
867 |
+
loss_fct = MSELoss()
|
868 |
+
if self.num_labels == 1:
|
869 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
870 |
+
else:
|
871 |
+
loss = loss_fct(pooled_logits, labels)
|
872 |
+
elif self.config.problem_type == "single_label_classification":
|
873 |
+
loss_fct = CrossEntropyLoss()
|
874 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
875 |
+
elif self.config.problem_type == "multi_label_classification":
|
876 |
+
loss_fct = BCEWithLogitsLoss()
|
877 |
+
loss = loss_fct(pooled_logits, labels)
|
878 |
+
if not return_dict:
|
879 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
880 |
+
return ((loss,) + output) if loss is not None else output
|
881 |
+
|
882 |
+
return SequenceClassifierOutputWithPast(
|
883 |
+
loss=loss,
|
884 |
+
logits=pooled_logits,
|
885 |
+
past_key_values=transformer_outputs.past_key_values,
|
886 |
+
hidden_states=transformer_outputs.hidden_states,
|
887 |
+
attentions=transformer_outputs.attentions,
|
888 |
+
)
|
model/modeling_opt.py
ADDED
@@ -0,0 +1,1223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch OPT model."""
|
16 |
+
import random
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
from torch import nn
|
22 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
23 |
+
from transformers.activations import ACT2FN
|
24 |
+
# from ...activations import ACT2FN
|
25 |
+
from transformers.modeling_outputs import (
|
26 |
+
BaseModelOutputWithPast,
|
27 |
+
CausalLMOutputWithPast,
|
28 |
+
QuestionAnsweringModelOutput,
|
29 |
+
SequenceClassifierOutputWithPast,
|
30 |
+
)
|
31 |
+
from transformers.modeling_utils import PreTrainedModel
|
32 |
+
from transformers.utils import (
|
33 |
+
add_code_sample_docstrings,
|
34 |
+
add_start_docstrings,
|
35 |
+
add_start_docstrings_to_model_forward,
|
36 |
+
logging,
|
37 |
+
replace_return_docstrings,
|
38 |
+
)
|
39 |
+
from transformers.models.opt.configuration_opt import OPTConfig
|
40 |
+
# from .configuration_opt
|
41 |
+
|
42 |
+
|
43 |
+
logger = logging.get_logger(__name__)
|
44 |
+
|
45 |
+
_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
|
46 |
+
_CONFIG_FOR_DOC = "OPTConfig"
|
47 |
+
|
48 |
+
# Base model docstring
|
49 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
|
50 |
+
|
51 |
+
# SequenceClassification docstring
|
52 |
+
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
|
53 |
+
_SEQ_CLASS_EXPECTED_LOSS = 1.71
|
54 |
+
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
|
55 |
+
|
56 |
+
OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
57 |
+
"facebook/opt-125m",
|
58 |
+
"facebook/opt-350m",
|
59 |
+
"facebook/opt-1.3b",
|
60 |
+
"facebook/opt-2.7b",
|
61 |
+
"facebook/opt-6.7b",
|
62 |
+
"facebook/opt-13b",
|
63 |
+
"facebook/opt-30b",
|
64 |
+
# See all OPT models at https://huggingface.co/models?filter=opt
|
65 |
+
]
|
66 |
+
|
67 |
+
|
68 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
69 |
+
def _make_causal_mask(
|
70 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
71 |
+
):
|
72 |
+
"""
|
73 |
+
Make causal mask used for bi-directional self-attention.
|
74 |
+
"""
|
75 |
+
bsz, tgt_len = input_ids_shape
|
76 |
+
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
|
77 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
78 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
79 |
+
mask = mask.to(dtype)
|
80 |
+
|
81 |
+
if past_key_values_length > 0:
|
82 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
83 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
84 |
+
|
85 |
+
|
86 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
87 |
+
"""
|
88 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
89 |
+
"""
|
90 |
+
bsz, src_len = mask.size()
|
91 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
92 |
+
|
93 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
94 |
+
|
95 |
+
inverted_mask = 1.0 - expanded_mask
|
96 |
+
|
97 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
98 |
+
|
99 |
+
|
100 |
+
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
101 |
+
"""
|
102 |
+
This module learns positional embeddings up to a fixed maximum size.
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
106 |
+
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
|
107 |
+
# and adjust num_embeddings appropriately. Other models don't have this hack
|
108 |
+
self.offset = 2
|
109 |
+
super().__init__(num_embeddings + self.offset, embedding_dim)
|
110 |
+
|
111 |
+
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
|
112 |
+
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
113 |
+
attention_mask = attention_mask.long()
|
114 |
+
|
115 |
+
# create positions depending on attention_mask
|
116 |
+
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
|
117 |
+
|
118 |
+
# cut positions if `past_key_values_length` is > 0
|
119 |
+
positions = positions[:, past_key_values_length:]
|
120 |
+
|
121 |
+
return super().forward(positions + self.offset)
|
122 |
+
|
123 |
+
|
124 |
+
class OPTAttention(nn.Module):
|
125 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
126 |
+
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
embed_dim: int,
|
130 |
+
num_heads: int,
|
131 |
+
dropout: float = 0.0,
|
132 |
+
is_decoder: bool = False,
|
133 |
+
bias: bool = True,
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
self.embed_dim = embed_dim
|
137 |
+
self.num_heads = num_heads
|
138 |
+
self.dropout = dropout
|
139 |
+
self.head_dim = embed_dim // num_heads
|
140 |
+
|
141 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
142 |
+
raise ValueError(
|
143 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
144 |
+
f" and `num_heads`: {num_heads})."
|
145 |
+
)
|
146 |
+
self.scaling = self.head_dim**-0.5
|
147 |
+
self.is_decoder = is_decoder
|
148 |
+
|
149 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
150 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
151 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
152 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
153 |
+
|
154 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
155 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
156 |
+
|
157 |
+
def forward(
|
158 |
+
self,
|
159 |
+
hidden_states: torch.Tensor,
|
160 |
+
key_value_states: Optional[torch.Tensor] = None,
|
161 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
162 |
+
attention_mask: Optional[torch.Tensor] = None,
|
163 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
164 |
+
output_attentions: bool = False,
|
165 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
166 |
+
"""Input shape: Batch x Time x Channel"""
|
167 |
+
|
168 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
169 |
+
# for the decoder
|
170 |
+
is_cross_attention = key_value_states is not None
|
171 |
+
|
172 |
+
bsz, tgt_len, _ = hidden_states.size()
|
173 |
+
|
174 |
+
# get query proj
|
175 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
176 |
+
# get key, value proj
|
177 |
+
if is_cross_attention and past_key_value is not None:
|
178 |
+
# reuse k,v, cross_attentions
|
179 |
+
key_states = past_key_value[0]
|
180 |
+
value_states = past_key_value[1]
|
181 |
+
elif is_cross_attention:
|
182 |
+
# cross_attentions
|
183 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
184 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
185 |
+
elif past_key_value is not None:
|
186 |
+
# reuse k, v, self_attention
|
187 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
188 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
189 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
190 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
191 |
+
else:
|
192 |
+
# self_attention
|
193 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
194 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
195 |
+
|
196 |
+
if self.is_decoder:
|
197 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
198 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
199 |
+
# key/value_states (first "if" case)
|
200 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
201 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
202 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
203 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
204 |
+
past_key_value = (key_states, value_states)
|
205 |
+
|
206 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
207 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
208 |
+
key_states = key_states.view(*proj_shape)
|
209 |
+
value_states = value_states.view(*proj_shape)
|
210 |
+
|
211 |
+
src_len = key_states.size(1)
|
212 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
213 |
+
|
214 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
215 |
+
raise ValueError(
|
216 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
217 |
+
f" {attn_weights.size()}"
|
218 |
+
)
|
219 |
+
|
220 |
+
if attention_mask is not None:
|
221 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
222 |
+
raise ValueError(
|
223 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
224 |
+
)
|
225 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
226 |
+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
227 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
228 |
+
|
229 |
+
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
230 |
+
if attn_weights.dtype == torch.float16:
|
231 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
|
232 |
+
else:
|
233 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
234 |
+
|
235 |
+
if layer_head_mask is not None:
|
236 |
+
if layer_head_mask.size() != (self.num_heads,):
|
237 |
+
raise ValueError(
|
238 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
239 |
+
f" {layer_head_mask.size()}"
|
240 |
+
)
|
241 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
242 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
243 |
+
|
244 |
+
if output_attentions:
|
245 |
+
# this operation is a bit awkward, but it's required to
|
246 |
+
# make sure that attn_weights keeps its gradient.
|
247 |
+
# In order to do so, attn_weights have to be reshaped
|
248 |
+
# twice and have to be reused in the following
|
249 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
250 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
251 |
+
else:
|
252 |
+
attn_weights_reshaped = None
|
253 |
+
|
254 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
255 |
+
|
256 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
257 |
+
|
258 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
259 |
+
raise ValueError(
|
260 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
261 |
+
f" {attn_output.size()}"
|
262 |
+
)
|
263 |
+
|
264 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
265 |
+
attn_output = attn_output.transpose(1, 2)
|
266 |
+
|
267 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
268 |
+
# partitioned aross GPUs when using tensor-parallelism.
|
269 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
270 |
+
|
271 |
+
attn_output = self.out_proj(attn_output)
|
272 |
+
|
273 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
274 |
+
|
275 |
+
|
276 |
+
class OPTDecoderLayer(nn.Module):
|
277 |
+
def __init__(self, config: OPTConfig):
|
278 |
+
super().__init__()
|
279 |
+
self.embed_dim = config.hidden_size
|
280 |
+
self.self_attn = OPTAttention(
|
281 |
+
embed_dim=self.embed_dim,
|
282 |
+
num_heads=config.num_attention_heads,
|
283 |
+
dropout=config.attention_dropout,
|
284 |
+
is_decoder=True,
|
285 |
+
bias=config.enable_bias,
|
286 |
+
)
|
287 |
+
self.do_layer_norm_before = config.do_layer_norm_before
|
288 |
+
self.dropout = config.dropout
|
289 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
290 |
+
|
291 |
+
self.self_attn_layer_norm = nn.LayerNorm(
|
292 |
+
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
|
293 |
+
)
|
294 |
+
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
|
295 |
+
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
|
296 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
297 |
+
|
298 |
+
def forward(
|
299 |
+
self,
|
300 |
+
hidden_states: torch.Tensor,
|
301 |
+
attention_mask: Optional[torch.Tensor] = None,
|
302 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
303 |
+
output_attentions: Optional[bool] = False,
|
304 |
+
use_cache: Optional[bool] = False,
|
305 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
306 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
307 |
+
"""
|
308 |
+
Args:
|
309 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
310 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
311 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
312 |
+
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
|
313 |
+
`(encoder_attention_heads,)`.
|
314 |
+
output_attentions (`bool`, *optional*):
|
315 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
316 |
+
returned tensors for more detail.
|
317 |
+
use_cache (`bool`, *optional*):
|
318 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
319 |
+
(see `past_key_values`).
|
320 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
321 |
+
"""
|
322 |
+
|
323 |
+
residual = hidden_states
|
324 |
+
|
325 |
+
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
326 |
+
if self.do_layer_norm_before:
|
327 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
328 |
+
|
329 |
+
# Self Attention
|
330 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
331 |
+
hidden_states=hidden_states,
|
332 |
+
past_key_value=past_key_value,
|
333 |
+
attention_mask=attention_mask,
|
334 |
+
layer_head_mask=layer_head_mask,
|
335 |
+
output_attentions=output_attentions,
|
336 |
+
)
|
337 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
338 |
+
hidden_states = residual + hidden_states
|
339 |
+
|
340 |
+
# 350m applies layer norm AFTER attention
|
341 |
+
if not self.do_layer_norm_before:
|
342 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
343 |
+
|
344 |
+
# Fully Connected
|
345 |
+
hidden_states_shape = hidden_states.shape
|
346 |
+
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
|
347 |
+
residual = hidden_states
|
348 |
+
|
349 |
+
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
350 |
+
if self.do_layer_norm_before:
|
351 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
352 |
+
|
353 |
+
hidden_states = self.fc1(hidden_states)
|
354 |
+
hidden_states = self.activation_fn(hidden_states)
|
355 |
+
|
356 |
+
hidden_states = self.fc2(hidden_states)
|
357 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
358 |
+
|
359 |
+
hidden_states = (residual + hidden_states).view(hidden_states_shape)
|
360 |
+
|
361 |
+
# 350m applies layer norm AFTER attention
|
362 |
+
if not self.do_layer_norm_before:
|
363 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
364 |
+
|
365 |
+
outputs = (hidden_states,)
|
366 |
+
|
367 |
+
if output_attentions:
|
368 |
+
outputs += (self_attn_weights,)
|
369 |
+
|
370 |
+
if use_cache:
|
371 |
+
outputs += (present_key_value,)
|
372 |
+
|
373 |
+
return outputs
|
374 |
+
|
375 |
+
|
376 |
+
OPT_START_DOCSTRING = r"""
|
377 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
378 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
379 |
+
etc.)
|
380 |
+
|
381 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
382 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
383 |
+
and behavior.
|
384 |
+
|
385 |
+
Parameters:
|
386 |
+
config ([`OPTConfig`]):
|
387 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
388 |
+
load the weights associated with the model, only the configuration. Check out the
|
389 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
390 |
+
"""
|
391 |
+
|
392 |
+
|
393 |
+
@add_start_docstrings(
|
394 |
+
"The bare OPT Model outputting raw hidden-states without any specific head on top.",
|
395 |
+
OPT_START_DOCSTRING,
|
396 |
+
)
|
397 |
+
class OPTPreTrainedModel(PreTrainedModel):
|
398 |
+
config_class = OPTConfig
|
399 |
+
base_model_prefix = "model"
|
400 |
+
supports_gradient_checkpointing = True
|
401 |
+
_no_split_modules = ["OPTDecoderLayer"]
|
402 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
403 |
+
|
404 |
+
def _init_weights(self, module):
|
405 |
+
std = self.config.init_std
|
406 |
+
if isinstance(module, nn.Linear):
|
407 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
408 |
+
if module.bias is not None:
|
409 |
+
module.bias.data.zero_()
|
410 |
+
elif isinstance(module, nn.Embedding):
|
411 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
412 |
+
if module.padding_idx is not None:
|
413 |
+
module.weight.data[module.padding_idx].zero_()
|
414 |
+
|
415 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
416 |
+
if isinstance(module, (OPTDecoder)):
|
417 |
+
module.gradient_checkpointing = value
|
418 |
+
|
419 |
+
|
420 |
+
OPT_INPUTS_DOCSTRING = r"""
|
421 |
+
Args:
|
422 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
423 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
424 |
+
it.
|
425 |
+
|
426 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
427 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
428 |
+
|
429 |
+
[What are input IDs?](../glossary#input-ids)
|
430 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
431 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
432 |
+
|
433 |
+
- 1 for tokens that are **not masked**,
|
434 |
+
- 0 for tokens that are **masked**.
|
435 |
+
|
436 |
+
[What are attention masks?](../glossary#attention-mask)
|
437 |
+
|
438 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
439 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
440 |
+
|
441 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
442 |
+
`past_key_values`).
|
443 |
+
|
444 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
445 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
446 |
+
information on the default strategy.
|
447 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
448 |
+
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
|
449 |
+
|
450 |
+
- 1 indicates the head is **not masked**,
|
451 |
+
- 0 indicates the head is **masked**.
|
452 |
+
|
453 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
454 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
455 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
456 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
457 |
+
|
458 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
459 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
460 |
+
|
461 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
462 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
463 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
464 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
465 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
466 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
467 |
+
model's internal embedding lookup matrix.
|
468 |
+
use_cache (`bool`, *optional*):
|
469 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
470 |
+
`past_key_values`).
|
471 |
+
output_attentions (`bool`, *optional*):
|
472 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
473 |
+
tensors for more detail.
|
474 |
+
output_hidden_states (`bool`, *optional*):
|
475 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
476 |
+
more detail.
|
477 |
+
return_dict (`bool`, *optional*):
|
478 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
479 |
+
"""
|
480 |
+
|
481 |
+
|
482 |
+
class OPTDecoder(OPTPreTrainedModel):
|
483 |
+
"""
|
484 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
|
485 |
+
|
486 |
+
Args:
|
487 |
+
config: OPTConfig
|
488 |
+
"""
|
489 |
+
|
490 |
+
def __init__(self, config: OPTConfig):
|
491 |
+
super().__init__(config)
|
492 |
+
self.dropout = config.dropout
|
493 |
+
self.layerdrop = config.layerdrop
|
494 |
+
self.padding_idx = config.pad_token_id
|
495 |
+
self.max_target_positions = config.max_position_embeddings
|
496 |
+
self.vocab_size = config.vocab_size
|
497 |
+
|
498 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
|
499 |
+
self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
|
500 |
+
|
501 |
+
if config.word_embed_proj_dim != config.hidden_size:
|
502 |
+
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
|
503 |
+
else:
|
504 |
+
self.project_out = None
|
505 |
+
|
506 |
+
if config.word_embed_proj_dim != config.hidden_size:
|
507 |
+
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
|
508 |
+
else:
|
509 |
+
self.project_in = None
|
510 |
+
|
511 |
+
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
512 |
+
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
513 |
+
# see https://github.com/facebookresearch/metaseq/pull/164
|
514 |
+
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
515 |
+
self.final_layer_norm = nn.LayerNorm(
|
516 |
+
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
|
517 |
+
)
|
518 |
+
else:
|
519 |
+
self.final_layer_norm = None
|
520 |
+
|
521 |
+
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
522 |
+
|
523 |
+
self.gradient_checkpointing = False
|
524 |
+
# Initialize weights and apply final processing
|
525 |
+
self.post_init()
|
526 |
+
|
527 |
+
def get_input_embeddings(self):
|
528 |
+
return self.embed_tokens
|
529 |
+
|
530 |
+
def set_input_embeddings(self, value):
|
531 |
+
self.embed_tokens = value
|
532 |
+
|
533 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
534 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
535 |
+
# create causal mask
|
536 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
537 |
+
combined_attention_mask = None
|
538 |
+
if input_shape[-1] > 1:
|
539 |
+
combined_attention_mask = _make_causal_mask(
|
540 |
+
input_shape,
|
541 |
+
inputs_embeds.dtype,
|
542 |
+
device=inputs_embeds.device,
|
543 |
+
past_key_values_length=past_key_values_length,
|
544 |
+
)
|
545 |
+
|
546 |
+
if attention_mask is not None:
|
547 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
548 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
549 |
+
inputs_embeds.device
|
550 |
+
)
|
551 |
+
combined_attention_mask = (
|
552 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
553 |
+
)
|
554 |
+
|
555 |
+
return combined_attention_mask
|
556 |
+
|
557 |
+
def forward(
|
558 |
+
self,
|
559 |
+
input_ids: torch.LongTensor = None,
|
560 |
+
attention_mask: Optional[torch.Tensor] = None,
|
561 |
+
head_mask: Optional[torch.Tensor] = None,
|
562 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
563 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
564 |
+
use_cache: Optional[bool] = None,
|
565 |
+
output_attentions: Optional[bool] = None,
|
566 |
+
output_hidden_states: Optional[bool] = None,
|
567 |
+
return_dict: Optional[bool] = None,
|
568 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
569 |
+
r"""
|
570 |
+
Args:
|
571 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
572 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
573 |
+
provide it.
|
574 |
+
|
575 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
576 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
577 |
+
|
578 |
+
[What are input IDs?](../glossary#input-ids)
|
579 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
580 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
581 |
+
|
582 |
+
- 1 for tokens that are **not masked**,
|
583 |
+
- 0 for tokens that are **masked**.
|
584 |
+
|
585 |
+
[What are attention masks?](../glossary#attention-mask)
|
586 |
+
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
587 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
588 |
+
|
589 |
+
- 1 indicates the head is **not masked**,
|
590 |
+
- 0 indicates the head is **masked**.
|
591 |
+
|
592 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
593 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
594 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
595 |
+
|
596 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
597 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
598 |
+
|
599 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
600 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
601 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
602 |
+
|
603 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
604 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
605 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
606 |
+
than the model's internal embedding lookup matrix.
|
607 |
+
output_attentions (`bool`, *optional*):
|
608 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
609 |
+
returned tensors for more detail.
|
610 |
+
output_hidden_states (`bool`, *optional*):
|
611 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
612 |
+
for more detail.
|
613 |
+
return_dict (`bool`, *optional*):
|
614 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
615 |
+
"""
|
616 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
617 |
+
output_hidden_states = (
|
618 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
619 |
+
)
|
620 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
621 |
+
|
622 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
623 |
+
|
624 |
+
# retrieve input_ids and inputs_embeds
|
625 |
+
if input_ids is not None and inputs_embeds is not None:
|
626 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
627 |
+
elif input_ids is not None:
|
628 |
+
input_shape = input_ids.size()
|
629 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
630 |
+
elif inputs_embeds is not None:
|
631 |
+
input_shape = inputs_embeds.size()[:-1]
|
632 |
+
else:
|
633 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
634 |
+
|
635 |
+
if inputs_embeds is None:
|
636 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
637 |
+
|
638 |
+
batch_size, seq_length = input_shape
|
639 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
640 |
+
# required mask seq length can be calculated via length of past
|
641 |
+
mask_seq_length = past_key_values_length + seq_length
|
642 |
+
|
643 |
+
# embed positions
|
644 |
+
if attention_mask is None:
|
645 |
+
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
646 |
+
causal_attention_mask = self._prepare_decoder_attention_mask(
|
647 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
648 |
+
)
|
649 |
+
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
650 |
+
|
651 |
+
if self.project_in is not None:
|
652 |
+
inputs_embeds = self.project_in(inputs_embeds)
|
653 |
+
|
654 |
+
hidden_states = inputs_embeds + pos_embeds
|
655 |
+
|
656 |
+
if self.gradient_checkpointing and self.training:
|
657 |
+
if use_cache:
|
658 |
+
logger.warning_once(
|
659 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
660 |
+
)
|
661 |
+
use_cache = False
|
662 |
+
|
663 |
+
# decoder layers
|
664 |
+
all_hidden_states = () if output_hidden_states else None
|
665 |
+
all_self_attns = () if output_attentions else None
|
666 |
+
next_decoder_cache = () if use_cache else None
|
667 |
+
|
668 |
+
# check if head_mask has a correct number of layers specified if desired
|
669 |
+
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
|
670 |
+
if attn_mask is not None:
|
671 |
+
if attn_mask.size()[0] != (len(self.layers)):
|
672 |
+
raise ValueError(
|
673 |
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
674 |
+
f" {head_mask.size()[0]}."
|
675 |
+
)
|
676 |
+
|
677 |
+
for idx, decoder_layer in enumerate(self.layers):
|
678 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
679 |
+
if output_hidden_states:
|
680 |
+
all_hidden_states += (hidden_states,)
|
681 |
+
|
682 |
+
dropout_probability = random.uniform(0, 1)
|
683 |
+
if self.training and (dropout_probability < self.layerdrop):
|
684 |
+
continue
|
685 |
+
|
686 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
687 |
+
|
688 |
+
if self.gradient_checkpointing and self.training:
|
689 |
+
|
690 |
+
def create_custom_forward(module):
|
691 |
+
def custom_forward(*inputs):
|
692 |
+
# None for past_key_value
|
693 |
+
return module(*inputs, output_attentions, None)
|
694 |
+
|
695 |
+
return custom_forward
|
696 |
+
|
697 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
698 |
+
create_custom_forward(decoder_layer),
|
699 |
+
hidden_states,
|
700 |
+
causal_attention_mask,
|
701 |
+
head_mask[idx] if head_mask is not None else None,
|
702 |
+
None,
|
703 |
+
)
|
704 |
+
else:
|
705 |
+
layer_outputs = decoder_layer(
|
706 |
+
hidden_states,
|
707 |
+
attention_mask=causal_attention_mask,
|
708 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
709 |
+
past_key_value=past_key_value,
|
710 |
+
output_attentions=output_attentions,
|
711 |
+
use_cache=use_cache,
|
712 |
+
)
|
713 |
+
|
714 |
+
hidden_states = layer_outputs[0]
|
715 |
+
|
716 |
+
if use_cache:
|
717 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
718 |
+
|
719 |
+
if output_attentions:
|
720 |
+
all_self_attns += (layer_outputs[1],)
|
721 |
+
|
722 |
+
if self.final_layer_norm is not None:
|
723 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
724 |
+
|
725 |
+
if self.project_out is not None:
|
726 |
+
hidden_states = self.project_out(hidden_states)
|
727 |
+
|
728 |
+
# add hidden states from the last decoder layer
|
729 |
+
if output_hidden_states:
|
730 |
+
all_hidden_states += (hidden_states,)
|
731 |
+
|
732 |
+
next_cache = next_decoder_cache if use_cache else None
|
733 |
+
if not return_dict:
|
734 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
735 |
+
return BaseModelOutputWithPast(
|
736 |
+
last_hidden_state=hidden_states,
|
737 |
+
past_key_values=next_cache,
|
738 |
+
hidden_states=all_hidden_states,
|
739 |
+
attentions=all_self_attns,
|
740 |
+
)
|
741 |
+
|
742 |
+
|
743 |
+
@add_start_docstrings(
|
744 |
+
"The bare OPT Model outputting raw hidden-states without any specific head on top.",
|
745 |
+
OPT_START_DOCSTRING,
|
746 |
+
)
|
747 |
+
class OPTModel(OPTPreTrainedModel):
|
748 |
+
def __init__(self, config: OPTConfig):
|
749 |
+
super().__init__(config)
|
750 |
+
self.decoder = OPTDecoder(config)
|
751 |
+
# Initialize weights and apply final processing
|
752 |
+
self.post_init()
|
753 |
+
|
754 |
+
def get_input_embeddings(self):
|
755 |
+
return self.decoder.embed_tokens
|
756 |
+
|
757 |
+
def set_input_embeddings(self, value):
|
758 |
+
self.decoder.embed_tokens = value
|
759 |
+
|
760 |
+
def get_decoder(self):
|
761 |
+
return self.decoder
|
762 |
+
|
763 |
+
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
|
764 |
+
@add_code_sample_docstrings(
|
765 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
766 |
+
output_type=BaseModelOutputWithPast,
|
767 |
+
config_class=_CONFIG_FOR_DOC,
|
768 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
769 |
+
)
|
770 |
+
def forward(
|
771 |
+
self,
|
772 |
+
input_ids: torch.LongTensor = None,
|
773 |
+
attention_mask: Optional[torch.Tensor] = None,
|
774 |
+
head_mask: Optional[torch.Tensor] = None,
|
775 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
776 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
777 |
+
use_cache: Optional[bool] = None,
|
778 |
+
output_attentions: Optional[bool] = None,
|
779 |
+
output_hidden_states: Optional[bool] = None,
|
780 |
+
return_dict: Optional[bool] = None,
|
781 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
782 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
783 |
+
output_hidden_states = (
|
784 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
785 |
+
)
|
786 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
787 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
788 |
+
|
789 |
+
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
790 |
+
decoder_outputs = self.decoder(
|
791 |
+
input_ids=input_ids,
|
792 |
+
attention_mask=attention_mask,
|
793 |
+
head_mask=head_mask,
|
794 |
+
past_key_values=past_key_values,
|
795 |
+
inputs_embeds=inputs_embeds,
|
796 |
+
use_cache=use_cache,
|
797 |
+
output_attentions=output_attentions,
|
798 |
+
output_hidden_states=output_hidden_states,
|
799 |
+
return_dict=return_dict,
|
800 |
+
)
|
801 |
+
|
802 |
+
if not return_dict:
|
803 |
+
return decoder_outputs
|
804 |
+
|
805 |
+
return BaseModelOutputWithPast(
|
806 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
807 |
+
past_key_values=decoder_outputs.past_key_values,
|
808 |
+
hidden_states=decoder_outputs.hidden_states,
|
809 |
+
attentions=decoder_outputs.attentions,
|
810 |
+
)
|
811 |
+
|
812 |
+
|
813 |
+
class OPTForCausalLM(OPTPreTrainedModel):
|
814 |
+
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
815 |
+
|
816 |
+
def __init__(self, config):
|
817 |
+
super().__init__(config)
|
818 |
+
self.model = OPTModel(config)
|
819 |
+
|
820 |
+
# the lm_head weight is automatically tied to the embed tokens weight
|
821 |
+
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
822 |
+
|
823 |
+
# Initialize weights and apply final processing
|
824 |
+
self.post_init()
|
825 |
+
|
826 |
+
def get_input_embeddings(self):
|
827 |
+
return self.model.decoder.embed_tokens
|
828 |
+
|
829 |
+
def set_input_embeddings(self, value):
|
830 |
+
self.model.decoder.embed_tokens = value
|
831 |
+
|
832 |
+
def get_output_embeddings(self):
|
833 |
+
return self.lm_head
|
834 |
+
|
835 |
+
def set_output_embeddings(self, new_embeddings):
|
836 |
+
self.lm_head = new_embeddings
|
837 |
+
|
838 |
+
def set_decoder(self, decoder):
|
839 |
+
self.model.decoder = decoder
|
840 |
+
|
841 |
+
def get_decoder(self):
|
842 |
+
return self.model.decoder
|
843 |
+
|
844 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
845 |
+
def forward(
|
846 |
+
self,
|
847 |
+
input_ids: torch.LongTensor = None,
|
848 |
+
attention_mask: Optional[torch.Tensor] = None,
|
849 |
+
head_mask: Optional[torch.Tensor] = None,
|
850 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
851 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
852 |
+
labels: Optional[torch.LongTensor] = None,
|
853 |
+
use_cache: Optional[bool] = None,
|
854 |
+
output_attentions: Optional[bool] = None,
|
855 |
+
output_hidden_states: Optional[bool] = None,
|
856 |
+
return_dict: Optional[bool] = None,
|
857 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
858 |
+
r"""
|
859 |
+
Args:
|
860 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
861 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
862 |
+
provide it.
|
863 |
+
|
864 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
865 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
866 |
+
|
867 |
+
[What are input IDs?](../glossary#input-ids)
|
868 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
869 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
870 |
+
|
871 |
+
- 1 for tokens that are **not masked**,
|
872 |
+
- 0 for tokens that are **masked**.
|
873 |
+
|
874 |
+
[What are attention masks?](../glossary#attention-mask)
|
875 |
+
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
876 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
877 |
+
|
878 |
+
- 1 indicates the head is **not masked**,
|
879 |
+
- 0 indicates the head is **masked**.
|
880 |
+
|
881 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
882 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
883 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
884 |
+
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
|
885 |
+
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
|
886 |
+
|
887 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
888 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
889 |
+
|
890 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
891 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
892 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
893 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
894 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
895 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
896 |
+
than the model's internal embedding lookup matrix.
|
897 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
898 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
899 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
900 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
901 |
+
use_cache (`bool`, *optional*):
|
902 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
903 |
+
(see `past_key_values`).
|
904 |
+
output_attentions (`bool`, *optional*):
|
905 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
906 |
+
returned tensors for more detail.
|
907 |
+
output_hidden_states (`bool`, *optional*):
|
908 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
909 |
+
for more detail.
|
910 |
+
return_dict (`bool`, *optional*):
|
911 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
912 |
+
|
913 |
+
Returns:
|
914 |
+
|
915 |
+
Example:
|
916 |
+
|
917 |
+
```python
|
918 |
+
>>> from transformers import AutoTokenizer, OPTForCausalLM
|
919 |
+
|
920 |
+
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
|
921 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
922 |
+
|
923 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
924 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
925 |
+
|
926 |
+
>>> # Generate
|
927 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
928 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
929 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
930 |
+
```"""
|
931 |
+
|
932 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
933 |
+
output_hidden_states = (
|
934 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
935 |
+
)
|
936 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
937 |
+
|
938 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
939 |
+
outputs = self.model.decoder(
|
940 |
+
input_ids=input_ids,
|
941 |
+
attention_mask=attention_mask,
|
942 |
+
head_mask=head_mask,
|
943 |
+
past_key_values=past_key_values,
|
944 |
+
inputs_embeds=inputs_embeds,
|
945 |
+
use_cache=use_cache,
|
946 |
+
output_attentions=output_attentions,
|
947 |
+
output_hidden_states=output_hidden_states,
|
948 |
+
return_dict=return_dict,
|
949 |
+
)
|
950 |
+
|
951 |
+
logits = self.lm_head(outputs[0]).contiguous()
|
952 |
+
|
953 |
+
loss = None
|
954 |
+
if labels is not None:
|
955 |
+
# Shift so that tokens < n predict n
|
956 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
957 |
+
shift_labels = labels[..., 1:].contiguous()
|
958 |
+
# Flatten the tokens
|
959 |
+
loss_fct = CrossEntropyLoss()
|
960 |
+
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
961 |
+
|
962 |
+
if not return_dict:
|
963 |
+
output = (logits,) + outputs[1:]
|
964 |
+
return (loss,) + output if loss is not None else output
|
965 |
+
|
966 |
+
return CausalLMOutputWithPast(
|
967 |
+
loss=loss,
|
968 |
+
logits=logits,
|
969 |
+
past_key_values=outputs.past_key_values,
|
970 |
+
hidden_states=outputs.hidden_states,
|
971 |
+
attentions=outputs.attentions,
|
972 |
+
)
|
973 |
+
|
974 |
+
def prepare_inputs_for_generation(
|
975 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
976 |
+
):
|
977 |
+
if past_key_values:
|
978 |
+
input_ids = input_ids[:, -1:]
|
979 |
+
|
980 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
981 |
+
if inputs_embeds is not None and past_key_values is None:
|
982 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
983 |
+
else:
|
984 |
+
model_inputs = {"input_ids": input_ids}
|
985 |
+
|
986 |
+
model_inputs.update(
|
987 |
+
{
|
988 |
+
"past_key_values": past_key_values,
|
989 |
+
"use_cache": kwargs.get("use_cache"),
|
990 |
+
"attention_mask": attention_mask,
|
991 |
+
}
|
992 |
+
)
|
993 |
+
return model_inputs
|
994 |
+
|
995 |
+
@staticmethod
|
996 |
+
def _reorder_cache(past_key_values, beam_idx):
|
997 |
+
reordered_past = ()
|
998 |
+
for layer_past in past_key_values:
|
999 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
1000 |
+
return reordered_past
|
1001 |
+
|
1002 |
+
|
1003 |
+
@add_start_docstrings(
|
1004 |
+
"""
|
1005 |
+
The OPT Model transformer with a sequence classification head on top (linear layer).
|
1006 |
+
|
1007 |
+
[`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1008 |
+
(e.g. GPT-2) do.
|
1009 |
+
|
1010 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1011 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1012 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1013 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1014 |
+
each row of the batch).
|
1015 |
+
""",
|
1016 |
+
OPT_START_DOCSTRING,
|
1017 |
+
)
|
1018 |
+
class OPTForSequenceClassification(OPTPreTrainedModel):
|
1019 |
+
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
1020 |
+
|
1021 |
+
def __init__(self, config: OPTConfig):
|
1022 |
+
super().__init__(config)
|
1023 |
+
self.num_labels = config.num_labels
|
1024 |
+
self.model = OPTModel(config)
|
1025 |
+
self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
|
1026 |
+
|
1027 |
+
# Initialize weights and apply final processing
|
1028 |
+
self.post_init()
|
1029 |
+
|
1030 |
+
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
|
1031 |
+
@add_code_sample_docstrings(
|
1032 |
+
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
|
1033 |
+
output_type=SequenceClassifierOutputWithPast,
|
1034 |
+
config_class=_CONFIG_FOR_DOC,
|
1035 |
+
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
1036 |
+
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
1037 |
+
)
|
1038 |
+
def forward(
|
1039 |
+
self,
|
1040 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1041 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1042 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1043 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1044 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1045 |
+
labels: Optional[torch.LongTensor] = None,
|
1046 |
+
use_cache: Optional[bool] = None,
|
1047 |
+
output_attentions: Optional[bool] = None,
|
1048 |
+
output_hidden_states: Optional[bool] = None,
|
1049 |
+
return_dict: Optional[bool] = None,
|
1050 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1051 |
+
r"""
|
1052 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1053 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1054 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1055 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1056 |
+
"""
|
1057 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1058 |
+
|
1059 |
+
transformer_outputs = self.model(
|
1060 |
+
input_ids,
|
1061 |
+
past_key_values=past_key_values,
|
1062 |
+
attention_mask=attention_mask, # shape = [B, max_len]
|
1063 |
+
head_mask=head_mask,
|
1064 |
+
inputs_embeds=inputs_embeds,
|
1065 |
+
use_cache=use_cache,
|
1066 |
+
output_attentions=output_attentions,
|
1067 |
+
output_hidden_states=output_hidden_states,
|
1068 |
+
return_dict=return_dict,
|
1069 |
+
)
|
1070 |
+
hidden_states = transformer_outputs[0]
|
1071 |
+
logits = self.score(hidden_states) # shape = [B, max_len, D]
|
1072 |
+
|
1073 |
+
denom = torch.sum(attention_mask, -1, keepdim=True) # shape = [B, 1]
|
1074 |
+
pooled_logits = torch.sum(logits * attention_mask.unsqueeze(-1), dim=1) # shape = [B, D]
|
1075 |
+
pooled_logits = pooled_logits / denom
|
1076 |
+
|
1077 |
+
loss = None
|
1078 |
+
return SequenceClassifierOutputWithPast(
|
1079 |
+
loss=loss,
|
1080 |
+
logits=pooled_logits,
|
1081 |
+
past_key_values=transformer_outputs.past_key_values,
|
1082 |
+
hidden_states=transformer_outputs.hidden_states,
|
1083 |
+
attentions=transformer_outputs.attentions,
|
1084 |
+
)
|
1085 |
+
|
1086 |
+
def get_input_embeddings(self):
|
1087 |
+
return self.model.decoder.embed_tokens
|
1088 |
+
|
1089 |
+
def set_input_embeddings(self, value):
|
1090 |
+
self.model.decoder.embed_tokens = value
|
1091 |
+
|
1092 |
+
|
1093 |
+
@add_start_docstrings(
|
1094 |
+
"""
|
1095 |
+
The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD
|
1096 |
+
(a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
1097 |
+
""",
|
1098 |
+
OPT_START_DOCSTRING,
|
1099 |
+
)
|
1100 |
+
class OPTForQuestionAnswering(OPTPreTrainedModel):
|
1101 |
+
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
1102 |
+
|
1103 |
+
def __init__(self, config: OPTConfig):
|
1104 |
+
super().__init__(config)
|
1105 |
+
self.model = OPTModel(config)
|
1106 |
+
self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2)
|
1107 |
+
|
1108 |
+
# Initialize weights and apply final processing
|
1109 |
+
self.post_init()
|
1110 |
+
|
1111 |
+
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
|
1112 |
+
@replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
|
1113 |
+
def forward(
|
1114 |
+
self,
|
1115 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1116 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1117 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1118 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1119 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1120 |
+
start_positions: Optional[torch.LongTensor] = None,
|
1121 |
+
end_positions: Optional[torch.LongTensor] = None,
|
1122 |
+
use_cache: Optional[bool] = None,
|
1123 |
+
output_attentions: Optional[bool] = None,
|
1124 |
+
output_hidden_states: Optional[bool] = None,
|
1125 |
+
return_dict: Optional[bool] = None,
|
1126 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
1127 |
+
r"""
|
1128 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1129 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1130 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1131 |
+
are not taken into account for computing the loss.
|
1132 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1133 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1134 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1135 |
+
are not taken into account for computing the loss.
|
1136 |
+
|
1137 |
+
Returns:
|
1138 |
+
|
1139 |
+
Example:
|
1140 |
+
|
1141 |
+
```python
|
1142 |
+
>>> from transformers import AutoTokenizer, OPTForQuestionAnswering
|
1143 |
+
>>> import torch
|
1144 |
+
|
1145 |
+
>>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
|
1146 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
1147 |
+
|
1148 |
+
>>> # note: we are loading a OPTForQuestionAnswering from the hub here,
|
1149 |
+
>>> # so the head will be randomly initialized, hence the predictions will be random
|
1150 |
+
>>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
|
1151 |
+
|
1152 |
+
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
1153 |
+
|
1154 |
+
>>> inputs = tokenizer(question, text, return_tensors="pt")
|
1155 |
+
>>> with torch.no_grad():
|
1156 |
+
... outputs = model(**inputs)
|
1157 |
+
|
1158 |
+
>>> answer_start_index = outputs.start_logits.argmax()
|
1159 |
+
>>> answer_end_index = outputs.end_logits.argmax()
|
1160 |
+
|
1161 |
+
>>> answer_offset = len(tokenizer(question)[0])
|
1162 |
+
|
1163 |
+
>>> predict_answer_tokens = inputs.input_ids[
|
1164 |
+
... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1
|
1165 |
+
... ]
|
1166 |
+
>>> predicted = tokenizer.decode(predict_answer_tokens)
|
1167 |
+
>>> predicted
|
1168 |
+
' a nice puppet'
|
1169 |
+
```"""
|
1170 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1171 |
+
|
1172 |
+
transformer_outputs = self.model(
|
1173 |
+
input_ids,
|
1174 |
+
past_key_values=past_key_values,
|
1175 |
+
attention_mask=attention_mask,
|
1176 |
+
head_mask=head_mask,
|
1177 |
+
inputs_embeds=inputs_embeds,
|
1178 |
+
use_cache=use_cache,
|
1179 |
+
output_attentions=output_attentions,
|
1180 |
+
output_hidden_states=output_hidden_states,
|
1181 |
+
return_dict=return_dict,
|
1182 |
+
)
|
1183 |
+
hidden_states = transformer_outputs[0]
|
1184 |
+
|
1185 |
+
logits = self.qa_outputs(hidden_states)
|
1186 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1187 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
1188 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
1189 |
+
|
1190 |
+
total_loss = None
|
1191 |
+
if start_positions is not None and end_positions is not None:
|
1192 |
+
# If we are on multi-GPU, split add a dimension
|
1193 |
+
if len(start_positions.size()) > 1:
|
1194 |
+
start_positions = start_positions.squeeze(-1)
|
1195 |
+
if len(end_positions.size()) > 1:
|
1196 |
+
end_positions = end_positions.squeeze(-1)
|
1197 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1198 |
+
ignored_index = start_logits.size(1)
|
1199 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
1200 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
1201 |
+
|
1202 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1203 |
+
start_loss = loss_fct(start_logits, start_positions)
|
1204 |
+
end_loss = loss_fct(end_logits, end_positions)
|
1205 |
+
total_loss = (start_loss + end_loss) / 2
|
1206 |
+
|
1207 |
+
if not return_dict:
|
1208 |
+
output = (start_logits, end_logits) + transformer_outputs[2:]
|
1209 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
1210 |
+
|
1211 |
+
return QuestionAnsweringModelOutput(
|
1212 |
+
loss=total_loss,
|
1213 |
+
start_logits=start_logits,
|
1214 |
+
end_logits=end_logits,
|
1215 |
+
hidden_states=transformer_outputs.hidden_states,
|
1216 |
+
attentions=transformer_outputs.attentions,
|
1217 |
+
)
|
1218 |
+
|
1219 |
+
def get_input_embeddings(self):
|
1220 |
+
return self.model.decoder.embed_tokens
|
1221 |
+
|
1222 |
+
def set_input_embeddings(self, value):
|
1223 |
+
self.model.decoder.embed_tokens = value
|
model/opt_flash_attention.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
import transformers
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
11 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
12 |
+
from transformers.models.opt.modeling_opt import _make_causal_mask, _expand_mask
|
13 |
+
|
14 |
+
|
15 |
+
def _prepare_decoder_attention_mask_original(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
16 |
+
# create causal mask
|
17 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
18 |
+
combined_attention_mask = None
|
19 |
+
if input_shape[-1] > 1:
|
20 |
+
combined_attention_mask = _make_causal_mask(
|
21 |
+
input_shape,
|
22 |
+
inputs_embeds.dtype,
|
23 |
+
device=inputs_embeds.device,
|
24 |
+
past_key_values_length=past_key_values_length,
|
25 |
+
)
|
26 |
+
|
27 |
+
if attention_mask is not None:
|
28 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
29 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
30 |
+
inputs_embeds.device
|
31 |
+
)
|
32 |
+
combined_attention_mask = (
|
33 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
34 |
+
)
|
35 |
+
|
36 |
+
return combined_attention_mask
|
37 |
+
|
38 |
+
def forward_original(
|
39 |
+
self,
|
40 |
+
hidden_states: torch.Tensor,
|
41 |
+
key_value_states: Optional[torch.Tensor] = None,
|
42 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
43 |
+
attention_mask: Optional[torch.Tensor] = None,
|
44 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
45 |
+
output_attentions: bool = False,
|
46 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
47 |
+
"""Input shape: Batch x Time x Channel"""
|
48 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
49 |
+
# for the decoder
|
50 |
+
is_cross_attention = key_value_states is not None
|
51 |
+
|
52 |
+
bsz, tgt_len, _ = hidden_states.size()
|
53 |
+
|
54 |
+
# get query proj
|
55 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
56 |
+
# get key, value proj
|
57 |
+
if is_cross_attention and past_key_value is not None:
|
58 |
+
# reuse k,v, cross_attentions
|
59 |
+
key_states = past_key_value[0]
|
60 |
+
value_states = past_key_value[1]
|
61 |
+
elif is_cross_attention:
|
62 |
+
# cross_attentions
|
63 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
64 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
65 |
+
elif past_key_value is not None:
|
66 |
+
# reuse k, v, self_attention
|
67 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
68 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
69 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
70 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
71 |
+
else:
|
72 |
+
# self_attention
|
73 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
74 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
75 |
+
|
76 |
+
if self.is_decoder:
|
77 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
78 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
79 |
+
# key/value_states (first "if" case)
|
80 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
81 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
82 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
83 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
84 |
+
past_key_value = (key_states, value_states)
|
85 |
+
|
86 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
87 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
88 |
+
key_states = key_states.view(*proj_shape)
|
89 |
+
value_states = value_states.view(*proj_shape)
|
90 |
+
|
91 |
+
src_len = key_states.size(1)
|
92 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
93 |
+
|
94 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
95 |
+
raise ValueError(
|
96 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
97 |
+
f" {attn_weights.size()}"
|
98 |
+
)
|
99 |
+
|
100 |
+
if attention_mask is not None:
|
101 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
102 |
+
raise ValueError(
|
103 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
104 |
+
)
|
105 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
106 |
+
attn_weights = torch.max(
|
107 |
+
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
|
108 |
+
)
|
109 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
110 |
+
|
111 |
+
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
112 |
+
if attn_weights.dtype == torch.float16:
|
113 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
|
114 |
+
else:
|
115 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
116 |
+
|
117 |
+
if layer_head_mask is not None:
|
118 |
+
if layer_head_mask.size() != (self.num_heads,):
|
119 |
+
raise ValueError(
|
120 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
121 |
+
f" {layer_head_mask.size()}"
|
122 |
+
)
|
123 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
124 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
125 |
+
|
126 |
+
if output_attentions:
|
127 |
+
# this operation is a bit awkward, but it's required to
|
128 |
+
# make sure that attn_weights keeps its gradient.
|
129 |
+
# In order to do so, attn_weights have to be reshaped
|
130 |
+
# twice and have to be reused in the following
|
131 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
132 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
133 |
+
else:
|
134 |
+
attn_weights_reshaped = None
|
135 |
+
|
136 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
137 |
+
|
138 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
139 |
+
|
140 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
141 |
+
raise ValueError(
|
142 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
143 |
+
f" {attn_output.size()}"
|
144 |
+
)
|
145 |
+
|
146 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
147 |
+
attn_output = attn_output.transpose(1, 2)
|
148 |
+
|
149 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
150 |
+
# partitioned aross GPUs when using tensor-parallelism.
|
151 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
152 |
+
|
153 |
+
attn_output = self.out_proj(attn_output)
|
154 |
+
|
155 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
156 |
+
|
157 |
+
|
158 |
+
def forward(
|
159 |
+
self,
|
160 |
+
hidden_states: torch.Tensor,
|
161 |
+
key_value_states: Optional[torch.Tensor] = None,
|
162 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
163 |
+
attention_mask: Optional[torch.Tensor] = None,
|
164 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
165 |
+
output_attentions: bool = False,
|
166 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
167 |
+
"""Input shape: Batch x Time x Channel"""
|
168 |
+
|
169 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
170 |
+
# for the decoder
|
171 |
+
is_cross_attention = key_value_states is not None
|
172 |
+
assert not is_cross_attention, "Cross attention is not supported for flash attention"
|
173 |
+
assert past_key_value is None, "past_key_value is not None is not supported for flash attention"
|
174 |
+
assert not output_attentions, "output_attentions is not supported for flash attention"
|
175 |
+
|
176 |
+
bsz, tgt_len, _ = hidden_states.size()
|
177 |
+
|
178 |
+
# get query proj
|
179 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
180 |
+
# get key, value proj
|
181 |
+
|
182 |
+
if past_key_value is not None:
|
183 |
+
# reuse k, v, self_attention
|
184 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
185 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
186 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
187 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
188 |
+
else:
|
189 |
+
# self_attention
|
190 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
191 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
192 |
+
|
193 |
+
if self.is_decoder:
|
194 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
195 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
196 |
+
# key/value_states (first "if" case)
|
197 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
198 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
199 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
200 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
201 |
+
past_key_value = (key_states, value_states)
|
202 |
+
|
203 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
204 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
205 |
+
key_states = key_states.view(*proj_shape)
|
206 |
+
value_states = value_states.view(*proj_shape)
|
207 |
+
|
208 |
+
## for flash attention
|
209 |
+
flash_shape = (bsz, self.num_heads, tgt_len, self.head_dim)
|
210 |
+
query_states = query_states.view(*flash_shape)
|
211 |
+
key_states = key_states.view(*flash_shape)
|
212 |
+
value_states = value_states.view(*flash_shape)
|
213 |
+
qkv = torch.stack([query_states, key_states, value_states], dim=2) # shape = [bsz, num_heads, 3, tgt_len, head_dim]
|
214 |
+
qkv = qkv.transpose(1, 3) # [bsz, tgt_len, 3, num_heads, head_dim]
|
215 |
+
|
216 |
+
key_padding_mask = attention_mask
|
217 |
+
|
218 |
+
|
219 |
+
assert key_padding_mask is not None
|
220 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
221 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
222 |
+
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=self.num_heads)
|
223 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
224 |
+
x_unpad, cu_seqlens, max_s, self.dropout if self.training else 0.0,
|
225 |
+
softmax_scale=1, causal=True, return_attn_probs=False
|
226 |
+
)
|
227 |
+
|
228 |
+
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
|
229 |
+
indices, bsz, tgt_len),
|
230 |
+
'b s (h d) -> b s h d', h=self.num_heads)
|
231 |
+
|
232 |
+
attn_output = self.out_proj(rearrange(output, "b s h d -> b s (h d)"))
|
233 |
+
return attn_output, None, past_key_value
|
234 |
+
|
235 |
+
|
236 |
+
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
237 |
+
# requires the attention mask to be the same as the key_padding_mask
|
238 |
+
def _prepare_decoder_attention_mask(
|
239 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
240 |
+
):
|
241 |
+
# [bsz, seq_len]
|
242 |
+
return attention_mask
|
243 |
+
|
244 |
+
|
245 |
+
def replace_opt_attn_with_flash_attn():
|
246 |
+
cuda_major, cuda_minor = torch.cuda.get_device_capability()
|
247 |
+
if cuda_major < 8:
|
248 |
+
logging.warning(
|
249 |
+
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
|
250 |
+
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
|
251 |
+
)
|
252 |
+
transformers.models.opt.modeling_opt.OPTDecoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
|
253 |
+
transformers.models.opt.modeling_opt.OPTAttention.forward = forward
|
254 |
+
|
255 |
+
def replace_opt_attn_with_original_attn():
|
256 |
+
transformers.models.opt.modeling_opt.OPTDecoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask_original
|
257 |
+
transformers.models.opt.modeling_opt.OPTAttention.forward = forward_original
|
258 |
+
|
259 |
+
if __name__ == '__main__':
|
260 |
+
## generate tests to verify the equivalence between forward_original and forward
|
261 |
+
import torch.nn as nn
|
262 |
+
import math
|
263 |
+
class FakeNN(nn.Module):
|
264 |
+
def __init__(self, ):
|
265 |
+
super().__init__()
|
266 |
+
self.scaling = 1 / math.sqrt(2048)
|
267 |
+
if False:
|
268 |
+
self.q_proj = nn.Linear(2048, 2048)
|
269 |
+
self.k_proj = nn.Linear(2048, 2048)
|
270 |
+
self.v_proj = nn.Linear(2048, 2048)
|
271 |
+
self.out_proj = nn.Linear(2048, 2048)
|
272 |
+
else:
|
273 |
+
self.q_proj = nn.Identity()
|
274 |
+
self.k_proj = nn.Identity()
|
275 |
+
self.v_proj = nn.Identity()
|
276 |
+
self.out_proj = nn.Identity()
|
277 |
+
|
278 |
+
self.is_decoder = True
|
279 |
+
self.num_heads = 2
|
280 |
+
self.head_dim = 128
|
281 |
+
self.embed_dim = 256
|
282 |
+
self.dropout = 0
|
283 |
+
|
284 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
285 |
+
# create causal mask
|
286 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
287 |
+
combined_attention_mask = None
|
288 |
+
if input_shape[-1] > 1:
|
289 |
+
combined_attention_mask = _make_causal_mask(
|
290 |
+
input_shape,
|
291 |
+
inputs_embeds.dtype,
|
292 |
+
device=inputs_embeds.device,
|
293 |
+
past_key_values_length=past_key_values_length,
|
294 |
+
)
|
295 |
+
|
296 |
+
if attention_mask is not None:
|
297 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
298 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
299 |
+
inputs_embeds.device
|
300 |
+
)
|
301 |
+
combined_attention_mask = (
|
302 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
303 |
+
)
|
304 |
+
|
305 |
+
return combined_attention_mask
|
306 |
+
|
307 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
308 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
309 |
+
|
310 |
+
fakenn = FakeNN().to(torch.bfloat16).to('cuda:0')
|
311 |
+
|
312 |
+
t_len = 3
|
313 |
+
fake_input = torch.randn(2, t_len, fakenn.embed_dim).to(torch.bfloat16).to('cuda:0')
|
314 |
+
if False:
|
315 |
+
fake_lens = torch.randint(0, t_len, (2,)).to('cuda:0')
|
316 |
+
fake_lens = torch.LongTensor([3, 2]).to('cuda:0')
|
317 |
+
# fake_lens = torch.ones((2,)).to('cuda:0') * 3
|
318 |
+
fake_mask = torch.arange(t_len).unsqueeze(0).to('cuda:0') < fake_lens.unsqueeze(1)
|
319 |
+
else:
|
320 |
+
fake_mask = torch.randint(0, t_len, (2, t_len)).bool().to('cuda:0')
|
321 |
+
|
322 |
+
fake_mask2 = fakenn._prepare_decoder_attention_mask(fake_mask, (2,t_len), fake_input, 0)
|
323 |
+
attn_output0, _, _ = forward_original(fakenn, fake_input, None, None, fake_mask2, None, False)
|
324 |
+
attn_output1, _, _ = forward(fakenn, fake_input, None, None, fake_mask, None, False) # shape = [2, 3, 256]
|
325 |
+
attn_output0 = attn_output0 * fake_mask.unsqueeze(-1)
|
326 |
+
|
327 |
+
print(torch.isclose(attn_output0, attn_output1).all())
|
328 |
+
print(attn_output0.shape, attn_output1.shape)
|
329 |
+
difference = (attn_output0- attn_output1).abs()
|
330 |
+
print(difference)
|
331 |
+
print(difference.sum())
|
read_results/baselines.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import *
|
2 |
+
import torch
|
3 |
+
from rxnfp.transformer_fingerprints import (
|
4 |
+
RXNBERTFingerprintGenerator, get_default_model_and_tokenizer, generate_fingerprints
|
5 |
+
)
|
6 |
+
|
7 |
+
class Reaction_model:
|
8 |
+
def __init__(self, train_list, test_list):
|
9 |
+
self.train_list = train_list
|
10 |
+
self.test_list = test_list
|
11 |
+
|
12 |
+
model, tokenizer = get_default_model_and_tokenizer()
|
13 |
+
self.rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer)
|
14 |
+
|
15 |
+
@time_it
|
16 |
+
def generate_random(self):
|
17 |
+
pred = random.sample(self.train_list, k=len(self.test_list))
|
18 |
+
pred = [i['actions'] for i in pred]
|
19 |
+
return pred
|
20 |
+
|
21 |
+
@time_it
|
22 |
+
def generate_random_compatible_old(self):
|
23 |
+
pred_list = []
|
24 |
+
len_id_map = defaultdict(list)
|
25 |
+
for train_rxn in self.train_list:
|
26 |
+
len_id_map[len(train_rxn['extracted_molecules'])-1].append(train_rxn['index'])
|
27 |
+
|
28 |
+
keys = sorted(k for k in len_id_map.keys())
|
29 |
+
accumulated_counts = {}
|
30 |
+
count = 0
|
31 |
+
for key in keys:
|
32 |
+
count += len(len_id_map[key])
|
33 |
+
accumulated_counts[key] = count
|
34 |
+
|
35 |
+
for rxn in self.test_list:
|
36 |
+
test_token_num = len(rxn['extracted_molecules'])-1
|
37 |
+
idx = random.randint(0, accumulated_counts[test_token_num] - 1)
|
38 |
+
for key in keys:
|
39 |
+
if idx < len(len_id_map[key]):
|
40 |
+
pred_list.append(self.train_list[len_id_map[key][idx]]['actions'])
|
41 |
+
break
|
42 |
+
else:
|
43 |
+
idx -= len(len_id_map[key])
|
44 |
+
return pred_list
|
45 |
+
|
46 |
+
@time_it
|
47 |
+
def generate_random_compatible(self):
|
48 |
+
pred_list = []
|
49 |
+
len_id_map = defaultdict(list)
|
50 |
+
for train_rxn in self.train_list:
|
51 |
+
len_id_map[len(train_rxn['extracted_molecules'])-1].append(train_rxn['index'])
|
52 |
+
|
53 |
+
for rxn in self.test_list:
|
54 |
+
mole_num = len(rxn['extracted_molecules'])-1
|
55 |
+
pred_list.append(self.train_list[random.choice(len_id_map[mole_num])]['actions'])
|
56 |
+
return pred_list
|
57 |
+
|
58 |
+
@time_it
|
59 |
+
def generate_nn(self, batch_size=2048):
|
60 |
+
train_rxns = [f"{'.'.join(rxn['REACTANT'])}>>{rxn['PRODUCT'][0]}" for rxn in self.train_list]
|
61 |
+
test_rxns = [f"{'.'.join(rxn['REACTANT'])}>>{rxn['PRODUCT'][0]}" for rxn in self.test_list]
|
62 |
+
|
63 |
+
train_rxns_batches = [train_rxns[i:i+batch_size] for i in range(0, len(train_rxns), batch_size)]
|
64 |
+
test_rxns_batches = [test_rxns[i:i+batch_size] for i in range(0, len(test_rxns), batch_size)]
|
65 |
+
|
66 |
+
device = torch.device("cuda")
|
67 |
+
train_fps = []
|
68 |
+
for batch in tqdm(train_rxns_batches, desc='Generating fingerprints for training reactions'):
|
69 |
+
batch_fps = self.rxnfp_generator.convert_batch(batch)
|
70 |
+
train_fps.extend(batch_fps)
|
71 |
+
train_fps = torch.tensor(train_fps, device=device) # N x 256
|
72 |
+
|
73 |
+
most_similar_indices = []
|
74 |
+
for batch in tqdm(test_rxns_batches, desc='Generating fingerprints for test reactions'):
|
75 |
+
batch_fps = self.rxnfp_generator.convert_batch(batch)
|
76 |
+
batch_fps = torch.tensor(batch_fps, device=device) # BS x 256
|
77 |
+
batch_fps = batch_fps / torch.norm(batch_fps, dim=1, keepdim=True)
|
78 |
+
|
79 |
+
similarity_matrix = torch.matmul(train_fps, batch_fps.T) # N x BS
|
80 |
+
most_similar_indices.extend(torch.argmax(similarity_matrix, dim=0).tolist())
|
81 |
+
|
82 |
+
return [self.train_list[i]['actions'] for i in most_similar_indices]
|
83 |
+
|
84 |
+
def save_results(self, gt_list, pred_list, target_file):
|
85 |
+
text_dict_list = [{
|
86 |
+
"targets": gt,
|
87 |
+
"indices": i,
|
88 |
+
"predictions": pred,
|
89 |
+
} for i, (gt, pred) in enumerate(zip(gt_list, pred_list))]
|
90 |
+
|
91 |
+
with open(target_file, 'w') as f:
|
92 |
+
json.dump(text_dict_list, f, indent=4)
|
93 |
+
|
94 |
+
def parse_args():
|
95 |
+
parser = argparse.ArgumentParser(description="A simple argument parser")
|
96 |
+
|
97 |
+
parser.add_argument('--name', default='none', type=str)
|
98 |
+
parser.add_argument('--train_file', default=None, type=str)
|
99 |
+
parser.add_argument('--test_file', default=None, type=str)
|
100 |
+
parser.add_argument('--use_tok', default=False, action='store_true')
|
101 |
+
args = parser.parse_args()
|
102 |
+
return args
|
103 |
+
|
104 |
+
def read_dataset(args):
|
105 |
+
print(f'Reading {args.train_file}...')
|
106 |
+
with open(args.train_file, 'r', encoding='utf-8') as f:
|
107 |
+
train_ds = json.load(f)
|
108 |
+
print(f'{len(train_ds)} samples read.')
|
109 |
+
print(f'Reading {args.test_file}...')
|
110 |
+
with open(args.test_file, 'r', encoding='utf-8') as f:
|
111 |
+
test_ds = json.load(f)
|
112 |
+
print(f'{len(test_ds)} samples read.')
|
113 |
+
return train_ds, test_ds
|
114 |
+
|
115 |
+
def run_baselines(args):
|
116 |
+
set_random_seed(0)
|
117 |
+
|
118 |
+
train_ds, test_ds = read_dataset(args)
|
119 |
+
model = Reaction_model(train_ds, test_ds)
|
120 |
+
calculator = Metric_calculator()
|
121 |
+
gt_list = [i['actions'] for i in test_ds]
|
122 |
+
|
123 |
+
print('Random:')
|
124 |
+
pred_list = model.generate_random()
|
125 |
+
calculator(gt_list, pred_list, args.use_tok)
|
126 |
+
model.save_results(gt_list, pred_list, f'results/{args.name}/random.json')
|
127 |
+
|
128 |
+
print('Random (compatible pattern):')
|
129 |
+
pred_list = model.generate_random_compatible()
|
130 |
+
calculator(gt_list, pred_list, args.use_tok)
|
131 |
+
model.save_results(gt_list, pred_list, f'results/{args.name}/random_compatible.json')
|
132 |
+
|
133 |
+
print('Nearest neighbor:')
|
134 |
+
pred_list = model.generate_nn()
|
135 |
+
calculator(gt_list, pred_list, args.use_tok)
|
136 |
+
model.save_results(gt_list, pred_list, f'results/{args.name}/nn.json')
|
137 |
+
# assert 0
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
args=parse_args()
|
141 |
+
run_baselines(args)
|
read_results/read_results.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import *
|
2 |
+
|
3 |
+
def parse_args():
|
4 |
+
parser = argparse.ArgumentParser(description="A simple argument parser")
|
5 |
+
|
6 |
+
parser.add_argument('--name', default='none', type=str)
|
7 |
+
parser.add_argument('--path', default=None, type=str)
|
8 |
+
parser.add_argument('--use_tok', default=False, action='store_true')
|
9 |
+
args = parser.parse_args()
|
10 |
+
return args
|
11 |
+
|
12 |
+
def read_dataset(args):
|
13 |
+
print(f'Reading {args.path}...')
|
14 |
+
with open(args.path, 'r', encoding='utf-8') as f:
|
15 |
+
test_tgt = [json.loads(line) for line in f.readlines()]
|
16 |
+
print(f'{len(test_tgt)} samples read.')
|
17 |
+
gt_list = [i['targets'] for i in test_tgt]
|
18 |
+
pred_list = [i['predictions'] for i in test_tgt]
|
19 |
+
return gt_list, pred_list
|
20 |
+
|
21 |
+
def read_result(args):
|
22 |
+
gt_list, pred_list = read_dataset(args)
|
23 |
+
calculator = Metric_calculator()
|
24 |
+
calculator(gt_list, pred_list, args.use_tok)
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
args=parse_args()
|
28 |
+
read_result(args)
|
read_results/score.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rdkit import Chem
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
from tqdm import tqdm
|
5 |
+
import multiprocessing
|
6 |
+
import pandas as pd
|
7 |
+
from rdkit import RDLogger
|
8 |
+
import re
|
9 |
+
from utils import *
|
10 |
+
|
11 |
+
lg = RDLogger.logger()
|
12 |
+
lg.setLevel(RDLogger.CRITICAL)
|
13 |
+
|
14 |
+
|
15 |
+
def extract_smiles(s):
|
16 |
+
start_token = "[START_SMILES]"
|
17 |
+
end_token = "[END_SMILES]"
|
18 |
+
start_index = s.find(start_token) + len(start_token)
|
19 |
+
end_index = s.find(end_token)
|
20 |
+
if start_index > -1 and end_index > -1:
|
21 |
+
return s[start_index:end_index].strip()
|
22 |
+
return s
|
23 |
+
|
24 |
+
def canonicalize_smiles_clear_map(smiles,return_max_frag=True):
|
25 |
+
mol = Chem.MolFromSmiles(smiles,sanitize=not opt.synthon)
|
26 |
+
if mol is not None:
|
27 |
+
[atom.ClearProp('molAtomMapNumber') for atom in mol.GetAtoms() if atom.HasProp('molAtomMapNumber')]
|
28 |
+
try:
|
29 |
+
smi = Chem.MolToSmiles(mol, isomericSmiles=False)
|
30 |
+
except:
|
31 |
+
if return_max_frag:
|
32 |
+
return '',''
|
33 |
+
else:
|
34 |
+
return ''
|
35 |
+
if return_max_frag:
|
36 |
+
sub_smi = smi.split(".")
|
37 |
+
sub_mol = [Chem.MolFromSmiles(smiles,sanitize=not opt.synthon) for smiles in sub_smi]
|
38 |
+
sub_mol_size = [(sub_smi[i], len(m.GetAtoms())) for i, m in enumerate(sub_mol) if m is not None]
|
39 |
+
if len(sub_mol_size) > 0:
|
40 |
+
return smi, canonicalize_smiles_clear_map(sorted(sub_mol_size,key=lambda x:x[1],reverse=True)[0][0],return_max_frag=False)
|
41 |
+
else:
|
42 |
+
return smi, ''
|
43 |
+
else:
|
44 |
+
return smi
|
45 |
+
else:
|
46 |
+
if return_max_frag:
|
47 |
+
return '',''
|
48 |
+
else:
|
49 |
+
return ''
|
50 |
+
|
51 |
+
|
52 |
+
def compute_rank(input_smiles, prediction,raw=False,alpha=1.0):
|
53 |
+
valid_score = [[k for k in range(len(prediction[j]))] for j in range(len(prediction))]
|
54 |
+
invalid_rates = [0 for k in range(len(prediction[0]))]
|
55 |
+
rank = {}
|
56 |
+
max_frag_rank = {}
|
57 |
+
highest = {}
|
58 |
+
if raw:
|
59 |
+
# no test augmentation
|
60 |
+
assert len(prediction) == 1
|
61 |
+
for j in range(len(prediction)):
|
62 |
+
for k in range(len(prediction[j])):
|
63 |
+
if prediction[j][k][0] == "":
|
64 |
+
invalid_rates[k] += 1
|
65 |
+
# error detection
|
66 |
+
de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0][0] != ""]
|
67 |
+
prediction[j] = list(set(de_error))
|
68 |
+
prediction[j].sort(key=de_error.index)
|
69 |
+
for k, data in enumerate(prediction[j]):
|
70 |
+
rank[data] = 1 / (alpha * k + 1)
|
71 |
+
else:
|
72 |
+
for j in range(len(prediction)): # aug_num, beam_size, 2
|
73 |
+
for k in range(len(prediction[j])):
|
74 |
+
# predictions[i][j][k] = canonicalize_smiles_clear_map(predictions[i][j][k])
|
75 |
+
if prediction[j][k][0] == "":
|
76 |
+
valid_score[j][k] = opt.beam_size + 1
|
77 |
+
invalid_rates[k] += 1
|
78 |
+
# error detection and deduplication
|
79 |
+
de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0][0] != ""]
|
80 |
+
prediction[j] = list(set(de_error))
|
81 |
+
prediction[j].sort(key=de_error.index)
|
82 |
+
for k, data in enumerate(prediction[j]):
|
83 |
+
if data in rank:
|
84 |
+
rank[data] += 1 / (alpha * k + 1)
|
85 |
+
else:
|
86 |
+
rank[data] = 1 / (alpha * k + 1)
|
87 |
+
if data in highest:
|
88 |
+
highest[data] = min(k,highest[data])
|
89 |
+
else:
|
90 |
+
highest[data] = k
|
91 |
+
for key in rank.keys():
|
92 |
+
rank[key] += highest[key] * -1
|
93 |
+
rank[key] += abs(len(key[0])-len(input_smiles)) * -0.2
|
94 |
+
rank[key] += len(key[0]) * -0.2
|
95 |
+
return rank,invalid_rates
|
96 |
+
|
97 |
+
def read_dataset(opt):
|
98 |
+
print(f'Reading {opt.path}...')
|
99 |
+
with open(opt.path, 'r', encoding='utf-8') as f:
|
100 |
+
test_tgt = [json.loads(line) for line in f.readlines()]
|
101 |
+
if opt.raw:
|
102 |
+
test_tgt = test_tgt[::opt.augmentation]
|
103 |
+
filtered_tgt = {}
|
104 |
+
idx_key = 'ds_idx' if 'ds_idx' in test_tgt[0] else 'index'
|
105 |
+
for dic in test_tgt:
|
106 |
+
if dic[idx_key] not in filtered_tgt:
|
107 |
+
filtered_tgt[dic[idx_key]] = dic
|
108 |
+
test_tgt = list(filtered_tgt.values())
|
109 |
+
test_tgt.sort(key=lambda x: x[idx_key])
|
110 |
+
print(f'{len(test_tgt)} samples read.')
|
111 |
+
input_list = [extract_smiles(i['input']) for i in test_tgt]
|
112 |
+
gt_list = [i['targets'].replace('[START_SMILES]', '').replace('[END_SMILES]', '').replace('SPL1T-TH1S-Pl3A5E','').strip().replace(' ','.') for i in test_tgt]
|
113 |
+
pred_list = [[smi.strip().replace(' ','.') for smi in i['predictions']] for i in test_tgt]
|
114 |
+
return input_list, gt_list, pred_list
|
115 |
+
|
116 |
+
def main(opt):
|
117 |
+
input_list, gt_list, pred_list = read_dataset(opt)
|
118 |
+
if opt.raw:
|
119 |
+
opt.augmentation=1
|
120 |
+
print('Reading predictions from file ...')
|
121 |
+
|
122 |
+
# inputs
|
123 |
+
print("Input Length", len(gt_list))
|
124 |
+
ras_src_smiles = input_list[::opt.augmentation]
|
125 |
+
with multiprocessing.Pool(processes=opt.process_number) as pool:
|
126 |
+
ras_src_smiles = pool.map(func=canonicalize_smiles_clear_map,iterable=ras_src_smiles)
|
127 |
+
ras_src_smiles = [i[0] for i in ras_src_smiles]
|
128 |
+
|
129 |
+
# predictions
|
130 |
+
print("Prediction Length", len(pred_list))
|
131 |
+
pred_lines = [i.split('>')[0] for d in pred_list for i in d]
|
132 |
+
data_size = len(pred_lines) // (opt.augmentation * opt.beam_size) if opt.length == -1 else opt.length
|
133 |
+
pred_lines = pred_lines[:data_size * (opt.augmentation * opt.beam_size)]
|
134 |
+
print("Canonicalizing predictions using Process Number ",opt.process_number)
|
135 |
+
with multiprocessing.Pool(processes=opt.process_number) as pool:
|
136 |
+
raw_predictions = pool.map(func=canonicalize_smiles_clear_map,iterable=pred_lines)
|
137 |
+
|
138 |
+
predictions = [[[] for j in range(opt.augmentation)] for i in range(data_size)] # data_len x augmentation x beam_size
|
139 |
+
for i, line in enumerate(raw_predictions):
|
140 |
+
predictions[i // (opt.beam_size * opt.augmentation)][i % (opt.beam_size * opt.augmentation) // opt.beam_size].append(line)
|
141 |
+
|
142 |
+
# ground truth
|
143 |
+
print("Origin Length", len(gt_list))
|
144 |
+
targets = [''.join(gt_list[i].strip().split(' ')) for i in tqdm(range(0,data_size * opt.augmentation,opt.augmentation))]
|
145 |
+
with multiprocessing.Pool(processes=opt.process_number) as pool:
|
146 |
+
targets = pool.map(func=canonicalize_smiles_clear_map, iterable=targets)
|
147 |
+
|
148 |
+
print("predictions Length", len(predictions), len(predictions[0]), len(predictions[0][0]))
|
149 |
+
print("Target Length", len(targets))
|
150 |
+
|
151 |
+
ground_truth = targets
|
152 |
+
print("Origin Target Lentgh, ", len(ground_truth))
|
153 |
+
print("Cutted Length, ",data_size)
|
154 |
+
print('\n')
|
155 |
+
accuracy = [0 for j in range(opt.n_best)]
|
156 |
+
topn_accuracy_chirality = [0 for _ in range(opt.n_best)]
|
157 |
+
topn_accuracy_wochirality = [0 for _ in range(opt.n_best)]
|
158 |
+
topn_accuracy_ringopening = [0 for _ in range(opt.n_best)]
|
159 |
+
topn_accuracy_ringformation = [0 for _ in range(opt.n_best)]
|
160 |
+
topn_accuracy_woring = [0 for _ in range(opt.n_best)]
|
161 |
+
total_chirality = 0
|
162 |
+
total_ringopening = 0
|
163 |
+
total_ringformation = 0
|
164 |
+
atomsize_topk = []
|
165 |
+
accurate_indices = [[] for j in range(opt.n_best)]
|
166 |
+
max_frag_accuracy = [0 for j in range(opt.n_best)]
|
167 |
+
invalid_rates = [0 for j in range(opt.beam_size)]
|
168 |
+
sorted_invalid_rates = [0 for j in range(opt.beam_size)]
|
169 |
+
unique_rates = 0
|
170 |
+
ranked_results = []
|
171 |
+
|
172 |
+
for i in tqdm(range(len(predictions))):
|
173 |
+
accurate_flag = False
|
174 |
+
if opt.detailed:
|
175 |
+
chirality_flag = False
|
176 |
+
ringopening_flag = False
|
177 |
+
ringformation_flag = False
|
178 |
+
pro_mol = Chem.MolFromSmiles(ras_src_smiles[i])
|
179 |
+
rea_mol = Chem.MolFromSmiles(ground_truth[i][0])
|
180 |
+
try:
|
181 |
+
pro_ringinfo = pro_mol.GetRingInfo()
|
182 |
+
rea_ringinfo = rea_mol.GetRingInfo()
|
183 |
+
pro_ringnum = pro_ringinfo.NumRings()
|
184 |
+
rea_ringnum = rea_ringinfo.NumRings()
|
185 |
+
size = len(rea_mol.GetAtoms()) - len(pro_mol.GetAtoms())
|
186 |
+
# if (int(ras_src_smiles[i].count("@") > 0) + int(ground_truth[i][0].count("@") > 0)) == 1:
|
187 |
+
if ras_src_smiles[i].count("@") > 0 or ground_truth[i][0].count("@") > 0:
|
188 |
+
total_chirality += 1
|
189 |
+
chirality_flag = True
|
190 |
+
if pro_ringnum < rea_ringnum:
|
191 |
+
total_ringopening += 1
|
192 |
+
ringopening_flag = True
|
193 |
+
if pro_ringnum > rea_ringnum:
|
194 |
+
total_ringformation += 1
|
195 |
+
ringformation_flag = True
|
196 |
+
except:
|
197 |
+
pass
|
198 |
+
# continue
|
199 |
+
|
200 |
+
inputs = input_list[i*opt.augmentation:(i+1)*opt.augmentation]
|
201 |
+
gt = gt_list[i*opt.augmentation:(i+1)*opt.augmentation]
|
202 |
+
rank, invalid_rate = compute_rank(ras_src_smiles[i], predictions[i], raw=opt.raw,alpha=opt.score_alpha)
|
203 |
+
|
204 |
+
rank_ = {k[0]: v for k, v in sorted(rank.items(), key=lambda item: item[1], reverse=True)}
|
205 |
+
if opt.detailed:
|
206 |
+
print('Index', i)
|
207 |
+
print('inputs', json.dumps(inputs, indent=4))
|
208 |
+
print('targets', json.dumps(gt, indent=4))
|
209 |
+
print('input', ras_src_smiles[i])
|
210 |
+
print('target', targets[i][0])
|
211 |
+
print('rank', json.dumps(rank_,indent=4))
|
212 |
+
print('invalid_rate', json.dumps(invalid_rate,indent=4))
|
213 |
+
print('\n')
|
214 |
+
for j in range(opt.beam_size):
|
215 |
+
invalid_rates[j] += invalid_rate[j]
|
216 |
+
rank = list(zip(rank.keys(),rank.values()))
|
217 |
+
rank.sort(key=lambda x:x[1],reverse=True)
|
218 |
+
rank = rank[:opt.n_best]
|
219 |
+
ranked_results.append([item[0][0] for item in rank])
|
220 |
+
|
221 |
+
for j, item in enumerate(rank):
|
222 |
+
if item[0][0] == ground_truth[i][0]:
|
223 |
+
if not accurate_flag:
|
224 |
+
accurate_flag = True
|
225 |
+
accurate_indices[j].append(i)
|
226 |
+
for k in range(j, opt.n_best):
|
227 |
+
accuracy[k] += 1
|
228 |
+
if opt.detailed:
|
229 |
+
atomsize_topk.append((size,j))
|
230 |
+
if chirality_flag:
|
231 |
+
for k in range(j,opt.n_best):
|
232 |
+
topn_accuracy_chirality[k] += 1
|
233 |
+
else:
|
234 |
+
for k in range(j,opt.n_best):
|
235 |
+
topn_accuracy_wochirality[k] += 1
|
236 |
+
if ringopening_flag:
|
237 |
+
for k in range(j,opt.n_best):
|
238 |
+
topn_accuracy_ringopening[k] += 1
|
239 |
+
if ringformation_flag:
|
240 |
+
for k in range(j,opt.n_best):
|
241 |
+
topn_accuracy_ringformation[k] += 1
|
242 |
+
if not ringopening_flag and not ringformation_flag:
|
243 |
+
for k in range(j,opt.n_best):
|
244 |
+
topn_accuracy_woring[k] += 1
|
245 |
+
|
246 |
+
if opt.detailed and not accurate_flag:
|
247 |
+
atomsize_topk.append((size,opt.n_best))
|
248 |
+
for j, item in enumerate(rank):
|
249 |
+
if item[0][1] == ground_truth[i][1]:
|
250 |
+
for k in range(j,opt.n_best):
|
251 |
+
max_frag_accuracy[k] += 1
|
252 |
+
break
|
253 |
+
for j in range(len(rank),opt.beam_size):
|
254 |
+
sorted_invalid_rates[j] += 1
|
255 |
+
unique_rates += len(rank)
|
256 |
+
|
257 |
+
for i in range(opt.n_best):
|
258 |
+
if i in [0,1,2,3,4,5,6,7,8,9,19,49]:
|
259 |
+
# if i in range(10):
|
260 |
+
print("Top-{} Acc:{:.3f}%, MaxFrag {:.3f}%,".format(i+1,accuracy[i] / data_size * 100,max_frag_accuracy[i] / data_size * 100),
|
261 |
+
" Invalid SMILES:{:.3f}% Sorted Invalid SMILES:{:.3f}%".format(invalid_rates[i] / data_size / opt.augmentation * 100,sorted_invalid_rates[i] / data_size / opt.augmentation * 100))
|
262 |
+
print(' '.join([f'{accuracy[i] / data_size * 100:.3f}' for i in [0,2,4,9]]))
|
263 |
+
print("Unique Rates:{:.3f}%".format(unique_rates / data_size / opt.beam_size * 100))
|
264 |
+
|
265 |
+
if opt.detailed:
|
266 |
+
print_topk = [1,3,5,10]
|
267 |
+
save_dict = {}
|
268 |
+
atomsize_topk.sort(key=lambda x:x[0])
|
269 |
+
differ_now = atomsize_topk[0][0]
|
270 |
+
topn_accuracy_bydiffer = [0 for _ in range(opt.n_best)]
|
271 |
+
total_bydiffer = 0
|
272 |
+
for i,item in enumerate(atomsize_topk):
|
273 |
+
if differ_now < 11 and differ_now != item[0]:
|
274 |
+
for j in range(opt.n_best):
|
275 |
+
if (j+1) in print_topk:
|
276 |
+
save_dict[f'top-{j+1}_size_{differ_now}'] = topn_accuracy_bydiffer[j] / total_bydiffer * 100
|
277 |
+
print("Top-{} Atom differ size {} Acc:{:.3f}%, Number:{:.3f}%".format(j+1,
|
278 |
+
differ_now,
|
279 |
+
topn_accuracy_bydiffer[j] / total_bydiffer * 100,
|
280 |
+
total_bydiffer/data_size * 100))
|
281 |
+
total_bydiffer = 0
|
282 |
+
topn_accuracy_bydiffer = [0 for _ in range(opt.n_best)]
|
283 |
+
differ_now = item[0]
|
284 |
+
for k in range(item[1],opt.n_best):
|
285 |
+
topn_accuracy_bydiffer[k] += 1
|
286 |
+
total_bydiffer += 1
|
287 |
+
for j in range(opt.n_best):
|
288 |
+
if (j + 1) in print_topk:
|
289 |
+
print("Top-{} Atom differ size {} Acc:{:.3f}%, Number:{:.3f}%".format(j + 1,
|
290 |
+
differ_now,
|
291 |
+
topn_accuracy_bydiffer[j] / total_bydiffer * 100,
|
292 |
+
total_bydiffer / data_size * 100))
|
293 |
+
save_dict[f'top-{j+1}_size_{differ_now}'] = topn_accuracy_bydiffer[j] / total_bydiffer * 100
|
294 |
+
|
295 |
+
for i in range(opt.n_best):
|
296 |
+
if (i+1) in print_topk:
|
297 |
+
if total_chirality > 0:
|
298 |
+
print("Top-{} Accuracy with chirality:{:.3f}%".format(i + 1, topn_accuracy_chirality[i] / total_chirality * 100))
|
299 |
+
save_dict[f'top-{i+1}_chilarity'] = topn_accuracy_chirality[i] / total_chirality * 100
|
300 |
+
print("Top-{} Accuracy without chirality:{:.3f}%".format(i + 1, topn_accuracy_wochirality[i] / (data_size - total_chirality) * 100))
|
301 |
+
save_dict[f'top-{i+1}_wochilarity'] = topn_accuracy_wochirality[i] / (data_size - total_chirality) * 100
|
302 |
+
if total_ringopening > 0:
|
303 |
+
print("Top-{} Accuracy ring Opening:{:.3f}%".format(i + 1, topn_accuracy_ringopening[i] / total_ringopening * 100))
|
304 |
+
save_dict[f'top-{i+1}_ringopening'] = topn_accuracy_ringopening[i] / total_ringopening * 100
|
305 |
+
if total_ringformation > 0:
|
306 |
+
print("Top-{} Accuracy ring Formation:{:.3f}%".format(i + 1, topn_accuracy_ringformation[i] / total_ringformation * 100))
|
307 |
+
save_dict[f'top-{i+1}_ringformation'] = topn_accuracy_ringformation[i] / total_ringformation * 100
|
308 |
+
print("Top-{} Accuracy without ring:{:.3f}%".format(i + 1, topn_accuracy_woring[i] / (data_size - total_ringopening - total_ringformation) * 100))
|
309 |
+
save_dict[f'top-{i+1}_wocring'] = topn_accuracy_woring[i] / (data_size - total_ringopening - total_ringformation)* 100
|
310 |
+
print(total_chirality)
|
311 |
+
print(total_ringformation)
|
312 |
+
print(total_ringopening)
|
313 |
+
# df = pd.DataFrame(list(save_dict.items()))
|
314 |
+
df = pd.DataFrame(save_dict,index=[0])
|
315 |
+
df.to_csv("detailed_results.csv")
|
316 |
+
if opt.save_accurate_indices != "":
|
317 |
+
with open(opt.save_accurate_indices, "w") as f:
|
318 |
+
total_accurate_indices = []
|
319 |
+
for indices in accurate_indices:
|
320 |
+
total_accurate_indices.extend(indices)
|
321 |
+
total_accurate_indices.sort()
|
322 |
+
|
323 |
+
# for index in total_accurate_indices:
|
324 |
+
for index in accurate_indices[0]:
|
325 |
+
f.write(str(index))
|
326 |
+
f.write("\n")
|
327 |
+
|
328 |
+
if opt.save_file != "":
|
329 |
+
with open(opt.save_file,"w") as f:
|
330 |
+
for res in ranked_results:
|
331 |
+
for smi in res:
|
332 |
+
f.write(smi)
|
333 |
+
f.write("\n")
|
334 |
+
for i in range(len(res),opt.n_best):
|
335 |
+
f.write("")
|
336 |
+
f.write("\n")
|
337 |
+
|
338 |
+
|
339 |
+
if __name__ == "__main__":
|
340 |
+
parser = argparse.ArgumentParser(
|
341 |
+
description='score.py',
|
342 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
343 |
+
parser.add_argument('--beam_size', type=int, default=10,help='Beam size')
|
344 |
+
parser.add_argument('--n_best', type=int, default=10,help='n best')
|
345 |
+
parser.add_argument('--path', type=str, required=True, help="Path to file containing the predictions and ground truth.")
|
346 |
+
parser.add_argument('--augmentation', type=int, default=20)
|
347 |
+
parser.add_argument('--score_alpha', type=float, default=1.0)
|
348 |
+
parser.add_argument('--length', type=int, default=-1)
|
349 |
+
parser.add_argument('--process_number', type=int, default=multiprocessing.cpu_count())
|
350 |
+
parser.add_argument('--synthon', action="store_true", default=False)
|
351 |
+
parser.add_argument('--detailed', action="store_true", default=False)
|
352 |
+
parser.add_argument('--raw', action="store_true", default=False)
|
353 |
+
parser.add_argument('--save_file', type=str,default="")
|
354 |
+
parser.add_argument('--save_accurate_indices', type=str,default="")
|
355 |
+
|
356 |
+
opt = parser.parse_args()
|
357 |
+
print(opt)
|
358 |
+
main(opt)
|
read_results/t_test.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import *
|
2 |
+
import scipy.stats as stats
|
3 |
+
|
4 |
+
def parse_args():
|
5 |
+
parser = argparse.ArgumentParser(description="A simple argument parser")
|
6 |
+
|
7 |
+
parser.add_argument('--name', default='none', type=str)
|
8 |
+
parser.add_argument('--path_exp', default=None, type=str)
|
9 |
+
parser.add_argument('--path_ref', default=None, type=str)
|
10 |
+
parser.add_argument('--use_tok', default=False, action='store_true')
|
11 |
+
args = parser.parse_args()
|
12 |
+
return args
|
13 |
+
|
14 |
+
def read_dataset(data_path):
|
15 |
+
print(f'Reading {data_path}...')
|
16 |
+
with open(data_path, 'r', encoding='utf-8') as f:
|
17 |
+
test_tgt = [json.loads(line) for line in f.readlines()]
|
18 |
+
print(f'{len(test_tgt)} samples read.')
|
19 |
+
gt_list = [i['targets'] for i in test_tgt]
|
20 |
+
pred_list = [i['predictions'] for i in test_tgt]
|
21 |
+
return gt_list, pred_list
|
22 |
+
|
23 |
+
def t_test(mean_exp, std_exp, mean_ref, std_ref, n):
|
24 |
+
numerator = mean_exp - mean_ref
|
25 |
+
denominator = np.sqrt((std_exp**2 / n) + (std_ref**2 / n))
|
26 |
+
t_statistic = numerator / denominator
|
27 |
+
df = (((std_exp**2 / n) + (std_ref**2 / n))**2) / (((std_exp**2 / n)**2 / (n-1)) + ((std_ref**2 / n)**2 / (n-1)))
|
28 |
+
|
29 |
+
p_value = 2 * stats.t.sf(np.abs(t_statistic), df)
|
30 |
+
return t_statistic, p_value
|
31 |
+
|
32 |
+
def read_result(args):
|
33 |
+
gt_list_exp, pred_list_exp = read_dataset(args.path_exp)
|
34 |
+
gt_list_ref, pred_list_ref = read_dataset(args.path_ref)
|
35 |
+
calculator = Metric_calculator()
|
36 |
+
result_exp = calculator.get_result_list(gt_list_exp, pred_list_exp, args.use_tok)
|
37 |
+
result_ref = calculator.get_result_list(gt_list_ref, pred_list_ref, args.use_tok)
|
38 |
+
|
39 |
+
for key in ['bleu2', 'bleu4', 'rouge_1', 'rouge_2', 'rouge_l', 'lev_score', 'meteor_score']:
|
40 |
+
if not isinstance(result_exp[key], list):
|
41 |
+
continue
|
42 |
+
levene_s, levene_p = stats.levene(result_exp[key], result_ref[key])
|
43 |
+
t_stat, p_val = stats.ttest_ind(result_exp[key], result_ref[key], equal_var=(levene_p > 0.05))
|
44 |
+
print(f'{key} (mean={float(np.mean(result_exp[key])):.4f}, levene p={levene_p:.3f}):\t{t_stat:.6f}\t{p_val}')
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
args=parse_args()
|
48 |
+
read_result(args)
|
read_results/utils.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Levenshtein import distance as lev_distance
|
2 |
+
import random
|
3 |
+
import json
|
4 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction, corpus_bleu
|
5 |
+
from nltk.translate.meteor_score import meteor_score
|
6 |
+
from rouge_score import rouge_scorer
|
7 |
+
from tqdm import tqdm
|
8 |
+
import random
|
9 |
+
import numpy as np
|
10 |
+
import argparse
|
11 |
+
from paragraph2actions.readable_converter import ReadableConverter
|
12 |
+
import re
|
13 |
+
from transformers import AutoTokenizer
|
14 |
+
from collections import defaultdict
|
15 |
+
import time
|
16 |
+
from functools import wraps
|
17 |
+
import os
|
18 |
+
import torch
|
19 |
+
import textdistance
|
20 |
+
from typing import List
|
21 |
+
|
22 |
+
def levenshtein_similarity(truth: List[str], pred: List[str]) -> List[float]:
|
23 |
+
assert len(truth) == len(pred)
|
24 |
+
scores: List[float] = [
|
25 |
+
textdistance.levenshtein.normalized_similarity(t, p)
|
26 |
+
for t, p in zip(truth, pred)
|
27 |
+
]
|
28 |
+
return scores
|
29 |
+
|
30 |
+
def modified_bleu(truth: List[str], pred: List[str], bleu_n=4) -> float:
|
31 |
+
"""
|
32 |
+
Calculates the BLEU score of a translation, with a small modification in order not to penalize sentences
|
33 |
+
with less than 4 words.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
value between 0 and 1.
|
37 |
+
"""
|
38 |
+
references = [sentence.split() for sentence in truth]
|
39 |
+
candidates = [sentence.split() for sentence in pred]
|
40 |
+
|
41 |
+
# BLEU penalizes sentences with only one word. Even correct translations get a score of zero.
|
42 |
+
references = [r + max(0, bleu_n - len(r)) * [""] for r in references]
|
43 |
+
candidates = [c + max(0, bleu_n - len(c)) * [""] for c in candidates]
|
44 |
+
|
45 |
+
# references must have a larger depth because it supports multiple choices
|
46 |
+
refs = [[r] for r in references]
|
47 |
+
weights = {
|
48 |
+
2: (0.5, 0.5),
|
49 |
+
4: (0.25, 0.25, 0.25, 0.25),
|
50 |
+
}
|
51 |
+
return 100*corpus_bleu(refs, candidates, weights=weights[bleu_n]) # type: ignore[no-any-return]
|
52 |
+
|
53 |
+
def set_random_seed(seed):
|
54 |
+
random.seed(seed)
|
55 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
56 |
+
np.random.seed(seed)
|
57 |
+
torch.manual_seed(seed)
|
58 |
+
torch.cuda.manual_seed(seed)
|
59 |
+
torch.cuda.manual_seed_all(seed) # If using multi-GPU.
|
60 |
+
torch.backends.cudnn.deterministic = True
|
61 |
+
torch.backends.cudnn.benchmark = False
|
62 |
+
|
63 |
+
def time_it(func):
|
64 |
+
@wraps(func)
|
65 |
+
def wrapper(*args, **kwargs):
|
66 |
+
start_time = time.time()
|
67 |
+
result = func(*args, **kwargs)
|
68 |
+
end_time = time.time()
|
69 |
+
print(f"Function {func.__name__} finished in {end_time - start_time:.5f} seconds.\n")
|
70 |
+
return result
|
71 |
+
return wrapper
|
72 |
+
|
73 |
+
def accuracy_score(score_list, threshold):
|
74 |
+
matches = sum(score>=threshold for score in score_list)
|
75 |
+
acc = matches / len(score_list)
|
76 |
+
return acc
|
77 |
+
|
78 |
+
def extract_tokenized_entities(text):
|
79 |
+
pattern = r'\$[^\$]+\$|#[^#]+#|@[^\@]+@'
|
80 |
+
return re.findall(pattern, text)
|
81 |
+
|
82 |
+
def extract_reactant_cnt(text):
|
83 |
+
max_id = None
|
84 |
+
for token in text.split():
|
85 |
+
if token.startswith('$') and token.endswith('$'):
|
86 |
+
try:
|
87 |
+
current_id = int(token.strip('$'))
|
88 |
+
if max_id is None or current_id > max_id:
|
89 |
+
max_id = current_id
|
90 |
+
except ValueError:
|
91 |
+
pass # Ignore tokens that do not represent an integer
|
92 |
+
if not max_id:
|
93 |
+
return 0
|
94 |
+
return max_id
|
95 |
+
|
96 |
+
class Metric_calculator:
|
97 |
+
def __init__(self, text_trunc_length=1024):
|
98 |
+
self.converter = ReadableConverter(separator=' ; ')
|
99 |
+
self.tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', use_fast=False, padding_side='right')
|
100 |
+
self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
|
101 |
+
self.text_trunc_length = text_trunc_length
|
102 |
+
self.scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'])
|
103 |
+
|
104 |
+
def tokenize(self, gt_list, pred_list):
|
105 |
+
references = []
|
106 |
+
hypotheses = []
|
107 |
+
|
108 |
+
for gt, out in tqdm(zip(gt_list, pred_list)):
|
109 |
+
gt_tokens = self.tokenizer.tokenize(gt)
|
110 |
+
## added for galactica
|
111 |
+
gt_tokens = list(filter(('<pad>').__ne__, gt_tokens))
|
112 |
+
gt_tokens = list(filter(('[PAD]').__ne__, gt_tokens))
|
113 |
+
gt_tokens = list(filter(('[CLS]').__ne__, gt_tokens))
|
114 |
+
gt_tokens = list(filter(('[SEP]').__ne__, gt_tokens))
|
115 |
+
|
116 |
+
out_tokens = self.tokenizer.tokenize(out)
|
117 |
+
out_tokens = list(filter(('<pad>').__ne__, out_tokens))
|
118 |
+
out_tokens = list(filter(('[PAD]').__ne__, out_tokens))
|
119 |
+
out_tokens = list(filter(('[CLS]').__ne__, out_tokens))
|
120 |
+
out_tokens = list(filter(('[SEP]').__ne__, out_tokens))
|
121 |
+
|
122 |
+
references.append([gt_tokens])
|
123 |
+
hypotheses.append(out_tokens)
|
124 |
+
return references, hypotheses
|
125 |
+
|
126 |
+
@time_it
|
127 |
+
def __call__(self, gt_list, pred_list, use_tokenizer=False):
|
128 |
+
gt_list = [gt.strip() for gt in gt_list]
|
129 |
+
pred_list = [pred.strip() for pred in pred_list]
|
130 |
+
|
131 |
+
if use_tokenizer:
|
132 |
+
references, hypotheses = self.tokenize(gt_list, pred_list)
|
133 |
+
bleu2, bleu4 = self.bleu(references, hypotheses)
|
134 |
+
_meteor_score = self.meteor(references, hypotheses)
|
135 |
+
else:
|
136 |
+
bleu2 = modified_bleu(gt_list, pred_list, bleu_n=2)
|
137 |
+
bleu4 = modified_bleu(gt_list, pred_list, bleu_n=4)
|
138 |
+
_meteor_score = 0
|
139 |
+
rouge_1, rouge_2, rouge_l = self.rouge(gt_list, pred_list)
|
140 |
+
|
141 |
+
validity = self.validity(gt_list, pred_list)
|
142 |
+
acc_100, acc_90, acc_75, acc_50 = self.accuracy(gt_list, pred_list)
|
143 |
+
|
144 |
+
print('BLEU-2 score:', bleu2)
|
145 |
+
print('BLEU-4 score:', bleu4)
|
146 |
+
print('Average Meteor score:', _meteor_score)
|
147 |
+
print('rouge1:', rouge_1)
|
148 |
+
print('rouge2:', rouge_2)
|
149 |
+
print('rougeL:', rouge_l)
|
150 |
+
|
151 |
+
print(f'Validity: {validity:.6f}')
|
152 |
+
print(f'Accuracy (100): {acc_100:.6f}')
|
153 |
+
print(f'Accuracy (90): {acc_90:.6f}')
|
154 |
+
print(f'Accuracy (75): {acc_75:.6f}')
|
155 |
+
print(f'Accuracy (50): {acc_50:.6f}')
|
156 |
+
|
157 |
+
line = ''
|
158 |
+
for score in [validity, bleu2, bleu4, acc_100, acc_90, acc_75, acc_50, rouge_1, rouge_2, rouge_l, _meteor_score]:
|
159 |
+
line += f'{score:.6f} '
|
160 |
+
print(line)
|
161 |
+
|
162 |
+
return {
|
163 |
+
'bleu2': bleu2,
|
164 |
+
'bleu4': bleu4,
|
165 |
+
'rouge_1': rouge_1,
|
166 |
+
'rouge_2': rouge_2,
|
167 |
+
'rouge_l': rouge_l,
|
168 |
+
'meteor_score': _meteor_score,
|
169 |
+
'validity': validity,
|
170 |
+
'acc_100': acc_100,
|
171 |
+
'acc_90': acc_90,
|
172 |
+
'acc_75': acc_75,
|
173 |
+
'acc_50': acc_50,
|
174 |
+
}
|
175 |
+
|
176 |
+
def get_result_list(self, gt_list, pred_list, use_tokenizer=False):
|
177 |
+
gt_list = [gt.strip() for gt in gt_list]
|
178 |
+
pred_list = [pred.strip() for pred in pred_list]
|
179 |
+
|
180 |
+
if use_tokenizer:
|
181 |
+
references, hypotheses = self.tokenize(gt_list, pred_list)
|
182 |
+
bleu2 = [corpus_bleu([gt], [pred], weights=(.5,.5)) for gt, pred in zip(references, hypotheses)]
|
183 |
+
bleu4 = [corpus_bleu([gt], [pred], weights=(.25,.25,.25,.25)) for gt, pred in zip(references, hypotheses)]
|
184 |
+
_meteor_score = [meteor_score(gt, out) for gt, out in zip(references, hypotheses)]
|
185 |
+
else:
|
186 |
+
bleu2 = [modified_bleu([gt], [pred], bleu_n=2) for gt, pred in zip(gt_list, pred_list)]
|
187 |
+
bleu4 = [modified_bleu([gt], [pred], bleu_n=4) for gt, pred in zip(gt_list, pred_list)]
|
188 |
+
_meteor_score = 0
|
189 |
+
rouge_1, rouge_2, rouge_l = self.rouge(gt_list, pred_list, return_list=True)
|
190 |
+
|
191 |
+
lev_score = levenshtein_similarity(gt_list, pred_list)
|
192 |
+
|
193 |
+
return {
|
194 |
+
'bleu2': bleu2,
|
195 |
+
'bleu4': bleu4,
|
196 |
+
'rouge_1': rouge_1,
|
197 |
+
'rouge_2': rouge_2,
|
198 |
+
'rouge_l': rouge_l,
|
199 |
+
'meteor_score': _meteor_score,
|
200 |
+
'lev_score': lev_score,
|
201 |
+
}
|
202 |
+
|
203 |
+
def bleu(self, references, hypotheses):
|
204 |
+
bleu2 = corpus_bleu(references, hypotheses, weights=(.5,.5))
|
205 |
+
bleu4 = corpus_bleu(references, hypotheses, weights=(.25,.25,.25,.25))
|
206 |
+
bleu2 *= 100
|
207 |
+
bleu4 *= 100
|
208 |
+
return bleu2, bleu4
|
209 |
+
|
210 |
+
def meteor(self, references, hypotheses):
|
211 |
+
meteor_scores = []
|
212 |
+
for gt, out in zip(references, hypotheses):
|
213 |
+
mscore = meteor_score(gt, out)
|
214 |
+
meteor_scores.append(mscore)
|
215 |
+
_meteor_score = np.mean(meteor_scores)
|
216 |
+
_meteor_score *= 100
|
217 |
+
return _meteor_score
|
218 |
+
|
219 |
+
def rouge(self, targets, predictions, return_list=False):
|
220 |
+
rouge_scores = []
|
221 |
+
for gt, out in zip(targets, predictions):
|
222 |
+
rs = self.scorer.score(out, gt)
|
223 |
+
rouge_scores.append(rs)
|
224 |
+
|
225 |
+
rouge_1 = [rs['rouge1'].fmeasure for rs in rouge_scores]
|
226 |
+
rouge_2 = [rs['rouge2'].fmeasure for rs in rouge_scores]
|
227 |
+
rouge_l = [rs['rougeL'].fmeasure for rs in rouge_scores]
|
228 |
+
if return_list:
|
229 |
+
return rouge_1, rouge_2, rouge_l
|
230 |
+
|
231 |
+
rouge_1 = np.mean(rouge_1) * 100
|
232 |
+
rouge_2 = np.mean(rouge_2) * 100
|
233 |
+
rouge_l = np.mean(rouge_l) * 100
|
234 |
+
return rouge_1, rouge_2, rouge_l
|
235 |
+
|
236 |
+
|
237 |
+
def validity(self, gt_list, pred_list):
|
238 |
+
num_valid, n = 0, len(pred_list)
|
239 |
+
for pred, gt in zip(pred_list, gt_list):
|
240 |
+
try:
|
241 |
+
actions = self.converter.string_to_actions(pred)
|
242 |
+
max_token_pred = extract_reactant_cnt(pred)
|
243 |
+
max_token_gt = extract_reactant_cnt(gt)
|
244 |
+
assert max_token_gt >= max_token_pred
|
245 |
+
num_valid += 1
|
246 |
+
except:
|
247 |
+
pass
|
248 |
+
return 100*(num_valid / n)
|
249 |
+
|
250 |
+
def accuracy(self, gt_list, pred_list):
|
251 |
+
score_list = levenshtein_similarity(gt_list, pred_list)
|
252 |
+
acc_100 = 100*accuracy_score(score_list, 1.0)
|
253 |
+
acc_90 = 100*accuracy_score(score_list, 0.90)
|
254 |
+
acc_75 = 100*accuracy_score(score_list, 0.75)
|
255 |
+
acc_50 = 100*accuracy_score(score_list, 0.50)
|
256 |
+
return acc_100, acc_90, acc_75, acc_50
|
visualize_context_gen.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data_provider.context_gen import *
|
2 |
+
|
3 |
+
def parse_args():
|
4 |
+
parser = argparse.ArgumentParser(description="A simple argument parser")
|
5 |
+
|
6 |
+
# Script arguments
|
7 |
+
parser.add_argument('--name', default='none', type=str)
|
8 |
+
parser.add_argument('--seed', default=0, type=int)
|
9 |
+
parser.add_argument('--epochs', default=100, type=int)
|
10 |
+
parser.add_argument('--chunk_size', default=100, type=int)
|
11 |
+
parser.add_argument('--rxn_num', default=50000, type=int)
|
12 |
+
parser.add_argument('--k', default=4, type=int)
|
13 |
+
parser.add_argument('--root', default='data/pretrain_data', type=str)
|
14 |
+
|
15 |
+
args = parser.parse_args()
|
16 |
+
return args
|
17 |
+
|
18 |
+
def pad_shorter_array(arr1, arr2):
|
19 |
+
len1 = arr1.shape[0]
|
20 |
+
len2 = arr2.shape[0]
|
21 |
+
if len1 > len2:
|
22 |
+
arr2 = np.pad(arr2, (0, len1 - len2), 'constant')
|
23 |
+
elif len2 > len1:
|
24 |
+
arr1 = np.pad(arr1, (0, len2 - len1), 'constant')
|
25 |
+
return arr1, arr2
|
26 |
+
|
27 |
+
def plot_distribution(values, target_path, x_lim=None, y_lim=None, chunk_size=100, color='blue'):
|
28 |
+
num_full_chunks = len(values) // chunk_size
|
29 |
+
values = np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1)
|
30 |
+
values = np.sort(values)[::-1]
|
31 |
+
plt.figure(figsize=(10, 4), dpi=100)
|
32 |
+
x = np.arange(len(values))
|
33 |
+
plt.bar(x, values, color=color)
|
34 |
+
current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int)
|
35 |
+
plt.xticks((current_values/chunk_size).astype(int), current_values)
|
36 |
+
plt.ylabel('Molecule Frequency', fontsize=20)
|
37 |
+
if x_lim:
|
38 |
+
plt.xlim(*x_lim)
|
39 |
+
if y_lim:
|
40 |
+
plt.ylim(*y_lim)
|
41 |
+
plt.tick_params(axis='both', which='major', labelsize=12)
|
42 |
+
plt.tight_layout(pad=0.5)
|
43 |
+
plt.savefig(target_path)
|
44 |
+
print(f'Figure saved to {target_path}')
|
45 |
+
plt.clf()
|
46 |
+
|
47 |
+
def plot_compare_distribution(list1, list2, target_path, x_lim=None, y_lim=None, labels=['Random', 'Ours'], colors=['blue', 'orange'], chunk_size=100):
|
48 |
+
num_full_chunks = len(list1) // chunk_size
|
49 |
+
list1, list2 = pad_shorter_array(list1, list2)
|
50 |
+
values1, values2 = [
|
51 |
+
np.sort(np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1))[::-1]
|
52 |
+
for values in (list1, list2)]
|
53 |
+
|
54 |
+
plt.figure(figsize=(10, 6), dpi=100)
|
55 |
+
x = np.arange(len(values1))
|
56 |
+
plt.bar(x, values1, color=colors[0], label=labels[0], alpha=0.6)
|
57 |
+
plt.bar(x, values2, color=colors[1], label=labels[1], alpha=0.5)
|
58 |
+
current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int)
|
59 |
+
plt.xticks((current_values/chunk_size).astype(int), current_values)
|
60 |
+
plt.ylabel('Molecule Frequency', fontsize=20)
|
61 |
+
if x_lim:
|
62 |
+
plt.xlim(*x_lim)
|
63 |
+
if y_lim:
|
64 |
+
plt.ylim(*y_lim)
|
65 |
+
plt.tick_params(axis='both', which='major', labelsize=18)
|
66 |
+
plt.tight_layout(pad=0.5)
|
67 |
+
plt.legend(fontsize=24, loc='upper right')
|
68 |
+
plt.savefig(target_path)
|
69 |
+
print(f'Figure saved to {target_path}')
|
70 |
+
plt.clf()
|
71 |
+
|
72 |
+
def statistics(args):
|
73 |
+
if args.seed:
|
74 |
+
set_random_seed(args.seed)
|
75 |
+
# 1141864 rxns from ord
|
76 |
+
# 1120773 rxns from uspto
|
77 |
+
cluster = Reaction_Cluster(args.root)
|
78 |
+
|
79 |
+
rxn_num = len(cluster.reaction_data)
|
80 |
+
abstract_num = 0
|
81 |
+
property_num = 0
|
82 |
+
calculated_property_num = 0
|
83 |
+
experimental_property_num = 0
|
84 |
+
avg_calculated_property_len = 0
|
85 |
+
avg_experimental_property_len = 0
|
86 |
+
mol_set = set()
|
87 |
+
for rxn_dict in cluster.reaction_data:
|
88 |
+
for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']:
|
89 |
+
for mol in rxn_dict[key]:
|
90 |
+
mol_set.add(mol)
|
91 |
+
mol_num = len(mol_set)
|
92 |
+
|
93 |
+
for mol_dict in cluster.property_data:
|
94 |
+
if 'abstract' in mol_dict:
|
95 |
+
abstract_num += 1
|
96 |
+
if 'property' in mol_dict:
|
97 |
+
property_num += 1
|
98 |
+
if 'Experimental Properties' in mol_dict['property']:
|
99 |
+
experimental_property_num += 1
|
100 |
+
avg_experimental_property_len += len(mol_dict['property']['Experimental Properties'])
|
101 |
+
if 'Computed Properties' in mol_dict['property']:
|
102 |
+
calculated_property_num += 1
|
103 |
+
avg_calculated_property_len += len(mol_dict['property']['Computed Properties'])
|
104 |
+
|
105 |
+
print(f'Reaction Number: {rxn_num}')
|
106 |
+
print(f'Molecule Number: {mol_num}')
|
107 |
+
print(f'Abstract Number: {abstract_num}/{mol_num}({abstract_num/mol_num*100:.2f}%)')
|
108 |
+
print(f'Property Number: {property_num}/{mol_num}({property_num/mol_num*100:.2f}%)')
|
109 |
+
print(f'- Experimental Properties Number: {experimental_property_num}/{property_num}({experimental_property_num/property_num*100:.2f}%), {avg_experimental_property_len/mol_num:.2f} items per molecule')
|
110 |
+
print(f'- Computed Properties: {calculated_property_num}/{property_num}({calculated_property_num/property_num*100:.2f}%), {avg_calculated_property_len/mol_num:.2f} items per molecule')
|
111 |
+
|
112 |
+
def visualize(args):
|
113 |
+
if args.seed:
|
114 |
+
set_random_seed(args.seed)
|
115 |
+
cluster = Reaction_Cluster(args.root)
|
116 |
+
prob_values, rxn_weights = cluster.visualize_mol_distribution()
|
117 |
+
rand_prob_values, rand_rxn_weights = cluster._randomly(
|
118 |
+
cluster.visualize_mol_distribution
|
119 |
+
)
|
120 |
+
fig_root = f'results/{args.name}/'
|
121 |
+
|
122 |
+
plot_distribution(prob_values, fig_root+'mol_distribution.pdf')
|
123 |
+
plot_distribution(rxn_weights, fig_root+'rxns_distribution.pdf')
|
124 |
+
plot_distribution(rand_prob_values, fig_root+'mol_distribution_random.pdf')
|
125 |
+
plot_distribution(rand_rxn_weights, fig_root+'rxns_distribution_random.pdf')
|
126 |
+
|
127 |
+
plot_compare_distribution(prob_values, rand_prob_values, fig_root+'Compare_mol.pdf', y_lim=(-0.5,15.5))
|
128 |
+
plot_compare_distribution(rxn_weights, rand_rxn_weights, fig_root+'Compare_rxns.pdf')
|
129 |
+
|
130 |
+
|
131 |
+
def visualize_frequency(args):
|
132 |
+
if args.seed:
|
133 |
+
set_random_seed(args.seed)
|
134 |
+
fig_root = f'results/{args.name}/'
|
135 |
+
name_suffix = f'E{args.epochs}_Rxn{args.rxn_num}_K{args.k}'
|
136 |
+
cache_path = f'{fig_root}/freq_{name_suffix}.npy'
|
137 |
+
if os.path.exists(cache_path):
|
138 |
+
mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq = np.load(cache_path, allow_pickle=True)
|
139 |
+
else:
|
140 |
+
cluster = Reaction_Cluster(args.root)
|
141 |
+
mol_freq, rxn_freq = cluster.visualize_mol_frequency(rxn_num=args.rxn_num, k=args.k, epochs=args.epochs)
|
142 |
+
rand_mol_freq, rand_rxn_freq = cluster._randomly(
|
143 |
+
cluster.visualize_mol_frequency,
|
144 |
+
rxn_num=args.rxn_num, k=args.k, epochs=args.epochs
|
145 |
+
)
|
146 |
+
np.save(cache_path, np.array([mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq], dtype=object), allow_pickle=True)
|
147 |
+
|
148 |
+
color1 = '#FA7F6F'
|
149 |
+
color2 = '#80AFBF'
|
150 |
+
color3 = '#FFBE7A'
|
151 |
+
plot_distribution(mol_freq, fig_root+f'mol_frequency_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2)
|
152 |
+
# plot_distribution(rxn_freq, fig_root+f'rxns_frequency_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1)
|
153 |
+
plot_distribution(rand_mol_freq, fig_root+f'mol_frequency_random_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2)
|
154 |
+
# plot_distribution(rand_rxn_freq, fig_root+f'rxns_frequency_random_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1)
|
155 |
+
|
156 |
+
plot_compare_distribution(rand_mol_freq, mol_freq, fig_root+f'Compare_mol_{name_suffix}.pdf', y_lim=(-2, 62), labels=['Before Adjustment', 'After Adjustment'], colors=[color1, color2], chunk_size=args.chunk_size)
|
157 |
+
# plot_compare_distribution(rxn_freq, rand_rxn_freq, fig_root+f'Compare_rxns_{name_suffix}.pdf', chunk_size=args.chunk_size)
|
158 |
+
|
159 |
+
if __name__=='__main__':
|
160 |
+
args = parse_args()
|
161 |
+
print(args, flush=True)
|
162 |
+
# statistics(args)
|
163 |
+
# visualize(args)
|
164 |
+
visualize_frequency(args)
|