SyrWin commited on
Commit
95f97c5
·
1 Parent(s): a281fc1
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/frameworks.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kvplm_pretrained
2
+ __pycache__
3
+ test*
4
+ log*
5
+ *.sh.e*
6
+ *.sh.o*
7
+ .d*
8
+ llms
9
+ Text2graph
10
+ fig/
11
+ results/
12
+ *.out
13
+ *.err
14
+ debug*
15
+ data
16
+ scripts
17
+ conda_env
18
+ tmp*
README.md CHANGED
@@ -1,13 +1,49 @@
1
- ---
2
- title: ReactXT
3
- emoji: 🏆
4
- colorFrom: green
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 4.36.0
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-sa-4.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReactXT: Understanding Molecular “Reaction-ship” via Reaction-Contextualized Molecule-Text Pretraining
2
+
3
+ ## Comparison to previous molecule-text generative modeling methods
4
+
5
+ ![fig1](./figures/comparison.pdf)
6
+
7
+
8
+ ## Framework of ReactXT
9
+
10
+ ![fig1](./figures/frameworks.pdf)
11
+
12
+
13
+ ## Requirements
14
+
15
+ Our environment is detailed in `environment.yml`. To create a new environment `reactxt`, run the following command:
16
+
17
+ ```bash
18
+ conda env create -f environment.yml
19
+ ```
20
+
21
+
22
+ ## Reproduce the results
23
+
24
+ ### Reaction-Contextualized Molecule-Text Pretraining
25
+
26
+ ```bash
27
+ bash scripts/run_pretrain.sh
28
+ ```
29
+
30
+ ### Finetuning on downstream tasks
31
+
32
+ 1. Experimental Procedure Prediction on OpenExp
33
+
34
+ ```bash
35
+ bash scripts/run_action.sh
36
+ ```
37
+
38
+ 2. Molecule Captioning on PubChem324k and CheBI-20
39
+
40
+ ```bash
41
+ bash scripts/run_caption.sh
42
+ bash scripts/run_chebi.sh
43
+ ```
44
+
45
+ 3. Retro-synthesis Prediction on USPTO-50k
46
+
47
+ ```bash
48
+ bash scripts/run_retro.sh
49
+ ```
all_checkpoints/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import warnings
5
+ from rdkit import Chem
6
+ from rdkit.Chem import CanonSmiles
7
+ from rdkit.Chem import MolFromSmiles, MolToSmiles
8
+ from data_provider.pretrain_dm import PretrainDM
9
+ from data_provider.tune_dm import *
10
+ from model.opt_flash_attention import replace_opt_attn_with_flash_attn
11
+ from model.blip2_model import Blip2Model
12
+ from data_provider.data_utils import json_read, json_write
13
+ from data_provider.data_utils import smiles2data, reformat_smiles
14
+ import gradio as gr
15
+ from datetime import datetime
16
+
17
+ ## for pyg bug
18
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
19
+ ## for A5000 gpus
20
+ torch.set_float32_matmul_precision('medium') # can be medium (bfloat16), high (tensorfloat32), highest (float32)
21
+
22
+ def smiles_split(string, separator='.'):
23
+ string = str(string)
24
+ mols = []
25
+ for smi in string.split(separator):
26
+ mol = MolFromSmiles(smi)
27
+ if mol is None:
28
+ continue # Skip invalid SMILES strings
29
+ mols.append(mol)
30
+
31
+ parts = []
32
+ current_part = []
33
+ charge_count = 0
34
+
35
+ for mol in mols:
36
+ charge = Chem.GetFormalCharge(mol)
37
+ if charge==0:
38
+ if current_part:
39
+ smiles = '.'.join([MolToSmiles(m) for m in current_part])
40
+ smiles = CanonSmiles(smiles)
41
+ parts.append(smiles)
42
+ current_part = []
43
+ charge_count = 0
44
+ parts.append(MolToSmiles(mol))
45
+ else:
46
+ charge_count += charge
47
+ current_part.append(mol)
48
+ if charge_count == 0:
49
+ smiles = '.'.join([MolToSmiles(m) for m in current_part])
50
+ smiles = CanonSmiles(smiles)
51
+ parts.append(smiles)
52
+ current_part = []
53
+ charge_count = 0
54
+ if current_part:
55
+ smiles = '.'.join([MolToSmiles(m) for m in current_part])
56
+ smiles = CanonSmiles(smiles)
57
+ parts.append(smiles)
58
+
59
+ return parts
60
+
61
+ def get_args():
62
+ parser = argparse.ArgumentParser()
63
+ parser.add_argument('--filename', type=str, default="main")
64
+ parser.add_argument('--seed', type=int, default=42, help='random seed')
65
+ # MM settings
66
+ parser.add_argument('--mode', type=str, default='pretrain', choices=['pretrain', 'ft', 'eval', 'pretrain_eval'])
67
+ parser.add_argument('--strategy_name', type=str, default='mydeepspeed')
68
+ parser.add_argument('--iupac_prediction', action='store_true', default=False)
69
+ parser.add_argument('--ckpt_path', type=str, default=None)
70
+ # parser = Trainer.add_argparse_args(parser)
71
+ parser = Blip2Model.add_model_specific_args(parser) # add model args
72
+ parser = PretrainDM.add_model_specific_args(parser)
73
+ parser.add_argument('--accelerator', type=str, default='gpu')
74
+ parser.add_argument('--devices', type=str, default='0,1,2,3')
75
+ parser.add_argument('--precision', type=str, default='bf16-mixed')
76
+ parser.add_argument('--downstream_task', type=str, default='action', choices=['action', 'synthesis', 'caption', 'chebi'])
77
+ parser.add_argument('--max_epochs', type=int, default=10)
78
+ parser.add_argument('--enable_flash', action='store_true', default=False)
79
+ parser.add_argument('--disable_graph_cache', action='store_true', default=False)
80
+ parser.add_argument('--generate_restrict_tokens', action='store_true', default=False)
81
+ parser.add_argument('--train_restrict_tokens', action='store_true', default=False)
82
+ parser.add_argument('--smiles_type', type=str, default='default', choices=['default', 'canonical', 'restricted', 'unrestricted', 'r_smiles'])
83
+ parser.add_argument('--accumulate_grad_batches', type=int, default=1)
84
+ parser.add_argument('--tqdm_interval', type=int, default=50)
85
+ parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
86
+ args = parser.parse_args()
87
+
88
+ if args.enable_flash:
89
+ replace_opt_attn_with_flash_attn()
90
+ return args
91
+
92
+ app_config = {
93
+ "init_checkpoint": "all_checkpoints/ckpt_tune_hybridFeb11_May31/last_converted.ckpt",
94
+ "filename": "app",
95
+ "opt_model": "facebook/galactica-1.3b",
96
+ "num_workers": 4,
97
+ "rxn_max_len": 512,
98
+ "text_max_len": 512,
99
+ "precision": "bf16-mixed",
100
+ "max_inference_len": 512,
101
+ }
102
+
103
+ class InferenceRunner:
104
+ def __init__(self, model, tokenizer, rxn_max_len, smi_max_len,
105
+ smiles_type='default', device='cuda', args=None):
106
+ self.model = model
107
+ self.rxn_max_len = rxn_max_len
108
+ self.smi_max_len = smi_max_len
109
+ self.tokenizer = tokenizer
110
+ self.collater = Collater([], [])
111
+ self.mol_ph = '<mol>' * args.num_query_token
112
+ self.mol_token_id = tokenizer.mol_token_id
113
+ self.is_gal = args.opt_model.find('galactica') >= 0
114
+ self.collater = Collater([], [])
115
+ self.device = device
116
+ self.smiles_type = smiles_type
117
+ self.args = args
118
+ time_stamp = datetime.now().strftime("%Y.%m.%d-%H:%M")
119
+ self.cache_dir = f'results/{self.args.filename}/{time_stamp}'
120
+ os.makedirs(self.cache_dir, exist_ok=True)
121
+
122
+ def make_query_dict(self, rxn_string):
123
+ try:
124
+ reactant, solvent, product = rxn_string.split('>')
125
+ reactant = smiles_split(reactant)
126
+ product = smiles_split(product)
127
+ solvent = smiles_split(solvent) if solvent else []
128
+ assert reactant and product
129
+ except:
130
+ raise KeyError('Please input a valid reaction string')
131
+
132
+ extracted_molecules = {product[0]: "$-1$"}
133
+ for mol in reactant+solvent:
134
+ extracted_molecules[mol] = f"${len(extracted_molecules)}$"
135
+
136
+ result_dict = {}
137
+ result_dict['time_stamp'] = datetime.now().strftime("%Y.%m.%d %H:%M:%S.%f")[:-3]
138
+ result_dict['reaction_string'] = rxn_string
139
+ result_dict['REACTANT'] = reactant
140
+ result_dict['SOLVENT'] = solvent
141
+ result_dict['CATALYST'] = []
142
+ result_dict['PRODUCT'] = product
143
+ result_dict['extracted_molecules'] = extracted_molecules
144
+ return result_dict
145
+
146
+ def save_prediction(self, result_dict):
147
+ os.makedirs(self.cache_dir, exist_ok=True)
148
+ result_id = result_dict['time_stamp']
149
+ result_path = os.path.join(self.cache_dir, f'{result_id}.json')
150
+ json_write(result_path, result_dict)
151
+
152
+ def make_prompt(self, param_dict, smi_max_len=128):
153
+ smiles_list = []
154
+ prompt = ''
155
+ prompt += 'Reactants: '
156
+ smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len]
157
+ for smi in param_dict['REACTANT']:
158
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
159
+ smiles_list.append(smi)
160
+
161
+ prompt += 'Product: '
162
+ for smi in param_dict['PRODUCT']:
163
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
164
+ smiles_list.append(smi)
165
+
166
+ if param_dict['CATALYST']:
167
+ prompt += 'Catalysts: '
168
+ for smi in param_dict['CATALYST']:
169
+ if smi in param_dict["extracted_molecules"]:
170
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
171
+ else:
172
+ prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
173
+ smiles_list.append(smi)
174
+
175
+ if param_dict['SOLVENT']:
176
+ prompt += 'Solvents: '
177
+ for smi in param_dict['SOLVENT']:
178
+ if smi in param_dict["extracted_molecules"]:
179
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
180
+ else:
181
+ prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
182
+ smiles_list.append(smi)
183
+
184
+ prompt += 'Action Squence: '
185
+ return prompt, smiles_list
186
+
187
+ def get_action_elements(self, rxn_dict):
188
+ input_text, smiles_list = self.make_prompt(rxn_dict, self.smi_max_len)
189
+
190
+ graph_list = []
191
+ for smiles in smiles_list:
192
+ graph_item = smiles2data(smiles)
193
+ graph_list.append(graph_item)
194
+ return graph_list, input_text
195
+
196
+ @torch.no_grad()
197
+ def predict(self, rxn_dict, temperature=1):
198
+ graphs, prompt_tokens = self.tokenize(rxn_dict)
199
+ result_dict = rxn_dict
200
+ samples = {'graphs': graphs, 'prompt_tokens': prompt_tokens}
201
+ prediction = self.model.blip2opt.generate(
202
+ samples,
203
+ do_sample=self.args.do_sample,
204
+ num_beams=self.args.num_beams,
205
+ max_length=self.args.max_inference_len,
206
+ min_length=self.args.min_inference_len,
207
+ num_captions=self.args.num_generate_captions,
208
+ temperature=temperature,
209
+ use_graph=True
210
+ )[0]
211
+ for k, v in result_dict['extracted_molecules'].items():
212
+ prediction = prediction.replace(v, k)
213
+ result_dict['prediction'] = prediction
214
+ return result_dict
215
+
216
+ def tokenize(self, rxn_dict):
217
+ graph_list, input_text = self.get_action_elements(rxn_dict)
218
+ if graph_list:
219
+ graphs = self.collater(graph_list).to(self.device)
220
+ input_prompt = smiles_handler(input_text, self.mol_ph, self.is_gal)[0]
221
+
222
+ ## deal with prompt
223
+ self.tokenizer.padding_side = 'left'
224
+ input_prompt_tokens = self.tokenizer(input_prompt,
225
+ truncation=True,
226
+ padding='max_length',
227
+ add_special_tokens=True,
228
+ max_length=self.rxn_max_len,
229
+ return_tensors='pt',
230
+ return_attention_mask=True).to(self.device)
231
+ is_mol_token = input_prompt_tokens.input_ids == self.mol_token_id
232
+ input_prompt_tokens['is_mol_token'] = is_mol_token
233
+ return graphs, input_prompt_tokens
234
+
235
+ def main(args):
236
+ device = torch.device('cuda')
237
+ # model
238
+ if args.init_checkpoint:
239
+ model = Blip2Model(args).to(device)
240
+ ckpt = torch.load(args.init_checkpoint, map_location='cpu')
241
+ model.load_state_dict(ckpt['state_dict'], strict=False)
242
+ print(f"loaded model from {args.init_checkpoint}")
243
+ else:
244
+ model = Blip2Model(args).to(device)
245
+ model.eval()
246
+
247
+ print('total params:', sum(p.numel() for p in model.parameters()))
248
+
249
+ if args.opt_model.find('galactica') >= 0 or args.opt_model.find('t5') >= 0:
250
+ tokenizer = model.blip2opt.opt_tokenizer
251
+ elif args.opt_model.find('llama') >= 0 or args.opt_model.find('vicuna') >= 0:
252
+ tokenizer = model.blip2opt.llm_tokenizer
253
+ else:
254
+ raise NotImplementedError
255
+
256
+ infer_runner = InferenceRunner(
257
+ model=model,
258
+ tokenizer=tokenizer,
259
+ rxn_max_len=args.rxn_max_len,
260
+ smi_max_len=args.smi_max_len,
261
+ device=device,
262
+ args=args
263
+ )
264
+ example_inputs = json_read('demo.json')
265
+ example_inputs = [[e] for e in example_inputs]
266
+
267
+ def online_chat(reaction_string, temperature=1):
268
+ data_item = infer_runner.make_query_dict(reaction_string)
269
+ result = infer_runner.predict(data_item, temperature=temperature)
270
+ infer_runner.save_prediction(result)
271
+ prediction = result['prediction'].replace(' ; ', ' ;\n')
272
+ return prediction
273
+
274
+ with gr.Blocks(css="""
275
+ .center { display: flex; justify-content: center; }
276
+ """) as demo:
277
+ gr.HTML(
278
+ """
279
+ <center><h1><b>ReactXT</b></h1></center>
280
+ <p style="font-size:20px; font-weight:bold;">This is the demo page of our ACL 2024 paper
281
+ <i>ReactXT: Understanding Molecular “Reaction-ship” via Reaction-Contextualized Molecule-Text Pretraining.</i></p>
282
+ """)
283
+ with gr.Row(elem_classes="center"):
284
+ gr.Image(value="./figures/frameworks.jpg", elem_classes="center", width=800, label="Framework of ReactXT")
285
+ gr.HTML(
286
+ """
287
+ <p style="font-size:16px;"> Please input one chemical reaction below, and we will generate the predicted experimental procedure.</p>
288
+ <p style="font-size:16px;"> The reaction should be in form of <b>Reactants>Reagents>Product</b>.</p>
289
+ """)
290
+
291
+ reaction_string = gr.Textbox(placeholder="Input one reaction", label='Input Reaction')
292
+ gr.Examples(example_inputs, [reaction_string,], fn=online_chat, label='Example Reactions')
293
+ with gr.Row():
294
+ btn = gr.Button("Submit")
295
+ clear_btn = gr.Button("Clear")
296
+ temperature = gr.Slider(0.1, 1, value=1, label='Temperature')
297
+ with gr.Row():
298
+ out = gr.Textbox(label="ReactXT's Output", placeholder="Predicted experimental procedure")
299
+ btn.click(fn=online_chat, inputs=[reaction_string, temperature], outputs=[out])
300
+ clear_btn.click(fn=lambda:("", ""), inputs=[], outputs=[reaction_string, out])
301
+
302
+ demo.launch(share=True)
303
+
304
+
305
+
306
+ if __name__ == '__main__':
307
+ args = get_args()
308
+ vars(args).update(app_config)
309
+ main(args)
average_ckpt.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ def average_checkpoints(checkpoint_paths):
5
+ averaged_ckpt = torch.load(checkpoint_paths[-1], map_location=torch.device('cpu'))
6
+ param_sum_dict = {}
7
+ for key, value in averaged_ckpt['state_dict'].items():
8
+ param_sum_dict[key] = value.clone()
9
+
10
+ num_checkpoints = len(checkpoint_paths)
11
+ for ckpt_path in checkpoint_paths[:-1]:
12
+ checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
13
+ for key, value in checkpoint['state_dict'].items():
14
+ param_sum_dict[key] += value
15
+
16
+ for key in param_sum_dict.keys():
17
+ param_sum_dict[key] = param_sum_dict[key] / num_checkpoints
18
+ averaged_ckpt['state_dict'] = param_sum_dict
19
+
20
+ return averaged_ckpt
21
+
22
+ def parse_arguments():
23
+ parser = argparse.ArgumentParser(description="Averages the weights of multiple transformer model checkpoints.")
24
+ parser.add_argument('--checkpoint_paths', nargs='+', required=True,
25
+ help='List of paths to the checkpoints to be averaged. Example: --checkpoint_paths path1 path2 path3')
26
+ parser.add_argument('--output_path', type=str, required=True,)
27
+ return parser.parse_args()
28
+
29
+ if __name__ == "__main__":
30
+ args = parse_arguments()
31
+ averaged_state_dict = average_checkpoints(args.checkpoint_paths)
32
+ torch.save(averaged_state_dict, args.output_path)
33
+ print(f"Averaged checkpoint saved to {args.output_path}")
convert.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
4
+
5
+ if __name__ == '__main__':
6
+ ## read a path using argparse and pass it to convert_zero_checkpoint_to_fp32_state_dict
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--input', type=str, default=None, help='path to the desired checkpoint folder')
9
+ parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file')
10
+ # parser.add_argument('--tag', type=str, help='checkpoint tag used as a unique identifier for checkpoint')
11
+ args = parser.parse_args()
12
+ if args.output is None:
13
+ args.output = Path(args.input) / 'converted.ckpt'
14
+ convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output)
data_provider/__init__.py ADDED
File without changes
data_provider/caption_dataset.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.data import Dataset
3
+ import os
4
+ from torch_geometric.data import InMemoryDataset
5
+ from .data_utils import reformat_smiles
6
+ import random
7
+ import json
8
+
9
+ class PubChemDataset(InMemoryDataset):
10
+ def __init__(self, path):
11
+ super(PubChemDataset, self).__init__()
12
+ self.data, self.slices = torch.load(path)
13
+
14
+ def __getitem__(self, idx):
15
+ return self.get(idx)
16
+
17
+ class CaptionDataset(Dataset):
18
+ def __init__(self, root, mode, smi_max_len=128, use_graph=True, disable_graph_cache=False, smiles_type='default'):
19
+ super(CaptionDataset, self).__init__(root)
20
+ self.root = root
21
+ self.file_path = os.path.join(root, f'{mode}.pt')
22
+ self.smi_max_len = smi_max_len
23
+ self.tokenizer = None
24
+ self.use_graph = use_graph
25
+ self.smiles_type = smiles_type
26
+
27
+ self.data = PubChemDataset(self.file_path)
28
+
29
+ def get(self, index):
30
+ return self.__getitem__(index)
31
+
32
+ def len(self):
33
+ return len(self)
34
+
35
+ def __len__(self):
36
+ return len(self.data)
37
+
38
+ def __getitem__(self, index):
39
+ data = self.data[index]
40
+ smiles = reformat_smiles(data.smiles, smiles_type=self.smiles_type)
41
+ smiles_prompt = f'[START_I_SMILES]{smiles[:self.smi_max_len]}[END_I_SMILES]. '
42
+
43
+ text_list = []
44
+ count = 0
45
+ for line in data.text.split('\n'):
46
+ count += 1
47
+ text_list.append(line.strip())
48
+ if count > 100:
49
+ break
50
+ text = ' '.join(text_list) + '\n'
51
+ graph_list = [data] if self.use_graph else []
52
+
53
+ return index, graph_list, text, smiles_prompt
54
+
55
+ class PretrainCaptionDataset(Dataset):
56
+ def __init__(self, root, smi_max_len=128, use_graph=True, disable_graph_cache=False):
57
+ super(PretrainCaptionDataset, self).__init__(root)
58
+ self.pre_train_data = CaptionDataset(
59
+ root,
60
+ 'pretrain',
61
+ smi_max_len=smi_max_len,
62
+ use_graph=use_graph,
63
+ )
64
+ self.train_data = CaptionDataset(
65
+ root,
66
+ 'train',
67
+ smi_max_len=smi_max_len,
68
+ use_graph=use_graph,
69
+ )
70
+
71
+ def get(self, index):
72
+ return self.__getitem__(index)
73
+
74
+ def len(self):
75
+ return len(self)
76
+
77
+ def __len__(self):
78
+ return len(self.pre_train_data) + len(self.train_data)
79
+
80
+ def __getitem__(self, index):
81
+ if index < len(self.pre_train_data):
82
+ index, graph_list, text, smiles_prompt = self.pre_train_data[index]
83
+ else:
84
+ index, graph_list, text, smiles_prompt = self.train_data[index - len(self.pre_train_data)]
85
+ graph_item = graph_list[0]
86
+ if hasattr(graph_item, 'iupac'):
87
+ del graph_item.iupac
88
+ if hasattr(graph_item, 'cid'):
89
+ del graph_item.cid
90
+ del graph_item.text
91
+ del graph_item.smiles
92
+
93
+ return graph_item, text, smiles_prompt
data_provider/chebi_dataset.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.data import Dataset
3
+ import os
4
+ from torch_geometric.data import InMemoryDataset
5
+ import random
6
+ import json
7
+ from .data_utils import reformat_smiles
8
+
9
+ class ChEBI_dataset(Dataset):
10
+ def __init__(self, root, mode, smi_max_len=128, use_graph=True, disable_graph_cache=False, smiles_type='default'):
11
+ super(ChEBI_dataset, self).__init__(root)
12
+ self.root = root
13
+ self.file_path = os.path.join(root, f'{mode}.txt')
14
+ self.smi_max_len = smi_max_len
15
+ self.tokenizer = None
16
+ self.use_graph = use_graph
17
+ self.smiles_type = smiles_type
18
+ if self.use_graph:
19
+ self.idx_graph_map = torch.load(os.path.join(root, 'cid_graph_map.pt'))
20
+ with open(self.file_path) as f:
21
+ lines = f.readlines()
22
+ self.data = [line.split('\t', maxsplit=2) for line in lines[1:]]
23
+
24
+
25
+ def get(self, index):
26
+ return self.__getitem__(index)
27
+
28
+ def len(self):
29
+ return len(self)
30
+
31
+ def __len__(self):
32
+ return len(self.data)
33
+
34
+ def __getitem__(self, index):
35
+ cid, smiles, text = self.data[index]
36
+ smiles = reformat_smiles(smiles, smiles_type=self.smiles_type)
37
+ smiles_prompt = f'[START_I_SMILES]{smiles[:self.smi_max_len]}[END_I_SMILES]. '
38
+ text = text.strip() + '\n'
39
+ if self.use_graph:
40
+ graph_list = [self.idx_graph_map[cid]]
41
+
42
+ return index, graph_list, text, smiles_prompt
data_provider/context_gen.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+ import numpy as np
4
+ import argparse
5
+ import json
6
+ from collections import defaultdict
7
+ from matplotlib import pyplot as plt
8
+ from collections import Counter
9
+ from .data_utils import json_read
10
+
11
+ def set_random_seed(seed):
12
+ random.seed(seed)
13
+ os.environ['PYTHONHASHSEED'] = str(seed)
14
+ np.random.seed(seed)
15
+
16
+ class Reaction_Cluster:
17
+ def __init__(self, root, reaction_filename, reverse_ratio=0.5):
18
+ self.root = root
19
+ self.reaction_data = json_read(os.path.join(self.root, reaction_filename))
20
+ self.property_data = json_read(os.path.join(self.root, 'Abstract_property.json'))
21
+ self.mol_property_map = {d['canon_smiles']: d for d in self.property_data}
22
+ self.reverse_ratio = reverse_ratio
23
+ self.rxn_mols_attr = defaultdict(lambda:{
24
+ 'freq': 0,
25
+ 'occurrence': 0,
26
+ 'in_caption': False,
27
+ })
28
+
29
+ self._read_reaction_mols() # add `valid_mols` in each rxn_dict
30
+ self.mol_counter = Counter(mol for rxn_dict in self.reaction_data for mol in rxn_dict['valid_mols'])
31
+ self._calculate_Pr() # calculate P(r), add `weight` in each rxn_dict
32
+ self._calculate_Pir() # calculate P(i|r), add `mol_weight` in each rxn_dict
33
+
34
+ def _read_reaction_mols(self):
35
+ self.valid_rxn_indices = []
36
+ for rxn_id, rxn_dict in enumerate(self.reaction_data):
37
+ mol_role_map = {}
38
+ for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']:
39
+ for m in rxn_dict[key]:
40
+ if m in mol_role_map:
41
+ continue
42
+ if m in self.mol_property_map:
43
+ mol_role_map[m] = key
44
+ valid_mols = []
45
+ for mol in mol_role_map:
46
+ assert mol in self.mol_property_map # this is garanteed by the above if statement
47
+ if 'abstract' not in self.mol_property_map[mol]:
48
+ continue
49
+ valid_mols.append(mol) # here the molecules should be in the R, C, S, P order.
50
+ if len(valid_mols) > 0:
51
+ self.valid_rxn_indices.append(rxn_id)
52
+ rxn_dict['valid_mols'] = valid_mols
53
+ rxn_dict['mol_role_map'] = mol_role_map
54
+
55
+ def _calculate_Pr(self):
56
+ total_weights = 0
57
+ for rxn_dict in self.reaction_data:
58
+ rxn_weight = sum([1/self.mol_counter[mol] for mol in rxn_dict['valid_mols']])
59
+ rxn_dict['weight'] = rxn_weight
60
+ total_weights += rxn_weight
61
+ for rxn_dict in self.reaction_data:
62
+ rxn_dict['weight'] = rxn_dict['weight'] / total_weights
63
+
64
+ def _calculate_Pir(self):
65
+ for rxn_dict in self.reaction_data:
66
+ mol_weight = {}
67
+ for mol in rxn_dict['valid_mols']:
68
+ mol_weight[mol] = 1/self.mol_counter[mol]
69
+ total_weight = sum(mol_weight.values())
70
+ rxn_dict['mol_weight'] = {m:w/total_weight for m, w in mol_weight.items()}
71
+
72
+ def choose_mol(self, valid_mols, k=4, weights=None):
73
+ if k>=len(valid_mols):
74
+ sampled_indices = list(range(len(valid_mols)))
75
+ else:
76
+ sampled_indices = np.random.choice(len(valid_mols), k, replace=False, p=weights)
77
+ sampled_indices = list(sampled_indices)
78
+ sampled_indices = sorted(sampled_indices)
79
+ if random.random() < self.reverse_ratio: # reverse the indices with reverse_ratio chance.
80
+ sampled_indices.reverse()
81
+ sampled_mols = [valid_mols[i] for i in sampled_indices]
82
+ return sampled_mols
83
+
84
+ def sample_mol_batch(self, index=None, k=4):
85
+ if index is None:
86
+ index = self.sample_rxn_index(1)[0]
87
+ assert index < len(self.reaction_data)
88
+ rxn = self.reaction_data[index]
89
+ valid_mols, weights = zip(*rxn['mol_weight'].items())
90
+
91
+ sampled_mols = self.choose_mol(valid_mols, k=k, weights=weights)
92
+ mol_property_batch = []
93
+ for mol in sampled_mols:
94
+ mol_property = self.mol_property_map[mol]
95
+ mol_role = rxn['mol_role_map'][mol]
96
+ mol_property['role'] = mol_role
97
+ mol_property_batch.append(mol_property)
98
+ if 'rsmiles_map' in rxn:
99
+ rsmiles_map = random.choice(rxn['rsmiles_map'])
100
+ for mol_property in mol_property_batch:
101
+ canon_smiles = mol_property['canon_smiles']
102
+ if canon_smiles in rsmiles_map:
103
+ mol_property['r_smiles'] = rsmiles_map[canon_smiles]
104
+ return mol_property_batch
105
+
106
+ def sample_rxn_index(self, num_samples):
107
+ indices = range(len(self.reaction_data))
108
+ weights = [d['weight'] for d in self.reaction_data]
109
+ return np.random.choice(indices, num_samples, replace=False, p=weights)
110
+
111
+ def __call__(self, rxn_num=1000, k=4):
112
+ sampled_indices = self.sample_rxn_index(rxn_num)
113
+ sampled_batch = [self.sample_mol_batch(idx, k=k) for idx in sampled_indices]
114
+ return sampled_batch
115
+
116
+ def generate_batch_uniform_rxn(self, rxn_num=1000, k=4):
117
+ assert rxn_num <= len(self.valid_rxn_indices)
118
+ sampled_rxn_indices = random.sample(self.valid_rxn_indices, rxn_num)
119
+ sampled_batch = []
120
+ for rxn_id in sampled_rxn_indices:
121
+ rxn = self.reaction_data[rxn_id]
122
+ sampled_mols = self.choose_mol(rxn['valid_mols'], k=k, weights=None)
123
+ mol_property_batch = []
124
+ for mol in sampled_mols:
125
+ mol_property = self.mol_property_map[mol]
126
+ mol_role = rxn['mol_role_map'][mol]
127
+ mol_property['role'] = mol_role
128
+ mol_property_batch.append(mol_property)
129
+ sampled_batch.append(mol_property_batch)
130
+ return sampled_batch
131
+
132
+ def generate_batch_uniform_mol(self, rxn_num=1000, k=4):
133
+ valid_mols = list(self.mol_counter.elements())
134
+ assert rxn_num*k <= len(valid_mols)
135
+ sampled_batch = []
136
+ sampled_mol_ids = random.sample(range(len(valid_mols)), rxn_num*k)
137
+ for i in range(rxn_num):
138
+ sampled_batch.append([self.mol_property_map[valid_mols[mol_id]] for mol_id in sampled_mol_ids[i*k:(i+1)*k]])
139
+ return sampled_batch
140
+
141
+ def generate_batch_single(self, rxn_num=1000):
142
+ valid_mols = list(self.mol_counter.elements())
143
+ sampled_mols = random.sample(valid_mols, rxn_num)
144
+ total_valid_mols = [[self.mol_property_map[mol]] for mol in sampled_mols]
145
+ return total_valid_mols
146
+
147
+ # visaulize probability for molecules in caption dataset.
148
+ def visualize_mol_distribution(self):
149
+ prob_dict = {mol:0.0 for mol in self.mol_property_map.keys()}
150
+ N = len(prob_dict)
151
+ M = len(self.reaction_data)
152
+ assert N == len(self.mol_property_map)
153
+ print(f'Number of molecules in Caption Dataset: {N}')
154
+ print(f'Number of Reactions in Reaction Dataset: {M}')
155
+
156
+ # prob distribution for molecules
157
+ for rxn_dict in self.reaction_data:
158
+ for mol, weight in rxn_dict['mol_weight'].items():
159
+ prob_dict[mol] += weight * rxn_dict['weight']
160
+ # sum of prob_dict.values() should already be 1.
161
+ prob_values = np.array(list(prob_dict.values()))
162
+ prob_values *= N
163
+
164
+ # prob distribution for reactions
165
+ rxn_weights = np.array([d['weight'] for d in self.reaction_data])
166
+ # sum of rxn_weights should already be 1.
167
+ rxn_weights *= M
168
+
169
+ return prob_values, rxn_weights
170
+
171
+ # visaulize the frequency for molecules in caption dataset.
172
+ def visualize_mol_frequency(self, rxn_num=1000, k=4, epochs=100):
173
+ sampled_mols_counter = Counter()
174
+ sampled_rxns_counter = Counter()
175
+ for _ in range(epochs):
176
+ rxn_indices = self.sample_rxn_index(rxn_num)
177
+ sampled_rxns_counter.update(rxn_indices)
178
+ for index in rxn_indices:
179
+ rxn = self.reaction_data[index]
180
+ if len(rxn['valid_mols']) ==0:
181
+ continue
182
+ valid_mols, weights = zip(*rxn['mol_weight'].items())
183
+ mol_batch = self.choose_mol(valid_mols, k=k, weights=weights)
184
+ sampled_mols_counter.update(mol_batch)
185
+ sampled_mols_count = np.array([c for _, c in sorted(sampled_mols_counter.items())])
186
+ sampled_rxns_count = np.array([c for _, c in sorted(sampled_rxns_counter.items())])
187
+ return sampled_mols_count, sampled_rxns_count
188
+
189
+ def _randomly(self, func, *args, **kwargs):
190
+ # make fake weights and backup the weights
191
+ for rxn_dict in self.reaction_data:
192
+ rxn_dict['weight_bak'] = rxn_dict['weight']
193
+ rxn_dict['weight'] = 1/len(self.reaction_data)
194
+ rxn_dict['mol_weight_bak'] = rxn_dict['mol_weight']
195
+ rxn_dict['mol_weight'] = {m:1/len(rxn_dict['mol_weight']) for m in rxn_dict['mol_weight']}
196
+
197
+ # run the function
198
+ result = func(*args, **kwargs)
199
+
200
+ # weights recovery
201
+ for rxn_dict in self.reaction_data:
202
+ rxn_dict['weight'] = rxn_dict['weight_bak']
203
+ del rxn_dict['weight_bak']
204
+ rxn_dict['mol_weight'] = rxn_dict['mol_weight_bak']
205
+ del rxn_dict['mol_weight_bak']
206
+
207
+ return result
data_provider/data_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.data import Data
3
+ from ogb.utils import smiles2graph
4
+ from rdkit import Chem
5
+ import random
6
+ import os
7
+ import json
8
+ from rdkit import RDLogger
9
+ RDLogger.DisableLog('rdApp.*')
10
+ from .r_smiles import multi_process
11
+ import multiprocessing
12
+
13
+ def reformat_smiles(smiles, smiles_type='default'):
14
+ if not smiles:
15
+ return None
16
+ if smiles_type == 'default':
17
+ return smiles
18
+ elif smiles_type=='canonical':
19
+ mol = Chem.MolFromSmiles(smiles)
20
+ return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
21
+ elif smiles_type=='restricted':
22
+ mol = Chem.MolFromSmiles(smiles)
23
+ new_atom_order = list(range(mol.GetNumAtoms()))
24
+ random.shuffle(new_atom_order)
25
+ random_mol = Chem.RenumberAtoms(mol, newOrder=new_atom_order)
26
+ return Chem.MolToSmiles(random_mol, canonical=False, isomericSmiles=False)
27
+ elif smiles_type=='unrestricted':
28
+ mol = Chem.MolFromSmiles(smiles)
29
+ return Chem.MolToSmiles(mol, canonical=False, doRandom=True, isomericSmiles=False)
30
+ elif smiles_type=='r_smiles':
31
+ # the implementation of root-aligned smiles is in r_smiles.py
32
+ return smiles
33
+ else:
34
+ raise NotImplementedError(f"smiles_type {smiles_type} not implemented")
35
+
36
+ def json_read(path):
37
+ with open(path, 'r') as f:
38
+ data = json.load(f)
39
+ return data
40
+
41
+ def json_write(path, data):
42
+ with open(path, 'w') as f:
43
+ json.dump(data, f, indent=4, ensure_ascii=False)
44
+
45
+ def format_float_from_string(s):
46
+ try:
47
+ float_value = float(s)
48
+ return f'{float_value:.2f}'
49
+ except ValueError:
50
+ return s
51
+
52
+ def make_abstract(mol_dict, abstract_max_len=256, property_max_len=256):
53
+ prompt = ''
54
+ if 'abstract' in mol_dict:
55
+ abstract_string = mol_dict['abstract'][:abstract_max_len]
56
+ prompt += f'[Abstract] {abstract_string} '
57
+
58
+ property_string = ''
59
+ property_dict = mol_dict['property'] if 'property' in mol_dict else {}
60
+ for property_key in ['Experimental Properties', 'Computed Properties']:
61
+ if not property_key in property_dict:
62
+ continue
63
+ for key, value in property_dict[property_key].items():
64
+ if isinstance(value, float):
65
+ key_value_string = f'{key}: {value:.2f}; '
66
+ elif isinstance(value, str):
67
+ float_value = format_float_from_string(value)
68
+ key_value_string = f'{key}: {float_value}; '
69
+ else:
70
+ key_value_string = f'{key}: {value}; '
71
+ if len(property_string+key_value_string) > property_max_len:
72
+ break
73
+ property_string += key_value_string
74
+ if property_string:
75
+ property_string = property_string[:property_max_len]
76
+ prompt += f'[Properties] {property_string}. '
77
+ return prompt
78
+
79
+ def smiles2data(smiles):
80
+ graph = smiles2graph(smiles)
81
+ x = torch.from_numpy(graph['node_feat'])
82
+ edge_index = torch.from_numpy(graph['edge_index'], )
83
+ edge_attr = torch.from_numpy(graph['edge_feat'])
84
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
85
+ return data
86
+
87
+ import re
88
+ SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
89
+
90
+ CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
91
+
92
+
93
+ def _insert_split_marker(m: re.Match):
94
+ """
95
+ Applies split marker based on a regex match of special tokens such as
96
+ [START_DNA].
97
+
98
+ Parameters
99
+ ----------
100
+ n : str
101
+ Input text to split
102
+
103
+ Returns
104
+ ----------
105
+ str - the text with the split token added
106
+ """
107
+ start_token, _, sequence, end_token = m.groups()
108
+ sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
109
+ return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
110
+
111
+ def escape_custom_split_sequence(text):
112
+ """
113
+ Applies custom splitting to the text for GALILEO's tokenization
114
+
115
+ Parameters
116
+ ----------
117
+ text : str
118
+ Input text to split
119
+
120
+ Returns
121
+ ----------
122
+ str - the text with the split token added
123
+ """
124
+ return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
125
+
126
+ def generate_rsmiles(reactants, products, augmentation=20):
127
+ """
128
+ reactants: list of N, reactant smiles
129
+ products: list of N, product smiles
130
+ augmentation: int, number of augmentations
131
+
132
+ return: list of N x augmentation
133
+ """
134
+ data = [{
135
+ 'reactant': r.strip().replace(' ', ''),
136
+ 'product': p.strip().replace(' ', ''),
137
+ 'augmentation': augmentation,
138
+ 'root_aligned': True,
139
+ } for r, p in zip(reactants, products)]
140
+ pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
141
+ results = pool.map(func=multi_process,iterable=data)
142
+ product_smiles = [smi for r in results for smi in r['src_data']]
143
+ reactant_smiles = [smi for r in results for smi in r['tgt_data']]
144
+ return reactant_smiles, product_smiles
data_provider/molecule_abstract_dataset.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.data import Dataset
3
+ import os
4
+ from .context_gen import Reaction_Cluster
5
+ import json
6
+ from .data_utils import smiles2data, reformat_smiles
7
+ from collections import defaultdict
8
+ import random
9
+ from data_provider.caption_dataset import PretrainCaptionDataset
10
+ from data_provider.synthesis_dataset import SynthesisDataset
11
+
12
+ def format_float_from_string(s):
13
+ try:
14
+ float_value = float(s)
15
+ return f'{float_value:.2f}'
16
+ except ValueError:
17
+ return s
18
+
19
+ class MoleculeAbstract(Dataset):
20
+ def __init__(self,
21
+ root,
22
+ rxn_num=1000,
23
+ rxn_batch_size=4,
24
+ smi_max_len=128,
25
+ prompt=None,
26
+ disable_graph_cache=False,
27
+ disable_graphs=False,
28
+ context_style='weighted_rxn',
29
+ use_caption_dataset=False,
30
+ caption_batch_num=10000,
31
+ synthesis_datasetpath=None,
32
+ synthesis_batch_num=10000,
33
+ reverse_ratio=0.5,
34
+ enable_abstract=True,
35
+ enable_property=True,
36
+ smiles_type='default',
37
+ mode='train'
38
+ ):
39
+ super(MoleculeAbstract, self).__init__(root)
40
+ self.root = root
41
+ self.rxn_num = rxn_num
42
+ self.rxn_batch_size = rxn_batch_size
43
+ self.smi_max_len = smi_max_len
44
+ self.context_style = context_style
45
+ self.tokenizer = None
46
+ self.disable_graph_cache = disable_graph_cache
47
+ self.disable_graphs = disable_graphs
48
+ self.use_caption_dataset = use_caption_dataset
49
+ self.smiles_type = smiles_type
50
+ if use_caption_dataset:
51
+ self.caption_dataset = PretrainCaptionDataset(
52
+ os.path.join(root, '../caption_data'),
53
+ smi_max_len=smi_max_len,
54
+ use_graph=not self.disable_graphs,
55
+ disable_graph_cache=disable_graph_cache,
56
+ smiles_type=smiles_type,
57
+ )
58
+ self.caption_batch_num = caption_batch_num
59
+ self.use_synthesis_dataset = bool(synthesis_datasetpath)
60
+ if self.use_synthesis_dataset:
61
+ self.synthesis_dataset = SynthesisDataset(
62
+ synthesis_datasetpath,
63
+ 'train',
64
+ smi_max_len,
65
+ roundrobin_train=True,
66
+ use_graph=not disable_graphs,
67
+ disable_graph_cache=disable_graph_cache,
68
+ smiles_type='default',
69
+ )
70
+ self.synthesis_batch_num = synthesis_batch_num
71
+ if not self.disable_graphs:
72
+ self.mol_graph_map = torch.load(os.path.join(self.root, 'mol_graph_map.pt'))
73
+ reaction_filename = 'reactions/reactions_test.json' if (mode=='test') else 'reactions/reactions.json'
74
+ if smiles_type=='r_smiles':
75
+ reaction_filename = 'reactions/reactions_wRSMILES.json'
76
+ self.cluster = Reaction_Cluster(self.root, reaction_filename=reaction_filename, reverse_ratio=reverse_ratio)
77
+ self.reload_data_list()
78
+ self.abstract_max_len = 10240
79
+ self.property_max_len = 10240
80
+ self.enable_abstract = enable_abstract
81
+ self.enable_property = enable_property
82
+
83
+ def get(self, index):
84
+ return self.__getitem__(index)
85
+
86
+ def len(self):
87
+ return len(self)
88
+
89
+ def __len__(self):
90
+ data_len = len(self.data_list)
91
+ if self.use_caption_dataset:
92
+ data_len += len(self.caption_index_list)
93
+ if self.use_synthesis_dataset:
94
+ data_len += len(self.synthesis_index_list)
95
+ return data_len
96
+
97
+ def reload_data_list(self):
98
+ k = self.rxn_batch_size
99
+ if self.context_style == 'weighted_rxn':
100
+ self.data_list = self.cluster(self.rxn_num, k=k)
101
+ elif self.context_style == 'uniform_rxn':
102
+ self.data_list = self.cluster.generate_batch_uniform_rxn(self.rxn_num, k=k)
103
+ elif self.context_style == 'uniform_mol':
104
+ self.data_list = self.cluster.generate_batch_uniform_mol(self.rxn_num, k=k)
105
+ elif self.context_style == 'single_mol':
106
+ self.data_list = self.cluster.generate_batch_single(self.rxn_num)
107
+ elif self.context_style == 'hybrid':
108
+ self.data_list = self.cluster(self.rxn_num//2, k=k)
109
+ self.data_list += self.cluster.generate_batch_uniform_mol(self.rxn_num//2, k=k)
110
+ else:
111
+ raise NotImplementedError
112
+ if self.use_caption_dataset:
113
+ assert self.caption_batch_num*k <= len(self.caption_dataset)
114
+ caption_index_list = random.sample(range(len(self.caption_dataset)), self.caption_batch_num*k)
115
+ self.caption_index_list = [caption_index_list[i*k:(i+1)*k] for i in range(self.caption_batch_num)]
116
+ else:
117
+ self.caption_index_list = []
118
+ if self.use_synthesis_dataset:
119
+ if self.synthesis_dataset.roundrobin_train:
120
+ self.synthesis_dataset.reload_data()
121
+ assert self.synthesis_batch_num <= len(self.synthesis_dataset)
122
+ self.synthesis_index_list = random.sample(range(len(self.synthesis_dataset)), self.synthesis_batch_num)
123
+ else:
124
+ self.synthesis_index_list = []
125
+
126
+ def make_prompt(self, mol_batch, smi_max_len=128):
127
+ mol_prompt_list, text_prompt_list = [], []
128
+ last_role = None
129
+ for mol_dict in mol_batch:
130
+ smiles = mol_dict['canon_smiles']
131
+ if self.smiles_type=='r_smiles':
132
+ if 'r_smiles' in mol_dict:
133
+ smiles = mol_dict['r_smiles']
134
+ # else:
135
+ # smiles = reformat_smiles(smiles, smiles_type='restricted')
136
+ else:
137
+ smiles = reformat_smiles(smiles, smiles_type=self.smiles_type)
138
+ mol_prompt = f'[START_SMILES]{smiles[:smi_max_len]}[END_SMILES]. '
139
+ if 'role' in mol_dict:
140
+ role = {
141
+ 'REACTANT': 'Reactant',
142
+ 'CATALYST': 'Catalyst',
143
+ 'SOLVENT': 'Solvent',
144
+ 'PRODUCT': 'Product',
145
+ }[mol_dict['role']]
146
+ if last_role != role:
147
+ mol_prompt = f'{role}: {mol_prompt}'
148
+ last_role = role
149
+ text_prompt = self.make_abstract(mol_dict)
150
+ mol_prompt_list.append(mol_prompt)
151
+ text_prompt_list.append(text_prompt)
152
+ return mol_prompt_list, text_prompt_list
153
+
154
+ def make_abstract(self, mol_dict):
155
+ prompt = ''
156
+ if self.enable_abstract and 'abstract' in mol_dict:
157
+ abstract_string = mol_dict['abstract'][:self.abstract_max_len]
158
+ prompt += f'[Abstract] {abstract_string} '
159
+
160
+ if self.enable_property:
161
+ property_string = ''
162
+ property_dict = mol_dict['property'] if 'property' in mol_dict else {}
163
+ for property_key in ['Experimental Properties', 'Computed Properties']:
164
+ if not property_key in property_dict:
165
+ continue
166
+ for key, value in property_dict[property_key].items():
167
+ if isinstance(value, float):
168
+ key_value_string = f'{key}: {value:.2f}; '
169
+ elif isinstance(value, str):
170
+ float_value = format_float_from_string(value)
171
+ key_value_string = f'{key}: {float_value}; '
172
+ else:
173
+ key_value_string = f'{key}: {value}; '
174
+ if len(property_string+key_value_string) > self.property_max_len:
175
+ break
176
+ property_string += key_value_string
177
+ if property_string:
178
+ property_string = property_string[:self.property_max_len]
179
+ prompt += f'[Properties] {property_string}. '
180
+ return prompt
181
+
182
+ def get_caption_data(self, index):
183
+ caption_index = self.caption_index_list[index]
184
+ graph_list, mol_prompt_list, text_prompt_list = [], [], []
185
+ for idx in caption_index:
186
+ graph_item, text, smiles_prompt = self.caption_dataset[idx]
187
+ graph_list.append(graph_item)
188
+ mol_prompt_list.append(smiles_prompt)
189
+ text_prompt_list.append(text)
190
+
191
+ return graph_list, mol_prompt_list, text_prompt_list
192
+
193
+ def get_synthesis_data(self, index):
194
+ synthesis_index = self.synthesis_index_list[index]
195
+ _, graph_list, output_text, input_text = self.synthesis_dataset[synthesis_index]
196
+ return graph_list, [input_text], [output_text]
197
+
198
+ def __getitem__(self, index):
199
+ if index < len(self.data_list):
200
+ mol_batch = self.data_list[index]
201
+ elif index < len(self.data_list)+len(self.caption_index_list):
202
+ assert self.use_caption_dataset
203
+ return self.get_caption_data(index-len(self.data_list))
204
+ else:
205
+ assert self.use_synthesis_dataset
206
+ return self.get_synthesis_data(index-(len(self.data_list)+len(self.caption_index_list)))
207
+
208
+ graph_list = []
209
+ for mol_dict in mol_batch:
210
+ smiles = mol_dict['canon_smiles']
211
+ if self.disable_graphs:
212
+ graph_item = None
213
+ else:
214
+ if self.disable_graph_cache:
215
+ graph_item = smiles2data(smiles)
216
+ else:
217
+ assert smiles in self.mol_graph_map
218
+ graph_item = self.mol_graph_map[smiles]
219
+ graph_list.append(graph_item)
220
+ mol_prompt_list, text_prompt_list = self.make_prompt(mol_batch, smi_max_len=self.smi_max_len)
221
+
222
+ return graph_list, mol_prompt_list, text_prompt_list
data_provider/pretrain_dm.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ import torch
4
+ from pytorch_lightning import LightningDataModule
5
+ import torch_geometric
6
+ # from torch_geometric.loader import DataLoader
7
+ from torch.utils.data import DataLoader
8
+ from torch_geometric.loader.dataloader import Collater
9
+ from data_provider.molecule_abstract_dataset import MoleculeAbstract
10
+ import re
11
+ from transformers import BatchEncoding
12
+
13
+ # we split individual characters inside special tokens like [START_DNA]
14
+ CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
15
+
16
+ # token added to implement a custom sequence tokenization. This token is added at
17
+ # corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
18
+ # that they do not occur in the corpus. The digits are escaped so that the token does not appear
19
+ # literally in the source code in case we ever include it in the training data.
20
+ SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
21
+
22
+ def _insert_split_marker(m: re.Match):
23
+ """
24
+ Applies split marker based on a regex match of special tokens such as
25
+ [START_DNA].
26
+
27
+ Parameters
28
+ ----------
29
+ n : str
30
+ Input text to split
31
+
32
+ Returns
33
+ ----------
34
+ str - the text with the split token added
35
+ """
36
+ start_token, _, sequence, end_token = m.groups()
37
+ sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
38
+ return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
39
+
40
+
41
+ def smiles_handler(text, mol_ph, is_gal=True):
42
+ smiles_list = []
43
+ for match in CUSTOM_SEQ_RE.finditer(text):
44
+ smiles = match.group(3)
45
+ smiles_list.append(smiles)
46
+ if is_gal:
47
+ text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text)
48
+ text = escape_custom_split_sequence(text)
49
+ return text, smiles_list
50
+ else:
51
+ text = CUSTOM_SEQ_RE.sub(r'\3%s' % (mol_ph), text)
52
+ return text, smiles_list
53
+
54
+
55
+ def escape_custom_split_sequence(text):
56
+ """
57
+ Applies custom splitting to the text for GALILEO's tokenization
58
+
59
+ Parameters
60
+ ----------
61
+ text : str
62
+ Input text to split
63
+
64
+ Returns
65
+ ----------
66
+ str - the text with the split token added
67
+ """
68
+ return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
69
+
70
+
71
+ def tokenize_and_merge_batched_qa_pairs(tokenizer, qa_pairs_list, max_length):
72
+ tokenized_batches = {
73
+ 'input_ids': [],
74
+ 'attention_mask': []
75
+ }
76
+ for qa_pairs in qa_pairs_list:
77
+ max_length_per_qa = max_length // len(qa_pairs)
78
+ batch_input_ids = []
79
+ batch_attention_mask = []
80
+ for qa in qa_pairs:
81
+ # here qa should be string
82
+ tokens = tokenizer(qa,
83
+ truncation=True,
84
+ padding=False,
85
+ add_special_tokens=False,
86
+ max_length=max_length_per_qa,
87
+ return_tensors='pt',
88
+ return_attention_mask=True)
89
+ batch_input_ids.extend(tokens['input_ids'].squeeze().tolist())
90
+ batch_attention_mask.extend(tokens['attention_mask'].squeeze().tolist())
91
+
92
+ # Pad the batch to max_length
93
+ padding_length = max_length - len(batch_input_ids)
94
+ batch_input_ids.extend([tokenizer.pad_token_id] * padding_length)
95
+ batch_attention_mask.extend([0] * padding_length)
96
+
97
+ tokenized_batches['input_ids'].append(torch.tensor(batch_input_ids).unsqueeze(0))
98
+ tokenized_batches['attention_mask'].append(torch.tensor(batch_attention_mask).unsqueeze(0))
99
+
100
+ tokenized_batches['input_ids'] = torch.cat(tokenized_batches['input_ids'], dim=0)
101
+ tokenized_batches['attention_mask'] = torch.cat(tokenized_batches['attention_mask'], dim=0)
102
+
103
+ tokenized_batch = BatchEncoding(data=tokenized_batches, tensor_type='pt')
104
+ return tokenized_batch
105
+
106
+ class TrainCollater:
107
+ def __init__(self, tokenizer, text_max_len, mol_ph, mol_token_id, is_gal=True, disable_graphs=False):
108
+ self.text_max_len = text_max_len
109
+ self.tokenizer = tokenizer
110
+ self.collater = Collater([], [])
111
+ self.mol_ph = mol_ph
112
+ self.mol_token_id = mol_token_id
113
+ self.is_gal = is_gal
114
+ self.disable_graphs = disable_graphs
115
+
116
+ def __call__(self, batch):
117
+ graphs, mol_prompt, text_prompt = zip(*batch)
118
+ if not self.disable_graphs:
119
+ graphs = [graph for graph_batch in graphs for graph in graph_batch]
120
+ graphs = self.collater(graphs)
121
+
122
+ qa_pairs = []
123
+ for mol_batch, text_batch in zip(mol_prompt, text_prompt):
124
+ qa_list = []
125
+ for mol_prompt, text_prompt in zip(mol_batch, text_batch):
126
+ smiles_prompt = smiles_handler(mol_prompt, self.mol_ph, self.is_gal)[0]
127
+ qa_list.append(f'{smiles_prompt} {text_prompt}')
128
+ qa_pairs.append(qa_list)
129
+
130
+ self.tokenizer.padding_side = 'right'
131
+ qa_batch = tokenize_and_merge_batched_qa_pairs(self.tokenizer, qa_pairs, self.text_max_len)
132
+
133
+ is_mol_token = qa_batch.input_ids == self.mol_token_id
134
+ qa_batch['is_mol_token'] = is_mol_token
135
+
136
+ return graphs, qa_batch
137
+
138
+ class InferenceCollater:
139
+ def __init__(self, tokenizer, text_max_len, mol_ph, mol_token_id, is_gal=True, disable_graphs=False, last_only=False):
140
+ self.text_max_len = text_max_len
141
+ self.tokenizer = tokenizer
142
+ self.collater = Collater([], [])
143
+ self.mol_ph = mol_ph
144
+ self.mol_token_id = mol_token_id
145
+ self.is_gal = is_gal
146
+ self.disable_graphs = disable_graphs
147
+ self.last_only = last_only
148
+
149
+ def __call__(self, batch):
150
+ graphs, mol_prompt, text_prompt = zip(*batch)
151
+ rxn_ids = [0 for i in range(len(mol_prompt))]
152
+ if self.last_only:
153
+ mol_prompt = [[mol_batch[-1]] for mol_batch in mol_prompt]
154
+ text_prompt = [[text_batch[-1]] for text_batch in text_prompt]
155
+ graphs = [[graph_batch[-1]] for graph_batch in graphs]
156
+ if not self.disable_graphs:
157
+ graphs = [graph for graph_batch in graphs for graph in graph_batch]
158
+ graphs = self.collater(graphs)
159
+
160
+ input_text, output_text = [], []
161
+ for mol_batch, text_batch in zip(mol_prompt, text_prompt):
162
+ qa_list = []
163
+ for mol_prompt, text_prompt in list(zip(mol_batch, text_batch))[:-1]:
164
+ smiles_prompt = smiles_handler(mol_prompt, self.mol_ph, self.is_gal)[0]
165
+ qa_list.append(f'{smiles_prompt} {text_prompt}')
166
+ qa_list.append(f'{smiles_handler(mol_batch[-1], self.mol_ph, self.is_gal)[0]} ')
167
+ output_text.append(text_batch[-1])
168
+ input_text.append(qa_list)
169
+
170
+ self.tokenizer.padding_side = 'right'
171
+ input_batch = tokenize_and_merge_batched_qa_pairs(self.tokenizer, input_text, self.text_max_len)
172
+
173
+ is_mol_token = input_batch.input_ids == self.mol_token_id
174
+ input_batch['is_mol_token'] = is_mol_token
175
+
176
+ return rxn_ids, graphs, input_batch, output_text, input_text
177
+
178
+
179
+ class PretrainDM(LightningDataModule):
180
+ def __init__(
181
+ self,
182
+ num_workers: int = 0,
183
+ batch_size: int = 256,
184
+ root: str = 'data/',
185
+ text_max_len: int = 128,
186
+ rxn_max_len: int = 128,
187
+ smi_max_len: int = 128,
188
+ tokenizer=None,
189
+ args=None,
190
+ ):
191
+ super().__init__()
192
+ self.args = args
193
+ self.batch_size = batch_size
194
+ self.inference_batch_size = args.inference_batch_size
195
+ self.num_workers = num_workers
196
+ self.text_max_len = text_max_len
197
+ self.rxn_max_len = rxn_max_len
198
+ self.pretrain_dataset = MoleculeAbstract(
199
+ root,
200
+ rxn_num=args.pretrain_rxn_num,
201
+ rxn_batch_size=args.rxn_batch_size,
202
+ smi_max_len=smi_max_len,
203
+ disable_graph_cache=args.disable_graph_cache,
204
+ context_style=args.context_style,
205
+ disable_graphs=args.disable_graphs,
206
+ use_caption_dataset=args.pretrain_use_caption,
207
+ caption_batch_num=args.caption_batch_num,
208
+ synthesis_datasetpath=args.pretrain_synthesis_path,
209
+ synthesis_batch_num=args.synthesis_batch_num,
210
+ reverse_ratio=args.reverse_ratio,
211
+ enable_abstract=not args.disable_abstract,
212
+ enable_property=not args.disable_property,
213
+ smiles_type=args.smiles_type,
214
+ )
215
+ self.test_dataset = MoleculeAbstract(
216
+ root,
217
+ rxn_num=args.pretrain_rxn_num,
218
+ rxn_batch_size=args.rxn_batch_size,
219
+ smi_max_len=smi_max_len,
220
+ disable_graph_cache=args.disable_graph_cache,
221
+ context_style=args.context_style,
222
+ disable_graphs=args.disable_graphs,
223
+ use_caption_dataset=args.pretrain_use_caption,
224
+ caption_batch_num=args.caption_batch_num,
225
+ reverse_ratio=args.reverse_ratio,
226
+ enable_abstract=not args.disable_abstract,
227
+ enable_property=not args.disable_property,
228
+ smiles_type=args.smiles_type,
229
+ mode='test',
230
+ )
231
+ self.init_tokenizer(tokenizer)
232
+ self.mol_ph_token = '<mol>' * self.args.num_query_token
233
+ self.is_gal = args.opt_model.find('galactica') >= 0
234
+ self.disable_graphs = args.disable_graphs
235
+ self.last_only = args.pretrain_eval_last_only
236
+
237
+ def init_tokenizer(self, tokenizer):
238
+ self.tokenizer = tokenizer
239
+ self.pretrain_dataset.tokenizer = tokenizer
240
+ self.test_dataset.tokenizer = tokenizer
241
+ self.mol_token_id = self.tokenizer.mol_token_id
242
+ # self.tokenizer.mol_token_id = tokenizer("<mol>", add_special_tokens=False).input_ids[0]
243
+
244
+ def train_dataloader(self):
245
+ self.pretrain_dataset.reload_data_list()
246
+ loader = DataLoader(
247
+ self.pretrain_dataset,
248
+ batch_size=self.batch_size,
249
+ shuffle=True,
250
+ num_workers=self.num_workers,
251
+ pin_memory=False,
252
+ drop_last=True,
253
+ persistent_workers=True,
254
+ collate_fn=TrainCollater(
255
+ tokenizer=self.tokenizer,
256
+ text_max_len=self.text_max_len,
257
+ mol_ph=self.mol_ph_token,
258
+ mol_token_id=self.mol_token_id,
259
+ is_gal=self.is_gal,
260
+ disable_graphs=self.disable_graphs,
261
+ ),
262
+ )
263
+ return loader
264
+ def val_dataloader(self):
265
+ test_loader = DataLoader(
266
+ self.test_dataset,
267
+ batch_size=self.inference_batch_size,
268
+ shuffle=False,
269
+ num_workers=self.num_workers,
270
+ pin_memory=False,
271
+ drop_last=False,
272
+ persistent_workers=True,
273
+ collate_fn=InferenceCollater(
274
+ tokenizer=self.tokenizer,
275
+ text_max_len=self.text_max_len,
276
+ mol_ph=self.mol_ph_token,
277
+ mol_token_id=self.mol_token_id,
278
+ is_gal=self.is_gal,
279
+ disable_graphs=self.disable_graphs,
280
+ last_only=self.last_only,
281
+ ),
282
+ )
283
+ return [test_loader]
284
+
285
+ def add_model_specific_args(parent_parser):
286
+ parser = parent_parser.add_argument_group("Data module")
287
+ parser.add_argument('--num_workers', type=int, default=2)
288
+ parser.add_argument('--batch_size', type=int, default=4)
289
+ parser.add_argument('--inference_batch_size', type=int, default=4)
290
+ parser.add_argument('--use_smiles', action='store_true', default=False)
291
+ parser.add_argument('--root', type=str, default='data/action_data')
292
+ parser.add_argument('--context_style', type=str, default='weighted_rxn', choices=['weighted_rxn', 'uniform_rxn', 'uniform_mol', 'single_mol', 'hybrid'])
293
+ parser.add_argument('--rxn_max_len', type=int, default=512)
294
+ parser.add_argument('--text_max_len', type=int, default=512)
295
+ parser.add_argument('--smi_max_len', type=int, default=128)
296
+ parser.add_argument('--pretrain_rxn_num', type=int, default=50000)
297
+ parser.add_argument('--reverse_ratio', type=float, default=0.5, help='ratio of reversed reactions (retro reactions)')
298
+ parser.add_argument('--disable_abstract', action='store_true', default=False)
299
+ parser.add_argument('--disable_property', action='store_true', default=False)
300
+ parser.add_argument('--pretrain_use_caption', action='store_true', default=False)
301
+ parser.add_argument('--caption_batch_num', type=int, default=5000)
302
+ parser.add_argument('--pretrain_synthesis_path', type=str, default=None)
303
+ parser.add_argument('--synthesis_batch_num', type=int, default=5000)
304
+ parser.add_argument('--rxn_batch_size', type=int, default=4)
305
+ parser.add_argument('--roundrobin_train', action='store_true', default=False)
306
+ parser.add_argument('--test_subset', type=int, default=-1)
307
+ parser.add_argument('--pretrain_eval_last_only', default=False, action='store_true')
308
+ parser.add_argument('--prompt', type=str, default=None)
309
+ return parent_parser
data_provider/r_smiles.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import argparse
3
+ import re
4
+ import random
5
+ import textdistance
6
+
7
+ from rdkit import Chem
8
+
9
+
10
+ from rdkit import RDLogger
11
+ RDLogger.DisableLog('rdApp.*')
12
+
13
+
14
+ def smi_tokenizer(smi):
15
+ pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
16
+ regex = re.compile(pattern)
17
+ tokens = [token for token in regex.findall(smi)]
18
+ assert smi == ''.join(tokens)
19
+ return ' '.join(tokens)
20
+
21
+
22
+ def clear_map_canonical_smiles(smi, canonical=True, root=-1):
23
+ mol = Chem.MolFromSmiles(smi)
24
+ if mol is not None:
25
+ for atom in mol.GetAtoms():
26
+ if atom.HasProp('molAtomMapNumber'):
27
+ atom.ClearProp('molAtomMapNumber')
28
+ return Chem.MolToSmiles(mol, isomericSmiles=True, rootedAtAtom=root, canonical=canonical)
29
+ else:
30
+ return smi
31
+
32
+
33
+ def get_cano_map_number(smi,root=-1):
34
+ atommap_mol = Chem.MolFromSmiles(smi)
35
+ canonical_mol = Chem.MolFromSmiles(clear_map_canonical_smiles(smi,root=root))
36
+ cano2atommapIdx = atommap_mol.GetSubstructMatch(canonical_mol)
37
+ correct_mapped = [canonical_mol.GetAtomWithIdx(i).GetSymbol() == atommap_mol.GetAtomWithIdx(index).GetSymbol() for i,index in enumerate(cano2atommapIdx)]
38
+ atom_number = len(canonical_mol.GetAtoms())
39
+ if np.sum(correct_mapped) < atom_number or len(cano2atommapIdx) < atom_number:
40
+ cano2atommapIdx = [0] * atom_number
41
+ atommap2canoIdx = canonical_mol.GetSubstructMatch(atommap_mol)
42
+ if len(atommap2canoIdx) != atom_number:
43
+ return None
44
+ for i, index in enumerate(atommap2canoIdx):
45
+ cano2atommapIdx[index] = i
46
+ id2atommap = [atom.GetAtomMapNum() for atom in atommap_mol.GetAtoms()]
47
+
48
+ return [id2atommap[cano2atommapIdx[i]] for i in range(atom_number)]
49
+
50
+
51
+ def get_root_id(mol,root_map_number):
52
+ root = -1
53
+ for i, atom in enumerate(mol.GetAtoms()):
54
+ if atom.GetAtomMapNum() == root_map_number:
55
+ root = i
56
+ break
57
+ return root
58
+ # root = -1
59
+ # for i, atom in enumerate(mol.GetAtoms()):
60
+ # if atom.GetAtomMapNum() == root_map_number:
61
+ # return i
62
+
63
+
64
+ def get_forward_rsmiles(data):
65
+ pt = re.compile(r':(\d+)]')
66
+ product = data['product']
67
+ reactant = data['reactant']
68
+ augmentation = data['augmentation']
69
+ separated = data['separated']
70
+ pro_mol = Chem.MolFromSmiles(product)
71
+ rea_mol = Chem.MolFromSmiles(reactant)
72
+ """checking data quality"""
73
+ rids = sorted(re.findall(pt, reactant))
74
+ pids = sorted(re.findall(pt, product))
75
+ return_status = {
76
+ "status":0,
77
+ "src_data":[],
78
+ "tgt_data":[],
79
+ "edit_distance":0,
80
+ }
81
+ reactant = reactant.split(".")
82
+ product = product.split(".")
83
+ rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant]
84
+ max_times = np.prod([len(map_numbers) for map_numbers in rea_atom_map_numbers])
85
+ times = min(augmentation, max_times)
86
+ reactant_roots = [[-1 for _ in reactant]]
87
+ j = 0
88
+ while j < times:
89
+ reactant_roots.append([random.sample(rea_atom_map_numbers[k], 1)[0] for k in range(len(reactant))])
90
+ if reactant_roots[-1] in reactant_roots[:-1]:
91
+ reactant_roots.pop()
92
+ else:
93
+ j += 1
94
+ if j < augmentation:
95
+ reactant_roots.extend(random.choices(reactant_roots, k=augmentation - times))
96
+ times = augmentation
97
+ reversable = False # no reverse
98
+ assert times == augmentation
99
+ if reversable:
100
+ times = int(times / 2)
101
+
102
+ pro_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", pro))) for pro in product]
103
+ full_pro_atom_map_numbers = set(map(int, re.findall(r"(?<=:)\d+", ".".join(product))))
104
+ for k in range(times):
105
+ tmp = list(zip(reactant, reactant_roots[k],rea_atom_map_numbers))
106
+ random.shuffle(tmp)
107
+ reactant_k, reactant_roots_k,rea_atom_map_numbers_k = [i[0] for i in tmp], [i[1] for i in tmp], [i[2] for i in tmp]
108
+ aligned_reactants = []
109
+ aligned_products = []
110
+ aligned_products_order = []
111
+ all_atom_map = []
112
+ for i, rea in enumerate(reactant_k):
113
+ rea_root_atom_map = reactant_roots_k[i]
114
+ rea_root = get_root_id(Chem.MolFromSmiles(rea), root_map_number=rea_root_atom_map)
115
+ cano_atom_map = get_cano_map_number(rea, rea_root)
116
+ if cano_atom_map is None:
117
+ print(f"Reactant Failed to find Canonical Mol with Atom MapNumber")
118
+ continue
119
+ rea_smi = clear_map_canonical_smiles(rea, canonical=True, root=rea_root)
120
+ aligned_reactants.append(rea_smi)
121
+ all_atom_map.extend(cano_atom_map)
122
+
123
+ for i, pro_map_number in enumerate(pro_atom_map_numbers):
124
+ reactant_candidates = []
125
+ selected_reactant = []
126
+ for j, map_number in enumerate(all_atom_map):
127
+ if map_number in pro_map_number:
128
+ for rea_index, rea_atom_map_number in enumerate(rea_atom_map_numbers_k):
129
+ if map_number in rea_atom_map_number and rea_index not in selected_reactant:
130
+ selected_reactant.append(rea_index)
131
+ reactant_candidates.append((map_number, j, len(rea_atom_map_number)))
132
+
133
+ # select maximal reactant
134
+ reactant_candidates.sort(key=lambda x: x[2], reverse=True)
135
+ map_number = reactant_candidates[0][0]
136
+ j = reactant_candidates[0][1]
137
+ pro_root = get_root_id(Chem.MolFromSmiles(product[i]), root_map_number=map_number)
138
+ pro_smi = clear_map_canonical_smiles(product[i], canonical=True, root=pro_root)
139
+ aligned_products.append(pro_smi)
140
+ aligned_products_order.append(j)
141
+
142
+ sorted_products = sorted(list(zip(aligned_products, aligned_products_order)), key=lambda x: x[1])
143
+ aligned_products = [item[0] for item in sorted_products]
144
+ pro_smi = ".".join(aligned_products)
145
+ if separated:
146
+ reactants = []
147
+ reagents = []
148
+ for i,cano_atom_map in enumerate(rea_atom_map_numbers_k):
149
+ if len(set(cano_atom_map) & full_pro_atom_map_numbers) > 0:
150
+ reactants.append(aligned_reactants[i])
151
+ else:
152
+ reagents.append(aligned_reactants[i])
153
+ rea_smi = ".".join(reactants)
154
+ reactant_tokens = smi_tokenizer(rea_smi)
155
+ if len(reagents) > 0 :
156
+ reactant_tokens += " <separated> " + smi_tokenizer(".".join(reagents))
157
+ else:
158
+ rea_smi = ".".join(aligned_reactants)
159
+ reactant_tokens = smi_tokenizer(rea_smi)
160
+ product_tokens = smi_tokenizer(pro_smi)
161
+ return_status['src_data'].append(reactant_tokens)
162
+ return_status['tgt_data'].append(product_tokens)
163
+ if reversable:
164
+ aligned_reactants.reverse()
165
+ aligned_products.reverse()
166
+ pro_smi = ".".join(aligned_products)
167
+ rea_smi = ".".join(aligned_reactants)
168
+ product_tokens = smi_tokenizer(pro_smi)
169
+ reactant_tokens = smi_tokenizer(rea_smi)
170
+ return_status['src_data'].append(reactant_tokens)
171
+ return_status['tgt_data'].append(product_tokens)
172
+ edit_distances = []
173
+ for src,tgt in zip(return_status['src_data'],return_status['tgt_data']):
174
+ edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split()))
175
+ return_status['edit_distance'] = np.mean(edit_distances)
176
+ return return_status
177
+
178
+
179
+ def get_retro_rsmiles(data):
180
+ pt = re.compile(r':(\d+)]')
181
+ product = data['product']
182
+ reactant = data['reactant']
183
+ augmentation = data['augmentation']
184
+ pro_mol = Chem.MolFromSmiles(product)
185
+ rea_mol = Chem.MolFromSmiles(reactant)
186
+ """checking data quality"""
187
+ rids = sorted(re.findall(pt, reactant))
188
+ pids = sorted(re.findall(pt, product))
189
+ return_status = {
190
+ "status":0,
191
+ "src_data":[],
192
+ "tgt_data":[],
193
+ "edit_distance":0,
194
+ }
195
+ pro_atom_map_numbers = list(map(int, re.findall(r"(?<=:)\d+", product)))
196
+ reactant = reactant.split(".")
197
+ reversable = False # no shuffle
198
+ # augmentation = 100
199
+ if augmentation == 999:
200
+ product_roots = pro_atom_map_numbers
201
+ times = len(product_roots)
202
+ else:
203
+ product_roots = [-1]
204
+ # reversable = len(reactant) > 1
205
+
206
+ max_times = len(pro_atom_map_numbers)
207
+ times = min(augmentation, max_times)
208
+ if times < augmentation: # times = max_times
209
+ product_roots.extend(pro_atom_map_numbers)
210
+ product_roots.extend(random.choices(product_roots, k=augmentation - len(product_roots)))
211
+ else: # times = augmentation
212
+ while len(product_roots) < times:
213
+ product_roots.append(random.sample(pro_atom_map_numbers, 1)[0])
214
+ # pro_atom_map_numbers.remove(product_roots[-1])
215
+ if product_roots[-1] in product_roots[:-1]:
216
+ product_roots.pop()
217
+ times = len(product_roots)
218
+ assert times == augmentation
219
+ if reversable:
220
+ times = int(times / 2)
221
+ # candidates = []
222
+ for k in range(times):
223
+ pro_root_atom_map = product_roots[k]
224
+ pro_root = get_root_id(pro_mol, root_map_number=pro_root_atom_map)
225
+ cano_atom_map = get_cano_map_number(product, root=pro_root)
226
+ if cano_atom_map is None:
227
+ return_status["status"] = "error_mapping"
228
+ return return_status
229
+ pro_smi = clear_map_canonical_smiles(product, canonical=True, root=pro_root)
230
+ aligned_reactants = []
231
+ aligned_reactants_order = []
232
+ rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant]
233
+ used_indices = []
234
+ for i, rea_map_number in enumerate(rea_atom_map_numbers):
235
+ for j, map_number in enumerate(cano_atom_map):
236
+ # select mapping reactans
237
+ if map_number in rea_map_number:
238
+ rea_root = get_root_id(Chem.MolFromSmiles(reactant[i]), root_map_number=map_number)
239
+ rea_smi = clear_map_canonical_smiles(reactant[i], canonical=True, root=rea_root)
240
+ aligned_reactants.append(rea_smi)
241
+ aligned_reactants_order.append(j)
242
+ used_indices.append(i)
243
+ break
244
+ sorted_reactants = sorted(list(zip(aligned_reactants, aligned_reactants_order)), key=lambda x: x[1])
245
+ aligned_reactants = [item[0] for item in sorted_reactants]
246
+ reactant_smi = ".".join(aligned_reactants)
247
+ product_tokens = smi_tokenizer(pro_smi)
248
+ reactant_tokens = smi_tokenizer(reactant_smi)
249
+
250
+ return_status['src_data'].append(product_tokens)
251
+ return_status['tgt_data'].append(reactant_tokens)
252
+
253
+ if reversable:
254
+ aligned_reactants.reverse()
255
+ reactant_smi = ".".join(aligned_reactants)
256
+ product_tokens = smi_tokenizer(pro_smi)
257
+ reactant_tokens = smi_tokenizer(reactant_smi)
258
+ return_status['src_data'].append(product_tokens)
259
+ return_status['tgt_data'].append(reactant_tokens)
260
+ assert len(return_status['src_data']) == data['augmentation']
261
+ edit_distances = []
262
+ for src,tgt in zip(return_status['src_data'],return_status['tgt_data']):
263
+ edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split()))
264
+ return_status['edit_distance'] = np.mean(edit_distances)
265
+ return return_status
266
+
267
+
268
+ def multi_process(data):
269
+ pt = re.compile(r':(\d+)]')
270
+ product = data['product']
271
+ reactant = data['reactant']
272
+ augmentation = data['augmentation']
273
+ pro_mol = Chem.MolFromSmiles(product)
274
+ rea_mol = Chem.MolFromSmiles(reactant)
275
+ """checking data quality"""
276
+ rids = sorted(re.findall(pt, reactant))
277
+ pids = sorted(re.findall(pt, product))
278
+ return_status = {
279
+ "status":0,
280
+ "src_data":[],
281
+ "tgt_data":[],
282
+ "edit_distance":0,
283
+ }
284
+ # if ",".join(rids) != ",".join(pids): # mapping is not 1:1
285
+ # return_status["status"] = "error_mapping"
286
+ # if len(set(rids)) != len(rids): # mapping is not 1:1
287
+ # return_status["status"] = "error_mapping"
288
+ # if len(set(pids)) != len(pids): # mapping is not 1:1
289
+ # return_status["status"] = "error_mapping"
290
+ if "" == product:
291
+ return_status["status"] = "empty_p"
292
+ if "" == reactant:
293
+ return_status["status"] = "empty_r"
294
+ if rea_mol is None:
295
+ return_status["status"] = "invalid_r"
296
+ if len(rea_mol.GetAtoms()) < 5:
297
+ return_status["status"] = "small_r"
298
+ if pro_mol is None:
299
+ return_status["status"] = "invalid_p"
300
+ if len(pro_mol.GetAtoms()) == 1:
301
+ return_status["status"] = "small_p"
302
+ if not all([a.HasProp('molAtomMapNumber') for a in pro_mol.GetAtoms()]):
303
+ return_status["status"] = "error_mapping_p"
304
+ """finishing checking data quality"""
305
+
306
+ if return_status['status'] == 0:
307
+ pro_atom_map_numbers = list(map(int, re.findall(r"(?<=:)\d+", product)))
308
+ reactant = reactant.split(".")
309
+ if data['root_aligned']:
310
+ reversable = False # no shuffle
311
+ # augmentation = 100
312
+ if augmentation == 999:
313
+ product_roots = pro_atom_map_numbers
314
+ times = len(product_roots)
315
+ else:
316
+ product_roots = [-1]
317
+ # reversable = len(reactant) > 1
318
+
319
+ max_times = len(pro_atom_map_numbers)
320
+ times = min(augmentation, max_times)
321
+ if times < augmentation: # times = max_times
322
+ product_roots.extend(pro_atom_map_numbers)
323
+ product_roots.extend(random.choices(product_roots, k=augmentation - len(product_roots)))
324
+ else: # times = augmentation
325
+ while len(product_roots) < times:
326
+ product_roots.append(random.sample(pro_atom_map_numbers, 1)[0])
327
+ # pro_atom_map_numbers.remove(product_roots[-1])
328
+ if product_roots[-1] in product_roots[:-1]:
329
+ product_roots.pop()
330
+ times = len(product_roots)
331
+ assert times == augmentation
332
+ if reversable:
333
+ times = int(times / 2)
334
+ # candidates = []
335
+ for k in range(times):
336
+ pro_root_atom_map = product_roots[k]
337
+ pro_root = get_root_id(pro_mol, root_map_number=pro_root_atom_map)
338
+ cano_atom_map = get_cano_map_number(product, root=pro_root)
339
+ if cano_atom_map is None:
340
+ return_status["status"] = "error_mapping"
341
+ return return_status
342
+ pro_smi = clear_map_canonical_smiles(product, canonical=True, root=pro_root)
343
+ aligned_reactants = []
344
+ aligned_reactants_order = []
345
+ rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant]
346
+ used_indices = []
347
+ for i, rea_map_number in enumerate(rea_atom_map_numbers):
348
+ for j, map_number in enumerate(cano_atom_map):
349
+ # select mapping reactans
350
+ if map_number in rea_map_number:
351
+ rea_root = get_root_id(Chem.MolFromSmiles(reactant[i]), root_map_number=map_number)
352
+ rea_smi = clear_map_canonical_smiles(reactant[i], canonical=True, root=rea_root)
353
+ aligned_reactants.append(rea_smi)
354
+ aligned_reactants_order.append(j)
355
+ used_indices.append(i)
356
+ break
357
+ sorted_reactants = sorted(list(zip(aligned_reactants, aligned_reactants_order)), key=lambda x: x[1])
358
+ aligned_reactants = [item[0] for item in sorted_reactants]
359
+ reactant_smi = ".".join(aligned_reactants)
360
+ product_tokens = smi_tokenizer(pro_smi)
361
+ reactant_tokens = smi_tokenizer(reactant_smi)
362
+
363
+ return_status['src_data'].append(product_tokens)
364
+ return_status['tgt_data'].append(reactant_tokens)
365
+
366
+ if reversable:
367
+ aligned_reactants.reverse()
368
+ reactant_smi = ".".join(aligned_reactants)
369
+ product_tokens = smi_tokenizer(pro_smi)
370
+ reactant_tokens = smi_tokenizer(reactant_smi)
371
+ return_status['src_data'].append(product_tokens)
372
+ return_status['tgt_data'].append(reactant_tokens)
373
+ assert len(return_status['src_data']) == data['augmentation']
374
+ else:
375
+ cano_product = clear_map_canonical_smiles(product)
376
+ cano_reactanct = ".".join([clear_map_canonical_smiles(rea) for rea in reactant if len(set(map(int, re.findall(r"(?<=:)\d+", rea))) & set(pro_atom_map_numbers)) > 0 ])
377
+ return_status['src_data'].append(smi_tokenizer(cano_product))
378
+ return_status['tgt_data'].append(smi_tokenizer(cano_reactanct))
379
+ pro_mol = Chem.MolFromSmiles(cano_product)
380
+ rea_mols = [Chem.MolFromSmiles(rea) for rea in cano_reactanct.split(".")]
381
+ for i in range(int(augmentation-1)):
382
+ pro_smi = Chem.MolToSmiles(pro_mol,doRandom=True)
383
+ rea_smi = [Chem.MolToSmiles(rea_mol,doRandom=True) for rea_mol in rea_mols]
384
+ rea_smi = ".".join(rea_smi)
385
+ return_status['src_data'].append(smi_tokenizer(pro_smi))
386
+ return_status['tgt_data'].append(smi_tokenizer(rea_smi))
387
+ edit_distances = []
388
+ for src,tgt in zip(return_status['src_data'],return_status['tgt_data']):
389
+ edit_distances.append(textdistance.levenshtein.distance(src.split(),tgt.split()))
390
+ return_status['edit_distance'] = np.mean(edit_distances)
391
+ return return_status
392
+
393
+ if __name__ == '__main__':
394
+ parser = argparse.ArgumentParser()
395
+ parser.add_argument('-rxn',type=str,required=True)
396
+ parser.add_argument('-mode',type=str,default="retro",)
397
+ parser.add_argument('-forward_mode',type=str,default="separated",)
398
+ parser.add_argument("-augmentation",type=int,default=1)
399
+ parser.add_argument("-seed",type=int,default=33)
400
+ args = parser.parse_args()
401
+ print(args)
402
+ reactant,reagent,product = args.rxn.split(">")
403
+ pt = re.compile(r':(\d+)]')
404
+ rids = sorted(re.findall(pt, reactant))
405
+ pids = sorted(re.findall(pt, product))
406
+ if len(rids) == 0 or len(pids) == 0:
407
+ print("No atom mapping found!")
408
+ exit(1)
409
+ if args.mode == "retro":
410
+ args.input = product
411
+ args.output = reactant
412
+ else:
413
+ args.input = reactant
414
+ args.output = product
415
+
416
+ print("Original input:", args.input)
417
+ print("Original output:",args.output)
418
+ src_smi = clear_map_canonical_smiles(args.input)
419
+ tgt_smi = clear_map_canonical_smiles(args.output)
420
+ if src_smi == "" or tgt_smi == "":
421
+ print("Invalid SMILES!")
422
+ exit(1)
423
+ print("Canonical input:", src_smi)
424
+ print("Canonical output:",tgt_smi)
425
+
426
+ mapping_check = True
427
+ if ",".join(rids) != ",".join(pids): # mapping is not 1:1
428
+ mapping_check = False
429
+ if len(set(rids)) != len(rids): # mapping is not 1:1
430
+ mapping_check = False
431
+ if len(set(pids)) != len(pids): # mapping is not 1:1
432
+ mapping_check = False
433
+ if not mapping_check:
434
+ print("The quality of the atom mapping may not be good enough, which can affect the effect of root alignment.")
435
+ data = {
436
+ 'product':product,
437
+ 'reactant':reactant,
438
+ 'augmentation':args.augmentation,
439
+ 'separated':args.forward_mode == "separated"
440
+ }
441
+ if args.mode == "retro":
442
+ res = get_retro_rsmiles(data)
443
+ else:
444
+ res = get_forward_rsmiles(data)
445
+ for index,(src,tgt) in enumerate(zip(res['src_data'], res['tgt_data'])):
446
+ print(f"ID:{index}")
447
+ print(f"R-SMILES input:{''.join(src.split())}")
448
+ print(f"R-SMILES output:{''.join(tgt.split())}")
449
+ print("Avg. edit distance:", res['edit_distance'])
data_provider/reaction_action_dataset.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.data import Dataset
3
+ import os
4
+ import random
5
+ import json
6
+ from .data_utils import smiles2data, reformat_smiles
7
+
8
+ class ActionDataset(Dataset):
9
+ def __init__(self, root, mode, smi_max_len, use_graph=True, disable_graph_cache=False, predict_rxn_condition=False, smiles_type='default'):
10
+ super(ActionDataset, self).__init__(root)
11
+ self.root = root
12
+ self.smi_max_len = smi_max_len
13
+ self.tokenizer = None
14
+ self.use_graph = use_graph
15
+ self.disable_graph_cache = disable_graph_cache
16
+ self.predict_rxn_condition = predict_rxn_condition
17
+ self.smiles_type = smiles_type
18
+
19
+ with open(os.path.join(self.root, f'{mode}.json'), encoding='utf-8') as f:
20
+ self.data_list = json.load(f)
21
+ if self.use_graph:
22
+ self.mol_graph_map = torch.load(os.path.join(self.root, 'mol_graph_map.pt'))
23
+ # self.data_list = self.data_list[:100]
24
+
25
+ def get(self, index):
26
+ return self.__getitem__(index)
27
+
28
+ def len(self):
29
+ return len(self)
30
+
31
+ def __len__(self):
32
+ return len(self.data_list)
33
+
34
+ def make_prompt(self, param_dict, smi_max_len=128, predict_rxn_condition=False):
35
+ action_sequence = param_dict['actions']
36
+ smiles_list = []
37
+ prompt = ''
38
+ prompt += 'Reactants: '
39
+ smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len]
40
+ for smi in param_dict['REACTANT']:
41
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
42
+ smiles_list.append(smi)
43
+
44
+ prompt += 'Product: '
45
+ for smi in param_dict['PRODUCT']:
46
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
47
+ smiles_list.append(smi)
48
+
49
+ if param_dict['CATALYST']:
50
+ prompt += 'Catalysts: '
51
+ for smi in param_dict['CATALYST']:
52
+ if smi in param_dict["extracted_molecules"]:
53
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
54
+ else:
55
+ prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
56
+ smiles_list.append(smi)
57
+
58
+ if param_dict['SOLVENT']:
59
+ prompt += 'Solvents: '
60
+ for smi in param_dict['SOLVENT']:
61
+ if smi in param_dict["extracted_molecules"]:
62
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
63
+ else:
64
+ prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
65
+ smiles_list.append(smi)
66
+
67
+ if predict_rxn_condition:
68
+ for value, token in param_dict['extracted_duration'].items():
69
+ action_sequence = action_sequence.replace(token, value)
70
+ for value, token in param_dict['extracted_temperature'].items():
71
+ action_sequence = action_sequence.replace(token, value)
72
+ else:
73
+ prompt += 'Temperatures: '
74
+ for value, token in param_dict['extracted_temperature'].items():
75
+ prompt += f'{token}: {value} '
76
+
77
+ prompt += 'Durations: '
78
+ for value, token in param_dict['extracted_duration'].items():
79
+ prompt += f'{token}: {value} '
80
+
81
+ prompt += 'Action Squence: '
82
+ return prompt, smiles_list, action_sequence
83
+
84
+ def __getitem__(self, index):
85
+ rxn_dict = self.data_list[index]
86
+ rxn_id = rxn_dict['index']
87
+ input_text, smiles_list, output_text = self.make_prompt(rxn_dict, self.smi_max_len, self.predict_rxn_condition)
88
+ output_text = output_text.strip() + '\n'
89
+
90
+ graph_list = []
91
+ if self.use_graph:
92
+ for smiles in smiles_list:
93
+ if self.disable_graph_cache:
94
+ graph_item = smiles2data(smiles)
95
+ else:
96
+ assert smiles in self.mol_graph_map
97
+ graph_item = self.mol_graph_map[smiles]
98
+ graph_list.append(graph_item)
99
+ return rxn_id, graph_list, output_text, input_text
100
+
data_provider/synthesis_dataset.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.data import Dataset
3
+ import os
4
+ import random
5
+ import json
6
+ from .data_utils import smiles2data, escape_custom_split_sequence, reformat_smiles, generate_rsmiles
7
+
8
+ class SynthesisDataset(Dataset):
9
+ def __init__(self,
10
+ root,
11
+ mode,
12
+ smi_max_len=128,
13
+ use_graph=True,
14
+ disable_graph_cache=False,
15
+ smiles_type='default',
16
+ roundrobin_train=False,
17
+ test_subset=-1
18
+ ):
19
+ super(SynthesisDataset, self).__init__(root)
20
+ self.root = root
21
+ if 'PtoR' in root:
22
+ self.task = 'retro'
23
+ elif 'pretrain' in root:
24
+ self.task = 'pretrain'
25
+ elif 'RtoP' in root:
26
+ self.task = 'forward'
27
+ else:
28
+ raise NotImplementedError(f'Invalid task: {root}')
29
+ if mode=='valid':
30
+ mode='val'
31
+ self.mode = mode
32
+ self.smi_max_len = smi_max_len
33
+ self.tokenizer = None
34
+ self.use_graph = use_graph
35
+ self.disable_graph_cache = disable_graph_cache
36
+ self.smiles_type = smiles_type
37
+ self.roundrobin_train = roundrobin_train
38
+ with open(os.path.join(root, 'mol_graphid_map.json')) as f:
39
+ self.mol_idx_map = json.load(f)
40
+ if self.use_graph:
41
+ self.idx_graph_map = torch.load(os.path.join(root, 'idx_graph_map.pt'))
42
+
43
+ if self.roundrobin_train and mode=='train':
44
+ self.reload_counter=-2
45
+ self.reload_data()
46
+ else:
47
+ with open(os.path.join(root, mode, f'src-{mode}.txt')) as f:
48
+ self.input_list = f.readlines()
49
+ with open(os.path.join(root, mode, f'tgt-{mode}.txt')) as f:
50
+ self.output_list = f.readlines()
51
+ assert len(self.input_list) == len(self.output_list)
52
+ self.renew_r_smiles()
53
+ self.input_list = [smi.strip().replace(' ','') for smi in self.input_list]
54
+ self.output_list = [smi.strip().replace(' ','') for smi in self.output_list]
55
+ if test_subset>0 and mode=='test':
56
+ assert test_subset<=len(self.input_list)
57
+ self.input_list = self.input_list[:test_subset]
58
+ self.input_list = self.input_list[:test_subset]
59
+
60
+ def reload_data(self):
61
+ if not self.roundrobin_train:
62
+ return
63
+ self.reload_counter = (self.reload_counter+1)%10
64
+ if hasattr(self, 'input_list'):
65
+ del self.input_list
66
+ if hasattr(self, 'output_list'):
67
+ del self.output_list
68
+ with open(os.path.join(self.root, f'train/src-train_{self.reload_counter}.txt')) as f:
69
+ self.input_list = f.readlines()
70
+ with open(os.path.join(self.root, f'train/tgt-train_{self.reload_counter}.txt')) as f:
71
+ self.output_list = f.readlines()
72
+ assert len(self.input_list) == len(self.output_list)
73
+ self.renew_r_smiles()
74
+ self.input_list = [smi.strip().replace(' ','') for smi in self.input_list]
75
+ self.output_list = [smi.strip().replace(' ','') for smi in self.output_list]
76
+ input_list, output_list = [], []
77
+ for input_smiles, output_smiles in zip(self.input_list, self.output_list):
78
+ if input_smiles.count('.') != output_smiles.count('.'):
79
+ continue
80
+ input_list.append(input_smiles)
81
+ output_list.append(output_smiles)
82
+ print(f'Reloaded data from {self.root}/train/src-train_{self.reload_counter}.txt, filtered len={len(self.input_list)}', flush=True)
83
+ self.input_list = input_list
84
+ self.output_list = output_list
85
+
86
+ def renew_r_smiles(self):
87
+ if self.smiles_type == 'r_smiles' and self.mode == 'train':
88
+ # only renew r_smiles for training set
89
+ if not hasattr(self, 'input_list_mapped'):
90
+ # here we back up the original input_list and output_list
91
+ self.input_list_mapped = self.input_list
92
+ self.output_list_mapped = self.output_list
93
+ self.output_list, self.input_list = generate_rsmiles(self.output_list_mapped, self.input_list_mapped)
94
+ self.input_list = [smi.strip().replace(' ','') for smi in self.input_list]
95
+ self.output_list = [smi.strip().replace(' ','') for smi in self.output_list]
96
+
97
+ def get(self, index):
98
+ return self.__getitem__(index)
99
+
100
+ def len(self):
101
+ return len(self)
102
+
103
+ def __len__(self):
104
+ return len(self.input_list)
105
+
106
+ def make_prompt(self, input_smiles, output_smiles, smi_max_len=512):
107
+ FORWARD_PROMPT = 'Question: Given the following reactant molecules: {}, what are the expected products? Answer: The product molecules are '
108
+ FORWARD_CATALYST_PROMPT = '{}, and the following catalyst molecules: {}'
109
+ RETRO_PROMPT = 'Question: Given the following product molecules: {}, what are the reactants that produce them? Answer: The reactant molecules are '
110
+ # RETRO_PROMPT = 'Predict the reaction that produces the following product: {} '
111
+ PRETRAIN_PROMPT = 'Reconstruct the masked molecule: {}. Answer: '
112
+ smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len]
113
+ if self.task=='retro':
114
+ assert '<separated>' not in input_smiles
115
+ smiles_list = input_smiles.split('.')
116
+ in_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in smiles_list])
117
+ input_prompt = RETRO_PROMPT.format(in_prompt)
118
+ elif self.task=='forward':
119
+ if '<separated>' in input_smiles:
120
+ reactant_smiles, reagent_smiles = input_smiles.split('<separated>')
121
+ reactant_smiles = reactant_smiles.split('.')
122
+ reagent_smiles = reagent_smiles.split('.')
123
+ reactant_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in reactant_smiles])
124
+ reagent_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in reagent_smiles])
125
+ smiles_list = reactant_smiles+reagent_smiles
126
+ input_prompt = FORWARD_CATALYST_PROMPT.format(reactant_prompt, reagent_prompt)
127
+ else:
128
+ smiles_list = input_smiles.split('.')
129
+ reactant_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in smiles_list])
130
+ input_prompt = reactant_prompt
131
+ input_prompt = FORWARD_PROMPT.format(input_prompt)
132
+ elif self.task=='pretrain':
133
+ in_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in input_smiles.split('.')])
134
+ input_prompt = PRETRAIN_PROMPT.format(in_prompt)
135
+ smiles_list = output_smiles.split('.')
136
+ # output_smiles = ' '.join([f'[START_SMILES]{smi[:smi_max_len]}[END_SMILES]' for smi in output_smiles.split('.')])
137
+ output_smiles = f'[START_SMILES]{output_smiles}[END_SMILES]'
138
+ output_smiles = escape_custom_split_sequence(output_smiles)
139
+
140
+ return input_prompt, smiles_list, output_smiles
141
+
142
+ def __getitem__(self, index):
143
+ input_smiles = self.input_list[index]
144
+ output_smiles = self.output_list[index]
145
+ input_text, smiles_list, output_text = self.make_prompt(input_smiles, output_smiles, smi_max_len=self.smi_max_len)
146
+ output_text = output_text.strip()+'\n'
147
+
148
+ graph_list = []
149
+ if self.use_graph:
150
+ for smiles in smiles_list:
151
+ if self.disable_graph_cache:
152
+ graph_item = smiles2data(smiles)
153
+ else:
154
+ assert smiles in self.mol_idx_map
155
+ idx = self.mol_idx_map[smiles]
156
+ assert idx in self.idx_graph_map
157
+ graph_item = self.idx_graph_map[idx]
158
+ graph_list.append(graph_item)
159
+
160
+ return index, graph_list, output_text, input_text
data_provider/tune_dm.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ import torch
4
+ from pytorch_lightning import LightningDataModule
5
+ import torch_geometric
6
+ # from torch_geometric.loader import DataLoader
7
+ from torch.utils.data import DataLoader
8
+ from torch_geometric.loader.dataloader import Collater
9
+ from data_provider.reaction_action_dataset import ActionDataset
10
+ from data_provider.synthesis_dataset import SynthesisDataset
11
+ from data_provider.caption_dataset import CaptionDataset
12
+ from data_provider.chebi_dataset import ChEBI_dataset
13
+ import re
14
+
15
+ # we split individual characters inside special tokens like [START_DNA]
16
+ CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
17
+
18
+ # token added to implement a custom sequence tokenization. This token is added at
19
+ # corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
20
+ # that they do not occur in the corpus. The digits are escaped so that the token does not appear
21
+ # literally in the source code in case we ever include it in the training data.
22
+ SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
23
+
24
+ def _insert_split_marker(m: re.Match):
25
+ """
26
+ Applies split marker based on a regex match of special tokens such as
27
+ [START_DNA].
28
+
29
+ Parameters
30
+ ----------
31
+ n : str
32
+ Input text to split
33
+
34
+ Returns
35
+ ----------
36
+ str - the text with the split token added
37
+ """
38
+ start_token, _, sequence, end_token = m.groups()
39
+ sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
40
+ return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
41
+
42
+ def smiles_handler(text, mol_ph, is_gal=True):
43
+ smiles_list = []
44
+ for match in CUSTOM_SEQ_RE.finditer(text):
45
+ smiles = match.group(3)
46
+ smiles_list.append(smiles)
47
+ if is_gal:
48
+ text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text)
49
+ text = escape_custom_split_sequence(text)
50
+ return text, smiles_list
51
+ else:
52
+ text = CUSTOM_SEQ_RE.sub(r'\3%s' % (mol_ph), text)
53
+ return text, smiles_list
54
+
55
+ def escape_custom_split_sequence(text):
56
+ """
57
+ Applies custom splitting to the text for GALILEO's tokenization
58
+
59
+ Parameters
60
+ ----------
61
+ text : str
62
+ Input text to split
63
+
64
+ Returns
65
+ ----------
66
+ str - the text with the split token added
67
+ """
68
+ return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
69
+
70
+ class TrainCollater:
71
+ def __init__(self, tokenizer, text_max_len, rxn_max_len, mol_ph, mol_token_id, is_gal=True, use_graph=True, use_qa_pair=True):
72
+ self.rxn_max_len = rxn_max_len
73
+ self.text_max_len = text_max_len
74
+ self.tokenizer = tokenizer
75
+ self.collater = Collater([], [])
76
+ self.mol_ph = mol_ph
77
+ self.mol_token_id = mol_token_id
78
+ self.is_gal = is_gal
79
+ self.use_graph = use_graph
80
+ self.use_qa_pair = use_qa_pair
81
+
82
+ def __call__(self, batch):
83
+ return self.collate_qa(batch) if self.use_qa_pair else self.collate(batch)
84
+
85
+ def collate(self, batch):
86
+ rxn_ids, graphs, texts, smiles_prompt = zip(*batch)
87
+ if graphs:
88
+ graphs = self.collater(graphs)
89
+
90
+ ## deal with prompt
91
+ if self.use_graph:
92
+ smiles_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in smiles_prompt]
93
+ else:
94
+ smiles_prompt = [escape_custom_split_sequence(p) for p in smiles_prompt]
95
+
96
+ self.tokenizer.padding_side = 'left'
97
+ smiles_prompt_tokens = self.tokenizer(text=smiles_prompt,
98
+ truncation=False,
99
+ padding='longest',
100
+ add_special_tokens=True,
101
+ return_tensors='pt',
102
+ return_attention_mask=True)
103
+
104
+ is_mol_token = smiles_prompt_tokens.input_ids == self.mol_token_id
105
+ smiles_prompt_tokens['is_mol_token'] = is_mol_token
106
+
107
+ self.tokenizer.padding_side = 'right'
108
+ text_tokens = self.tokenizer(text=texts,
109
+ truncation=True,
110
+ padding='longest',
111
+ add_special_tokens=True,
112
+ max_length=self.text_max_len,
113
+ return_tensors='pt',
114
+ return_attention_mask=True)
115
+ return rxn_ids, graphs, smiles_prompt_tokens, text_tokens
116
+
117
+ def collate_qa(self, batch):
118
+ rxn_ids, graphs, texts, input_prompt = zip(*batch)
119
+ graphs = [graph for graph_batch in graphs for graph in graph_batch]
120
+ if graphs:
121
+ graphs = self.collater(graphs)
122
+
123
+ ## deal with prompt
124
+ if self.use_graph:
125
+ input_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in input_prompt]
126
+ else:
127
+ input_prompt = [escape_custom_split_sequence(p) for p in input_prompt]
128
+
129
+ self.tokenizer.padding_side = 'right'
130
+ qa_pair = [[q, a] for q, a in zip(input_prompt, texts)]
131
+ qa_batch = self.tokenizer(qa_pair,
132
+ truncation=True,
133
+ padding='longest',
134
+ add_special_tokens=True,
135
+ max_length=self.rxn_max_len + self.text_max_len,
136
+ return_tensors='pt',
137
+ return_attention_mask=True,
138
+ return_token_type_ids=True)
139
+ is_mol_token = qa_batch.input_ids == self.mol_token_id
140
+ qa_batch['is_mol_token'] = is_mol_token
141
+ return rxn_ids, graphs, qa_batch
142
+
143
+ class InferenceCollater:
144
+ def __init__(self, tokenizer, text_max_len, rxn_max_len, mol_ph, mol_token_id, is_gal=True):
145
+ self.text_max_len = text_max_len
146
+ self.rxn_max_len = rxn_max_len
147
+ self.tokenizer = tokenizer
148
+ self.collater = Collater([], [])
149
+ self.mol_ph = mol_ph
150
+ self.mol_token_id = mol_token_id
151
+ self.is_gal = is_gal
152
+
153
+ def __call__(self, batch):
154
+ rxn_ids, graphs, texts, input_prompt = zip(*batch)
155
+ inputs = input_prompt
156
+ graphs = [graph for graph_batch in graphs for graph in graph_batch]
157
+ if graphs:
158
+ graphs = self.collater(graphs)
159
+ input_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in input_prompt]
160
+
161
+ ## deal with prompt
162
+ self.tokenizer.padding_side = 'left'
163
+ input_prompt_tokens = self.tokenizer(input_prompt,
164
+ truncation=True,
165
+ padding='longest',
166
+ add_special_tokens=True,
167
+ max_length=self.rxn_max_len,
168
+ return_tensors='pt',
169
+ return_attention_mask=True)
170
+
171
+ is_mol_token = input_prompt_tokens.input_ids == self.mol_token_id
172
+ input_prompt_tokens['is_mol_token'] = is_mol_token
173
+ return rxn_ids, graphs, input_prompt_tokens, texts, inputs
174
+
175
+ class TuneDM(LightningDataModule):
176
+ def __init__(
177
+ self,
178
+ num_workers: int = 0,
179
+ batch_size: int = 256,
180
+ root: str = 'data/',
181
+ text_max_len: int = 128,
182
+ smi_max_len: int = 128,
183
+ rxn_max_len: int = 128,
184
+ tokenizer=None,
185
+ downstream_task='action',
186
+ args=None,
187
+ ):
188
+ super().__init__()
189
+ self.args = args
190
+ self.batch_size = batch_size
191
+ self.inference_batch_size = args.inference_batch_size
192
+ self.num_workers = num_workers
193
+ self.rxn_max_len = rxn_max_len
194
+ self.text_max_len = text_max_len
195
+ self.prompt = args.prompt
196
+ DownstreamDataset = {
197
+ 'action': ActionDataset,
198
+ 'synthesis': SynthesisDataset,
199
+ 'caption': CaptionDataset,
200
+ 'chebi': ChEBI_dataset,
201
+ }[downstream_task]
202
+ ds_args = {
203
+ 'use_graph': not args.disable_graphs,
204
+ 'disable_graph_cache': args.disable_graph_cache,
205
+ 'smiles_type': args.smiles_type,
206
+ }
207
+ if downstream_task == 'action':
208
+ ds_args['predict_rxn_condition'] = args.predict_rxn_condition
209
+ if downstream_task == 'synthesis':
210
+ ds_args['roundrobin_train'] = args.roundrobin_train
211
+ ds_args['test_subset'] = args.test_subset
212
+ self.train_dataset = DownstreamDataset(root, 'train', smi_max_len, **ds_args)
213
+ self.val_dataset = DownstreamDataset(root, 'valid', smi_max_len, **ds_args)
214
+ self.test_dataset = DownstreamDataset(root, 'test', smi_max_len, **ds_args)
215
+ self.init_tokenizer(tokenizer)
216
+ self.mol_ph_token = '<mol>' * self.args.num_query_token
217
+ self.is_gal = args.opt_model.find('galactica') >= 0
218
+ self.use_graph = not args.disable_graphs
219
+ self.is_t5 = args.opt_model.find('t5') >= 0
220
+
221
+ def init_tokenizer(self, tokenizer):
222
+ self.tokenizer = tokenizer
223
+ self.train_dataset.tokenizer = tokenizer
224
+ self.val_dataset.tokenizer = tokenizer
225
+ self.test_dataset.tokenizer = tokenizer
226
+ self.mol_token_id = self.tokenizer.mol_token_id
227
+ # self.tokenizer.mol_token_id = tokenizer("<mol>", add_special_tokens=False).input_ids[0]
228
+
229
+ def train_dataloader(self):
230
+ if self.args.roundrobin_train:
231
+ self.train_dataset.reload_data()
232
+ if hasattr(self.train_dataset, 'renew_r_smiles'):
233
+ self.train_dataset.renew_r_smiles()
234
+ loader = DataLoader(
235
+ self.train_dataset,
236
+ batch_size=self.batch_size,
237
+ shuffle=True,
238
+ num_workers=self.num_workers,
239
+ pin_memory=False,
240
+ drop_last=True,
241
+ persistent_workers=True,
242
+ collate_fn=TrainCollater(
243
+ tokenizer=self.tokenizer,
244
+ text_max_len=self.text_max_len,
245
+ rxn_max_len=self.rxn_max_len,
246
+ mol_ph=self.mol_ph_token,
247
+ mol_token_id=self.mol_token_id,
248
+ is_gal=self.is_gal,
249
+ use_graph=self.use_graph,
250
+ use_qa_pair=not self.is_t5,
251
+ ),
252
+ )
253
+ return loader
254
+
255
+ def val_dataloader(self):
256
+ test_loader = DataLoader(
257
+ self.test_dataset,
258
+ batch_size=self.inference_batch_size,
259
+ shuffle=False,
260
+ num_workers=self.num_workers,
261
+ pin_memory=False,
262
+ drop_last=False,
263
+ persistent_workers=True,
264
+ collate_fn=InferenceCollater(
265
+ tokenizer=self.tokenizer,
266
+ text_max_len=self.text_max_len,
267
+ rxn_max_len=self.rxn_max_len,
268
+ mol_ph=self.mol_ph_token,
269
+ mol_token_id=self.mol_token_id,
270
+ is_gal=self.is_gal
271
+ ),
272
+ )
273
+ return [test_loader]
274
+ val_loader = DataLoader(
275
+ self.val_dataset,
276
+ batch_size=self.batch_size,
277
+ shuffle=False,
278
+ num_workers=self.num_workers,
279
+ pin_memory=False,
280
+ drop_last=False,
281
+ persistent_workers=True,
282
+ collate_fn=InferenceCollater(
283
+ tokenizer=self.tokenizer,
284
+ text_max_len=self.text_max_len,
285
+ rxn_max_len=self.rxn_max_len,
286
+ mol_ph=self.mol_ph_token,
287
+ mol_token_id=self.mol_token_id,
288
+ is_gal=self.is_gal
289
+ ),
290
+ )
291
+ return [val_loader, test_loader]
292
+
293
+ def test_dataloader(self):
294
+ loader = DataLoader(
295
+ self.test_dataset,
296
+ batch_size=self.inference_batch_size,
297
+ shuffle=False,
298
+ num_workers=self.num_workers,
299
+ pin_memory=False,
300
+ drop_last=False,
301
+ persistent_workers=True,
302
+ collate_fn=InferenceCollater(
303
+ tokenizer=self.tokenizer,
304
+ text_max_len=self.text_max_len,
305
+ rxn_max_len=self.rxn_max_len,
306
+ mol_ph=self.mol_ph_token,
307
+ mol_token_id=self.mol_token_id,
308
+ is_gal=self.is_gal
309
+ ),
310
+ )
311
+ return loader
312
+
demo.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [
2
+ "[BH4-].[Na+].COC(=O)c1nc2ccccn2c1C>C1CCOC1.CO>Cc1c(CO)nc2ccccn12",
3
+ "NCCN.O=C(O)c1cc2nc(-c3ccc(-c4ccccc4)cc3)c(Cl)cc2n1CO>CN(C)C=O>O=C(O)c1cc2nc(-c3ccc(-c4ccccc4)cc3)c(Cl)cc2[nH]1",
4
+ "CC[O-].[Na+].CC(C)[N+](=O)[O-].Cc1cc(C)nc(NC(=O)NS(=O)(=O)c2ccccc2CCl)n1>CCO>Cc1cc(C)nc(NC(=O)NS(=O)(=O)c2ccccc2C=O)n1",
5
+ "COC(=O)c1ccc2c(c1)nc(C(C)(C)C)n2CC1CCC(F)(F)CC1.O=S(=O)([O-])O.[K+]>[Li+].[OH-].C1COCCO1>CC(C)(C)c1nc2cc(C(=O)O)ccc2n1CC1CCC(F)(F)CC1",
6
+ "FC(F)(F)c1cccc2c(Br)c(Cc3ccccc3)cnc12.CC1(C)OB(c2cncc(C=O)c2)OC1(C)C>O=C([O-])[O-].[Na+].[Na+].Cc1ccccc1.CCO>O=Cc1cncc(-c2c(Cc3ccccc3)cnc3c(C(F)(F)F)cccc23)c1"
7
+ ]
demo.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import warnings
5
+ import pytorch_lightning as pl
6
+ from pytorch_lightning import Trainer, strategies
7
+ import pytorch_lightning.callbacks as plc
8
+ from pytorch_lightning.loggers import CSVLogger
9
+ from pytorch_lightning.callbacks import TQDMProgressBar
10
+ from data_provider.pretrain_dm import PretrainDM
11
+ from data_provider.tune_dm import *
12
+ from model.opt_flash_attention import replace_opt_attn_with_flash_attn
13
+ from model.blip2_model import Blip2Model
14
+ from model.dist_funs import MyDeepSpeedStrategy
15
+ from data_provider.reaction_action_dataset import ActionDataset
16
+ from data_provider.data_utils import json_read, json_write
17
+ from data_provider.data_utils import smiles2data, reformat_smiles
18
+
19
+ ## for pyg bug
20
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
21
+ ## for A5000 gpus
22
+ torch.set_float32_matmul_precision('medium') # can be medium (bfloat16), high (tensorfloat32), highest (float32)
23
+
24
+
25
+ class InferenceRunner:
26
+ def __init__(self, model, tokenizer, rxn_max_len, smi_max_len,
27
+ smiles_type='default', device='cuda', predict_rxn_condition=True, args=None):
28
+ self.model = model
29
+ self.rxn_max_len = rxn_max_len
30
+ self.smi_max_len = smi_max_len
31
+ self.tokenizer = tokenizer
32
+ self.collater = Collater([], [])
33
+ self.mol_ph = '<mol>' * args.num_query_token
34
+ self.mol_token_id = tokenizer.mol_token_id
35
+ self.is_gal = args.opt_model.find('galactica') >= 0
36
+ self.collater = Collater([], [])
37
+ self.device = device
38
+ self.smiles_type = smiles_type
39
+ self.predict_rxn_condition = predict_rxn_condition
40
+ self.args = args
41
+
42
+ def make_prompt(self, param_dict, smi_max_len=128, predict_rxn_condition=False):
43
+ action_sequence = param_dict['actions']
44
+ smiles_list = []
45
+ prompt = ''
46
+ prompt += 'Reactants: '
47
+ smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len]
48
+ for smi in param_dict['REACTANT']:
49
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
50
+ smiles_list.append(smi)
51
+
52
+ prompt += 'Product: '
53
+ for smi in param_dict['PRODUCT']:
54
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
55
+ smiles_list.append(smi)
56
+
57
+ if param_dict['CATALYST']:
58
+ prompt += 'Catalysts: '
59
+ for smi in param_dict['CATALYST']:
60
+ if smi in param_dict["extracted_molecules"]:
61
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
62
+ else:
63
+ prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
64
+ smiles_list.append(smi)
65
+
66
+ if param_dict['SOLVENT']:
67
+ prompt += 'Solvents: '
68
+ for smi in param_dict['SOLVENT']:
69
+ if smi in param_dict["extracted_molecules"]:
70
+ prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
71
+ else:
72
+ prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
73
+ smiles_list.append(smi)
74
+
75
+ if predict_rxn_condition:
76
+ for value, token in param_dict['extracted_duration'].items():
77
+ action_sequence = action_sequence.replace(token, value)
78
+ for value, token in param_dict['extracted_temperature'].items():
79
+ action_sequence = action_sequence.replace(token, value)
80
+ else:
81
+ prompt += 'Temperatures: '
82
+ for value, token in param_dict['extracted_temperature'].items():
83
+ prompt += f'{token}: {value} '
84
+
85
+ prompt += 'Durations: '
86
+ for value, token in param_dict['extracted_duration'].items():
87
+ prompt += f'{token}: {value} '
88
+
89
+ prompt += 'Action Squence: '
90
+ return prompt, smiles_list, action_sequence
91
+
92
+ def get_action_elements(self, rxn_dict):
93
+ rxn_id = rxn_dict['index']
94
+ input_text, smiles_list, output_text = self.make_prompt(rxn_dict, self.smi_max_len, self.predict_rxn_condition)
95
+ output_text = output_text.strip() + '\n'
96
+
97
+ graph_list = []
98
+ for smiles in smiles_list:
99
+ graph_item = smiles2data(smiles)
100
+ graph_list.append(graph_item)
101
+ return rxn_id, graph_list, output_text, input_text
102
+
103
+ @torch.no_grad()
104
+ def predict(self, rxn_dict):
105
+ rxn_id, graphs, prompt_tokens, output_text, input_text = self.tokenize(rxn_dict)
106
+ result_dict = {
107
+ 'raw': rxn_dict,
108
+ 'index': rxn_id,
109
+ 'input': input_text,
110
+ 'target': output_text
111
+ }
112
+ samples = {'graphs': graphs, 'prompt_tokens': prompt_tokens}
113
+ with torch.no_grad():
114
+ result_dict['prediction'] = self.model.blip2opt.generate(
115
+ samples,
116
+ do_sample=self.args.do_sample,
117
+ num_beams=self.args.num_beams,
118
+ max_length=self.args.max_inference_len,
119
+ min_length=self.args.min_inference_len,
120
+ num_captions=self.args.num_generate_captions,
121
+ use_graph=True
122
+ )
123
+ return result_dict
124
+
125
+ def tokenize(self, rxn_dict):
126
+ rxn_id, graph_list, output_text, input_text = self.get_action_elements(rxn_dict)
127
+ if graph_list:
128
+ graphs = self.collater(graph_list).to(self.device)
129
+ input_prompt = smiles_handler(input_text, self.mol_ph, self.is_gal)[0]
130
+
131
+ ## deal with prompt
132
+ self.tokenizer.padding_side = 'left'
133
+ input_prompt_tokens = self.tokenizer(input_prompt,
134
+ truncation=True,
135
+ padding='max_length',
136
+ add_special_tokens=True,
137
+ max_length=self.rxn_max_len,
138
+ return_tensors='pt',
139
+ return_attention_mask=True).to(self.device)
140
+ is_mol_token = input_prompt_tokens.input_ids == self.mol_token_id
141
+ input_prompt_tokens['is_mol_token'] = is_mol_token
142
+ return rxn_id, graphs, input_prompt_tokens, output_text, input_text
143
+
144
+
145
+ def main(args):
146
+ device = torch.device('cuda')
147
+ data_list = json_read('demo.json')
148
+ pl.seed_everything(args.seed)
149
+ # model
150
+ if args.init_checkpoint:
151
+ model = Blip2Model(args).to(device)
152
+ ckpt = torch.load(args.init_checkpoint, map_location='cpu')
153
+ model.load_state_dict(ckpt['state_dict'], strict=False)
154
+ print(f"loaded model from {args.init_checkpoint}")
155
+ else:
156
+ model = Blip2Model(args).to(device)
157
+ model.eval()
158
+
159
+ print('total params:', sum(p.numel() for p in model.parameters()))
160
+
161
+ if args.opt_model.find('galactica') >= 0 or args.opt_model.find('t5') >= 0:
162
+ tokenizer = model.blip2opt.opt_tokenizer
163
+ elif args.opt_model.find('llama') >= 0 or args.opt_model.find('vicuna') >= 0:
164
+ tokenizer = model.blip2opt.llm_tokenizer
165
+ else:
166
+ raise NotImplementedError
167
+
168
+ infer_runner = InferenceRunner(
169
+ model=model,
170
+ tokenizer=tokenizer,
171
+ rxn_max_len=args.rxn_max_len,
172
+ smi_max_len=args.smi_max_len,
173
+ device=device,
174
+ predict_rxn_condition=args.predict_rxn_condition,
175
+ args=args
176
+ )
177
+
178
+ import time
179
+ for data_item in data_list:
180
+ t1 = time.time()
181
+ result = infer_runner.predict(data_item)
182
+ print(result)
183
+ print(f"Time: {time.time() - t1:.2f}s")
184
+
185
+
186
+ def get_args():
187
+ parser = argparse.ArgumentParser()
188
+ parser.add_argument('--filename', type=str, default="main")
189
+ parser.add_argument('--seed', type=int, default=42, help='random seed')
190
+ # MM settings
191
+ parser.add_argument('--mode', type=str, default='pretrain', choices=['pretrain', 'ft', 'eval', 'pretrain_eval'])
192
+ parser.add_argument('--strategy_name', type=str, default='mydeepspeed')
193
+ parser.add_argument('--iupac_prediction', action='store_true', default=False)
194
+ parser.add_argument('--ckpt_path', type=str, default=None)
195
+ # parser = Trainer.add_argparse_args(parser)
196
+ parser = Blip2Model.add_model_specific_args(parser) # add model args
197
+ parser = PretrainDM.add_model_specific_args(parser)
198
+ parser.add_argument('--accelerator', type=str, default='gpu')
199
+ parser.add_argument('--devices', type=str, default='0,1,2,3')
200
+ parser.add_argument('--precision', type=str, default='bf16-mixed')
201
+ parser.add_argument('--downstream_task', type=str, default='action', choices=['action', 'synthesis', 'caption', 'chebi'])
202
+ parser.add_argument('--max_epochs', type=int, default=10)
203
+ parser.add_argument('--enable_flash', action='store_true', default=False)
204
+ parser.add_argument('--disable_graph_cache', action='store_true', default=False)
205
+ parser.add_argument('--predict_rxn_condition', action='store_true', default=False)
206
+ parser.add_argument('--generate_restrict_tokens', action='store_true', default=False)
207
+ parser.add_argument('--train_restrict_tokens', action='store_true', default=False)
208
+ parser.add_argument('--smiles_type', type=str, default='default', choices=['default', 'canonical', 'restricted', 'unrestricted', 'r_smiles'])
209
+ parser.add_argument('--accumulate_grad_batches', type=int, default=1)
210
+ parser.add_argument('--tqdm_interval', type=int, default=50)
211
+ parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
212
+ args = parser.parse_args()
213
+
214
+ if args.enable_flash:
215
+ replace_opt_attn_with_flash_attn()
216
+ print("=========================================")
217
+ for k, v in sorted(vars(args).items()):
218
+ print(k, '=', v)
219
+ print("=========================================")
220
+ return args
221
+
222
+ if __name__ == '__main__':
223
+ main(get_args())
224
+
environment.yml ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: reactxt
2
+ channels:
3
+ - pyg
4
+ - tmap
5
+ - pytorch
6
+ - nvidia
7
+ - nvidia/label/cuda-11.7.0
8
+ - conda-forge
9
+ - defaults
10
+ dependencies:
11
+ - _libgcc_mutex=0.1=main
12
+ - _openmp_mutex=5.1=1_gnu
13
+ - appdirs=1.4.4=pyhd3eb1b0_0
14
+ - asttokens=2.2.1=pyhd8ed1ab_0
15
+ - backcall=0.2.0=pyh9f0ad1d_0
16
+ - backports=1.0=pyhd8ed1ab_3
17
+ - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
18
+ - blas=1.0=mkl
19
+ - brotlipy=0.7.0=py38h27cfd23_1003
20
+ - bzip2=1.0.8=h7b6447c_0
21
+ - ca-certificates=2023.08.22=h06a4308_0
22
+ - certifi=2023.11.17=py38h06a4308_0
23
+ - cffi=1.15.1=py38h5eee18b_3
24
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
25
+ - cryptography=39.0.1=py38h9ce1e76_2
26
+ - cuda-cccl=11.7.58=hc415cf5_0
27
+ - cuda-cudart=11.7.99=0
28
+ - cuda-cudart-dev=11.7.60=h6a7c232_0
29
+ - cuda-cupti=11.7.101=0
30
+ - cuda-driver-dev=11.7.60=0
31
+ - cuda-libraries=11.7.1=0
32
+ - cuda-libraries-dev=11.7.0=0
33
+ - cuda-nvcc=11.7.64=0
34
+ - cuda-nvrtc=11.7.99=0
35
+ - cuda-nvrtc-dev=11.7.50=heada363_0
36
+ - cuda-nvtx=11.7.91=0
37
+ - cuda-runtime=11.7.0=0
38
+ - cycler=0.11.0=pyhd3eb1b0_0
39
+ - debugpy=1.5.1=py38h295c915_0
40
+ - decorator=5.1.1=pyhd8ed1ab_0
41
+ - entrypoints=0.4=pyhd8ed1ab_0
42
+ - executing=1.2.0=pyhd8ed1ab_0
43
+ - ffmpeg=4.3=hf484d3e_0
44
+ - freetype=2.12.1=h4a9f257_0
45
+ - giflib=5.2.1=h5eee18b_3
46
+ - gmp=6.2.1=h295c915_3
47
+ - gmpy2=2.1.2=py38heeb90bb_0
48
+ - gnutls=3.6.15=he1e5248_0
49
+ - idna=3.4=py38h06a4308_0
50
+ - intel-openmp=2023.1.0=hdb19cb5_46305
51
+ - ipykernel=6.15.0=pyh210e3f2_0
52
+ - jedi=0.18.2=pyhd8ed1ab_0
53
+ - jinja2=3.1.2=py38h06a4308_0
54
+ - joblib=1.2.0=py38h06a4308_0
55
+ - jpeg=9e=h5eee18b_1
56
+ - jupyter_client=7.0.6=pyhd8ed1ab_0
57
+ - jupyter_core=4.12.0=py38h578d9bd_0
58
+ - lame=3.100=h7b6447c_0
59
+ - lcms2=2.12=h3be6417_0
60
+ - ld_impl_linux-64=2.38=h1181459_1
61
+ - lerc=3.0=h295c915_0
62
+ - libcublas=11.10.3.66=0
63
+ - libcublas-dev=11.10.1.25=h0c8ac2b_0
64
+ - libcufft=10.7.2.124=h4fbf590_0
65
+ - libcufft-dev=10.7.2.50=h59a5ac8_0
66
+ - libcufile=1.7.0.149=0
67
+ - libcufile-dev=1.3.0.44=0
68
+ - libcurand=10.3.3.53=0
69
+ - libcurand-dev=10.2.10.50=hd49a9cd_0
70
+ - libcusolver=11.4.0.1=0
71
+ - libcusolver-dev=11.3.5.50=hc6eba6f_0
72
+ - libcusparse=11.7.4.91=0
73
+ - libcusparse-dev=11.7.3.50=hc644b96_0
74
+ - libdeflate=1.17=h5eee18b_0
75
+ - libffi=3.4.4=h6a678d5_0
76
+ - libgcc-ng=11.2.0=h1234567_1
77
+ - libgfortran-ng=11.2.0=h00389a5_1
78
+ - libgfortran5=11.2.0=h1234567_1
79
+ - libgomp=11.2.0=h1234567_1
80
+ - libiconv=1.16=h7f8727e_2
81
+ - libidn2=2.3.4=h5eee18b_0
82
+ - libnpp=11.7.4.75=0
83
+ - libnpp-dev=11.7.3.21=hb6476a9_0
84
+ - libnvjpeg=11.8.0.2=0
85
+ - libnvjpeg-dev=11.7.2.34=h2e48410_0
86
+ - libpng=1.6.39=h5eee18b_0
87
+ - libsodium=1.0.18=h36c2ea0_1
88
+ - libstdcxx-ng=11.2.0=h1234567_1
89
+ - libtasn1=4.19.0=h5eee18b_0
90
+ - libtiff=4.5.0=h6a678d5_2
91
+ - libunistring=0.9.10=h27cfd23_0
92
+ - libwebp=1.2.4=h11a3e52_1
93
+ - libwebp-base=1.2.4=h5eee18b_1
94
+ - lz4-c=1.9.4=h6a678d5_0
95
+ - markupsafe=2.1.1=py38h7f8727e_0
96
+ - matplotlib-inline=0.1.6=pyhd8ed1ab_0
97
+ - mkl=2023.1.0=h6d00ec8_46342
98
+ - mkl-service=2.4.0=py38h5eee18b_1
99
+ - mkl_fft=1.3.6=py38h417a72b_1
100
+ - mkl_random=1.2.2=py38h417a72b_1
101
+ - mpc=1.1.0=h10f8cd9_1
102
+ - mpfr=4.0.2=hb69a4c5_1
103
+ - mpmath=1.2.1=py38h06a4308_0
104
+ - ncurses=6.4=h6a678d5_0
105
+ - nest-asyncio=1.5.6=pyhd8ed1ab_0
106
+ - nettle=3.7.3=hbbd107a_1
107
+ - networkx=2.8.4=py38h06a4308_1
108
+ - numpy-base=1.24.3=py38h060ed82_1
109
+ - ogdf=1.2.0=h2bc3f7f_0
110
+ - openh264=2.1.1=h4ff587b_0
111
+ - openssl=3.0.12=h7f8727e_0
112
+ - parso=0.8.3=pyhd8ed1ab_0
113
+ - pexpect=4.8.0=pyh1a96a4e_2
114
+ - pickleshare=0.7.5=py_1003
115
+ - pillow=9.4.0=py38h6a678d5_0
116
+ - pooch=1.4.0=pyhd3eb1b0_0
117
+ - prompt-toolkit=3.0.39=pyha770c72_0
118
+ - prompt_toolkit=3.0.39=hd8ed1ab_0
119
+ - ptyprocess=0.7.0=pyhd3deb0d_0
120
+ - pure_eval=0.2.2=pyhd8ed1ab_0
121
+ - pycparser=2.21=pyhd3eb1b0_0
122
+ - pyg=2.3.0=py38_torch_2.0.0_cu117
123
+ - pygments=2.15.1=pyhd8ed1ab_0
124
+ - pyopenssl=23.0.0=py38h06a4308_0
125
+ - pyparsing=3.0.9=py38h06a4308_0
126
+ - pysocks=1.7.1=py38h06a4308_0
127
+ - python=3.8.17=h955ad1f_0
128
+ - python-dateutil=2.8.2=pyhd8ed1ab_0
129
+ - python_abi=3.8=2_cp38
130
+ - pytorch=2.0.1=py3.8_cuda11.7_cudnn8.5.0_0
131
+ - pytorch-cuda=11.7=h778d358_5
132
+ - pytorch-mutex=1.0=cuda
133
+ - readline=8.2=h5eee18b_0
134
+ - setuptools=67.8.0=py38h06a4308_0
135
+ - six=1.16.0=pyh6c4a22f_0
136
+ - sqlite=3.41.2=h5eee18b_0
137
+ - stack_data=0.6.2=pyhd8ed1ab_0
138
+ - sympy=1.11.1=py38h06a4308_0
139
+ - tbb=2021.8.0=hdb19cb5_0
140
+ - threadpoolctl=2.2.0=pyh0d69192_0
141
+ - tk=8.6.12=h1ccaba5_0
142
+ - tmap=1.0.6=py38h2bc3f7f_0
143
+ - torchaudio=2.0.2=py38_cu117
144
+ - torchtriton=2.0.0=py38
145
+ - torchvision=0.15.2=py38_cu117
146
+ - tqdm=4.65.0=py38hb070fc8_0
147
+ - traitlets=5.9.0=pyhd8ed1ab_0
148
+ - typing_extensions=4.6.3=py38h06a4308_0
149
+ - urllib3=1.26.16=py38h06a4308_0
150
+ - wcwidth=0.2.6=pyhd8ed1ab_0
151
+ - wheel=0.38.4=py38h06a4308_0
152
+ - xz=5.4.2=h5eee18b_0
153
+ - zeromq=4.3.4=h9c3ff4c_1
154
+ - zlib=1.2.13=h5eee18b_0
155
+ - zstd=1.5.5=hc292b87_0
156
+ - pip:
157
+ - absl-py==1.4.0
158
+ - accelerate==0.20.3
159
+ - aiofiles==23.2.1
160
+ - aiohttp==3.8.4
161
+ - aiosignal==1.3.1
162
+ - aliyun-python-sdk-core==2.13.36
163
+ - aliyun-python-sdk-kms==2.16.1
164
+ - altair==4.2.2
165
+ - annotated-types==0.6.0
166
+ - antlr4-python3-runtime==4.9.3
167
+ - anyio==3.7.1
168
+ - argon2-cffi==23.1.0
169
+ - argon2-cffi-bindings==21.2.0
170
+ - arrow==1.2.3
171
+ - async-lru==2.0.4
172
+ - async-timeout==4.0.2
173
+ - attrs==23.1.0
174
+ - autocommand==2.2.2
175
+ - babel==2.13.0
176
+ - backoff==2.2.1
177
+ - backports-zoneinfo==0.2.1
178
+ - beautifulsoup4==4.12.2
179
+ - bigmodelvis==0.0.1
180
+ - binaryornot==0.4.4
181
+ - bleach==6.0.0
182
+ - blessed==1.20.0
183
+ - blinker==1.6.2
184
+ - blis==0.7.9
185
+ - braceexpand==0.1.7
186
+ - cachetools==5.3.1
187
+ - catalogue==2.0.8
188
+ - cfgv==3.3.1
189
+ - chardet==5.2.0
190
+ - cheroot==10.0.0
191
+ - cherrypy==18.8.0
192
+ - click==8.1.4
193
+ - cloudpathlib==0.16.0
194
+ - cmake==3.27.7
195
+ - colorama==0.4.6
196
+ - colour==0.1.5
197
+ - comm==0.1.4
198
+ - confection==0.1.0
199
+ - configargparse==1.7
200
+ - contexttimer==0.3.3
201
+ - contourpy==1.1.0
202
+ - cookiecutter==2.4.0
203
+ - crcmod==1.7
204
+ - croniter==1.4.1
205
+ - ctranslate2==3.20.0
206
+ - cymem==2.0.7
207
+ - datasets==2.13.1
208
+ - dateutils==0.6.12
209
+ - decord==0.6.0
210
+ - deepdiff==6.3.1
211
+ - deepspeed==0.10.1+ff7d5275
212
+ - defusedxml==0.7.1
213
+ - delta-center-client==0.0.4
214
+ - dill==0.3.6
215
+ - diskcache==5.6.3
216
+ - distlib==0.3.6
217
+ - distro==1.8.0
218
+ - dnspython==2.4.2
219
+ - docker-pycreds==0.4.0
220
+ - einops==0.6.1
221
+ - evaluate==0.4.1
222
+ - exceptiongroup==1.1.2
223
+ - faerun==0.3.20
224
+ - fairscale==0.4.4
225
+ - fastapi==0.100.0
226
+ - fastjsonschema==2.18.1
227
+ - fasttext-wheel==0.9.2
228
+ - ffmpy==0.3.1
229
+ - filelock==3.12.2
230
+ - flash-attn==2.3.3
231
+ - flask==3.0.0
232
+ - fonttools==4.40.0
233
+ - fqdn==1.5.1
234
+ - frozenlist==1.3.3
235
+ - fsspec==2023.6.0
236
+ - ftfy==6.1.1
237
+ - future==0.18.3
238
+ - gdown==4.7.1
239
+ - gitdb==4.0.10
240
+ - gitpython==3.1.37
241
+ - google-auth==2.23.2
242
+ - google-auth-oauthlib==1.0.0
243
+ - gpustat==1.1.1
244
+ - gradio-client==0.7.0
245
+ - grpcio==1.59.0
246
+ - h11==0.14.0
247
+ - hjson==3.1.0
248
+ - httpcore==1.0.2
249
+ - httpx==0.25.1
250
+ - huggingface-hub==0.16.4
251
+ - identify==2.5.24
252
+ - imageio==2.31.1
253
+ - importlib-metadata==6.8.0
254
+ - importlib-resources==6.0.0
255
+ - inflect==7.0.0
256
+ - inquirer==3.1.3
257
+ - iopath==0.1.10
258
+ - ipython==8.12.2
259
+ - ipython-genutils==0.2.0
260
+ - ipywidgets==8.1.1
261
+ - isoduration==20.11.0
262
+ - itsdangerous==2.1.2
263
+ - jaraco-collections==4.3.0
264
+ - jaraco-context==4.3.0
265
+ - jaraco-functools==3.8.0
266
+ - jaraco-text==3.12.0
267
+ - jmespath==0.10.0
268
+ - json5==0.9.14
269
+ - jsonpointer==2.4
270
+ - jsonschema==4.18.0
271
+ - jsonschema-specifications==2023.6.1
272
+ - jupyter==1.0.0
273
+ - jupyter-client==8.4.0
274
+ - jupyter-console==6.6.3
275
+ - jupyter-events==0.8.0
276
+ - jupyter-lsp==2.2.0
277
+ - jupyter-server==2.8.0
278
+ - jupyter-server-terminals==0.4.4
279
+ - jupyterlab==4.0.7
280
+ - jupyterlab-pygments==0.2.2
281
+ - jupyterlab-server==2.25.0
282
+ - jupyterlab-widgets==3.0.9
283
+ - kaggle==1.5.15
284
+ - kiwisolver==1.4.4
285
+ - langcodes==3.3.0
286
+ - lazy-loader==0.3
287
+ - levenshtein==0.23.0
288
+ - lightning==2.1.2
289
+ - lightning-cloud==0.5.37
290
+ - lightning-utilities==0.9.0
291
+ - lit==17.0.6
292
+ - littleutils==0.2.2
293
+ - lmppl==0.3.1
294
+ - lxml==4.9.3
295
+ - markdown==3.5
296
+ - markdown-it-py==3.0.0
297
+ - matplotlib==3.2.2
298
+ - mdurl==0.1.2
299
+ - mistune==3.0.2
300
+ - more-itertools==9.1.0
301
+ - multidict==6.0.4
302
+ - multiprocess==0.70.14
303
+ - murmurhash==1.0.9
304
+ - nbclient==0.8.0
305
+ - nbconvert==7.9.2
306
+ - nbformat==5.9.2
307
+ - ninja==1.11.1
308
+ - nltk==3.8.1
309
+ - nodeenv==1.8.0
310
+ - notebook==7.0.6
311
+ - notebook-shim==0.2.3
312
+ - numpy==1.24.4
313
+ - nvidia-cublas-cu11==11.10.3.66
314
+ - nvidia-cublas-cu12==12.1.3.1
315
+ - nvidia-cuda-cupti-cu11==11.7.101
316
+ - nvidia-cuda-cupti-cu12==12.1.105
317
+ - nvidia-cuda-nvrtc-cu11==11.7.99
318
+ - nvidia-cuda-nvrtc-cu12==12.1.105
319
+ - nvidia-cuda-runtime-cu11==11.7.99
320
+ - nvidia-cuda-runtime-cu12==12.1.105
321
+ - nvidia-cudnn-cu11==8.5.0.96
322
+ - nvidia-cudnn-cu12==8.9.2.26
323
+ - nvidia-cufft-cu11==10.9.0.58
324
+ - nvidia-cufft-cu12==11.0.2.54
325
+ - nvidia-curand-cu11==10.2.10.91
326
+ - nvidia-curand-cu12==10.3.2.106
327
+ - nvidia-cusolver-cu11==11.4.0.1
328
+ - nvidia-cusolver-cu12==11.4.5.107
329
+ - nvidia-cusparse-cu11==11.7.4.91
330
+ - nvidia-cusparse-cu12==12.1.0.106
331
+ - nvidia-ml-py==12.535.77
332
+ - nvidia-nccl-cu11==2.14.3
333
+ - nvidia-nccl-cu12==2.18.1
334
+ - nvidia-nvjitlink-cu12==12.3.101
335
+ - nvidia-nvtx-cu11==11.7.91
336
+ - nvidia-nvtx-cu12==12.1.105
337
+ - oauthlib==3.2.2
338
+ - ogb==1.3.6
339
+ - omegaconf==2.3.0
340
+ - openai==1.2.4
341
+ - opencv-python-headless==4.5.5.64
342
+ - opendatasets==0.1.22
343
+ - opendelta==0.3.2
344
+ - opennmt-py==3.4.1
345
+ - ordered-set==4.1.0
346
+ - orjson==3.9.10
347
+ - oss2==2.15.0
348
+ - outdated==0.2.2
349
+ - overrides==7.4.0
350
+ - packaging==23.1
351
+ - pandas==2.0.3
352
+ - pandocfilters==1.5.0
353
+ - paragraph2actions==1.5.0
354
+ - pathtools==0.1.2
355
+ - pathy==0.10.2
356
+ - peft==0.3.0
357
+ - pip==23.3.1
358
+ - pkgutil-resolve-name==1.3.10
359
+ - platformdirs==3.8.1
360
+ - plotly==5.15.0
361
+ - portalocker==2.7.0
362
+ - portend==3.2.0
363
+ - pre-commit==3.3.3
364
+ - preshed==3.0.8
365
+ - prometheus-client==0.17.1
366
+ - promise==2.3
367
+ - protobuf==3.19.6
368
+ - psutil==5.9.5
369
+ - pubchempy==1.0.4
370
+ - py-cpuinfo==9.0.0
371
+ - pyahocorasick==2.0.0
372
+ - pyarrow==12.0.1
373
+ - pyasn1==0.5.0
374
+ - pyasn1-modules==0.3.0
375
+ - pybind11==2.11.1
376
+ - pycocoevalcap==1.2
377
+ - pycocotools==2.0.6
378
+ - pycryptodome==3.18.0
379
+ - pydantic==1.10.11
380
+ - pydantic-core==2.14.3
381
+ - pydeck==0.8.1b0
382
+ - pydub==0.25.1
383
+ - pyjwt==2.7.0
384
+ - pymongo==4.6.0
385
+ - pympler==1.0.1
386
+ - pyonmttok==1.37.1
387
+ - python-editor==1.0.4
388
+ - python-json-logger==2.0.7
389
+ - python-levenshtein==0.23.0
390
+ - python-magic==0.4.27
391
+ - python-multipart==0.0.6
392
+ - python-slugify==8.0.1
393
+ - pytorch-lightning==2.0.0
394
+ - pytz==2023.3
395
+ - pytz-deprecation-shim==0.1.0.post0
396
+ - pywavelets==1.4.1
397
+ - pyyaml==6.0.1
398
+ - pyzmq==25.1.1
399
+ - qtconsole==5.4.4
400
+ - qtpy==2.4.0
401
+ - rapidfuzz==3.4.0
402
+ - rdkit==2023.3.3
403
+ - readchar==4.0.5
404
+ - referencing==0.29.1
405
+ - regex==2023.6.3
406
+ - requests==2.31.0
407
+ - requests-oauthlib==1.3.1
408
+ - responses==0.18.0
409
+ - rfc3339-validator==0.1.4
410
+ - rfc3986-validator==0.1.1
411
+ - rich==13.4.2
412
+ - rouge-score==0.1.2
413
+ - rpds-py==0.8.10
414
+ - rsa==4.9
415
+ - rxn-onmt-utils==1.1.0
416
+ - rxn-opennmt-py==1.1.5
417
+ - rxn-utils==1.6.0
418
+ - sacrebleu==2.3.1
419
+ - safetensors==0.3.1
420
+ - salesforce-lavis==1.0.0
421
+ - scikit-image==0.20.0
422
+ - scikit-learn==0.23.1
423
+ - scipy==1.4.1
424
+ - semantic-version==2.10.0
425
+ - send2trash==1.8.2
426
+ - sentencepiece==0.1.99
427
+ - sentry-sdk==1.31.0
428
+ - setproctitle==1.3.3
429
+ - shellingham==1.5.4
430
+ - smart-open==6.3.0
431
+ - smmap==5.0.1
432
+ - sniffio==1.3.0
433
+ - soupsieve==2.4.1
434
+ - spacy==3.7.2
435
+ - spacy-legacy==3.0.12
436
+ - spacy-loggers==1.0.4
437
+ - srsly==2.4.6
438
+ - starlette==0.27.0
439
+ - starsessions==1.3.0
440
+ - streamlit==1.22.0
441
+ - tabulate==0.9.0
442
+ - tempora==5.5.0
443
+ - tenacity==8.2.2
444
+ - tensorboard==2.14.0
445
+ - tensorboard-data-server==0.7.1
446
+ - terminado==0.17.1
447
+ - text-unidecode==1.3
448
+ - textdistance==4.6.0
449
+ - thinc==8.1.10
450
+ - tifffile==2023.7.10
451
+ - timm==0.4.12
452
+ - tinycss2==1.2.1
453
+ - tokenizers==0.13.3
454
+ - toml==0.10.2
455
+ - tomli==2.0.1
456
+ - tomlkit==0.12.0
457
+ - toolz==0.12.0
458
+ - torch==2.0.1
459
+ - torchmetrics==1.0.0
460
+ - torchtext==0.4.0
461
+ - tornado==6.3.3
462
+ - transformers==4.33.3
463
+ - triton==2.0.0
464
+ - typer==0.9.0
465
+ - tzdata==2023.3
466
+ - tzlocal==4.3.1
467
+ - ujson==5.8.0
468
+ - uri-template==1.3.0
469
+ - uvicorn==0.22.0
470
+ - validators==0.20.0
471
+ - virtualenv==20.23.1
472
+ - waitress==2.1.2
473
+ - wandb==0.15.5
474
+ - wasabi==1.1.2
475
+ - watchdog==3.0.0
476
+ - weasel==0.3.4
477
+ - web-py==0.62
478
+ - webcolors==1.13
479
+ - webdataset==0.2.48
480
+ - webencodings==0.5.1
481
+ - websocket-client==1.6.1
482
+ - websockets==11.0.3
483
+ - werkzeug==3.0.0
484
+ - widgetsnbextension==4.0.9
485
+ - xxhash==3.2.0
486
+ - yacs==0.1.8
487
+ - yarl==1.9.2
488
+ - zc-lockfile==3.0.post1
489
+ - zipp==3.16.0
figures/frameworks.jpg ADDED

Git LFS Details

  • SHA256: f278b9619d545ab56442b98e4c43272ed086c595aba314c30f02a76150b6aa1c
  • Pointer size: 131 Bytes
  • Size of remote file: 257 kB
gin_pretrained/graphcl_80.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc8f685ace701ad71e3d82330fa78add25c573287ebc2908d9f7fddf13bc745f
3
+ size 7454162
graph_gen.ipynb ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "from torch_geometric.data import Data\n",
11
+ "from ogb.utils import smiles2graph\n",
12
+ "import os\n",
13
+ "import json\n",
14
+ "from rdkit import RDLogger\n",
15
+ "from rdkit import Chem\n",
16
+ "RDLogger.DisableLog('rdApp.*')\n",
17
+ "from tqdm import tqdm\n",
18
+ "import multiprocessing\n",
19
+ "\n",
20
+ "def write_json(data, filename):\n",
21
+ " with open(filename, 'w') as f:\n",
22
+ " json.dump(data, f, indent=4, ensure_ascii=False)\n",
23
+ "\n",
24
+ "def read_json(filename):\n",
25
+ " with open(filename, 'r') as f:\n",
26
+ " data = json.load(f)\n",
27
+ " return data\n",
28
+ "\n",
29
+ "def smiles2data(smiles):\n",
30
+ " graph = smiles2graph(smiles)\n",
31
+ " x = torch.from_numpy(graph['node_feat'])\n",
32
+ " edge_index = torch.from_numpy(graph['edge_index'], )\n",
33
+ " edge_attr = torch.from_numpy(graph['edge_feat'])\n",
34
+ " data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)\n",
35
+ " return data\n"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "# make pretrain graphs\n",
45
+ "root = 'data/pretrain_data/'\n",
46
+ "mol_property_list = read_json(f'{root}/Abstract_property.json')\n",
47
+ "target_file = f'{root}/mol_graph_map.pt'\n",
48
+ "\n",
49
+ "if not os.path.exists(target_file):\n",
50
+ " mol_graph_map = {}\n",
51
+ " for mol_dict in tqdm(mol_property_list):\n",
52
+ " smiles = mol_dict['canon_smiles']\n",
53
+ " graph = smiles2data(smiles)\n",
54
+ " mol_graph_map[smiles] = graph\n",
55
+ " torch.save(mol_graph_map, target_file)"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "# make downstrem (action prediction) graphs\n",
65
+ "root = 'data/action_data'\n",
66
+ "target_file = f'{root}/mol_graph_map.pt'\n",
67
+ "\n",
68
+ "if not os.path.exists(target_file):\n",
69
+ " all_mols = set()\n",
70
+ " reaction_list = read_json(f'{root}/processed.json')\n",
71
+ " rxn_keys = ['REACTANT', 'PRODUCT', 'CATALYST', 'SOLVENT']\n",
72
+ "\n",
73
+ " for rxn in reaction_list:\n",
74
+ " for key in rxn_keys:\n",
75
+ " for mol in rxn[key]:\n",
76
+ " if mol in all_mols:\n",
77
+ " continue\n",
78
+ " all_mols.add(mol)\n",
79
+ " mol_graph_map={}\n",
80
+ "\n",
81
+ " for smiles in all_mols:\n",
82
+ " graph = smiles2data(smiles)\n",
83
+ " mol_graph_map[smiles] = graph\n",
84
+ " torch.save(mol_graph_map, target_file)"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "# make downstream (retrosynthesis) graphs\n",
94
+ "root = 'data/synthesis_data'\n",
95
+ "\n",
96
+ "for folder in [\n",
97
+ " 'USPTO_50K_PtoR',\n",
98
+ " 'USPTO_50K_PtoR_aug20',\n",
99
+ " 'USPTO-MIT_PtoR_aug5',\n",
100
+ " 'USPTO-MIT_RtoP_aug5_mixed',\n",
101
+ " 'USPTO-MIT_RtoP_aug5_separated',\n",
102
+ " 'USPTO_full_pretrain_aug5_masked_token',\n",
103
+ " ]:\n",
104
+ " mol_graphid_file = f'{root}/{folder}/mol_graphid_map.json'\n",
105
+ " target_file = f'{root}/{folder}/idx_graph_map.pt'\n",
106
+ " if not os.path.exists(mol_graphid_file):\n",
107
+ " canon_idx_map = {}\n",
108
+ " mol_idx_map = {}\n",
109
+ " mol_set = set()\n",
110
+ " for mode in ['train', 'val', 'test']:\n",
111
+ " for file in ['src', 'tgt']:\n",
112
+ " if 'pretrain' in folder:\n",
113
+ " if file=='src':\n",
114
+ " continue\n",
115
+ " else:\n",
116
+ " if file=='tgt':\n",
117
+ " continue\n",
118
+ " file_path = f'{root}/{folder}/{mode}/{file}-{mode}.txt'\n",
119
+ " with open(file_path) as f:\n",
120
+ " lines = f.readlines()\n",
121
+ " for line in lines:\n",
122
+ " line = line.strip().replace(' ', '')\n",
123
+ " line = line.replace('<separated>', '.')\n",
124
+ " for smi in line.split('.'):\n",
125
+ " mol_set.add(smi)\n",
126
+ " smi_list = list(mol_set)\n",
127
+ " pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())\n",
128
+ " canon_list = pool.map(func=Chem.CanonSmiles,iterable=smi_list)\n",
129
+ " for smi, canon in zip(smi_list, canon_list):\n",
130
+ " if canon not in canon_idx_map:\n",
131
+ " canon_idx_map[canon] = len(canon_idx_map)\n",
132
+ " mol_idx_map[smi] = canon_idx_map[canon]\n",
133
+ " write_json(mol_idx_map, mol_graphid_file)\n",
134
+ " else:\n",
135
+ " mol_idx_map = read_json(mol_graphid_file)\n",
136
+ "\n",
137
+ " cid_graph_map = {}\n",
138
+ " for smiles, graph_id in mol_idx_map.items():\n",
139
+ " if graph_id in cid_graph_map:\n",
140
+ " continue\n",
141
+ " graph = smiles2data(smiles)\n",
142
+ " cid_graph_map[graph_id] = graph\n",
143
+ " torch.save(cid_graph_map, target_file)"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 3,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "# make downstream (retrosynthesis) graphs\n",
153
+ "root = 'data/ChEBI-20_data'\n",
154
+ "target_file = f'{root}/cid_graph_map.pt'\n",
155
+ "\n",
156
+ "cid_graph_map = {}\n",
157
+ "if not os.path.exists(target_file):\n",
158
+ " for mode in ['train', 'validation', 'test']:\n",
159
+ " with open(f'{root}/{mode}.txt') as f:\n",
160
+ " lines = f.readlines()\n",
161
+ " for line in lines[1:]:\n",
162
+ " cid, smiles, _ = line.strip().split('\\t', maxsplit=2)\n",
163
+ " graph = smiles2data(smiles)\n",
164
+ " cid_graph_map[cid] = graph\n",
165
+ " torch.save(cid_graph_map, target_file)"
166
+ ]
167
+ }
168
+ ],
169
+ "metadata": {
170
+ "kernelspec": {
171
+ "display_name": "pth20v3",
172
+ "language": "python",
173
+ "name": "python3"
174
+ },
175
+ "language_info": {
176
+ "codemirror_mode": {
177
+ "name": "ipython",
178
+ "version": 3
179
+ },
180
+ "file_extension": ".py",
181
+ "mimetype": "text/x-python",
182
+ "name": "python",
183
+ "nbconvert_exporter": "python",
184
+ "pygments_lexer": "ipython3",
185
+ "version": "3.8.17"
186
+ }
187
+ },
188
+ "nbformat": 4,
189
+ "nbformat_minor": 2
190
+ }
lora_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model_name_or_path": null,
3
+ "bias": "none",
4
+ "fan_in_fan_out": false,
5
+ "inference_mode": false,
6
+ "init_lora_weights": true,
7
+ "lora_alpha": 16,
8
+ "lora_dropout": 0.1,
9
+ "target_modules": ["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"],
10
+ "peft_type": "LORA",
11
+ "r": 8,
12
+ "modules_to_save": null,
13
+ "task_type": "CAUSAL_LM"
14
+ }
main.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import warnings
5
+ import pytorch_lightning as pl
6
+ from pytorch_lightning import Trainer, strategies
7
+ import pytorch_lightning.callbacks as plc
8
+ from pytorch_lightning.loggers import CSVLogger
9
+ from pytorch_lightning.callbacks import TQDMProgressBar
10
+ from data_provider.pretrain_dm import PretrainDM
11
+ from data_provider.tune_dm import TuneDM
12
+ from model.opt_flash_attention import replace_opt_attn_with_flash_attn
13
+ from model.blip2_model import Blip2Model
14
+ from model.dist_funs import MyDeepSpeedStrategy
15
+
16
+ ## for pyg bug
17
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
18
+ ## for A5000 gpus
19
+ torch.set_float32_matmul_precision('medium') # can be medium (bfloat16), high (tensorfloat32), highest (float32)
20
+
21
+ try:
22
+ class MyDDPSpawnStrategy(strategies.DDPSpawnStrategy):
23
+ def load_model_state_dict(self, checkpoint):
24
+ assert self.lightning_module is not None
25
+ self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=False)
26
+ except:
27
+ pass
28
+
29
+ def main(args):
30
+ pl.seed_everything(args.seed)
31
+ # model
32
+ if args.init_checkpoint:
33
+ model = Blip2Model(args)
34
+ ckpt = torch.load(args.init_checkpoint, map_location='cpu')
35
+ model.load_state_dict(ckpt['state_dict'], strict=False)
36
+ print(f"loaded model from {args.init_checkpoint}")
37
+ else:
38
+ model = Blip2Model(args)
39
+
40
+ print('total params:', sum(p.numel() for p in model.parameters()))
41
+
42
+ if args.opt_model.find('galactica') >= 0 or args.opt_model.find('t5') >= 0:
43
+ tokenizer = model.blip2opt.opt_tokenizer
44
+ elif args.opt_model.find('llama') >= 0 or args.opt_model.find('vicuna') >= 0:
45
+ tokenizer = model.blip2opt.llm_tokenizer
46
+ else:
47
+ raise NotImplementedError
48
+ # data
49
+ if args.mode in {'pretrain', 'pretrain_eval'}:
50
+ dm = PretrainDM(
51
+ num_workers=args.num_workers,
52
+ batch_size=args.batch_size,
53
+ root=args.root,
54
+ text_max_len=args.text_max_len,
55
+ rxn_max_len=args.rxn_max_len,
56
+ smi_max_len=args.smi_max_len,
57
+ tokenizer=tokenizer,
58
+ args=args
59
+ )
60
+ elif args.mode in {'ft', 'eval'}:
61
+ dm = TuneDM(
62
+ num_workers=args.num_workers,
63
+ batch_size=args.batch_size,
64
+ root=args.root,
65
+ text_max_len=args.text_max_len,
66
+ rxn_max_len=args.rxn_max_len,
67
+ smi_max_len=args.smi_max_len,
68
+ tokenizer=tokenizer,
69
+ downstream_task=args.downstream_task,
70
+ args=args
71
+ )
72
+
73
+ callbacks = [TQDMProgressBar(refresh_rate=args.tqdm_interval)]
74
+ ## fixme save only used parameters
75
+ # callbacks.append(plc.ModelCheckpoint(dirpath="all_checkpoints/"+args.filename+"/", every_n_epochs=10, save_top_k=-1))
76
+ callbacks.append(plc.ModelCheckpoint(dirpath="all_checkpoints/"+args.filename+"/",
77
+ filename='{epoch:02d}',
78
+ every_n_epochs=args.save_every_n_epochs,
79
+ save_last=True,
80
+ save_top_k=-1,
81
+ save_on_train_epoch_end=True))
82
+ if len(args.devices.split(',')) > 1:
83
+ if args.strategy_name == 'fsdp':
84
+ strategy = strategies.DDPFullyShardedNativeStrategy()
85
+ elif args.strategy_name == 'deepspeed':
86
+ strategy = strategies.DeepSpeedStrategy(stage=3)
87
+ elif args.strategy_name == 'mydeepspeed':
88
+ strategy = MyDeepSpeedStrategy(stage=2)
89
+ else:
90
+ strategy = MyDDPSpawnStrategy(find_unused_parameters=True)
91
+ else:
92
+ strategy = None
93
+ args.devices = eval(args.devices)
94
+ logger = CSVLogger(save_dir=f'./all_checkpoints/{args.filename}/')
95
+ reload_freq = 1 if args.mode == 'pretrain' else 0
96
+ trainer = Trainer(
97
+ accelerator=args.accelerator,
98
+ devices=args.devices,
99
+ precision=args.precision,
100
+ max_epochs=args.max_epochs,
101
+ accumulate_grad_batches=args.accumulate_grad_batches,
102
+ check_val_every_n_epoch=args.check_val_every_n_epoch,
103
+ callbacks=callbacks,
104
+ strategy=strategy,
105
+ logger=logger,
106
+ reload_dataloaders_every_n_epochs=reload_freq
107
+ # limit_train_batches=100,
108
+ )
109
+
110
+ if args.mode in {'pretrain', 'ft'}:
111
+ trainer.fit(model, datamodule=dm, ckpt_path=args.ckpt_path)
112
+ elif args.mode in {'eval', 'pretrain_eval'}:
113
+ trainer.fit_loop.epoch_progress.current.completed = args.caption_eval_epoch - 1
114
+ trainer.validate(model, datamodule=dm)
115
+ # trainer.test(model, datamodule=dm)
116
+ else:
117
+ raise NotImplementedError()
118
+
119
+ def get_args():
120
+ parser = argparse.ArgumentParser()
121
+ parser.add_argument('--filename', type=str, default="main")
122
+ parser.add_argument('--seed', type=int, default=42, help='random seed')
123
+ # MM settings
124
+ parser.add_argument('--mode', type=str, default='pretrain', choices=['pretrain', 'ft', 'eval', 'pretrain_eval'])
125
+ parser.add_argument('--strategy_name', type=str, default='mydeepspeed')
126
+ parser.add_argument('--iupac_prediction', action='store_true', default=False)
127
+ parser.add_argument('--ckpt_path', type=str, default=None)
128
+ # parser = Trainer.add_argparse_args(parser)
129
+ parser = Blip2Model.add_model_specific_args(parser) # add model args
130
+ parser = PretrainDM.add_model_specific_args(parser)
131
+ parser.add_argument('--accelerator', type=str, default='gpu')
132
+ parser.add_argument('--devices', type=str, default='0,1,2,3')
133
+ parser.add_argument('--precision', type=str, default='bf16-mixed')
134
+ parser.add_argument('--downstream_task', type=str, default='action', choices=['action', 'synthesis', 'caption', 'chebi'])
135
+ parser.add_argument('--max_epochs', type=int, default=10)
136
+ parser.add_argument('--enable_flash', action='store_true', default=False)
137
+ parser.add_argument('--disable_graph_cache', action='store_true', default=False)
138
+ parser.add_argument('--predict_rxn_condition', action='store_true', default=False)
139
+ parser.add_argument('--generate_restrict_tokens', action='store_true', default=False)
140
+ parser.add_argument('--train_restrict_tokens', action='store_true', default=False)
141
+ parser.add_argument('--smiles_type', type=str, default='default', choices=['default', 'canonical', 'restricted', 'unrestricted', 'r_smiles'])
142
+ parser.add_argument('--accumulate_grad_batches', type=int, default=1)
143
+ parser.add_argument('--tqdm_interval', type=int, default=50)
144
+ parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
145
+ args = parser.parse_args()
146
+
147
+ if args.enable_flash:
148
+ replace_opt_attn_with_flash_attn()
149
+ print("=========================================")
150
+ for k, v in sorted(vars(args).items()):
151
+ print(k, '=', v)
152
+ print("=========================================")
153
+ return args
154
+
155
+ if __name__ == '__main__':
156
+ main(get_args())
157
+
model/allowed_words.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "30": "(",
3
+ "31": ")",
4
+ "57": "C",
5
+ "39": "1",
6
+ "89": "c",
7
+ "69": "O",
8
+ "40": "2",
9
+ "51": "=",
10
+ "2275": "CC",
11
+ "1030": "cc",
12
+ "19": "[START_SMILES]",
13
+ "20": "[END_SMILES]",
14
+ "221": "\n",
15
+ "68": "N",
16
+ "24552": "ccc",
17
+ "36": ".",
18
+ "81": "[",
19
+ "83": "]",
20
+ "41": "3",
21
+ "4162": "OC",
22
+ "60": "F",
23
+ "19321": "cccc",
24
+ "100": "n",
25
+ "2356": "Cl",
26
+ "11863": "nc",
27
+ "62": "H",
28
+ "35": "-",
29
+ "8183": "Br",
30
+ "6597": "NC",
31
+ "43888": "CCC",
32
+ "54": "@",
33
+ "29332": "@@",
34
+ "4015": "CN",
35
+ "42": "4",
36
+ "38985": "Nc",
37
+ "3027": "CO",
38
+ "73": "S",
39
+ "25": "#",
40
+ "33": "+",
41
+ "27312": "Oc",
42
+ "16288": "cn",
43
+ "8095": "nn",
44
+ "56": "B",
45
+ "1228": "sc",
46
+ "37": "/",
47
+ "63": "I",
48
+ "105": "s",
49
+ "408": "oc",
50
+ "1912": "SC",
51
+ "5965": "Si",
52
+ "43": "5",
53
+ "46183": "Cn",
54
+ "98": "l",
55
+ "101": "o",
56
+ "7662": "NS",
57
+ "5869": "NN",
58
+ "8314": "cs",
59
+ "5396": "CI",
60
+ "82": "\\",
61
+ "9835": "Sc",
62
+ "3470": "CS",
63
+ "32712": "Fc",
64
+ "2304": "OS",
65
+ "3330": "NO",
66
+ "6882": "FC",
67
+ "70": "P",
68
+ "13136": "Sn",
69
+ "12702": "Mg",
70
+ "3529": "no",
71
+ "2812": "co",
72
+ "14530": "SCC",
73
+ "6342": "rc",
74
+ "35011": "BrN",
75
+ "9677": "NH",
76
+ "283": "on",
77
+ "20938": "onc",
78
+ "37190": "COS",
79
+ "44": "6",
80
+ "17952": "OB",
81
+ "11004": "Zn",
82
+ "28819": "OO",
83
+ "2085": "ns",
84
+ "3696": "CP",
85
+ "5097": "CF",
86
+ "978": "con",
87
+ "6017": "non",
88
+ "34244": "CNS",
89
+ "4232": "occ",
90
+ "10907": "CON",
91
+ "8072": "Cu",
92
+ "13346": "CB",
93
+ "45": "7",
94
+ "16378": "sn",
95
+ "1513": "ON",
96
+ "46": "8",
97
+ "4939": "OP",
98
+ "6321": "SN",
99
+ "26505": "conc",
100
+ "6913": "Se",
101
+ "2636": "SS",
102
+ "422": "se",
103
+ "47": "9",
104
+ "48321": "SSC",
105
+ "47306": "SCN",
106
+ "15780": "CNN",
107
+ "48968": "OCI",
108
+ "27": "%",
109
+ "38": "0",
110
+ "6389": "FS",
111
+ "4864": "On",
112
+ "27133": "SCO",
113
+ "2001": "IC",
114
+ "0": "<s>",
115
+ "1": "<pad>",
116
+ "2": "</s>",
117
+ "3": "<unk>"
118
+ }
model/blip2.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import contextlib
8
+ import logging
9
+ import os
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from lavis.common.dist_utils import download_cached_file
15
+ from lavis.common.utils import is_url
16
+ from lavis.models.base_model import BaseModel
17
+ from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel
18
+ from transformers import BertTokenizer
19
+ from model.gin_model import GNN
20
+
21
+
22
+
23
+ class Blip2Base(BaseModel):
24
+ @classmethod
25
+ def init_tokenizer(cls):
26
+ if True:
27
+ bert_name = 'allenai/scibert_scivocab_uncased'
28
+ else:
29
+ bert_name = 'bert_pretrained/'
30
+ tokenizer = BertTokenizer.from_pretrained(bert_name)
31
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
32
+ return tokenizer
33
+
34
+ def maybe_autocast(self, dtype=torch.float16):
35
+ # if on cpu, don't use autocast
36
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
37
+ enable_autocast = self.device != torch.device("cpu")
38
+
39
+ if enable_autocast:
40
+ return torch.cuda.amp.autocast(dtype=dtype)
41
+ else:
42
+ return contextlib.nullcontext()
43
+
44
+ @classmethod
45
+ def init_Qformer(cls, model_name, num_query_token, graph_width, cross_attention_freq=2):
46
+ assert model_name == 'scibert'
47
+ print("bert load scibert")
48
+ if True:
49
+ bert_name = 'allenai/scibert_scivocab_uncased'
50
+ else:
51
+ bert_name = 'bert_pretrained/'
52
+
53
+
54
+ encoder_config = BertConfig.from_pretrained(bert_name)
55
+ encoder_config.encoder_width = graph_width
56
+ # insert cross-attention layer every other block
57
+ encoder_config.add_cross_attention = True
58
+ encoder_config.cross_attention_freq = cross_attention_freq
59
+ encoder_config.query_length = num_query_token
60
+
61
+ Qformer = BertLMHeadModel.from_pretrained(
62
+ bert_name, config=encoder_config
63
+ )
64
+ query_tokens = nn.Parameter(
65
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
66
+ )
67
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
68
+ return Qformer, query_tokens
69
+
70
+
71
+ @classmethod
72
+ def init_graph_encoder(
73
+ cls, gin_num_layers, gin_hidden_dim, gin_drop_ratio):
74
+ graph_encoder = GNN(
75
+ num_layer=gin_num_layers,
76
+ emb_dim=gin_hidden_dim,
77
+ gnn_type='gin',
78
+ drop_ratio=gin_drop_ratio,
79
+ JK='last',
80
+ )
81
+ ckpt = torch.load('gin_pretrained/graphcl_80.pth', map_location=torch.device('cpu'))
82
+ missing_keys, unexpected_keys = graph_encoder.load_state_dict(ckpt, strict=False)
83
+ if len(missing_keys) or len(unexpected_keys):
84
+ print(missing_keys)
85
+ print(unexpected_keys)
86
+
87
+ ln_graph = LayerNorm(graph_encoder.num_features)
88
+
89
+ return graph_encoder, ln_graph
90
+
91
+ def load_from_pretrained(self, url_or_filename):
92
+ if is_url(url_or_filename):
93
+ cached_file = download_cached_file(
94
+ url_or_filename, check_hash=False, progress=True
95
+ )
96
+ checkpoint = torch.load(cached_file, map_location="cpu")
97
+ elif os.path.isfile(url_or_filename):
98
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
99
+ else:
100
+ raise RuntimeError("checkpoint url or path is invalid")
101
+
102
+ state_dict = checkpoint["model"]
103
+
104
+ msg = self.load_state_dict(state_dict, strict=False)
105
+
106
+ # logging.info("Missing keys {}".format(msg.missing_keys))
107
+ logging.info("load checkpoint from %s" % url_or_filename)
108
+
109
+ return msg
110
+
111
+
112
+ def disabled_train(self, mode=True):
113
+ """Overwrite model.train with this function to make sure train/eval mode
114
+ does not change anymore."""
115
+ return self
116
+
117
+
118
+ class LayerNorm(nn.LayerNorm):
119
+ """Subclass torch's LayerNorm to handle fp16."""
120
+
121
+ def forward(self, x: torch.Tensor, mask=None):
122
+ orig_type = x.dtype
123
+ # ret = super().forward(x.type(torch.float32))
124
+ ret = super().forward(x)
125
+ return ret.type(orig_type)
126
+
model/blip2_llama.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.cuda.amp import autocast as autocast
11
+ from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel
12
+
13
+ from lavis.models.blip2_models.blip2 import (
14
+ # Blip2Base,
15
+ disabled_train,
16
+ )
17
+ from model.blip2 import Blip2Base
18
+ from transformers import LlamaTokenizer
19
+ from model.modeling_llama import LlamaForCausalLM
20
+
21
+
22
+
23
+ llama_model_list = [
24
+ "decapoda-research/llama-13b-hf",
25
+ "decapoda-research/llama-7b-hf",
26
+ ]
27
+
28
+ def mask_by_len(input, lens, fill_value=0):
29
+ '''
30
+ input: shape = [N, D]
31
+ lens: shape = [N]
32
+ '''
33
+ mask = torch.arange(input.shape[1], device=input.device).reshape(1, -1)
34
+ mask = mask < lens.reshape(-1, 1)
35
+ input[mask] = fill_value
36
+ return input
37
+
38
+ # @registry.register_model("blip2")
39
+ # @registry.register_model("blip2_feature_extractor")
40
+ class Blip2Llama(Blip2Base):
41
+ """
42
+ BLIP2 first-stage model with Q-former and ViT.
43
+ Supported model types:
44
+ - pretrained: pretrained model with vit-g
45
+ - pretrain_vitL: pretrained model with vit-large
46
+ - coco: fintuned model on coco
47
+ Usage:
48
+ >>> from lavis.models import load_model
49
+ >>> model = load_model("blip2", "pretrain")
50
+ """
51
+ def __init__(
52
+ self,
53
+ bert_name,
54
+ gin_num_layers,
55
+ gin_hidden_dim,
56
+ gin_drop_ratio,
57
+ tune_gnn=False,
58
+ num_query_token=32,
59
+ cross_attention_freq=2,
60
+ lora_tuning=False,
61
+ peft_dir='',
62
+ llm_model="decapoda-research/llama-7b-hf",
63
+ prompt="",
64
+ args=None,
65
+ ):
66
+ super().__init__()
67
+ self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio)
68
+ self.tune_gnn = tune_gnn
69
+ if not tune_gnn:
70
+ for name, param in self.graph_encoder.named_parameters():
71
+ param.requires_grad = False
72
+ self.graph_encoder = self.graph_encoder.eval()
73
+ self.graph_encoder.train = disabled_train
74
+ logging.info("freeze graph encoder")
75
+
76
+ self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq)
77
+ ### remove the unused parameters
78
+ self.Qformer.cls = None
79
+ self.Qformer.bert.embeddings.word_embeddings = None
80
+ self.Qformer.bert.embeddings.position_embeddings = None
81
+ for layer in self.Qformer.bert.encoder.layer:
82
+ layer.output = None
83
+ layer.intermediate = None
84
+
85
+ ## initialize opt model
86
+ self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_model, use_fast=False, padding_side='right')
87
+ self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
88
+ self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
89
+ self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
90
+ self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
91
+ self.llm_model = LlamaForCausalLM.from_pretrained(llm_model, torch_dtype=torch.bfloat16)
92
+ # self.llm_model = LlamaForCausalLM.from_pretrained(llm_model)
93
+ self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
94
+
95
+ self.lora_tuning = lora_tuning
96
+ if lora_tuning:
97
+ if peft_dir:
98
+ self.llm_model = PeftModel.from_pretrained(self.llm_model, peft_dir, is_trainable=True)
99
+ else:
100
+ peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
101
+ self.llm_model = get_peft_model(self.llm_model, peft_config)
102
+ self.llm_model.print_trainable_parameters()
103
+ else:
104
+ for name, param in self.llm_model.named_parameters():
105
+ param.requires_grad = False
106
+
107
+ ## fixme: this is different from the original BLIP2
108
+ self.eos_token_id = self.llm_tokenizer(
109
+ "\n", add_special_tokens=False
110
+ ).input_ids[0]
111
+ self.pad_token_id = self.llm_tokenizer.pad_token_id
112
+
113
+ self.llm_proj = nn.Linear(
114
+ self.Qformer.config.hidden_size, self.llm_model.config.hidden_size
115
+ )
116
+
117
+ ## fixme: no prompt yet
118
+ self.prompt = prompt
119
+ # prompt_tokens = self.opt_tokenizer(self.prompt, return_tensors="pt")
120
+ # self.prompt_length = prompt_tokens.attention_mask.sum(1)
121
+
122
+ def forward(self, batch):
123
+ graphs, text_tokens, prompt_lens = batch
124
+ graph_embeds, graph_masks = self.graph_encoder(graphs)
125
+ if not self.tune_gnn:
126
+ graph_embeds = graph_embeds.detach()
127
+ graph_embeds = self.ln_graph(graph_embeds, graph_masks)
128
+ device = graph_embeds.device
129
+ query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
130
+ query_output = self.Qformer.bert(
131
+ query_embeds=query_tokens,
132
+ encoder_hidden_states=graph_embeds,
133
+ encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
134
+ return_dict=True,
135
+ )
136
+ inputs_llm = self.llm_proj(query_output.last_hidden_state)
137
+ atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(device)
138
+ targets = text_tokens.input_ids.masked_fill(
139
+ text_tokens.input_ids == self.llm_tokenizer.pad_token_id, -100
140
+ )
141
+ if self.prompt:
142
+ targets = mask_by_len(targets, prompt_lens, -100) # do not apply loss to the prompt
143
+ # targets[:, : self.prompt_length] = -100 # do not apply loss to the prompt
144
+
145
+ empty_targets = (
146
+ torch.ones(atts_llm.size(), dtype=torch.long).to(device).fill_(-100)
147
+ )
148
+ targets = torch.cat([empty_targets, targets], dim=1)
149
+ # if self.lora_tuning:
150
+ # inputs_embeds = self.llm_model.model.get_decoder().embed_tokens(text_tokens.input_ids)
151
+ # else:
152
+ # inputs_embeds = self.llm_model.model.decoder.embed_tokens(text_tokens.input_ids)
153
+ inputs_embeds = self.llm_model.get_input_embeddings()(text_tokens.input_ids)
154
+ inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
155
+ attention_mask = torch.cat([atts_llm, text_tokens.attention_mask], dim=1)
156
+
157
+ outputs = self.llm_model(
158
+ inputs_embeds=inputs_embeds,
159
+ attention_mask=attention_mask,
160
+ return_dict=True,
161
+ labels=targets,
162
+ # use_cache=False,
163
+ )
164
+ loss = outputs.loss
165
+ return {"loss": loss}
166
+
167
+ @torch.no_grad()
168
+ def generate(
169
+ self,
170
+ samples,
171
+ do_sample=False,
172
+ num_beams=5,
173
+ max_length=128,
174
+ min_length=1,
175
+ top_p=0.9,
176
+ repetition_penalty=1.0,
177
+ length_penalty=1.0,
178
+ num_captions=1,
179
+ temperature=1,
180
+ ):
181
+ """
182
+ Args:
183
+ samples (dict): A dictionary containing the following keys:
184
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
185
+ num_beams (int): Number of beams for beam search. 1 means no beam search.
186
+ max_length (int): The maximum length of the sequence to be generated.
187
+ min_length (int): The minimum length of the sequence to be generated.
188
+ top_p (float): The cumulative probability for nucleus sampling.
189
+ repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
190
+ num_captions (int): Number of captions to be generated for each image.
191
+ Returns:
192
+ captions (list): A list of strings of length batch_size * num_captions.
193
+ """
194
+ graphs = samples['graphs']
195
+ prompt_tokens = samples['prompt_tokens']
196
+ # prompt_lens = samples['prompt_lens']
197
+ with self.maybe_autocast():
198
+ graph_embeds, graph_masks = self.graph_encoder(graphs)
199
+ graph_embeds = self.ln_graph(graph_embeds)
200
+
201
+ query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
202
+ query_output = self.Qformer.bert(
203
+ query_embeds=query_tokens,
204
+ encoder_hidden_states=graph_embeds,
205
+ encoder_attention_mask=graph_masks,
206
+ return_dict=True,
207
+ )
208
+
209
+ device = graph_embeds.device
210
+ inputs_llm = self.llm_proj(query_output.last_hidden_state)
211
+ atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long, device=device)
212
+
213
+ attention_mask = torch.cat([atts_llm, prompt_tokens.attention_mask], dim=1)
214
+
215
+ if False:
216
+ if do_sample:
217
+ query_embeds = inputs_llm.repeat_interleave(num_captions, dim=0)
218
+ num_beams = 1
219
+ else:
220
+ query_embeds = inputs_llm.repeat_interleave(num_beams, dim=0)
221
+
222
+ outputs = self.llm_model.generate(
223
+ input_ids=prompt_tokens.input_ids,
224
+ query_embeds=query_embeds,
225
+ attention_mask=attention_mask,
226
+ do_sample=do_sample,
227
+ top_p=top_p,
228
+ temperature=temperature,
229
+ num_beams=num_beams,
230
+ max_new_tokens=max_length,
231
+ min_length=min_length,
232
+ eos_token_id=self.eos_token_id,
233
+ repetition_penalty=repetition_penalty,
234
+ length_penalty=length_penalty,
235
+ num_return_sequences=num_captions,
236
+ )
237
+
238
+ prompt_length = prompt_tokens.input_ids.shape[1]
239
+ output_text = self.opt_tokenizer.batch_decode(
240
+ outputs[:, prompt_length:], skip_special_tokens=True
241
+ )
242
+ else:
243
+ inputs_embeds = self.llm_model.get_input_embeddings()(prompt_tokens.input_ids)
244
+ inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
245
+ attention_mask = torch.cat([atts_llm, prompt_tokens.attention_mask], dim=1)
246
+
247
+ outputs = self.llm_model.generate(
248
+ inputs_embeds=inputs_embeds,
249
+ attention_mask=attention_mask,
250
+ do_sample=do_sample,
251
+ top_p=top_p,
252
+ temperature=temperature,
253
+ num_beams=num_beams,
254
+ max_length=max_length,
255
+ min_length=min_length,
256
+ pad_token_id=self.pad_token_id,
257
+ eos_token_id=self.eos_token_id,
258
+ repetition_penalty=repetition_penalty,
259
+ length_penalty=length_penalty,
260
+ num_return_sequences=num_captions,
261
+ # use_cache=False,
262
+ )
263
+ # outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
264
+ output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
265
+ output_text = [text.strip() for text in output_text]
266
+ return output_text
model/blip2_model.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict
3
+ import torch
4
+ from model.blip2_opt import Blip2OPT
5
+ from model.blip2_llama import Blip2Llama
6
+ from model.blip2_t5 import Blip2T5
7
+ import pytorch_lightning as pl
8
+ from torch import optim
9
+ from lavis.common.optims import LinearWarmupCosineLRScheduler, LinearWarmupStepLRScheduler
10
+ import json
11
+ from model.opt_flash_attention import replace_opt_attn_with_flash_attn, replace_opt_attn_with_original_attn
12
+ import torch.distributed as dist
13
+ from peft import LoraConfig, TaskType
14
+ from model.help_funcs import caption_evaluate, AttrDict
15
+ from transformers import Adafactor
16
+ from torch_ema import ExponentialMovingAverage
17
+
18
+ def load_ignore_unexpected(model, state_dict):
19
+ keys = set(model.state_dict().keys())
20
+ state_dict = {k: v for k, v in state_dict.items() if k in keys}
21
+
22
+ ## try to print keys that are not included
23
+ model.load_state_dict(state_dict, strict=True)
24
+
25
+
26
+ # def load_ignore_mismatch(model, state_dict):
27
+ # keys = set(model.state_dict().keys())
28
+ # extra_keys = set()
29
+ # for key in state_dict:
30
+ # if key not in keys:
31
+ # extra_keys.add(key)
32
+ # missing_keys = set()
33
+ # for key in keys:
34
+ # if key not in state_dict:
35
+ # missing_keys.add(key)
36
+ # ## try to print keys that are not included
37
+ # model.load_state_dict(state_dict, strict=False)
38
+
39
+
40
+ def get_module_state_dict(state_dict, module_name):
41
+ module_state_dict = {}
42
+ for key, value in state_dict.items():
43
+ if key.startswith(module_name):
44
+ key = key[len(module_name) + 1:]
45
+ if key == '':
46
+ return value
47
+ module_state_dict[key] = value
48
+ return module_state_dict
49
+ # peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
50
+ class Blip2Model(pl.LightningModule):
51
+ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
52
+ if self.llm_tune != 'full':
53
+ to_be_removed = []
54
+ for key in checkpoint['state_dict']:
55
+ if key.startswith('blip2opt.opt_model') or key.startswith('blip2opt.llm_model'):
56
+ to_be_removed.append(key)
57
+ for key in to_be_removed:
58
+ checkpoint['state_dict'].pop(key)
59
+ if isinstance(self.args.save_every_n_epochs, int) and self.args.save_every_n_epochs > 0:
60
+ if self.llm_tune == 'lora' and (self.current_epoch + 1) % self.args.save_every_n_epochs == 0:
61
+ if self.local_rank == 0: # manually fix a bug in peft module
62
+ if self.args.peft_config:
63
+ peft_config = LoraConfig(**LoraConfig.from_json_file(self.args.peft_config))
64
+ else:
65
+ peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=self.args.lora_r, lora_alpha=self.args.lora_alpha, lora_dropout=self.args.lora_dropout)
66
+ if hasattr(self.blip2opt, 'opt_model'):
67
+ self.blip2opt.opt_model.peft_config['default'] = peft_config
68
+ self.blip2opt.opt_model.save_pretrained(os.path.join(self.logger.save_dir, f'lora_epoch_{self.current_epoch}'))
69
+ elif hasattr(self.blip2opt, 'llm_model'):
70
+ self.blip2opt.llm_model.peft_config['default'] = peft_config
71
+ self.blip2opt.llm_model.save_pretrained(os.path.join(self.logger.save_dir, f'lora_epoch_{self.current_epoch}'))
72
+ return super().on_save_checkpoint(checkpoint)
73
+
74
+ def __init__(self, args):
75
+ super().__init__()
76
+ if isinstance(args, dict):
77
+ args = AttrDict(**args)
78
+
79
+ self.args = args
80
+ if not hasattr(args, 'do_sample'):
81
+ args.do_sample = False
82
+ self.caption_eval_epoch = args.caption_eval_epoch
83
+ self.do_sample = args.do_sample
84
+ self.num_beams = args.num_beams
85
+ self.max_inference_len = args.max_inference_len
86
+ self.min_inference_len = args.min_inference_len
87
+ self.num_generate_captions = args.num_generate_captions
88
+ self.reaction_weight = args.reaction_weight
89
+ self.llm_tune = args.llm_tune
90
+ self.enable_flash = args.enable_flash
91
+ if args.opt_model.find('galactica') >= 0:
92
+ self.blip2opt = Blip2OPT(args.bert_name, args.gin_num_layers, args.gin_hidden_dim, args.drop_ratio, args.tune_gnn, not args.not_tune_qformer, args.num_query_token, args.cross_attention_freq, args.llm_tune, args.peft_dir, args.opt_model, args.prompt, args)
93
+ elif args.opt_model.find('llama') >= 0 or args.opt_model.find('vicuna') >= 0:
94
+ self.blip2opt = Blip2Llama(args.bert_name, args.gin_num_layers, args.gin_hidden_dim, args.drop_ratio, args.tune_gnn, args.num_query_token, args.cross_attention_freq, args.llm_tune, args.peft_dir, args.opt_model, args.prompt, args)
95
+ elif args.opt_model.find('t5') >= 0:
96
+ self.blip2opt = Blip2T5(args.bert_name, args.gin_num_layers, args.gin_hidden_dim, args.drop_ratio, args.tune_gnn, args.num_query_token, args.cross_attention_freq, args.llm_tune, args.peft_dir, args.opt_model, args.prompt, args)
97
+ else:
98
+ raise NotImplementedError()
99
+ self.tokenizer = self.blip2opt.init_tokenizer()
100
+ self.mode = args.mode
101
+ self.downstream_task = args.downstream_task
102
+ self.save_hyperparameters(args)
103
+ self.save_ema_checkpoint = args.save_ema_checkpoint
104
+ if self.save_ema_checkpoint:
105
+ self.ema = ExponentialMovingAverage(self.parameters(), 0.99)
106
+ self.save_on_steps = args.save_on_steps
107
+
108
+ def load_from_stage1_checkpoint(self, path):
109
+ ckpt = torch.load(path, map_location='cpu')
110
+ state_dict = ckpt['state_dict']
111
+ graph_encoder_dict = get_module_state_dict(state_dict, 'blip2qformer.graph_encoder')
112
+ qformer_dict = get_module_state_dict(state_dict, 'blip2qformer.Qformer')
113
+ ln_graph_dict = get_module_state_dict(state_dict, 'blip2qformer.ln_graph')
114
+ qs_weight = get_module_state_dict(state_dict, 'blip2qformer.query_tokens')
115
+ load_ignore_unexpected(self.blip2opt.Qformer, qformer_dict)
116
+ self.blip2opt.graph_encoder.load_state_dict(graph_encoder_dict)
117
+ self.blip2opt.ln_graph.load_state_dict(ln_graph_dict)
118
+ self.blip2opt.query_tokens.data.copy_(qs_weight)
119
+ return self
120
+
121
+ # def load_from_stage1_checkpoint(self, path):
122
+ # ckpt = torch.load(path, map_location='cpu')
123
+ # state_dict = ckpt['state_dict']
124
+ # state_dict = {k[13:]: v for k,v in state_dict.items()}
125
+ # load_ignore_mismatch(self.blip2opt, state_dict)
126
+ # return self
127
+
128
+ def configure_optimizers(self):
129
+ if self.args.optimizer == 'adafactor':
130
+ print('Using adafactor optimizer')
131
+ optimizer = Adafactor(
132
+ self.parameters(),
133
+ lr=1e-3,
134
+ relative_step=False,
135
+ scale_parameter=False,
136
+ warmup_init=False
137
+ )
138
+ self.scheduler = None
139
+ else:
140
+ self.trainer.fit_loop.setup_data()
141
+ # self.trainer.reset_train_dataloader()
142
+ warmup_steps = min(len(self.trainer.train_dataloader), self.args.warmup_steps)
143
+ optimizer = optim.AdamW(self.parameters(), lr=self.args.init_lr, weight_decay=self.args.weight_decay)
144
+ if self.args.scheduler == 'linear_warmup_cosine_lr':
145
+ self.scheduler = LinearWarmupCosineLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, warmup_steps, self.args.warmup_lr)
146
+ elif self.args.scheduler == 'linear_warmup_step_lr':
147
+ self.scheduler = LinearWarmupStepLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, self.args.lr_decay_rate, self.args.warmup_lr, warmup_steps)
148
+ elif self.args.scheduler == 'None':
149
+ self.scheduler = None
150
+ else:
151
+ raise NotImplementedError()
152
+ return optimizer
153
+
154
+ def test_epoch_end(self, outputs):
155
+ print('test epoch end')
156
+ list_ids, list_predictions, list_targets = zip(*outputs)
157
+ predictions = [i for ii in list_predictions for i in ii]
158
+ targets = [i for ii in list_targets for i in ii]
159
+
160
+ all_ids = [None for _ in range(self.trainer.world_size)]
161
+ all_predictions = [None for _ in range(self.trainer.world_size)]
162
+ all_targets = [None for _ in range(self.trainer.world_size)]
163
+
164
+ dist.all_gather_object(all_ids, list_ids)
165
+ dist.all_gather_object(all_predictions, predictions)
166
+ dist.all_gather_object(all_targets, targets)
167
+ print(len(all_ids), len(all_predictions), len(all_targets))
168
+ if self.global_rank == 0:
169
+ print(f'saveing predictions to {self.logger.log_dir}')
170
+
171
+ all_predictions = [i for ii in all_predictions for i in ii]
172
+ all_targets = [i for ii in all_targets for i in ii]
173
+ self.save_predictions(all_ids, all_predictions, all_targets)
174
+ ## fixme: I am not sure if the max length is the same as previous experiments
175
+ bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score = \
176
+ caption_evaluate(all_predictions, all_targets, self.tokenizer, self.max_inference_len * 2)
177
+ self.log("bleu2", bleu2, sync_dist=False)
178
+ self.log("bleu4", bleu4, sync_dist=False)
179
+ self.log("rouge_1", rouge_1, sync_dist=False)
180
+ self.log("rouge_2", rouge_2, sync_dist=False)
181
+ self.log("rouge_l", rouge_l, sync_dist=False)
182
+ self.log("meteor_score", meteor_score, sync_dist=False)
183
+
184
+ def save_predictions(self, rxn_ids, predictions, targets):
185
+ assert False
186
+ assert len(rxn_ids) == len(targets)
187
+ assert len(predictions) == len(targets)
188
+ with open(os.path.join(self.logger.log_dir, 'predictions.txt'), 'w', encoding='utf8') as f:
189
+ for i, p, t in zip(rxn_ids, predictions, targets):
190
+ line = {'index': i, 'prediction': p, 'target': t}
191
+ f.write(json.dumps(line, ensure_ascii=False) + '\n')
192
+
193
+ @torch.no_grad()
194
+ def test_step(self, batch, batch_idx):
195
+ assert False
196
+
197
+ def gather_dict_results(self, dict_list):
198
+ list_of_dict_list = [None for _ in range(self.trainer.world_size)]
199
+ dist.all_gather_object(list_of_dict_list, dict_list)
200
+ dict_list = [i for ii in list_of_dict_list for i in ii] ## dict list, each dict has values that are lists of predictions, etc.
201
+ keys = dict_list[0].keys()
202
+ gathered_dict = {} # each value is a list of predictions, etc.
203
+ for key in keys:
204
+ gathered_dict[key] = [i for d in dict_list for i in d[key]]
205
+ if self.num_generate_captions>1:
206
+ M = self.num_generate_captions
207
+ N = len(gathered_dict['index'])
208
+ assert len(gathered_dict['predictions'])==N*M
209
+ gathered_dict['predictions'] = [
210
+ gathered_dict['predictions'][i * M:(i + 1) * M]
211
+ for i in range(N)
212
+ ]
213
+ dict_list = []
214
+ for i in range(len(gathered_dict['predictions'])):
215
+ d = {k:gathered_dict[k][i] for k in keys}
216
+ dict_list.append(d)
217
+ return dict_list
218
+
219
+ def save_results(self, dict_list, log_prefix=""):
220
+ if log_prefix:
221
+ name = f'{log_prefix}_predictions.txt'
222
+ else:
223
+ name = 'predictions.txt'
224
+ with open(os.path.join(self.logger.log_dir, name), 'w', encoding='utf8') as f:
225
+ for i in range(len(dict_list)):
226
+ f.write(json.dumps(dict_list[i], ensure_ascii=True) + '\n')
227
+
228
+ def on_validation_epoch_start(self):
229
+ if self.enable_flash:
230
+ replace_opt_attn_with_original_attn()
231
+ self.saved_dict_list = []
232
+
233
+ def on_validation_epoch_end(self):
234
+ if self.enable_flash:
235
+ replace_opt_attn_with_flash_attn()
236
+ if (self.current_epoch+1) % self.caption_eval_epoch != 0:
237
+ return
238
+ result_list = self.gather_dict_results(self.saved_dict_list)
239
+ ## empty cache
240
+ self.saved_dict_list = []
241
+ if self.global_rank == 0:
242
+ self.save_results(result_list, 'epoch_{}'.format(self.current_epoch))
243
+ if self.downstream_task == 'synthesis':
244
+ return
245
+ all_predictions = [i['predictions'] for i in result_list]
246
+ all_targets = [i['targets'] for i in result_list]
247
+ bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score = \
248
+ caption_evaluate(all_predictions, all_targets, self.tokenizer, self.max_inference_len * 2)
249
+ self.log("bleu2", bleu2, sync_dist=False)
250
+ self.log("bleu4", bleu4, sync_dist=False)
251
+ self.log("rouge_1", rouge_1, sync_dist=False)
252
+ self.log("rouge_2", rouge_2, sync_dist=False)
253
+ self.log("rouge_l", rouge_l, sync_dist=False)
254
+ self.log("meteor_score", meteor_score, sync_dist=False)
255
+
256
+ @torch.no_grad()
257
+ def validation_step(self, batch, batch_idx, dataloader_idx=1):
258
+ if dataloader_idx == 0:
259
+ return
260
+ elif dataloader_idx == 1:
261
+ if (self.current_epoch+1) % self.caption_eval_epoch != 0:
262
+ return
263
+ rxn_ids, graphs, prompt_tokens, texts, inputs = batch
264
+ ###============== Captioning Results ===================###
265
+ samples = {'graphs': graphs, 'prompt_tokens': prompt_tokens}
266
+ if self.mode in {'ft', 'eval', 'pretrain_eval'}:
267
+ predictions = self.blip2opt.generate(
268
+ samples,
269
+ do_sample=self.do_sample,
270
+ num_beams=self.num_beams,
271
+ max_length=self.max_inference_len,
272
+ min_length=self.min_inference_len,
273
+ num_captions=self.num_generate_captions,
274
+ use_graph=not self.args.disable_graphs
275
+ )
276
+ else:
277
+ raise NotImplementedError()
278
+ self.saved_dict_list.append({
279
+ 'index': rxn_ids,
280
+ 'input': inputs,
281
+ 'predictions': predictions,
282
+ 'targets': texts
283
+ })
284
+ else:
285
+ raise NotImplementedError
286
+
287
+ def on_train_start(self):
288
+ if hasattr(self, 'ema'):
289
+ self.ema.to(self.device)
290
+
291
+ def on_before_zero_grad(self, *args, **kwargs):
292
+ if self.save_ema_checkpoint:
293
+ if self.trainer.global_step % 100 == 0:
294
+ self.ema.update(self.parameters())
295
+ if self.trainer.global_step in self.save_on_steps:
296
+ checkpoint_path = os.path.join(f"all_checkpoints/{self.args.filename}/", f'step{self.trainer.global_step}.ckpt')
297
+ self.trainer.save_checkpoint(checkpoint_path)
298
+
299
+ def on_train_epoch_end(self):
300
+ save_every_n_epochs = self.args.save_every_n_epochs if self.args.save_every_n_epochs > 0 else self.args.max_epochs
301
+ if (self.current_epoch + 1) % save_every_n_epochs != 0:
302
+ return
303
+ if self.save_ema_checkpoint:
304
+ with self.ema.average_parameters():
305
+ checkpoint_path = os.path.join(f"all_checkpoints/{self.args.filename}/", f'ema_epoch{self.current_epoch}.ckpt')
306
+ self.trainer.save_checkpoint(checkpoint_path)
307
+
308
+ def training_step(self, batch, batch_idx):
309
+ if self.scheduler:
310
+ self.scheduler.step(self.trainer.current_epoch, self.trainer.global_step)
311
+
312
+ batch_size = batch[-1].input_ids.size(0)
313
+ ###============== Overall Loss ===================###
314
+ if self.mode == 'ft':
315
+ loss = self.blip2opt.forward_action(batch, use_gragh=not self.args.disable_graphs)
316
+ elif self.mode == 'pretrain':
317
+ loss = self.blip2opt.forward_abstract(batch, use_gragh=not self.args.disable_graphs)
318
+ else:
319
+ raise NotImplementedError()
320
+ self.log("molecule loss", float(loss['loss']), batch_size=batch_size, sync_dist=True, prog_bar=True)
321
+ self.log("lr", self.trainer.optimizers[0].param_groups[0]['lr'], batch_size=batch_size, sync_dist=True, prog_bar=True)
322
+ return loss['loss']
323
+
324
+ @staticmethod
325
+ def add_model_specific_args(parent_parser):
326
+ parser = parent_parser.add_argument_group("GINSimclr")
327
+ # train mode
328
+ # GIN
329
+ parser.add_argument('--gin_hidden_dim', type=int, default=300)
330
+ parser.add_argument('--gin_num_layers', type=int, default=5)
331
+ parser.add_argument('--drop_ratio', type=float, default=0.0)
332
+ parser.add_argument('--tune_gnn', action='store_true', default=False)
333
+ parser.add_argument('--not_tune_qformer', action='store_true', default=False)
334
+ parser.add_argument('--disable_graphs', action='store_true', default=False)
335
+ # Bert
336
+ parser.add_argument('--bert_hidden_dim', type=int, default=2048, help='')
337
+ parser.add_argument('--bert_name', type=str, default='scibert')
338
+ parser.add_argument('--cross_attention_freq', type=int, default=2)
339
+ parser.add_argument('--num_query_token', type=int, default=8)
340
+ # OPT
341
+ parser.add_argument('--opt_model', type=str, default="facebook/galactica-1.3b")
342
+ # parser.add_argument('--prompt', type=str, default='a molecule of ')
343
+ parser.add_argument('--num_beams', type=int, default=5)
344
+ parser.add_argument('--do_sample', action='store_true', default=False)
345
+ parser.add_argument('--max_inference_len', type=int, default=512)
346
+ parser.add_argument('--min_inference_len', type=int, default=8)
347
+ parser.add_argument('--llm_tune', type=str, default='freeze')
348
+ parser.add_argument('--peft_config', type=str, default=None)
349
+ parser.add_argument('--peft_dir', type=str, default='')
350
+
351
+ parser.add_argument('--save_every_n_epochs', type=int, default=0)
352
+ ## quantization
353
+ parser.add_argument('--load_in_8bit', action='store_true', default=False)
354
+
355
+ ## lora config
356
+ parser.add_argument('--lora_r', type=int, default=8)
357
+ parser.add_argument('--lora_alpha', type=int, default=32)
358
+ parser.add_argument('--lora_dropout', type=int, default=0.1)
359
+
360
+ # optimization
361
+ parser.add_argument('--reaction_weight', type=float, default=1.0)
362
+ parser.add_argument('--weight_decay', type=float, default=0.05, help='optimizer weight decay')
363
+ parser.add_argument('--init_lr', type=float, default=1e-4, help='optimizer init learning rate')
364
+ parser.add_argument('--min_lr', type=float, default=1e-5, help='optimizer min learning rate')
365
+ parser.add_argument('--warmup_lr', type=float, default=1e-6, help='optimizer warmup learning rate')
366
+ parser.add_argument('--warmup_steps', type=int, default=1000, help='optimizer warmup steps')
367
+ parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='optimizer lr decay rate')
368
+ parser.add_argument('--scheduler', type=str, default='linear_warmup_cosine_lr', help='type of scheduler') # or linear_warmup_step_lr
369
+ parser.add_argument('--optimizer', type=str, default='adamw', help='type of scheduler')
370
+ parser.add_argument('--init_checkpoint', type=str, default='')
371
+ parser.add_argument('--caption_eval_epoch', type=int, default=10)
372
+ parser.add_argument('--num_generate_captions', type=int, default=1)
373
+
374
+ # OPT Config
375
+ parser.add_argument('--optconfig_attention_dropout', type=float, default=0.0)
376
+ parser.add_argument('--optconfig_dropout', type=float, default=0.0)
377
+
378
+ # others
379
+ parser.add_argument('--save_ema_checkpoint', action='store_true', default=False)
380
+ parser.add_argument('--save_on_steps', default=[], nargs='+', type=int)
381
+ return parent_parser
model/blip2_opt.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.cuda.amp import autocast as autocast
11
+ from torch.nn import functional as F
12
+ from torch.nn import CrossEntropyLoss
13
+ from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel
14
+ from ogb.utils import smiles2graph
15
+ from torch_geometric.loader.dataloader import Collater
16
+ from torch_geometric.data import Data
17
+ import numpy as np
18
+ from lavis.models.blip2_models.blip2 import (
19
+ # Blip2Base,
20
+ disabled_train,
21
+ )
22
+ from model.blip2 import Blip2Base
23
+ from model.help_funcs import get_not_allowed_tokens_ids
24
+ from transformers import AutoTokenizer
25
+ from transformers import OPTForCausalLM, OPTConfig
26
+ # from opendelta import LoraModel
27
+ # from opendelta.delta_models.lora import LoraConfig
28
+ # from opendelta.delta_configs
29
+
30
+ opt_model_list = [
31
+ "facebook/galactica-125m",
32
+ "facebook/galactica-1.3b",
33
+ "facebook/galactica-6.7b",
34
+ "facebook/galactica-30b",
35
+ ]
36
+
37
+ def mask_by_len(input, lens, fill_value=0):
38
+ '''
39
+ input: shape = [N, D]
40
+ lens: shape = [N]
41
+ '''
42
+ mask = torch.arange(input.shape[1], device=input.device).reshape(1, -1)
43
+ mask = mask < lens.reshape(-1, 1)
44
+ input[mask] = fill_value
45
+ return input
46
+
47
+
48
+ def smiles2data(smiles):
49
+ graph = smiles2graph(smiles)
50
+ x = torch.from_numpy(graph['node_feat'])
51
+ edge_index = torch.from_numpy(graph['edge_index'], )
52
+ edge_attr = torch.from_numpy(graph['edge_feat'])
53
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
54
+ return data
55
+
56
+ import re
57
+ SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
58
+
59
+ CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
60
+
61
+
62
+ def _insert_split_marker(m: re.Match):
63
+ """
64
+ Applies split marker based on a regex match of special tokens such as
65
+ [START_DNA].
66
+
67
+ Parameters
68
+ ----------
69
+ n : str
70
+ Input text to split
71
+
72
+ Returns
73
+ ----------
74
+ str - the text with the split token added
75
+ """
76
+ start_token, _, sequence, end_token = m.groups()
77
+ sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
78
+ return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
79
+
80
+ def escape_custom_split_sequence(text):
81
+ """
82
+ Applies custom splitting to the text for GALILEO's tokenization
83
+
84
+ Parameters
85
+ ----------
86
+ text : str
87
+ Input text to split
88
+
89
+ Returns
90
+ ----------
91
+ str - the text with the split token added
92
+ """
93
+ return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
94
+
95
+ def smiles_handler(text, mol_ph):
96
+ smiles_list = []
97
+ for match in CUSTOM_SEQ_RE.finditer(text):
98
+ smiles = match.group(3)
99
+ smiles_list.append(smiles)
100
+
101
+ text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text)
102
+ text = escape_custom_split_sequence(text)
103
+ return text, smiles_list
104
+
105
+
106
+ class Blip2OPT(Blip2Base):
107
+ """
108
+ BLIP2 first-stage model with Q-former and ViT.
109
+ Supported model types:
110
+ - pretrained: pretrained model with vit-g
111
+ - pretrain_vitL: pretrained model with vit-large
112
+ - coco: fintuned model on coco
113
+ Usage:
114
+ >>> from lavis.models import load_model
115
+ >>> model = load_model("blip2", "pretrain")
116
+ """
117
+ def __init__(
118
+ self,
119
+ bert_name,
120
+ gin_num_layers,
121
+ gin_hidden_dim,
122
+ gin_drop_ratio,
123
+ tune_gnn=False,
124
+ tune_qformer=False,
125
+ num_query_token=32,
126
+ cross_attention_freq=2,
127
+ llm_tune='freeze',
128
+ peft_dir='',
129
+ opt_model="facebook/galactica-1.3b",
130
+ prompt="",
131
+ args=None,
132
+ ):
133
+ super().__init__()
134
+ self.args = args
135
+
136
+ self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio)
137
+ self.tune_gnn = tune_gnn
138
+ self.tune_qformer = tune_qformer
139
+ if not tune_gnn:
140
+ for name, param in self.graph_encoder.named_parameters():
141
+ param.requires_grad = False
142
+ self.graph_encoder = self.graph_encoder.eval()
143
+ self.graph_encoder.train = disabled_train
144
+ logging.info("freeze graph encoder")
145
+ else:
146
+ logging.info("tune graph encoder")
147
+
148
+ self.num_query_token = num_query_token
149
+ self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq)
150
+ if not tune_qformer:
151
+ for name, param in self.Qformer.named_parameters():
152
+ param.requires_grad = False
153
+ self.Qformer = self.Qformer.eval()
154
+ self.Qformer.train = disabled_train
155
+ self.query_tokens.requires_grad = False
156
+ logging.info("freeze qformer encoder")
157
+ else:
158
+ logging.info("tune qformer encoder")
159
+ ### remove the unused parameters
160
+ self.Qformer.cls = None
161
+ self.Qformer.bert.embeddings.word_embeddings = None
162
+ self.Qformer.bert.embeddings.position_embeddings = None
163
+ for layer in self.Qformer.bert.encoder.layer:
164
+ layer.output = None
165
+ layer.intermediate = None
166
+
167
+ opt_config_params = {k[len("optconfig_"):]: v for k, v in vars(args).items() if k.startswith("optconfig_")}
168
+ config = OPTConfig.from_pretrained(opt_model, **opt_config_params)
169
+ ## initialize opt model
170
+ self.opt_tokenizer = AutoTokenizer.from_pretrained(opt_model, use_fast=False, padding_side='right')
171
+ self.opt_tokenizer.add_special_tokens({'pad_token': '<pad>'})
172
+ self.opt_tokenizer.add_tokens('<mol>') # molecule placeholder
173
+ self.mol_token = '<mol>'
174
+ self.opt_tokenizer.mol_token_id = self.opt_tokenizer("<mol>", add_special_tokens=False).input_ids[0]
175
+
176
+ self.collater = Collater([], [])
177
+
178
+ if opt_model == 'facebook/galactica-125m':
179
+ self.opt_model = OPTForCausalLM.from_pretrained(opt_model, config=config)
180
+ else:
181
+ if torch.cuda.is_bf16_supported():
182
+ self.opt_model = OPTForCausalLM.from_pretrained(opt_model, torch_dtype=torch.bfloat16, config=config)
183
+ else:
184
+ self.opt_model = OPTForCausalLM.from_pretrained(opt_model, torch_dtype=torch.float16, config=config)
185
+ self.opt_model.resize_token_embeddings(len(self.opt_tokenizer)) ## this will cause bug when full fine-tuning the opt model
186
+
187
+ self.llm_tune = llm_tune
188
+ if llm_tune == 'lora':
189
+ if peft_dir:
190
+ self.opt_model = PeftModel.from_pretrained(self.opt_model, peft_dir, is_trainable=True)
191
+ else:
192
+ if self.args.peft_config:
193
+ peft_config = LoraConfig(**LoraConfig.from_json_file(self.args.peft_config))
194
+ else:
195
+ peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout)
196
+ self.peft_config = peft_config
197
+ self.opt_model = get_peft_model(self.opt_model, peft_config)
198
+ self.opt_model.print_trainable_parameters()
199
+ elif llm_tune == 'freeze':
200
+ for name, param in self.opt_model.named_parameters():
201
+ param.requires_grad = False
202
+ elif llm_tune == 'full':
203
+ pass
204
+ else:
205
+ raise NotImplementedError()
206
+
207
+ ## fixme: this is different from the original BLIP2
208
+ if args.mode=='pretrain_eval':
209
+ self.eos_token_id = self.opt_tokenizer(
210
+ "[START_SMILES]\n", add_special_tokens=False
211
+ ).input_ids
212
+ else:
213
+ self.eos_token_id = self.opt_tokenizer(
214
+ "\n", add_special_tokens=False
215
+ ).input_ids[0]
216
+
217
+ self.opt_proj = nn.Linear(
218
+ self.Qformer.config.hidden_size, self.opt_model.config.hidden_size
219
+ )
220
+
221
+ ## fixme: no prompt yet
222
+ self.prompt = prompt
223
+ self.rxn_batch_size = args.rxn_batch_size
224
+ self.generate_restrict_tokens = args.generate_restrict_tokens
225
+ self.train_restrict_tokens = args.train_restrict_tokens
226
+ if self.generate_restrict_tokens or self.train_restrict_tokens:
227
+ self.bad_words_ids = get_not_allowed_tokens_ids(opt_model)
228
+ # prompt_tokens = self.opt_tokenizer(self.prompt, return_tensors="pt")
229
+ # self.prompt_length = prompt_tokens.attention_mask.sum(1)
230
+
231
+ def opt_forward_v2(
232
+ self,
233
+ inputs_embeds,
234
+ attention_mask,
235
+ labels,
236
+ bad_word_ids=None,
237
+ ):
238
+ output = self.opt_model(
239
+ inputs_embeds=inputs_embeds,
240
+ attention_mask=attention_mask,
241
+ return_dict=True,
242
+ labels=labels,
243
+ )
244
+ logits = output.logits
245
+ labels = labels.to(logits.device)
246
+ # Shift so that tokens < n predict n
247
+
248
+ if bad_word_ids:
249
+ bad_word_ids = torch.tensor(bad_word_ids, device=logits.device, dtype=torch.long)
250
+ bad_word_ids = bad_word_ids.squeeze()
251
+ logits[:, :, bad_word_ids] = -100
252
+
253
+ shift_logits = logits[..., :-1, :].contiguous()
254
+ shift_labels = labels[..., 1:].contiguous()
255
+ shift_logits = shift_logits.view(-1, self.opt_model.config.vocab_size)
256
+ loss_fct = CrossEntropyLoss()
257
+ loss = loss_fct(shift_logits, shift_labels.view(-1))
258
+ return loss
259
+
260
+ def forward_action(self, batch, use_gragh=True):
261
+ # batch unpack
262
+ rxn_ids, graphs, text_tokens = batch
263
+ if use_gragh:
264
+ graph_embeds, graph_masks = self.graph_encoder(graphs)
265
+ if not self.tune_gnn:
266
+ graph_embeds = graph_embeds.detach()
267
+
268
+ # graph embedding calculation
269
+ graph_embeds = self.ln_graph(graph_embeds, graph_masks)
270
+ query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
271
+ query_output = self.Qformer.bert(
272
+ query_embeds=query_tokens,
273
+ encoder_hidden_states=graph_embeds,
274
+ encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
275
+ return_dict=True,
276
+ )
277
+ mol_tokens = self.opt_proj(query_output.last_hidden_state) # graph_num x num_query_token x D
278
+ else:
279
+ del graphs
280
+
281
+ pad_mask = text_tokens.input_ids == self.opt_tokenizer.pad_token_id
282
+ targets = text_tokens.input_ids.masked_fill(pad_mask, -100)
283
+ targets = targets.masked_fill(text_tokens.is_mol_token, -100)
284
+ targets = targets.masked_fill(text_tokens.token_type_ids == 0, -100)
285
+
286
+ inputs_embeds = self.opt_model.get_input_embeddings()(text_tokens.input_ids)
287
+ if use_gragh:
288
+ inputs_embeds[text_tokens.is_mol_token] = mol_tokens.flatten(0, 1) # graph_num x emb_dim
289
+
290
+ if self.train_restrict_tokens:
291
+ loss = self.opt_forward_v2(
292
+ inputs_embeds=inputs_embeds,
293
+ attention_mask=text_tokens.attention_mask,
294
+ labels=targets,
295
+ bad_word_ids=self.bad_words_ids,
296
+ )
297
+ else:
298
+ outputs = self.opt_model(
299
+ inputs_embeds=inputs_embeds,
300
+ attention_mask=text_tokens.attention_mask,
301
+ return_dict=True,
302
+ labels=targets,
303
+ )
304
+ loss = outputs.loss
305
+ return {"loss": loss}
306
+
307
+ def forward_abstract(self, batch, use_gragh=True):
308
+ # batch unpack
309
+ graphs, text_tokens = batch
310
+ if use_gragh:
311
+ graph_embeds, graph_masks = self.graph_encoder(graphs)
312
+ if not self.tune_gnn:
313
+ graph_embeds = graph_embeds.detach()
314
+
315
+ # graph embedding calculation
316
+ graph_embeds = self.ln_graph(graph_embeds, graph_masks)
317
+ query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
318
+ query_output = self.Qformer.bert(
319
+ query_embeds=query_tokens,
320
+ encoder_hidden_states=graph_embeds,
321
+ encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
322
+ return_dict=True,
323
+ )
324
+ mol_tokens = self.opt_proj(query_output.last_hidden_state) # graph_num x num_query_token x D
325
+ else:
326
+ del graphs
327
+
328
+ pad_mask = text_tokens.input_ids == self.opt_tokenizer.pad_token_id
329
+ targets = text_tokens.input_ids.masked_fill(pad_mask, -100)
330
+ targets = targets.masked_fill(text_tokens.is_mol_token, -100)
331
+
332
+ inputs_embeds = self.opt_model.get_input_embeddings()(text_tokens.input_ids)
333
+ if use_gragh:
334
+ inputs_embeds[text_tokens.is_mol_token] = mol_tokens.flatten(0, 1) # graph_num x emb_dim
335
+
336
+ outputs = self.opt_model(
337
+ inputs_embeds=inputs_embeds,
338
+ attention_mask=text_tokens.attention_mask,
339
+ return_dict=True,
340
+ labels=targets,
341
+ )
342
+ loss = outputs.loss
343
+ return {"loss": loss}
344
+
345
+ @torch.no_grad()
346
+ def generate(
347
+ self,
348
+ samples,
349
+ do_sample=False,
350
+ num_beams=5,
351
+ max_length=128,
352
+ min_length=1,
353
+ top_p=0.9,
354
+ repetition_penalty=1.0,
355
+ length_penalty=1.0,
356
+ num_captions=1,
357
+ temperature=1,
358
+ use_graph=True,
359
+ ):
360
+ """
361
+ Args:
362
+ samples (dict): A dictionary containing the following keys:
363
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
364
+ num_beams (int): Number of beams for beam search. 1 means no beam search.
365
+ max_length (int): The maximum length of the sequence to be generated.
366
+ min_length (int): The minimum length of the sequence to be generated.
367
+ top_p (float): The cumulative probability for nucleus sampling.
368
+ repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
369
+ num_captions (int): Number of captions to be generated for each image.
370
+ Returns:
371
+ captions (list): A list of strings of length batch_size * num_captions.
372
+ """
373
+ graphs = samples['graphs']
374
+ prompt_tokens = samples['prompt_tokens']
375
+ # prompt_lens = samples['prompt_lens']
376
+ # with self.maybe_autocast():
377
+ if use_graph:
378
+ graph_embeds, graph_masks = self.graph_encoder(graphs)
379
+ graph_embeds = self.ln_graph(graph_embeds)
380
+
381
+ query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
382
+ query_output = self.Qformer.bert(
383
+ query_embeds=query_tokens,
384
+ encoder_hidden_states=graph_embeds,
385
+ encoder_attention_mask=graph_masks,
386
+ return_dict=True,
387
+ )
388
+ mol_tokens = self.opt_proj(query_output.last_hidden_state)
389
+
390
+ prompt_embeds = self.opt_model.get_input_embeddings()(prompt_tokens.input_ids)
391
+ if use_graph:
392
+ prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(dtype=prompt_embeds.dtype)
393
+ extra_params = {}
394
+ if self.generate_restrict_tokens:
395
+ extra_params['bad_words_ids'] = self.bad_words_ids
396
+
397
+ outputs = self.opt_model.generate(
398
+ inputs_embeds=prompt_embeds,
399
+ attention_mask=prompt_tokens.attention_mask,
400
+ do_sample=do_sample,
401
+ top_p=top_p,
402
+ temperature=temperature,
403
+ num_beams=num_beams,
404
+ max_length=max_length,
405
+ min_length=min_length,
406
+ # pad_token_id=self.pad_token_id,
407
+ eos_token_id=self.eos_token_id,
408
+ repetition_penalty=repetition_penalty,
409
+ length_penalty=length_penalty,
410
+ num_return_sequences=num_captions,
411
+ # use_cache=False,
412
+ **extra_params
413
+ )
414
+ output_text = self.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
415
+
416
+ output_text = [text.strip() for text in output_text]
417
+ return output_text
model/blip2_t5.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.cuda.amp import autocast as autocast
11
+ from peft import get_peft_model, LoraConfig, TaskType, PeftModel
12
+ from lavis.models.blip2_models.blip2 import disabled_train
13
+ from model.blip2 import Blip2Base
14
+ # from model.smiles_t5_captioning
15
+ from lavis.models.blip2_models.modeling_t5 import T5ForConditionalGeneration
16
+ from transformers import AutoTokenizer, T5TokenizerFast
17
+ #, T5ForConditionalGeneration
18
+
19
+
20
+
21
+
22
+ class Blip2T5(Blip2Base):
23
+ """
24
+ BLIP2 first-stage model with Q-former and ViT.
25
+ Supported model types:
26
+ - pretrained: pretrained model with vit-g
27
+ - pretrain_vitL: pretrained model with vit-large
28
+ - coco: fintuned model on coco
29
+ Usage:
30
+ >>> from lavis.models import load_model
31
+ >>> model = load_model("blip2", "pretrain")
32
+ """
33
+ def __init__(
34
+ self,
35
+ bert_name,
36
+ gin_num_layers,
37
+ gin_hidden_dim,
38
+ gin_drop_ratio,
39
+ tune_gnn=False,
40
+ num_query_token=32,
41
+ cross_attention_freq=2,
42
+ llm_tune='freeze',
43
+ peft_dir='',
44
+ opt_model="facebook/galactica-1.3b",
45
+ prompt="",
46
+ args=None,
47
+ ):
48
+ super().__init__()
49
+ self.args = args
50
+
51
+ self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio)
52
+ self.tune_gnn = tune_gnn
53
+ if not tune_gnn:
54
+ for name, param in self.graph_encoder.named_parameters():
55
+ param.requires_grad = False
56
+ self.graph_encoder = self.graph_encoder.eval()
57
+ self.graph_encoder.train = disabled_train
58
+ logging.info("freeze graph encoder")
59
+
60
+ self.num_query_token = num_query_token
61
+ self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq)
62
+ ### remove the unused parameters
63
+ self.Qformer.cls = None
64
+ self.Qformer.bert.embeddings.word_embeddings = None
65
+ self.Qformer.bert.embeddings.position_embeddings = None
66
+ for layer in self.Qformer.bert.encoder.layer:
67
+ layer.output = None
68
+ layer.intermediate = None
69
+
70
+ # assert opt_model == 'laituan245/molt5-large'
71
+ ## initialize opt model
72
+ # self.opt_tokenizer = AutoTokenizer.from_pretrained(opt_model)
73
+ self.opt_tokenizer = T5TokenizerFast.from_pretrained(opt_model)
74
+ self.opt_tokenizer.add_tokens('<mol>') # molecule placeholder
75
+ self.mol_token = '<mol>'
76
+ self.opt_tokenizer.mol_token_id = self.opt_tokenizer("<mol>", add_special_tokens=False).input_ids[0]
77
+
78
+ self.opt_model = T5ForConditionalGeneration.from_pretrained(opt_model, torch_dtype=torch.float32)
79
+ self.opt_model.resize_token_embeddings(len(self.opt_tokenizer)) ## this will cause bug when full fine-tuning the opt model
80
+
81
+ self.llm_tune = llm_tune
82
+ if llm_tune == 'lora':
83
+ if peft_dir:
84
+ self.opt_model = PeftModel.from_pretrained(self.opt_model, peft_dir, is_trainable=True)
85
+ else:
86
+ if self.args.peft_config:
87
+ peft_config = LoraConfig(**LoraConfig.from_json_file(self.args.peft_config))
88
+ else:
89
+ peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout)
90
+ self.peft_config = peft_config
91
+ self.opt_model = get_peft_model(self.opt_model, peft_config)
92
+ self.opt_model.print_trainable_parameters()
93
+ elif llm_tune == 'freeze':
94
+ for name, param in self.opt_model.named_parameters():
95
+ param.requires_grad = False
96
+ elif llm_tune == 'full':
97
+ pass
98
+ else:
99
+ raise NotImplementedError()
100
+
101
+ ## fixme: this is different from the original BLIP2
102
+ # self.eos_token_id = self.opt_tokenizer(
103
+ # "\n", add_special_tokens=False
104
+ # ).input_ids[0]
105
+ self.eos_token_id = self.opt_tokenizer(
106
+ "</s>", add_special_tokens=False
107
+ ).input_ids[0]
108
+
109
+ self.opt_proj = nn.Linear(
110
+ self.Qformer.config.hidden_size, self.opt_model.config.hidden_size
111
+ )
112
+
113
+
114
+ def forward(self, batch):
115
+ graphs, prompt_tokens, text_tokens = batch
116
+
117
+ graph_embeds, graph_masks = self.graph_encoder(graphs)
118
+ if not self.tune_gnn:
119
+ graph_embeds = graph_embeds.detach()
120
+ graph_embeds = self.ln_graph(graph_embeds, graph_masks)
121
+ query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
122
+ query_output = self.Qformer.bert(
123
+ query_embeds=query_tokens,
124
+ encoder_hidden_states=graph_embeds,
125
+ encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
126
+ return_dict=True,
127
+ )
128
+ mol_tokens = self.opt_proj(query_output.last_hidden_state)
129
+
130
+ targets = text_tokens.input_ids.masked_fill(
131
+ text_tokens.input_ids == self.opt_tokenizer.pad_token_id, -100
132
+ )
133
+ with self.maybe_autocast(torch.float32):
134
+ prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
135
+ prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32)
136
+ outputs = self.opt_model(
137
+ inputs_embeds=prompt_embeds,
138
+ attention_mask=prompt_tokens.attention_mask,
139
+ decoder_attention_mask=text_tokens.attention_mask,
140
+ return_dict=True,
141
+ labels=targets,
142
+ )
143
+ loss = outputs.loss
144
+ return {"loss": loss}
145
+
146
+ def forward_action(self, batch, use_gragh=True):
147
+ rxn_ids, graphs, prompt_tokens, text_tokens = batch
148
+ if use_gragh:
149
+ graph_embeds, graph_masks = self.graph_encoder(graphs)
150
+ if not self.tune_gnn:
151
+ graph_embeds = graph_embeds.detach()
152
+ graph_embeds = self.ln_graph(graph_embeds, graph_masks)
153
+ query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
154
+ query_output = self.Qformer.bert(
155
+ query_embeds=query_tokens,
156
+ encoder_hidden_states=graph_embeds,
157
+ encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
158
+ return_dict=True,
159
+ )
160
+ mol_tokens = self.opt_proj(query_output.last_hidden_state)
161
+ else:
162
+ del graphs
163
+
164
+ targets = text_tokens.input_ids.masked_fill(
165
+ text_tokens.input_ids == self.opt_tokenizer.pad_token_id, -100
166
+ )
167
+ with self.maybe_autocast(torch.float32):
168
+ prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
169
+ if use_gragh:
170
+ prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32)
171
+ outputs = self.opt_model(
172
+ inputs_embeds=prompt_embeds,
173
+ attention_mask=prompt_tokens.attention_mask,
174
+ decoder_attention_mask=text_tokens.attention_mask,
175
+ return_dict=True,
176
+ labels=targets,
177
+ )
178
+ loss = outputs.loss
179
+ return {"loss": loss}
180
+
181
+
182
+ @torch.no_grad()
183
+ def generate(
184
+ self,
185
+ samples,
186
+ do_sample=False,
187
+ num_beams=5,
188
+ max_length=128,
189
+ min_length=1,
190
+ top_p=0.9,
191
+ repetition_penalty=1.0,
192
+ length_penalty=1.0,
193
+ num_captions=1,
194
+ temperature=1,
195
+ ):
196
+ """
197
+ Args:
198
+ samples (dict): A dictionary containing the following keys:
199
+ - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
200
+ num_beams (int): Number of beams for beam search. 1 means no beam search.
201
+ max_length (int): The maximum length of the sequence to be generated.
202
+ min_length (int): The minimum length of the sequence to be generated.
203
+ top_p (float): The cumulative probability for nucleus sampling.
204
+ repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
205
+ num_captions (int): Number of captions to be generated for each image.
206
+ Returns:
207
+ captions (list): A list of strings of length batch_size * num_captions.
208
+ """
209
+ graphs = samples['graphs']
210
+ prompt_tokens = samples['prompt_tokens']
211
+ graph_embeds, graph_masks = self.graph_encoder(graphs)
212
+ graph_embeds = self.ln_graph(graph_embeds)
213
+
214
+ query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
215
+ query_output = self.Qformer.bert(
216
+ query_embeds=query_tokens,
217
+ encoder_hidden_states=graph_embeds,
218
+ encoder_attention_mask=graph_masks,
219
+ return_dict=True,
220
+ )
221
+ mol_tokens = self.opt_proj(query_output.last_hidden_state)
222
+ with self.maybe_autocast(torch.float32):
223
+ prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
224
+ prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32)
225
+ # prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
226
+ # prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1)
227
+
228
+ outputs = self.opt_model.generate(
229
+ inputs_embeds=prompt_embeds,
230
+ attention_mask=prompt_tokens.attention_mask,
231
+ do_sample=do_sample,
232
+ top_p=top_p,
233
+ temperature=temperature,
234
+ num_beams=num_beams,
235
+ max_length=max_length,
236
+ min_length=min_length,
237
+ # pad_token_id=self.pad_token_id,
238
+ eos_token_id=self.eos_token_id,
239
+ repetition_penalty=repetition_penalty,
240
+ length_penalty=length_penalty,
241
+ num_return_sequences=num_captions,
242
+ # use_cache=False,
243
+ )
244
+ output_text = self.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
245
+
246
+ output_text = [text.strip() for text in output_text]
247
+ return output_text
248
+
249
+ @torch.no_grad()
250
+ def generate_action(
251
+ self,
252
+ samples,
253
+ do_sample=False,
254
+ num_beams=5,
255
+ max_length=128,
256
+ min_length=1,
257
+ top_p=0.9,
258
+ repetition_penalty=1.0,
259
+ length_penalty=1.0,
260
+ num_captions=1,
261
+ temperature=1,
262
+ use_graph=True
263
+ ):
264
+ graphs = samples['graphs']
265
+ prompt_tokens = samples['prompt_tokens']
266
+ if use_graph:
267
+ graph_embeds, graph_masks = self.graph_encoder(graphs)
268
+ graph_embeds = self.ln_graph(graph_embeds)
269
+
270
+ query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
271
+ query_output = self.Qformer.bert(
272
+ query_embeds=query_tokens,
273
+ encoder_hidden_states=graph_embeds,
274
+ encoder_attention_mask=graph_masks,
275
+ return_dict=True,
276
+ )
277
+ mol_tokens = self.opt_proj(query_output.last_hidden_state)
278
+
279
+ with self.maybe_autocast(torch.float32):
280
+ prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
281
+ if use_graph:
282
+ prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32)
283
+ # prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids)
284
+ # prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1)
285
+
286
+ outputs = self.opt_model.generate(
287
+ inputs_embeds=prompt_embeds,
288
+ attention_mask=prompt_tokens.attention_mask,
289
+ do_sample=do_sample,
290
+ top_p=top_p,
291
+ temperature=temperature,
292
+ num_beams=num_beams,
293
+ max_length=max_length,
294
+ min_length=min_length,
295
+ # pad_token_id=self.pad_token_id,
296
+ eos_token_id=self.eos_token_id,
297
+ repetition_penalty=repetition_penalty,
298
+ length_penalty=length_penalty,
299
+ num_return_sequences=num_captions,
300
+ # use_cache=False,
301
+ )
302
+ output_text = self.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
303
+
304
+ output_text = [text.strip() for text in output_text]
305
+ return output_text
model/blip2qformer.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import logging
8
+ import os
9
+ import torch
10
+ import torch.distributed as dist
11
+ import torch.nn as nn
12
+ from torch.cuda.amp import autocast as autocast
13
+ from torch.nn import functional as F
14
+
15
+ # from lavis.common.registry import registry
16
+ # from lavis.models.base_model import all_gather_with_grad, concat_all_gather
17
+ from lavis.models.blip2_models.blip2 import (
18
+ disabled_train,
19
+ )
20
+ from lavis.models.blip_models.blip_outputs import BlipOutput
21
+ from lavis.common.dist_utils import is_dist_avail_and_initialized
22
+ from model.blip2 import Blip2Base
23
+ from pytorch_lightning.utilities import distributed
24
+
25
+ @torch.no_grad()
26
+ def concat_all_gather(tensor):
27
+ """
28
+ Performs all_gather operation on the provided tensors.
29
+ *** Warning ***: torch.distributed.all_gather has no gradient.
30
+ """
31
+ # if use distributed training
32
+ if not is_dist_avail_and_initialized():
33
+ return tensor
34
+
35
+ tensors_gather = [
36
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
37
+ ]
38
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
39
+
40
+ output = torch.cat(tensors_gather, dim=0)
41
+ print('running here')
42
+ return output
43
+
44
+ @torch.no_grad()
45
+ def pl_concat_all_gather(tensor):
46
+ """
47
+ Performs all_gather operation on the provided tensors.
48
+ *** Warning ***: torch.distributed.all_gather has no gradient.
49
+ """
50
+ # if use distributed training
51
+ if not is_dist_avail_and_initialized():
52
+ return tensor
53
+
54
+ tensors_gather = distributed.gather_all_tensors(tensor)
55
+ output = torch.cat(tensors_gather, dim=0)
56
+ return output
57
+
58
+
59
+ # @registry.register_model("blip2")
60
+ # @registry.register_model("blip2_feature_extractor")
61
+ class Blip2Qformer(Blip2Base):
62
+ """
63
+ BLIP2 first-stage model with Q-former and ViT.
64
+ Supported model types:
65
+ - pretrained: pretrained model with vit-g
66
+ - pretrain_vitL: pretrained model with vit-large
67
+ - coco: fintuned model on coco
68
+ Usage:
69
+ >>> from lavis.models import load_model
70
+ >>> model = load_model("blip2", "pretrain")
71
+ """
72
+ def __init__(
73
+ self,
74
+ gtm,
75
+ lm,
76
+ bert_name,
77
+ temperature,
78
+ gin_num_layers,
79
+ gin_hidden_dim,
80
+ gin_drop_ratio,
81
+ tune_gnn=False,
82
+ num_query_token=32,
83
+ cross_attention_freq=2,
84
+ embed_dim=256,
85
+ ):
86
+ super().__init__()
87
+ self.gtm = gtm
88
+ self.lm = lm
89
+
90
+ self.tokenizer = self.init_tokenizer()
91
+
92
+ self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio)
93
+ self.tune_gnn = tune_gnn
94
+ if not tune_gnn:
95
+ for name, param in self.graph_encoder.named_parameters():
96
+ param.requires_grad = False
97
+ self.graph_encoder = self.graph_encoder.eval()
98
+ self.graph_encoder.train = disabled_train
99
+ logging.info("freeze graph encoder")
100
+
101
+ self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq)
102
+ self.Qformer.resize_token_embeddings(len(self.tokenizer))
103
+ state_dict = self.Qformer.state_dict()
104
+ for name, param in self.Qformer.named_parameters():
105
+ if "_query" in name:
106
+ key_orig = name.replace("_query", "")
107
+ param.data.copy_(state_dict[key_orig])
108
+
109
+ self.graph_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
110
+ self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
111
+
112
+ self.gtm_head = nn.Linear(self.Qformer.config.hidden_size, 2)
113
+
114
+ self.temperature = temperature
115
+
116
+
117
+ def contrast(self, features_graph, features_text, return_sim=False):
118
+ '''
119
+ features_graph: shape = [B, num_qs, D]
120
+ features_text: shape = [B, D]
121
+ '''
122
+ batch_size = features_graph.size(0)
123
+
124
+ # normalized features
125
+ features_graph = F.normalize(features_graph, dim=-1)
126
+ features_text = F.normalize(features_text, dim=-1)
127
+
128
+ # cosine similarity as logits
129
+ sim_q2t = (features_graph.unsqueeze(1) @ features_text.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B, D, 1]; output shape = [B, B, num_qs]
130
+ sim_g2t, _ = sim_q2t.max(-1) # shape = [B, B]
131
+
132
+ logits_per_graph = sim_g2t / self.temperature
133
+ logits_per_text = logits_per_graph.t()
134
+
135
+ labels = torch.arange(batch_size, dtype=torch.long, device=self.device) # 大小为B
136
+ loss_graph = F.cross_entropy(logits_per_graph, labels)
137
+ loss_text = F.cross_entropy(logits_per_text, labels)
138
+ loss = (loss_graph + loss_text) / 2
139
+
140
+ if return_sim:
141
+ return logits_per_graph, logits_per_text, loss
142
+ else:
143
+ return loss
144
+
145
+ def contrast_global(self, features_graph, features_text, features_graph_all, features_text_all, return_sim=False):
146
+ '''
147
+ features_graph: shape = [B, num_qs, D]
148
+ features_text: shape = [B, D]
149
+ features_text_all: shape = [B * num_gpus, D]
150
+ features_graph_all: shape = [B * num_gpus, num_qs, D]
151
+ '''
152
+ bs = features_graph.size(0)
153
+
154
+ # cosine similarity as logits
155
+ sim_q2t = (features_graph.unsqueeze(1) @ features_text_all.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B * num_gpus, D, 1]; output shape = [B, B * num_gpus, num_qs]
156
+ sim_g2t, _ = sim_q2t.max(-1) # shape = [B, B * num_gpus]
157
+
158
+ logits_per_graph = sim_g2t / self.temperature
159
+
160
+
161
+ sim_t2q = (features_text.unsqueeze(1).unsqueeze(1) @ features_graph_all.permute(0, 2, 1)).squeeze() # shape = [B, 1, 1, D]; [B*num_gpus, D, num_qs]; output shape = [B, B*num_gpus, 1, num_qs]
162
+ sim_t2g, _ = sim_t2q.max(-1)
163
+ logits_per_text = sim_t2g / self.temperature
164
+
165
+ # labels = torch.arange(bs, dtype=torch.long, device=self.device)
166
+ rank = dist.get_rank()
167
+ labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(self.device)
168
+
169
+ loss_graph = F.cross_entropy(logits_per_graph, labels)
170
+ loss_text = F.cross_entropy(logits_per_text, labels)
171
+ loss = (loss_graph + loss_text) / 2
172
+
173
+ if return_sim:
174
+ return logits_per_graph[:, rank*bs:rank*bs+bs], logits_per_text[:, rank*bs:rank*bs+bs], loss
175
+ else:
176
+ return loss
177
+
178
+ def forward_old(self, batch):
179
+ ## v1: not gather results from all gpus
180
+ ###============== Image-text Contrastive ===================###
181
+ graph, text, mask = batch
182
+ batch_node, batch_mask = self.graph_encoder(graph)
183
+ batch_node = batch_node.detach()
184
+ batch_size = batch_node.shape[0]
185
+
186
+ batch_node = self.ln_graph(batch_node, batch_mask)
187
+ query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1)
188
+ query_output = self.Qformer.bert(
189
+ query_embeds=query_tokens,
190
+ encoder_hidden_states=batch_node,
191
+ encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct
192
+ use_cache=True,
193
+ return_dict=True,
194
+ )
195
+ graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D]
196
+ text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D]
197
+ text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :])
198
+ sim_g2t, sim_t2g, loss_gtc = self.contrast(graph_feats, text_feats, return_sim=True)
199
+
200
+
201
+ ###============== Image-text Matching ===================###
202
+ loss_gtm = 0
203
+ if self.gtm:
204
+ g_emb = batch_node
205
+ g_mask = batch_mask
206
+ text_ids = text.clone()
207
+ with torch.no_grad():
208
+ weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4
209
+ weights_t2g.fill_diagonal_(0)
210
+ weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4
211
+ weights_g2t.fill_diagonal_(0)
212
+
213
+ # select a negative graph for each text
214
+ graph_embeds_neg = []
215
+ graph_mask_neg = []
216
+ for b in range(batch_size):
217
+ neg_idx = torch.multinomial(weights_t2g[b], 1).item()
218
+ graph_embeds_neg.append(g_emb[neg_idx])
219
+ graph_mask_neg.append(g_mask[neg_idx])
220
+
221
+ graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0)
222
+ graph_mask_neg = torch.stack(graph_mask_neg, dim=0)
223
+
224
+ # select a negative text for each image
225
+ text_ids_neg = []
226
+ text_atts_neg = []
227
+ for b in range(batch_size):
228
+ neg_idx = torch.multinomial(weights_g2t[b], 1).item()
229
+ text_ids_neg.append(text_ids[neg_idx])
230
+ text_atts_neg.append(mask[neg_idx])
231
+
232
+ text_ids_neg = torch.stack(text_ids_neg, dim=0)
233
+ text_atts_neg = torch.stack(text_atts_neg, dim=0)
234
+
235
+ text_ids_all = torch.cat(
236
+ [text_ids, text_ids, text_ids_neg], dim=0
237
+ ) # pos, pos, neg
238
+ text_atts_all = torch.cat(
239
+ [mask, mask, text_atts_neg],
240
+ dim=0,
241
+ )
242
+
243
+ query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
244
+ query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text.device)
245
+ attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
246
+
247
+ graph_embeds_all = torch.cat([g_emb, graph_embeds_neg, g_emb], dim=0) # pos, neg, pos
248
+ graph_atts_all = torch.cat([g_mask, graph_mask_neg, g_mask], dim=0)
249
+
250
+ output_itm = self.Qformer.bert(
251
+ text_ids_all,
252
+ query_embeds=query_tokens_itm,
253
+ attention_mask=attention_mask_all,
254
+ encoder_hidden_states=graph_embeds_all,
255
+ encoder_attention_mask=graph_atts_all,
256
+ return_dict=True,
257
+ )
258
+
259
+ vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only
260
+ vl_output = self.gtm_head(vl_embeddings)
261
+ logits = vl_output.mean(dim=1)
262
+
263
+ itm_labels = torch.cat(
264
+ [torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)],
265
+ dim=0,
266
+ ).to(text.device)
267
+ loss_gtm = F.cross_entropy(logits, itm_labels)
268
+
269
+ ##================= Image Captioning ========================##
270
+ loss_lm = 0
271
+ if self.lm:
272
+ decoder_input_ids = text.clone()
273
+ decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
274
+ labels = decoder_input_ids.masked_fill(
275
+ decoder_input_ids == self.tokenizer.pad_token_id, -100
276
+ )
277
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text.device)
278
+
279
+ attention_mask = torch.cat([query_atts, mask], dim=1)
280
+ lm_output = self.Qformer(
281
+ decoder_input_ids,
282
+ attention_mask=attention_mask,
283
+ past_key_values=query_output.past_key_values,
284
+ return_dict=True,
285
+ labels=labels,
286
+ )
287
+
288
+ loss_lm = lm_output.loss
289
+
290
+ return BlipOutput(
291
+ loss=loss_gtc + loss_gtm + loss_lm,
292
+ loss_itc=loss_gtc,
293
+ loss_itm=loss_gtm,
294
+ loss_lm=loss_lm,
295
+ )
296
+
297
+
298
+ def forward(self, batch):
299
+ ## v2: gather results from all gpus
300
+ ###============== Image-text Contrastive ===================###
301
+ graph, text, mask = batch
302
+ batch_node, batch_mask = self.graph_encoder(graph)
303
+ if not self.tune_gnn:
304
+ batch_node = batch_node.detach()
305
+ batch_size = batch_node.shape[0]
306
+
307
+ batch_node = self.ln_graph(batch_node, batch_mask)
308
+ query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1)
309
+ query_output = self.Qformer.bert(
310
+ query_embeds=query_tokens,
311
+ encoder_hidden_states=batch_node,
312
+ encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct
313
+ use_cache=True,
314
+ return_dict=True,
315
+ )
316
+ graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D]
317
+ text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D]
318
+ text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :])
319
+
320
+ text_feats, graph_feats = F.normalize(text_feats, p=2, dim=-1), F.normalize(graph_feats, p=2, dim=-1)
321
+ text_feats_all, graph_feats_all = pl_concat_all_gather(text_feats), pl_concat_all_gather(graph_feats) # shape = [B * num_gpus, D]
322
+ sim_g2t, sim_t2g, loss_gtc = self.contrast_global(graph_feats, text_feats, graph_feats_all, text_feats_all, return_sim=True)
323
+
324
+
325
+ ###============== Image-text Matching ===================###
326
+ loss_gtm = 0
327
+ if self.gtm:
328
+ ## not aggregate global tensor because of their different shapes
329
+ g_emb_world = batch_node
330
+ g_mask_world = batch_mask
331
+ text_ids_world = text
332
+ text_mask_world = mask
333
+ with torch.no_grad():
334
+ weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4
335
+ weights_t2g.fill_diagonal_(0)
336
+ weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4
337
+ weights_g2t.fill_diagonal_(0)
338
+
339
+ # select a negative graph for each text
340
+ graph_embeds_neg = []
341
+ graph_mask_neg = []
342
+ for b in range(batch_size):
343
+ neg_idx = torch.multinomial(weights_t2g[b], 1).item()
344
+ graph_embeds_neg.append(g_emb_world[neg_idx])
345
+ graph_mask_neg.append(g_mask_world[neg_idx])
346
+
347
+ graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0)
348
+ graph_mask_neg = torch.stack(graph_mask_neg, dim=0)
349
+
350
+ # select a negative text for each image
351
+ text_ids_neg = []
352
+ text_atts_neg = []
353
+ for b in range(batch_size):
354
+ neg_idx = torch.multinomial(weights_g2t[b], 1).item()
355
+ text_ids_neg.append(text_ids_world[neg_idx])
356
+ text_atts_neg.append(text_mask_world[neg_idx])
357
+
358
+ text_ids_neg = torch.stack(text_ids_neg, dim=0)
359
+ text_atts_neg = torch.stack(text_atts_neg, dim=0)
360
+
361
+ text_ids_all = torch.cat(
362
+ [text, text, text_ids_neg], dim=0
363
+ ) # pos, pos, neg
364
+ text_atts_all = torch.cat(
365
+ [mask, mask, text_atts_neg],
366
+ dim=0,
367
+ )
368
+
369
+ query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
370
+ query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text.device)
371
+ attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
372
+
373
+ graph_embeds_all = torch.cat([batch_node, graph_embeds_neg, batch_node], dim=0) # pos, neg, pos
374
+ graph_atts_all = torch.cat([batch_mask, graph_mask_neg, batch_mask], dim=0)
375
+
376
+ output_itm = self.Qformer.bert(
377
+ text_ids_all,
378
+ query_embeds=query_tokens_itm,
379
+ attention_mask=attention_mask_all,
380
+ encoder_hidden_states=graph_embeds_all,
381
+ encoder_attention_mask=graph_atts_all,
382
+ return_dict=True,
383
+ )
384
+
385
+ vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only
386
+ vl_output = self.gtm_head(vl_embeddings)
387
+ logits = vl_output.mean(dim=1)
388
+
389
+ itm_labels = torch.cat(
390
+ [torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)],
391
+ dim=0,
392
+ ).to(text.device)
393
+ loss_gtm = F.cross_entropy(logits, itm_labels)
394
+
395
+ ##================= Image Captioning ========================##
396
+ loss_lm = 0
397
+ if self.lm:
398
+ decoder_input_ids = text.clone()
399
+ decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
400
+ labels = decoder_input_ids.masked_fill(
401
+ decoder_input_ids == self.tokenizer.pad_token_id, -100
402
+ )
403
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text.device)
404
+
405
+ attention_mask = torch.cat([query_atts, mask], dim=1)
406
+ lm_output = self.Qformer(
407
+ decoder_input_ids,
408
+ attention_mask=attention_mask,
409
+ past_key_values=query_output.past_key_values,
410
+ return_dict=True,
411
+ labels=labels,
412
+ )
413
+
414
+ loss_lm = lm_output.loss
415
+
416
+ return BlipOutput(
417
+ loss=loss_gtc + loss_gtm + loss_lm,
418
+ loss_itc=loss_gtc,
419
+ loss_itm=loss_gtm,
420
+ loss_lm=loss_lm,
421
+ )
422
+
423
+ def forward_v3(self, batch):
424
+ ## v3: use smiles instruction
425
+ ###============== Image-text Contrastive ===================###
426
+ graphs, text_tokens, prompt_tokens = batch
427
+ graph_embeds, graph_masks = self.graph_encoder(graphs)
428
+ if not self.tune_gnn:
429
+ graph_embeds = graph_embeds.detach()
430
+ graph_embeds = self.ln_graph(graph_embeds, graph_masks)
431
+
432
+ device = text_tokens.input_ids.device
433
+ batch_size = graph_embeds.shape[0]
434
+
435
+ ##
436
+ query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
437
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=device)
438
+ attention_mask_gtc = torch.cat([query_atts, prompt_tokens.attention_mask], dim=1)
439
+ query_output = self.Qformer.bert(
440
+ input_ids=prompt_tokens,
441
+ query_embeds=query_tokens,
442
+ attention_mask=attention_mask_gtc,
443
+ encoder_hidden_states=graph_embeds,
444
+ encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
445
+ use_cache=True,
446
+ return_dict=True,
447
+ )
448
+
449
+ query_output = query_output.last_hidden_state[:, : query_tokens.size(1), :] # keep query tokens only
450
+ graph_feats = self.graph_proj(query_output) # shape = [B, num_q, D]
451
+ text_output = self.Qformer.bert(text_tokens.input_ids, attention_mask=text_tokens.attention_mask, return_dict=True) # shape = [B, n_max, D]
452
+ text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :])
453
+
454
+ text_feats, graph_feats = F.normalize(text_feats, p=2, dim=-1), F.normalize(graph_feats, p=2, dim=-1)
455
+ text_feats_all, graph_feats_all = pl_concat_all_gather(text_feats), pl_concat_all_gather(graph_feats) # shape = [B * num_gpus, D]
456
+ sim_g2t, sim_t2g, loss_gtc = self.contrast_global(graph_feats, text_feats, graph_feats_all, text_feats_all, return_sim=True)
457
+
458
+
459
+ ###============== Image-text Matching ===================###
460
+ loss_gtm = 0
461
+ if self.gtm:
462
+ ## not aggregate global tensor because of their different shapes
463
+ g_emb_world = graph_embeds
464
+ g_mask_world = graph_masks
465
+ text_ids_world = text_tokens.input_ids
466
+ text_mask_world = text_tokens.attention_mask
467
+ with torch.no_grad():
468
+ weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4
469
+ weights_t2g.fill_diagonal_(0)
470
+ weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4
471
+ weights_g2t.fill_diagonal_(0)
472
+
473
+ # select a negative graph for each text
474
+ graph_embeds_neg = []
475
+ graph_mask_neg = []
476
+ for b in range(batch_size):
477
+ neg_idx = torch.multinomial(weights_t2g[b], 1).item()
478
+ graph_embeds_neg.append(g_emb_world[neg_idx])
479
+ graph_mask_neg.append(g_mask_world[neg_idx])
480
+
481
+ graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0)
482
+ graph_mask_neg = torch.stack(graph_mask_neg, dim=0)
483
+
484
+ # select a negative text for each image
485
+ text_ids_neg = []
486
+ text_atts_neg = []
487
+ for b in range(batch_size):
488
+ neg_idx = torch.multinomial(weights_g2t[b], 1).item()
489
+ text_ids_neg.append(text_ids_world[neg_idx])
490
+ text_atts_neg.append(text_mask_world[neg_idx])
491
+
492
+ text_ids_neg = torch.stack(text_ids_neg, dim=0)
493
+ text_atts_neg = torch.stack(text_atts_neg, dim=0)
494
+
495
+ text_ids_all = torch.cat(
496
+ [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
497
+ ) # pos, pos, neg
498
+ text_atts_all = torch.cat(
499
+ [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
500
+ dim=0,
501
+ )
502
+
503
+ query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
504
+ query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text_tokens.input_ids.device)
505
+ attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
506
+
507
+ graph_embeds_all = torch.cat([graph_embeds, graph_embeds_neg, graph_embeds], dim=0) # pos, neg, pos
508
+ graph_atts_all = torch.cat([graph_masks, graph_mask_neg, graph_masks], dim=0)
509
+
510
+ output_itm = self.Qformer.bert(
511
+ text_ids_all,
512
+ query_embeds=query_tokens_itm,
513
+ attention_mask=attention_mask_all,
514
+ encoder_hidden_states=graph_embeds_all,
515
+ encoder_attention_mask=graph_atts_all,
516
+ return_dict=True,
517
+ )
518
+
519
+ vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only
520
+ vl_output = self.gtm_head(vl_embeddings)
521
+ logits = vl_output.mean(dim=1)
522
+
523
+ itm_labels = torch.cat(
524
+ [torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)],
525
+ dim=0,
526
+ ).to(text_tokens.input_ids.device)
527
+ loss_gtm = F.cross_entropy(logits, itm_labels)
528
+
529
+ ##================= Image Captioning ========================##
530
+ loss_lm = 0
531
+ if self.lm:
532
+ decoder_input_ids = text_tokens.input_ids.clone()
533
+ decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
534
+ labels = decoder_input_ids.masked_fill(
535
+ decoder_input_ids == self.tokenizer.pad_token_id, -100
536
+ )
537
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text_tokens.input_ids.device)
538
+
539
+ attention_mask = torch.cat([query_atts, prompt_tokens.attention_mask, text_tokens.attention_mask], dim=1)
540
+ lm_output = self.Qformer(
541
+ decoder_input_ids,
542
+ attention_mask=attention_mask,
543
+ past_key_values=query_output.past_key_values,
544
+ return_dict=True,
545
+ labels=labels,
546
+ )
547
+
548
+ loss_lm = lm_output.loss
549
+
550
+ return BlipOutput(
551
+ loss=loss_gtc + loss_gtm + loss_lm,
552
+ loss_itc=loss_gtc,
553
+ loss_itm=loss_gtm,
554
+ loss_lm=loss_lm,
555
+ )
556
+
557
+ def graph_forward(self, graph):
558
+ batch_node, batch_mask = self.graph_encoder(graph)
559
+ batch_node = self.ln_graph(batch_node, batch_mask)
560
+ query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1)
561
+ query_output = self.Qformer.bert(
562
+ query_embeds=query_tokens,
563
+ encoder_hidden_states=batch_node,
564
+ encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct
565
+ use_cache=False,
566
+ return_dict=True,
567
+ )
568
+ graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D]
569
+ graph_feats = F.normalize(graph_feats, p=2, dim=-1)
570
+ return graph_feats, batch_node, batch_mask
571
+
572
+ def text_forward(self, text, mask):
573
+ text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D]
574
+ text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :] )
575
+ text_feats = F.normalize(text_feats, dim=-1, p=2)
576
+ return text_feats
577
+
578
+ def compute_gtm(self, batch_node, batch_mask, text_ids, text_atts):
579
+ '''
580
+ batch_node shape = [B, N, D]
581
+ batch_mask shape = [B, N]
582
+ text_ids shape = [B, N]
583
+ text_atts shape = [B, N]
584
+ '''
585
+ query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1) # shape = [B, Nq, D]
586
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
587
+ batch_node.device
588
+ ) # shape = [B, Nq]
589
+ attention_mask = torch.cat([query_atts, text_atts], dim=1) # shape = [B, Nq + N]
590
+ output_gtm = self.Qformer.bert(
591
+ text_ids,
592
+ query_embeds=query_tokens,
593
+ attention_mask=attention_mask,
594
+ encoder_hidden_states=batch_node,
595
+ encoder_attention_mask=batch_mask,
596
+ return_dict=True,
597
+ )
598
+ gl_embeddings = output_gtm.last_hidden_state[:, : query_tokens.size(1), :] # shape = [B, Nq, D]
599
+ gtm_logit = self.gtm_head(gl_embeddings).mean(dim=1) # shape = [B, Nq, 2]
600
+ # gtm_logit = F.softmax(gtm_logit, dim=-1)[:, 1] # select the axis of the positive class
601
+ gtm_logit = gtm_logit[:, 1] # select the axis of the positive class
602
+ return gtm_logit
603
+
model/dist_funs.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union, Dict
3
+ from pytorch_lightning import strategies
4
+ from lightning_fabric.utilities.types import _PATH
5
+ from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict
6
+
7
+
8
+ '''
9
+ overwrite the function in deepspeed
10
+ '''
11
+
12
+ ### start overwrite ###
13
+ def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False):
14
+ sd = self.module.state_dict(destination, prefix, keep_vars)
15
+ # Remove frozen parameter weights from state_dict if specified
16
+ if exclude_frozen_parameters:
17
+ to_be_removed = []
18
+ for n in sd:
19
+ try:
20
+ if not self.module.get_parameter(n).requires_grad:
21
+ to_be_removed.append(n)
22
+ except AttributeError:
23
+ to_be_removed.append(n)
24
+ for key in to_be_removed:
25
+ sd.pop(key)
26
+ if self.random_ltd_enabled():
27
+ sd = remove_random_ltd_state_dict(sd)
28
+ return sd
29
+ from deepspeed import DeepSpeedEngine
30
+ DeepSpeedEngine.module_state_dict = module_state_dict
31
+ ### end overwrite ###
32
+
33
+ class MyDeepSpeedStrategy(strategies.DeepSpeedStrategy):
34
+ def save_checkpoint_v1(
35
+ self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
36
+ ):
37
+ """Save model/training states as a checkpoint file through state-dump and file-write.
38
+
39
+ Args:
40
+ checkpoint: dict containing model and trainer state
41
+ filepath: write-target file's path
42
+ storage_options: parameter for how to save to st
43
+ orage, passed to ``CheckpointIO`` plugin
44
+ """
45
+ if self.is_global_zero:
46
+ self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
47
+
48
+ def load_model_state_dict(self, checkpoint):
49
+ assert self.lightning_module is not None
50
+ self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=False)
51
+ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None:
52
+ """Save model/training states as a checkpoint file through state-dump and file-write.
53
+
54
+ Args:
55
+ checkpoint: The checkpoint state dictionary
56
+ filepath: write-target file's path
57
+ storage_options: not used for ``DeepSpeedStrategy`` as ``CheckpointIO`` is not used
58
+
59
+ Raises:
60
+ TypeError:
61
+ If ``storage_options`` arg is passed in
62
+ """
63
+ # broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath
64
+ filepath = self.broadcast(filepath)
65
+ if storage_options is not None:
66
+ raise TypeError(
67
+ "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
68
+ f" is not supported for `{self.__class__.__name__}` as `CheckpointIO` is not used."
69
+ )
70
+
71
+ if self.zero_stage_3 and self._multi_device and self.is_global_zero:
72
+ print(
73
+ "Warning: When saving the DeepSpeed Stage 3 checkpoint, "
74
+ "each worker will save a shard of the checkpoint within a directory. "
75
+ "If a single file is required after training, "
76
+ "see https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#"
77
+ "deepspeed-zero-stage-3-single-file for instructions."
78
+ )
79
+ # Use deepspeed's internal checkpointing function to handle partitioned weights across processes
80
+ # dump states as a checkpoint dictionary object
81
+ _exclude_keys = ["state_dict", "optimizer_states"]
82
+ checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
83
+ self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint", exclude_frozen_parameters=True)
model/gin_model.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.nn import MessagePassing
3
+ from torch_geometric.utils import add_self_loops, degree, softmax, to_dense_batch
4
+ from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
5
+ import torch.nn.functional as F
6
+ # from torch_scatter import scatter_add
7
+ from torch_geometric.nn.inits import glorot, zeros
8
+
9
+ num_atom_type = 120 #including the extra mask tokens
10
+ num_chirality_tag = 3
11
+
12
+ num_bond_type = 6 #including aromatic and self-loop edge, and extra masked tokens
13
+ num_bond_direction = 3
14
+
15
+ class GINConv(MessagePassing):
16
+ """
17
+ Extension of GIN aggregation to incorporate edge information by concatenation.
18
+
19
+ Args:
20
+ emb_dim (int): dimensionality of embeddings for nodes and edges.
21
+ embed_input (bool): whether to embed input or not.
22
+
23
+
24
+ See https://arxiv.org/abs/1810.00826
25
+ """
26
+ def __init__(self, emb_dim, aggr = "add"):
27
+ super(GINConv, self).__init__(aggr = "add")
28
+ #multi-layer perceptron
29
+ self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
30
+ self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
31
+ self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
32
+
33
+ torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
34
+ torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
35
+ self.aggr = aggr
36
+
37
+ def forward(self, x, edge_index, edge_attr):
38
+ #add self loops in the edge space
39
+ # print('--------------------')
40
+ # print('x:', x.shape)
41
+ # print('edge_index:',edge_index.shape)
42
+ edge_index, edge_attr = add_self_loops(edge_index, edge_attr, fill_value=0, num_nodes = x.size(0))
43
+
44
+
45
+ #add features corresponding to self-loop edges.
46
+ # self_loop_attr = torch.zeros(x.size(0), 2)
47
+ # self_loop_attr[:,0] = 4 #bond type for self-loop edge
48
+ # self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
49
+ # print('edge_attr:',edge_attr.shape)
50
+ # print('self_loop_attr:',self_loop_attr.shape)
51
+ # print('--------------------')
52
+ # edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
53
+
54
+ edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
55
+
56
+ return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
57
+
58
+ def message(self, x_j, edge_attr):
59
+ return x_j + edge_attr
60
+
61
+ def update(self, aggr_out):
62
+ return self.mlp(aggr_out)
63
+
64
+
65
+ class GCNConv(MessagePassing):
66
+
67
+ def __init__(self, emb_dim, aggr = "add"):
68
+ super(GCNConv, self).__init__()
69
+
70
+ self.emb_dim = emb_dim
71
+ self.linear = torch.nn.Linear(emb_dim, emb_dim)
72
+ self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
73
+ self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
74
+
75
+ torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
76
+ torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
77
+
78
+ self.aggr = aggr
79
+
80
+ def norm(self, edge_index, num_nodes, dtype):
81
+ ### assuming that self-loops have been already added in edge_index
82
+ edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
83
+ device=edge_index.device)
84
+ row, col = edge_index
85
+ deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
86
+ deg_inv_sqrt = deg.pow(-0.5)
87
+ deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
88
+
89
+ return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
90
+
91
+
92
+ def forward(self, x, edge_index, edge_attr):
93
+ #add self loops in the edge space
94
+ edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
95
+
96
+ #add features corresponding to self-loop edges.
97
+ self_loop_attr = torch.zeros(x.size(0), 2)
98
+ self_loop_attr[:,0] = 4 #bond type for self-loop edge
99
+ self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
100
+ edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
101
+
102
+ edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
103
+
104
+ norm = self.norm(edge_index, x.size(0), x.dtype)
105
+
106
+ x = self.linear(x)
107
+
108
+ return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings, norm = norm)
109
+
110
+ def message(self, x_j, edge_attr, norm):
111
+ return norm.view(-1, 1) * (x_j + edge_attr)
112
+
113
+
114
+ class GATConv(MessagePassing):
115
+ def __init__(self, emb_dim, heads=2, negative_slope=0.2, aggr = "add"):
116
+ super(GATConv, self).__init__()
117
+
118
+ self.aggr = aggr
119
+
120
+ self.emb_dim = emb_dim
121
+ self.heads = heads
122
+ self.negative_slope = negative_slope
123
+
124
+ self.weight_linear = torch.nn.Linear(emb_dim, heads * emb_dim)
125
+ self.att = torch.nn.Parameter(torch.Tensor(1, heads, 2 * emb_dim))
126
+
127
+ self.bias = torch.nn.Parameter(torch.Tensor(emb_dim))
128
+
129
+ self.edge_embedding1 = torch.nn.Embedding(num_bond_type, heads * emb_dim)
130
+ self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, heads * emb_dim)
131
+
132
+ torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
133
+ torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
134
+
135
+ self.reset_parameters()
136
+
137
+ def reset_parameters(self):
138
+ glorot(self.att)
139
+ zeros(self.bias)
140
+
141
+ def forward(self, x, edge_index, edge_attr):
142
+
143
+ #add self loops in the edge space
144
+ edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
145
+
146
+ #add features corresponding to self-loop edges.
147
+ self_loop_attr = torch.zeros(x.size(0), 2)
148
+ self_loop_attr[:,0] = 4 #bond type for self-loop edge
149
+ self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
150
+ edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
151
+
152
+ edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
153
+
154
+ x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)
155
+ return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings)
156
+
157
+ def message(self, edge_index, x_i, x_j, edge_attr):
158
+ edge_attr = edge_attr.view(-1, self.heads, self.emb_dim)
159
+ x_j += edge_attr
160
+
161
+ alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
162
+
163
+ alpha = F.leaky_relu(alpha, self.negative_slope)
164
+ alpha = softmax(alpha, edge_index[0])
165
+
166
+ return x_j * alpha.view(-1, self.heads, 1)
167
+
168
+ def update(self, aggr_out):
169
+ aggr_out = aggr_out.mean(dim=1)
170
+ aggr_out = aggr_out + self.bias
171
+
172
+ return aggr_out
173
+
174
+
175
+ class GraphSAGEConv(MessagePassing):
176
+ def __init__(self, emb_dim, aggr = "mean"):
177
+ super(GraphSAGEConv, self).__init__()
178
+
179
+ self.emb_dim = emb_dim
180
+ self.linear = torch.nn.Linear(emb_dim, emb_dim)
181
+ self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
182
+ self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
183
+
184
+ torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
185
+ torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
186
+
187
+ self.aggr = aggr
188
+
189
+ def forward(self, x, edge_index, edge_attr):
190
+ #add self loops in the edge space
191
+ edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
192
+
193
+ #add features corresponding to self-loop edges.
194
+ self_loop_attr = torch.zeros(x.size(0), 2)
195
+ self_loop_attr[:,0] = 4 #bond type for self-loop edge
196
+ self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
197
+ edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
198
+
199
+ edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
200
+
201
+ x = self.linear(x)
202
+
203
+ return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings)
204
+
205
+ def message(self, x_j, edge_attr):
206
+ return x_j + edge_attr
207
+
208
+ def update(self, aggr_out):
209
+ return F.normalize(aggr_out, p = 2, dim = -1)
210
+
211
+
212
+
213
+ class GNN(torch.nn.Module):
214
+ """
215
+
216
+
217
+ Args:
218
+ num_layer (int): the number of GNN layers
219
+ emb_dim (int): dimensionality of embeddings
220
+ JK (str): last, concat, max or sum.
221
+ max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
222
+ drop_ratio (float): dropout rate
223
+ gnn_type: gin, gcn, graphsage, gat
224
+
225
+ Output:
226
+ node representations
227
+
228
+ """
229
+ def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, gnn_type = "gin"):
230
+ super(GNN, self).__init__()
231
+ self.num_layer = num_layer
232
+ self.drop_ratio = drop_ratio
233
+ self.JK = JK
234
+
235
+ if self.num_layer < 2:
236
+ raise ValueError("Number of GNN layers must be greater than 1.")
237
+
238
+ self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
239
+ self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)
240
+
241
+ torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
242
+ torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)
243
+
244
+ ###List of MLPs
245
+ self.gnns = torch.nn.ModuleList()
246
+ for layer in range(num_layer):
247
+ if gnn_type == "gin":
248
+ self.gnns.append(GINConv(emb_dim, aggr = "add"))
249
+ elif gnn_type == "gcn":
250
+ self.gnns.append(GCNConv(emb_dim))
251
+ elif gnn_type == "gat":
252
+ self.gnns.append(GATConv(emb_dim))
253
+ elif gnn_type == "graphsage":
254
+ self.gnns.append(GraphSAGEConv(emb_dim))
255
+
256
+ self.pool = global_mean_pool
257
+
258
+ ###List of batchnorms
259
+ self.batch_norms = torch.nn.ModuleList()
260
+ for layer in range(num_layer):
261
+ self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
262
+ self.num_features = emb_dim
263
+ self.cat_grep = True
264
+
265
+ #def forward(self, x, edge_index, edge_attr):
266
+ def forward(self, *argv):
267
+ if len(argv) == 3:
268
+ x, edge_index, edge_attr = argv[0], argv[1], argv[2]
269
+ elif len(argv) == 1:
270
+ data = argv[0]
271
+ x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
272
+ else:
273
+ raise ValueError("unmatched number of arguments.")
274
+
275
+ x = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])
276
+
277
+ h_list = [x]
278
+ for layer in range(self.num_layer):
279
+ h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
280
+ h = self.batch_norms[layer](h)
281
+ #h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
282
+ if layer == self.num_layer - 1:
283
+ #remove relu for the last layer
284
+ h = F.dropout(h, self.drop_ratio, training = self.training)
285
+ else:
286
+ h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
287
+ h_list.append(h)
288
+
289
+ ### Different implementations of Jk-concat
290
+ if self.JK == "concat":
291
+ node_representation = torch.cat(h_list, dim = 1)
292
+ elif self.JK == "last":
293
+ node_representation = h_list[-1]
294
+ elif self.JK == "max":
295
+ h_list = [h.unsqueeze_(0) for h in h_list]
296
+ node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
297
+ elif self.JK == "sum":
298
+ h_list = [h.unsqueeze_(0) for h in h_list]
299
+ node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]
300
+
301
+
302
+ h_graph = self.pool(node_representation, batch) # shape = [B, D]
303
+ batch_node, batch_mask = to_dense_batch(node_representation, batch) # shape = [B, n_max, D],
304
+ batch_mask = batch_mask.bool()
305
+
306
+ if self.cat_grep:
307
+ batch_node = torch.cat((h_graph.unsqueeze(1), batch_node), dim=1) # shape = [B, n_max+1, D]
308
+ batch_mask = torch.cat([torch.ones((batch_mask.shape[0], 1), dtype=torch.bool, device=batch.device), batch_mask], dim=1)
309
+ return batch_node, batch_mask
310
+ else:
311
+ return batch_node, batch_mask, h_graph
312
+
313
+
314
+ class GNN_graphpred(torch.nn.Module):
315
+ """
316
+ Extension of GIN to incorporate edge information by concatenation.
317
+
318
+ Args:
319
+ num_layer (int): the number of GNN layers
320
+ emb_dim (int): dimensionality of embeddings
321
+ num_tasks (int): number of tasks in multi-task learning scenario
322
+ drop_ratio (float): dropout rate
323
+ JK (str): last, concat, max or sum.
324
+ graph_pooling (str): sum, mean, max, attention, set2set
325
+ gnn_type: gin, gcn, graphsage, gat
326
+
327
+ See https://arxiv.org/abs/1810.00826
328
+ JK-net: https://arxiv.org/abs/1806.03536
329
+ """
330
+ def __init__(self, num_layer, emb_dim, num_tasks, JK = "last", drop_ratio = 0, graph_pooling = "mean", gnn_type = "gin"):
331
+ super(GNN_graphpred, self).__init__()
332
+ self.num_layer = num_layer
333
+ self.drop_ratio = drop_ratio
334
+ self.JK = JK
335
+ self.emb_dim = emb_dim
336
+ self.num_tasks = num_tasks
337
+
338
+ if self.num_layer < 2:
339
+ raise ValueError("Number of GNN layers must be greater than 1.")
340
+
341
+ self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type = gnn_type)
342
+
343
+ #Different kind of graph pooling
344
+ if graph_pooling == "sum":
345
+ self.pool = global_add_pool
346
+ elif graph_pooling == "mean":
347
+ self.pool = global_mean_pool
348
+ elif graph_pooling == "max":
349
+ self.pool = global_max_pool
350
+ elif graph_pooling == "attention":
351
+ if self.JK == "concat":
352
+ self.pool = GlobalAttention(gate_nn = torch.nn.Linear((self.num_layer + 1) * emb_dim, 1))
353
+ else:
354
+ self.pool = GlobalAttention(gate_nn = torch.nn.Linear(emb_dim, 1))
355
+ elif graph_pooling[:-1] == "set2set":
356
+ set2set_iter = int(graph_pooling[-1])
357
+ if self.JK == "concat":
358
+ self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter)
359
+ else:
360
+ self.pool = Set2Set(emb_dim, set2set_iter)
361
+ else:
362
+ raise ValueError("Invalid graph pooling type.")
363
+
364
+ #For graph-level binary classification
365
+ if graph_pooling[:-1] == "set2set":
366
+ self.mult = 2
367
+ else:
368
+ self.mult = 1
369
+
370
+ if self.JK == "concat":
371
+ self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_tasks)
372
+ else:
373
+ self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_tasks)
374
+
375
+ def from_pretrained(self, model_file):
376
+ #self.gnn = GNN(self.num_layer, self.emb_dim, JK = self.JK, drop_ratio = self.drop_ratio)
377
+ missing_keys, unexpected_keys = self.gnn.load_state_dict(torch.load(model_file))
378
+ print(missing_keys)
379
+ print(unexpected_keys)
380
+
381
+ def forward(self, *argv):
382
+ if len(argv) == 4:
383
+ x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
384
+ elif len(argv) == 1:
385
+ data = argv[0]
386
+ x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
387
+ else:
388
+ raise ValueError("unmatched number of arguments.")
389
+
390
+ node_representation = self.gnn(x, edge_index, edge_attr)
391
+
392
+ return self.graph_pred_linear(self.pool(node_representation, batch))
393
+
394
+
395
+ if __name__ == "__main__":
396
+ pass
397
+
model/help_funcs.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nltk.translate.bleu_score import corpus_bleu
2
+ from nltk.translate.meteor_score import meteor_score
3
+ from rouge_score import rouge_scorer
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import json
7
+ from transformers import AutoTokenizer
8
+
9
+ def caption_evaluate(predictions, targets, tokenizer, text_trunc_length):
10
+ meteor_scores = []
11
+ references = []
12
+ hypotheses = []
13
+ for gt, out in tqdm(zip(targets, predictions)):
14
+ gt_tokens = tokenizer.tokenize(gt, truncation=True, max_length=text_trunc_length,
15
+ padding='max_length')
16
+ gt_tokens = list(filter(('[PAD]').__ne__, gt_tokens))
17
+ gt_tokens = list(filter(('[CLS]').__ne__, gt_tokens))
18
+ gt_tokens = list(filter(('[SEP]').__ne__, gt_tokens))
19
+
20
+ out_tokens = tokenizer.tokenize(out, truncation=True, max_length=text_trunc_length,
21
+ padding='max_length')
22
+ out_tokens = list(filter(('[PAD]').__ne__, out_tokens))
23
+ out_tokens = list(filter(('[CLS]').__ne__, out_tokens))
24
+ out_tokens = list(filter(('[SEP]').__ne__, out_tokens))
25
+
26
+ references.append([gt_tokens])
27
+ hypotheses.append(out_tokens)
28
+
29
+ mscore = meteor_score([gt_tokens], out_tokens)
30
+ meteor_scores.append(mscore)
31
+
32
+ bleu2 = corpus_bleu(references, hypotheses, weights=(.5,.5))
33
+ bleu4 = corpus_bleu(references, hypotheses, weights=(.25,.25,.25,.25))
34
+ bleu2 *= 100
35
+ bleu4 *= 100
36
+
37
+ print('BLEU-2 score:', bleu2)
38
+ print('BLEU-4 score:', bleu4)
39
+ _meteor_score = np.mean(meteor_scores)
40
+ _meteor_score *= 100
41
+ print('Average Meteor score:', _meteor_score)
42
+
43
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'])
44
+
45
+ rouge_scores = []
46
+
47
+ references = []
48
+ hypotheses = []
49
+
50
+ for gt, out in tqdm(zip(targets, predictions)):
51
+ rs = scorer.score(out, gt)
52
+ rouge_scores.append(rs)
53
+
54
+ print('ROUGE score:')
55
+ rouge_1 = np.mean([rs['rouge1'].fmeasure for rs in rouge_scores]) * 100
56
+ rouge_2 = np.mean([rs['rouge2'].fmeasure for rs in rouge_scores]) * 100
57
+ rouge_l = np.mean([rs['rougeL'].fmeasure for rs in rouge_scores]) * 100
58
+ print('rouge1:', rouge_1)
59
+ print('rouge2:', rouge_2)
60
+ print('rougeL:', rouge_l)
61
+ return bleu2, bleu4, rouge_1, rouge_2, rouge_l, _meteor_score
62
+
63
+
64
+ class AttrDict(dict):
65
+ def __init__(self, *args, **kwargs):
66
+ super(AttrDict, self).__init__(*args, **kwargs)
67
+ self.__dict__ = self
68
+
69
+ def get_tokens_as_list(tokenizer, word_list):
70
+ "Converts a sequence of words into a list of tokens"
71
+ "Source: https://huggingface.co/docs/transformers/internal/generation_utils"
72
+ tokens_list = []
73
+ for word in word_list:
74
+ tokenized_word = tokenizer([word], add_special_tokens=False).input_ids[0]
75
+ tokens_list.extend(tokenized_word)
76
+ return tokens_list
77
+
78
+ def get_not_allowed_tokens_ids(tokenizer_name, allowed_words_file='model/allowed_words.json'):
79
+ tokenizer_with_prefix_space = AutoTokenizer.from_pretrained(tokenizer_name, add_prefix_space=True)
80
+ with open(allowed_words_file, 'r') as f:
81
+ allowed_words = json.load(f)
82
+ allowed_words = list(allowed_words.values())
83
+ allowed_tokens_ids = get_tokens_as_list(tokenizer_with_prefix_space, allowed_words)
84
+ full_token_space = list(range(tokenizer_with_prefix_space.vocab_size))
85
+ not_allowed_tokens_ids = [[token_id] for token_id in full_token_space if token_id not in allowed_tokens_ids]
86
+ return not_allowed_tokens_ids
model/modeling_llama.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
+ from transformers.models.llama.configuration_llama import LlamaConfig
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ _CONFIG_FOR_DOC = "LlamaConfig"
39
+
40
+
41
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
42
+ def _make_causal_mask(
43
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
44
+ ):
45
+ """
46
+ Make causal mask used for bi-directional self-attention.
47
+ """
48
+ bsz, tgt_len = input_ids_shape
49
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
50
+ mask_cond = torch.arange(mask.size(-1), device=device)
51
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
52
+ mask = mask.to(dtype)
53
+
54
+ if past_key_values_length > 0:
55
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
56
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
57
+
58
+
59
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
60
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
61
+ """
62
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
63
+ """
64
+ bsz, src_len = mask.size()
65
+ tgt_len = tgt_len if tgt_len is not None else src_len
66
+
67
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68
+
69
+ inverted_mask = 1.0 - expanded_mask
70
+
71
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
72
+
73
+
74
+ class LlamaRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ """
77
+ LlamaRMSNorm is equivalent to T5LayerNorm
78
+ """
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.ones(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward(self, hidden_states):
84
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
85
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
86
+
87
+ # convert into half-precision if necessary
88
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
89
+ hidden_states = hidden_states.to(self.weight.dtype)
90
+
91
+ return self.weight * hidden_states
92
+
93
+
94
+ class LlamaRotaryEmbedding(torch.nn.Module):
95
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
96
+ super().__init__()
97
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
98
+ self.register_buffer("inv_freq", inv_freq)
99
+
100
+ # Build here to make `torch.jit.trace` work.
101
+ self.max_seq_len_cached = max_position_embeddings
102
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
103
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
104
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
105
+ emb = torch.cat((freqs, freqs), dim=-1)
106
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
107
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
108
+
109
+ def forward(self, x, seq_len=None):
110
+ # x: [bs, num_attention_heads, seq_len, head_size]
111
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
112
+ if seq_len > self.max_seq_len_cached:
113
+ self.max_seq_len_cached = seq_len
114
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
115
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
116
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
117
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
118
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
119
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
120
+ return (
121
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
122
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
123
+ )
124
+
125
+
126
+ def rotate_half(x):
127
+ """Rotates half the hidden dims of the input."""
128
+ x1 = x[..., : x.shape[-1] // 2]
129
+ x2 = x[..., x.shape[-1] // 2 :]
130
+ return torch.cat((-x2, x1), dim=-1)
131
+
132
+
133
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
134
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
135
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
136
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
137
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
138
+ q_embed = (q * cos) + (rotate_half(q) * sin)
139
+ k_embed = (k * cos) + (rotate_half(k) * sin)
140
+ return q_embed, k_embed
141
+
142
+
143
+ class LlamaMLP(nn.Module):
144
+ def __init__(
145
+ self,
146
+ hidden_size: int,
147
+ intermediate_size: int,
148
+ hidden_act: str,
149
+ ):
150
+ super().__init__()
151
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
152
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
153
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
154
+ self.act_fn = ACT2FN[hidden_act]
155
+
156
+ def forward(self, x):
157
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
158
+
159
+
160
+ class LlamaAttention(nn.Module):
161
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
162
+
163
+ def __init__(self, config: LlamaConfig):
164
+ super().__init__()
165
+ self.config = config
166
+ self.hidden_size = config.hidden_size
167
+ self.num_heads = config.num_attention_heads
168
+ self.head_dim = self.hidden_size // self.num_heads
169
+ self.max_position_embeddings = config.max_position_embeddings
170
+
171
+ if (self.head_dim * self.num_heads) != self.hidden_size:
172
+ raise ValueError(
173
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
174
+ f" and `num_heads`: {self.num_heads})."
175
+ )
176
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
177
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
178
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
179
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
180
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
181
+
182
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
183
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
184
+
185
+ def forward(
186
+ self,
187
+ hidden_states: torch.Tensor,
188
+ attention_mask: Optional[torch.Tensor] = None,
189
+ position_ids: Optional[torch.LongTensor] = None,
190
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
191
+ output_attentions: bool = False,
192
+ use_cache: bool = False,
193
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
194
+ bsz, q_len, _ = hidden_states.size()
195
+
196
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
197
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
198
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
199
+
200
+ kv_seq_len = key_states.shape[-2]
201
+ if past_key_value is not None:
202
+ kv_seq_len += past_key_value[0].shape[-2]
203
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
204
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
205
+ # [bsz, nh, t, hd]
206
+
207
+ if past_key_value is not None:
208
+ # reuse k, v, self_attention
209
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
210
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
211
+
212
+ past_key_value = (key_states, value_states) if use_cache else None
213
+
214
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
215
+
216
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
217
+ raise ValueError(
218
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
219
+ f" {attn_weights.size()}"
220
+ )
221
+
222
+ if attention_mask is not None:
223
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
224
+ raise ValueError(
225
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
226
+ )
227
+ attn_weights = attn_weights + attention_mask
228
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
229
+
230
+ # upcast attention to fp32
231
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
232
+ attn_output = torch.matmul(attn_weights, value_states)
233
+
234
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
235
+ raise ValueError(
236
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
237
+ f" {attn_output.size()}"
238
+ )
239
+
240
+ attn_output = attn_output.transpose(1, 2)
241
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
242
+
243
+ attn_output = self.o_proj(attn_output)
244
+
245
+ if not output_attentions:
246
+ attn_weights = None
247
+
248
+ return attn_output, attn_weights, past_key_value
249
+
250
+
251
+ class LlamaDecoderLayer(nn.Module):
252
+ def __init__(self, config: LlamaConfig):
253
+ super().__init__()
254
+ self.hidden_size = config.hidden_size
255
+ self.self_attn = LlamaAttention(config=config)
256
+ self.mlp = LlamaMLP(
257
+ hidden_size=self.hidden_size,
258
+ intermediate_size=config.intermediate_size,
259
+ hidden_act=config.hidden_act,
260
+ )
261
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
262
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ attention_mask: Optional[torch.Tensor] = None,
268
+ position_ids: Optional[torch.LongTensor] = None,
269
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
270
+ output_attentions: Optional[bool] = False,
271
+ use_cache: Optional[bool] = False,
272
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
273
+ """
274
+ Args:
275
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
276
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
277
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
278
+ output_attentions (`bool`, *optional*):
279
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
280
+ returned tensors for more detail.
281
+ use_cache (`bool`, *optional*):
282
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
283
+ (see `past_key_values`).
284
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
285
+ """
286
+
287
+ residual = hidden_states
288
+
289
+ hidden_states = self.input_layernorm(hidden_states)
290
+
291
+ # Self Attention
292
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
293
+ hidden_states=hidden_states,
294
+ attention_mask=attention_mask,
295
+ position_ids=position_ids,
296
+ past_key_value=past_key_value,
297
+ output_attentions=output_attentions,
298
+ use_cache=use_cache,
299
+ )
300
+ hidden_states = residual + hidden_states
301
+
302
+ # Fully Connected
303
+ residual = hidden_states
304
+ hidden_states = self.post_attention_layernorm(hidden_states)
305
+ hidden_states = self.mlp(hidden_states)
306
+ hidden_states = residual + hidden_states
307
+
308
+ outputs = (hidden_states,)
309
+
310
+ if output_attentions:
311
+ outputs += (self_attn_weights,)
312
+
313
+ if use_cache:
314
+ outputs += (present_key_value,)
315
+
316
+ return outputs
317
+
318
+
319
+ LLAMA_START_DOCSTRING = r"""
320
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
321
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
322
+ etc.)
323
+
324
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
325
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
326
+ and behavior.
327
+
328
+ Parameters:
329
+ config ([`LlamaConfig`]):
330
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
331
+ load the weights associated with the model, only the configuration. Check out the
332
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
333
+ """
334
+
335
+
336
+ @add_start_docstrings(
337
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
338
+ LLAMA_START_DOCSTRING,
339
+ )
340
+ class LlamaPreTrainedModel(PreTrainedModel):
341
+ config_class = LlamaConfig
342
+ base_model_prefix = "model"
343
+ supports_gradient_checkpointing = True
344
+ _no_split_modules = ["LlamaDecoderLayer"]
345
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
346
+
347
+ def _init_weights(self, module):
348
+ std = self.config.initializer_range
349
+ if isinstance(module, nn.Linear):
350
+ module.weight.data.normal_(mean=0.0, std=std)
351
+ if module.bias is not None:
352
+ module.bias.data.zero_()
353
+ elif isinstance(module, nn.Embedding):
354
+ module.weight.data.normal_(mean=0.0, std=std)
355
+ if module.padding_idx is not None:
356
+ module.weight.data[module.padding_idx].zero_()
357
+
358
+ def _set_gradient_checkpointing(self, module, value=False):
359
+ if isinstance(module, LlamaModel):
360
+ module.gradient_checkpointing = value
361
+
362
+
363
+ LLAMA_INPUTS_DOCSTRING = r"""
364
+ Args:
365
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
366
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
367
+ it.
368
+
369
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
370
+ [`PreTrainedTokenizer.__call__`] for details.
371
+
372
+ [What are input IDs?](../glossary#input-ids)
373
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
374
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
375
+
376
+ - 1 for tokens that are **not masked**,
377
+ - 0 for tokens that are **masked**.
378
+
379
+ [What are attention masks?](../glossary#attention-mask)
380
+
381
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
382
+ [`PreTrainedTokenizer.__call__`] for details.
383
+
384
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
385
+ `past_key_values`).
386
+
387
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
388
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
389
+ information on the default strategy.
390
+
391
+ - 1 indicates the head is **not masked**,
392
+ - 0 indicates the head is **masked**.
393
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
394
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
395
+ config.n_positions - 1]`.
396
+
397
+ [What are position IDs?](../glossary#position-ids)
398
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
399
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
400
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
401
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
402
+
403
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
404
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
405
+
406
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
407
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
408
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
409
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
410
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
411
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
412
+ model's internal embedding lookup matrix.
413
+ use_cache (`bool`, *optional*):
414
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
415
+ `past_key_values`).
416
+ output_attentions (`bool`, *optional*):
417
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
418
+ tensors for more detail.
419
+ output_hidden_states (`bool`, *optional*):
420
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
421
+ more detail.
422
+ return_dict (`bool`, *optional*):
423
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
424
+ """
425
+
426
+
427
+ @add_start_docstrings(
428
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
429
+ LLAMA_START_DOCSTRING,
430
+ )
431
+ class LlamaModel(LlamaPreTrainedModel):
432
+ """
433
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
434
+
435
+ Args:
436
+ config: LlamaConfig
437
+ """
438
+
439
+ def __init__(self, config: LlamaConfig):
440
+ super().__init__(config)
441
+ self.padding_idx = config.pad_token_id
442
+ self.vocab_size = config.vocab_size
443
+
444
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
445
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
446
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
447
+
448
+ self.gradient_checkpointing = False
449
+ # Initialize weights and apply final processing
450
+ self.post_init()
451
+
452
+ def get_input_embeddings(self):
453
+ return self.embed_tokens
454
+
455
+ def set_input_embeddings(self, value):
456
+ self.embed_tokens = value
457
+
458
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
459
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
460
+ # create causal mask
461
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
462
+ combined_attention_mask = None
463
+ if input_shape[-1] > 1:
464
+ combined_attention_mask = _make_causal_mask(
465
+ input_shape,
466
+ inputs_embeds.dtype,
467
+ device=inputs_embeds.device,
468
+ past_key_values_length=past_key_values_length,
469
+ )
470
+
471
+ if attention_mask is not None:
472
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
473
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
474
+ inputs_embeds.device
475
+ )
476
+ combined_attention_mask = (
477
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
478
+ )
479
+
480
+ return combined_attention_mask
481
+
482
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
483
+ def forward(
484
+ self,
485
+ input_ids: torch.LongTensor = None,
486
+ attention_mask: Optional[torch.Tensor] = None,
487
+ position_ids: Optional[torch.LongTensor] = None,
488
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
489
+ inputs_embeds: Optional[torch.FloatTensor] = None,
490
+ use_cache: Optional[bool] = None,
491
+ output_attentions: Optional[bool] = None,
492
+ output_hidden_states: Optional[bool] = None,
493
+ return_dict: Optional[bool] = None,
494
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
495
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
496
+ output_hidden_states = (
497
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
498
+ )
499
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
500
+
501
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
502
+
503
+ # retrieve input_ids and inputs_embeds
504
+ if input_ids is not None and inputs_embeds is not None:
505
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
506
+ elif input_ids is not None:
507
+ batch_size, seq_length = input_ids.shape
508
+ elif inputs_embeds is not None:
509
+ batch_size, seq_length, _ = inputs_embeds.shape
510
+ else:
511
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
512
+
513
+ seq_length_with_past = seq_length
514
+ past_key_values_length = 0
515
+
516
+ if past_key_values is not None:
517
+ past_key_values_length = past_key_values[0][0].shape[2]
518
+ seq_length_with_past = seq_length_with_past + past_key_values_length
519
+
520
+ if position_ids is None:
521
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
522
+ position_ids = torch.arange(
523
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
524
+ )
525
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
526
+ else:
527
+ position_ids = position_ids.view(-1, seq_length).long()
528
+
529
+ if inputs_embeds is None:
530
+ inputs_embeds = self.embed_tokens(input_ids)
531
+ # embed positions
532
+ if attention_mask is None:
533
+ attention_mask = torch.ones(
534
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
535
+ )
536
+ attention_mask = self._prepare_decoder_attention_mask(
537
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
538
+ )
539
+
540
+ hidden_states = inputs_embeds
541
+
542
+ if self.gradient_checkpointing and self.training:
543
+ if use_cache:
544
+ logger.warning_once(
545
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
546
+ )
547
+ use_cache = False
548
+
549
+ # decoder layers
550
+ all_hidden_states = () if output_hidden_states else None
551
+ all_self_attns = () if output_attentions else None
552
+ next_decoder_cache = () if use_cache else None
553
+
554
+ for idx, decoder_layer in enumerate(self.layers):
555
+ if output_hidden_states:
556
+ all_hidden_states += (hidden_states,)
557
+
558
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
559
+
560
+ if self.gradient_checkpointing and self.training:
561
+
562
+ def create_custom_forward(module):
563
+ def custom_forward(*inputs):
564
+ # None for past_key_value
565
+ return module(*inputs, output_attentions, None)
566
+
567
+ return custom_forward
568
+
569
+ layer_outputs = torch.utils.checkpoint.checkpoint(
570
+ create_custom_forward(decoder_layer),
571
+ hidden_states,
572
+ attention_mask,
573
+ position_ids,
574
+ None,
575
+ )
576
+ else:
577
+ layer_outputs = decoder_layer(
578
+ hidden_states,
579
+ attention_mask=attention_mask,
580
+ position_ids=position_ids,
581
+ past_key_value=past_key_value,
582
+ output_attentions=output_attentions,
583
+ use_cache=use_cache,
584
+ )
585
+
586
+ hidden_states = layer_outputs[0]
587
+
588
+ if use_cache:
589
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
590
+
591
+ if output_attentions:
592
+ all_self_attns += (layer_outputs[1],)
593
+
594
+ hidden_states = self.norm(hidden_states)
595
+
596
+ # add hidden states from the last decoder layer
597
+ if output_hidden_states:
598
+ all_hidden_states += (hidden_states,)
599
+
600
+ next_cache = next_decoder_cache if use_cache else None
601
+ if not return_dict:
602
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
603
+ return BaseModelOutputWithPast(
604
+ last_hidden_state=hidden_states,
605
+ past_key_values=next_cache,
606
+ hidden_states=all_hidden_states,
607
+ attentions=all_self_attns,
608
+ )
609
+
610
+
611
+ class LlamaForCausalLM(LlamaPreTrainedModel):
612
+ def __init__(self, config):
613
+ super().__init__(config)
614
+ self.model = LlamaModel(config)
615
+
616
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
617
+
618
+ # Initialize weights and apply final processing
619
+ self.post_init()
620
+
621
+ def get_input_embeddings(self):
622
+ return self.model.embed_tokens
623
+
624
+ def set_input_embeddings(self, value):
625
+ self.model.embed_tokens = value
626
+
627
+ def get_output_embeddings(self):
628
+ return self.lm_head
629
+
630
+ def set_output_embeddings(self, new_embeddings):
631
+ self.lm_head = new_embeddings
632
+
633
+ def set_decoder(self, decoder):
634
+ self.model = decoder
635
+
636
+ def get_decoder(self):
637
+ return self.model
638
+
639
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
640
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
641
+ def forward(
642
+ self,
643
+ input_ids: torch.LongTensor = None,
644
+ attention_mask: Optional[torch.Tensor] = None,
645
+ position_ids: Optional[torch.LongTensor] = None,
646
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
647
+ inputs_embeds: Optional[torch.FloatTensor] = None,
648
+ labels: Optional[torch.LongTensor] = None,
649
+ use_cache: Optional[bool] = None,
650
+ output_attentions: Optional[bool] = None,
651
+ output_hidden_states: Optional[bool] = None,
652
+ return_dict: Optional[bool] = None,
653
+ reduction: Optional[str] = "mean",
654
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
655
+ r"""
656
+ Args:
657
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
658
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
659
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
660
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
661
+
662
+ Returns:
663
+
664
+ Example:
665
+
666
+ ```python
667
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
668
+
669
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
670
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
671
+
672
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
673
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
674
+
675
+ >>> # Generate
676
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
677
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
678
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
679
+ ```"""
680
+
681
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
682
+ output_hidden_states = (
683
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
684
+ )
685
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
686
+
687
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
688
+ outputs = self.model(
689
+ input_ids=input_ids,
690
+ attention_mask=attention_mask,
691
+ position_ids=position_ids,
692
+ past_key_values=past_key_values,
693
+ inputs_embeds=inputs_embeds,
694
+ use_cache=use_cache,
695
+ output_attentions=output_attentions,
696
+ output_hidden_states=output_hidden_states,
697
+ return_dict=return_dict,
698
+ )
699
+
700
+ hidden_states = outputs[0]
701
+ logits = self.lm_head(hidden_states)
702
+
703
+ loss = None
704
+ if labels is not None:
705
+ # Shift so that tokens < n predict n
706
+ shift_logits = logits[..., :-1, :].contiguous()
707
+ shift_labels = labels[..., 1:].contiguous()
708
+ # Flatten the tokens
709
+ loss_fct = CrossEntropyLoss(reduction=reduction)
710
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
711
+ shift_labels = shift_labels.view(-1)
712
+ # Enable model parallelism
713
+ shift_labels = shift_labels.to(shift_logits.device)
714
+ loss = loss_fct(shift_logits, shift_labels)
715
+ if reduction == "none":
716
+ # loss = loss.view(logits.size(0), -1).sum(1)
717
+ loss = loss.view(logits.size(0), -1).mean(1)
718
+
719
+ if not return_dict:
720
+ output = (logits,) + outputs[1:]
721
+ return (loss,) + output if loss is not None else output
722
+
723
+ return CausalLMOutputWithPast(
724
+ loss=loss,
725
+ logits=logits,
726
+ past_key_values=outputs.past_key_values,
727
+ hidden_states=outputs.hidden_states,
728
+ attentions=outputs.attentions,
729
+ )
730
+
731
+ def prepare_inputs_for_generation(
732
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
733
+ ):
734
+ if past_key_values:
735
+ input_ids = input_ids[:, -1:]
736
+
737
+ position_ids = kwargs.get("position_ids", None)
738
+ if attention_mask is not None and position_ids is None:
739
+ # create position_ids on the fly for batch generation
740
+ position_ids = attention_mask.long().cumsum(-1) - 1
741
+ position_ids.masked_fill_(attention_mask == 0, 1)
742
+ if past_key_values:
743
+ position_ids = position_ids[:, -1].unsqueeze(-1)
744
+
745
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
746
+ if inputs_embeds is not None and past_key_values is None:
747
+ model_inputs = {"inputs_embeds": inputs_embeds}
748
+ else:
749
+ model_inputs = {"input_ids": input_ids}
750
+
751
+ model_inputs.update(
752
+ {
753
+ "position_ids": position_ids,
754
+ "past_key_values": past_key_values,
755
+ "use_cache": kwargs.get("use_cache"),
756
+ "attention_mask": attention_mask,
757
+ }
758
+ )
759
+ return model_inputs
760
+
761
+ @staticmethod
762
+ def _reorder_cache(past_key_values, beam_idx):
763
+ reordered_past = ()
764
+ for layer_past in past_key_values:
765
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
766
+ return reordered_past
767
+
768
+
769
+ @add_start_docstrings(
770
+ """
771
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
772
+
773
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
774
+ (e.g. GPT-2) do.
775
+
776
+ Since it does classification on the last token, it requires to know the position of the last token. If a
777
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
778
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
779
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
780
+ each row of the batch).
781
+ """,
782
+ LLAMA_START_DOCSTRING,
783
+ )
784
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
785
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
786
+
787
+ def __init__(self, config):
788
+ super().__init__(config)
789
+ self.num_labels = config.num_labels
790
+ self.model = LlamaModel(config)
791
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
792
+
793
+ # Initialize weights and apply final processing
794
+ self.post_init()
795
+
796
+ def get_input_embeddings(self):
797
+ return self.model.embed_tokens
798
+
799
+ def set_input_embeddings(self, value):
800
+ self.model.embed_tokens = value
801
+
802
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
803
+ def forward(
804
+ self,
805
+ input_ids: torch.LongTensor = None,
806
+ attention_mask: Optional[torch.Tensor] = None,
807
+ position_ids: Optional[torch.LongTensor] = None,
808
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
809
+ inputs_embeds: Optional[torch.FloatTensor] = None,
810
+ labels: Optional[torch.LongTensor] = None,
811
+ use_cache: Optional[bool] = None,
812
+ output_attentions: Optional[bool] = None,
813
+ output_hidden_states: Optional[bool] = None,
814
+ return_dict: Optional[bool] = None,
815
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
816
+ r"""
817
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
818
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
819
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
820
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
821
+ """
822
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
823
+
824
+ transformer_outputs = self.model(
825
+ input_ids,
826
+ attention_mask=attention_mask,
827
+ position_ids=position_ids,
828
+ past_key_values=past_key_values,
829
+ inputs_embeds=inputs_embeds,
830
+ use_cache=use_cache,
831
+ output_attentions=output_attentions,
832
+ output_hidden_states=output_hidden_states,
833
+ return_dict=return_dict,
834
+ )
835
+ hidden_states = transformer_outputs[0]
836
+ logits = self.score(hidden_states)
837
+
838
+ if input_ids is not None:
839
+ batch_size = input_ids.shape[0]
840
+ else:
841
+ batch_size = inputs_embeds.shape[0]
842
+
843
+ if self.config.pad_token_id is None and batch_size != 1:
844
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
845
+ if self.config.pad_token_id is None:
846
+ sequence_lengths = -1
847
+ else:
848
+ if input_ids is not None:
849
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
850
+ else:
851
+ sequence_lengths = -1
852
+
853
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
854
+
855
+ loss = None
856
+ if labels is not None:
857
+ labels = labels.to(logits.device)
858
+ if self.config.problem_type is None:
859
+ if self.num_labels == 1:
860
+ self.config.problem_type = "regression"
861
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
862
+ self.config.problem_type = "single_label_classification"
863
+ else:
864
+ self.config.problem_type = "multi_label_classification"
865
+
866
+ if self.config.problem_type == "regression":
867
+ loss_fct = MSELoss()
868
+ if self.num_labels == 1:
869
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
870
+ else:
871
+ loss = loss_fct(pooled_logits, labels)
872
+ elif self.config.problem_type == "single_label_classification":
873
+ loss_fct = CrossEntropyLoss()
874
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
875
+ elif self.config.problem_type == "multi_label_classification":
876
+ loss_fct = BCEWithLogitsLoss()
877
+ loss = loss_fct(pooled_logits, labels)
878
+ if not return_dict:
879
+ output = (pooled_logits,) + transformer_outputs[1:]
880
+ return ((loss,) + output) if loss is not None else output
881
+
882
+ return SequenceClassifierOutputWithPast(
883
+ loss=loss,
884
+ logits=pooled_logits,
885
+ past_key_values=transformer_outputs.past_key_values,
886
+ hidden_states=transformer_outputs.hidden_states,
887
+ attentions=transformer_outputs.attentions,
888
+ )
model/modeling_opt.py ADDED
@@ -0,0 +1,1223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch OPT model."""
16
+ import random
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+ from transformers.activations import ACT2FN
24
+ # from ...activations import ACT2FN
25
+ from transformers.modeling_outputs import (
26
+ BaseModelOutputWithPast,
27
+ CausalLMOutputWithPast,
28
+ QuestionAnsweringModelOutput,
29
+ SequenceClassifierOutputWithPast,
30
+ )
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import (
33
+ add_code_sample_docstrings,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ logging,
37
+ replace_return_docstrings,
38
+ )
39
+ from transformers.models.opt.configuration_opt import OPTConfig
40
+ # from .configuration_opt
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
46
+ _CONFIG_FOR_DOC = "OPTConfig"
47
+
48
+ # Base model docstring
49
+ _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
50
+
51
+ # SequenceClassification docstring
52
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
53
+ _SEQ_CLASS_EXPECTED_LOSS = 1.71
54
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
55
+
56
+ OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
57
+ "facebook/opt-125m",
58
+ "facebook/opt-350m",
59
+ "facebook/opt-1.3b",
60
+ "facebook/opt-2.7b",
61
+ "facebook/opt-6.7b",
62
+ "facebook/opt-13b",
63
+ "facebook/opt-30b",
64
+ # See all OPT models at https://huggingface.co/models?filter=opt
65
+ ]
66
+
67
+
68
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
69
+ def _make_causal_mask(
70
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
71
+ ):
72
+ """
73
+ Make causal mask used for bi-directional self-attention.
74
+ """
75
+ bsz, tgt_len = input_ids_shape
76
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
77
+ mask_cond = torch.arange(mask.size(-1), device=device)
78
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
79
+ mask = mask.to(dtype)
80
+
81
+ if past_key_values_length > 0:
82
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
83
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
84
+
85
+
86
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
87
+ """
88
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
89
+ """
90
+ bsz, src_len = mask.size()
91
+ tgt_len = tgt_len if tgt_len is not None else src_len
92
+
93
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
94
+
95
+ inverted_mask = 1.0 - expanded_mask
96
+
97
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
98
+
99
+
100
+ class OPTLearnedPositionalEmbedding(nn.Embedding):
101
+ """
102
+ This module learns positional embeddings up to a fixed maximum size.
103
+ """
104
+
105
+ def __init__(self, num_embeddings: int, embedding_dim: int):
106
+ # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
107
+ # and adjust num_embeddings appropriately. Other models don't have this hack
108
+ self.offset = 2
109
+ super().__init__(num_embeddings + self.offset, embedding_dim)
110
+
111
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
112
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
113
+ attention_mask = attention_mask.long()
114
+
115
+ # create positions depending on attention_mask
116
+ positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
117
+
118
+ # cut positions if `past_key_values_length` is > 0
119
+ positions = positions[:, past_key_values_length:]
120
+
121
+ return super().forward(positions + self.offset)
122
+
123
+
124
+ class OPTAttention(nn.Module):
125
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
126
+
127
+ def __init__(
128
+ self,
129
+ embed_dim: int,
130
+ num_heads: int,
131
+ dropout: float = 0.0,
132
+ is_decoder: bool = False,
133
+ bias: bool = True,
134
+ ):
135
+ super().__init__()
136
+ self.embed_dim = embed_dim
137
+ self.num_heads = num_heads
138
+ self.dropout = dropout
139
+ self.head_dim = embed_dim // num_heads
140
+
141
+ if (self.head_dim * num_heads) != self.embed_dim:
142
+ raise ValueError(
143
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
144
+ f" and `num_heads`: {num_heads})."
145
+ )
146
+ self.scaling = self.head_dim**-0.5
147
+ self.is_decoder = is_decoder
148
+
149
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
150
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
151
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
152
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
153
+
154
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
155
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
156
+
157
+ def forward(
158
+ self,
159
+ hidden_states: torch.Tensor,
160
+ key_value_states: Optional[torch.Tensor] = None,
161
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
162
+ attention_mask: Optional[torch.Tensor] = None,
163
+ layer_head_mask: Optional[torch.Tensor] = None,
164
+ output_attentions: bool = False,
165
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
166
+ """Input shape: Batch x Time x Channel"""
167
+
168
+ # if key_value_states are provided this layer is used as a cross-attention layer
169
+ # for the decoder
170
+ is_cross_attention = key_value_states is not None
171
+
172
+ bsz, tgt_len, _ = hidden_states.size()
173
+
174
+ # get query proj
175
+ query_states = self.q_proj(hidden_states) * self.scaling
176
+ # get key, value proj
177
+ if is_cross_attention and past_key_value is not None:
178
+ # reuse k,v, cross_attentions
179
+ key_states = past_key_value[0]
180
+ value_states = past_key_value[1]
181
+ elif is_cross_attention:
182
+ # cross_attentions
183
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
184
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
185
+ elif past_key_value is not None:
186
+ # reuse k, v, self_attention
187
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
188
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
189
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
190
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
191
+ else:
192
+ # self_attention
193
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
194
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
195
+
196
+ if self.is_decoder:
197
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
198
+ # Further calls to cross_attention layer can then reuse all cross-attention
199
+ # key/value_states (first "if" case)
200
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
201
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
202
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
203
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
204
+ past_key_value = (key_states, value_states)
205
+
206
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
207
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
208
+ key_states = key_states.view(*proj_shape)
209
+ value_states = value_states.view(*proj_shape)
210
+
211
+ src_len = key_states.size(1)
212
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
213
+
214
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
215
+ raise ValueError(
216
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
217
+ f" {attn_weights.size()}"
218
+ )
219
+
220
+ if attention_mask is not None:
221
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
222
+ raise ValueError(
223
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
224
+ )
225
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
226
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
227
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
228
+
229
+ # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
230
+ if attn_weights.dtype == torch.float16:
231
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
232
+ else:
233
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
234
+
235
+ if layer_head_mask is not None:
236
+ if layer_head_mask.size() != (self.num_heads,):
237
+ raise ValueError(
238
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
239
+ f" {layer_head_mask.size()}"
240
+ )
241
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
242
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
243
+
244
+ if output_attentions:
245
+ # this operation is a bit awkward, but it's required to
246
+ # make sure that attn_weights keeps its gradient.
247
+ # In order to do so, attn_weights have to be reshaped
248
+ # twice and have to be reused in the following
249
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
250
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
251
+ else:
252
+ attn_weights_reshaped = None
253
+
254
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
255
+
256
+ attn_output = torch.bmm(attn_probs, value_states)
257
+
258
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
259
+ raise ValueError(
260
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
261
+ f" {attn_output.size()}"
262
+ )
263
+
264
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
265
+ attn_output = attn_output.transpose(1, 2)
266
+
267
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
268
+ # partitioned aross GPUs when using tensor-parallelism.
269
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
270
+
271
+ attn_output = self.out_proj(attn_output)
272
+
273
+ return attn_output, attn_weights_reshaped, past_key_value
274
+
275
+
276
+ class OPTDecoderLayer(nn.Module):
277
+ def __init__(self, config: OPTConfig):
278
+ super().__init__()
279
+ self.embed_dim = config.hidden_size
280
+ self.self_attn = OPTAttention(
281
+ embed_dim=self.embed_dim,
282
+ num_heads=config.num_attention_heads,
283
+ dropout=config.attention_dropout,
284
+ is_decoder=True,
285
+ bias=config.enable_bias,
286
+ )
287
+ self.do_layer_norm_before = config.do_layer_norm_before
288
+ self.dropout = config.dropout
289
+ self.activation_fn = ACT2FN[config.activation_function]
290
+
291
+ self.self_attn_layer_norm = nn.LayerNorm(
292
+ self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
293
+ )
294
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
295
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
296
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
297
+
298
+ def forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ attention_mask: Optional[torch.Tensor] = None,
302
+ layer_head_mask: Optional[torch.Tensor] = None,
303
+ output_attentions: Optional[bool] = False,
304
+ use_cache: Optional[bool] = False,
305
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
306
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
307
+ """
308
+ Args:
309
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
310
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
311
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
312
+ layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
313
+ `(encoder_attention_heads,)`.
314
+ output_attentions (`bool`, *optional*):
315
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
316
+ returned tensors for more detail.
317
+ use_cache (`bool`, *optional*):
318
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
319
+ (see `past_key_values`).
320
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
321
+ """
322
+
323
+ residual = hidden_states
324
+
325
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
326
+ if self.do_layer_norm_before:
327
+ hidden_states = self.self_attn_layer_norm(hidden_states)
328
+
329
+ # Self Attention
330
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
331
+ hidden_states=hidden_states,
332
+ past_key_value=past_key_value,
333
+ attention_mask=attention_mask,
334
+ layer_head_mask=layer_head_mask,
335
+ output_attentions=output_attentions,
336
+ )
337
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
338
+ hidden_states = residual + hidden_states
339
+
340
+ # 350m applies layer norm AFTER attention
341
+ if not self.do_layer_norm_before:
342
+ hidden_states = self.self_attn_layer_norm(hidden_states)
343
+
344
+ # Fully Connected
345
+ hidden_states_shape = hidden_states.shape
346
+ hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
347
+ residual = hidden_states
348
+
349
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
350
+ if self.do_layer_norm_before:
351
+ hidden_states = self.final_layer_norm(hidden_states)
352
+
353
+ hidden_states = self.fc1(hidden_states)
354
+ hidden_states = self.activation_fn(hidden_states)
355
+
356
+ hidden_states = self.fc2(hidden_states)
357
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
358
+
359
+ hidden_states = (residual + hidden_states).view(hidden_states_shape)
360
+
361
+ # 350m applies layer norm AFTER attention
362
+ if not self.do_layer_norm_before:
363
+ hidden_states = self.final_layer_norm(hidden_states)
364
+
365
+ outputs = (hidden_states,)
366
+
367
+ if output_attentions:
368
+ outputs += (self_attn_weights,)
369
+
370
+ if use_cache:
371
+ outputs += (present_key_value,)
372
+
373
+ return outputs
374
+
375
+
376
+ OPT_START_DOCSTRING = r"""
377
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
378
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
379
+ etc.)
380
+
381
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
382
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
383
+ and behavior.
384
+
385
+ Parameters:
386
+ config ([`OPTConfig`]):
387
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
388
+ load the weights associated with the model, only the configuration. Check out the
389
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
390
+ """
391
+
392
+
393
+ @add_start_docstrings(
394
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
395
+ OPT_START_DOCSTRING,
396
+ )
397
+ class OPTPreTrainedModel(PreTrainedModel):
398
+ config_class = OPTConfig
399
+ base_model_prefix = "model"
400
+ supports_gradient_checkpointing = True
401
+ _no_split_modules = ["OPTDecoderLayer"]
402
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
403
+
404
+ def _init_weights(self, module):
405
+ std = self.config.init_std
406
+ if isinstance(module, nn.Linear):
407
+ module.weight.data.normal_(mean=0.0, std=std)
408
+ if module.bias is not None:
409
+ module.bias.data.zero_()
410
+ elif isinstance(module, nn.Embedding):
411
+ module.weight.data.normal_(mean=0.0, std=std)
412
+ if module.padding_idx is not None:
413
+ module.weight.data[module.padding_idx].zero_()
414
+
415
+ def _set_gradient_checkpointing(self, module, value=False):
416
+ if isinstance(module, (OPTDecoder)):
417
+ module.gradient_checkpointing = value
418
+
419
+
420
+ OPT_INPUTS_DOCSTRING = r"""
421
+ Args:
422
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
423
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
424
+ it.
425
+
426
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
427
+ [`PreTrainedTokenizer.__call__`] for details.
428
+
429
+ [What are input IDs?](../glossary#input-ids)
430
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
431
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
432
+
433
+ - 1 for tokens that are **not masked**,
434
+ - 0 for tokens that are **masked**.
435
+
436
+ [What are attention masks?](../glossary#attention-mask)
437
+
438
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
439
+ [`PreTrainedTokenizer.__call__`] for details.
440
+
441
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
442
+ `past_key_values`).
443
+
444
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
445
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
446
+ information on the default strategy.
447
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
448
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
449
+
450
+ - 1 indicates the head is **not masked**,
451
+ - 0 indicates the head is **masked**.
452
+
453
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
454
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
455
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
456
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
457
+
458
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
459
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
460
+
461
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
462
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
463
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
464
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
465
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
466
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
467
+ model's internal embedding lookup matrix.
468
+ use_cache (`bool`, *optional*):
469
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
470
+ `past_key_values`).
471
+ output_attentions (`bool`, *optional*):
472
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
473
+ tensors for more detail.
474
+ output_hidden_states (`bool`, *optional*):
475
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
476
+ more detail.
477
+ return_dict (`bool`, *optional*):
478
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
479
+ """
480
+
481
+
482
+ class OPTDecoder(OPTPreTrainedModel):
483
+ """
484
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
485
+
486
+ Args:
487
+ config: OPTConfig
488
+ """
489
+
490
+ def __init__(self, config: OPTConfig):
491
+ super().__init__(config)
492
+ self.dropout = config.dropout
493
+ self.layerdrop = config.layerdrop
494
+ self.padding_idx = config.pad_token_id
495
+ self.max_target_positions = config.max_position_embeddings
496
+ self.vocab_size = config.vocab_size
497
+
498
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
499
+ self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
500
+
501
+ if config.word_embed_proj_dim != config.hidden_size:
502
+ self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
503
+ else:
504
+ self.project_out = None
505
+
506
+ if config.word_embed_proj_dim != config.hidden_size:
507
+ self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
508
+ else:
509
+ self.project_in = None
510
+
511
+ # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
512
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
513
+ # see https://github.com/facebookresearch/metaseq/pull/164
514
+ if config.do_layer_norm_before and not config._remove_final_layer_norm:
515
+ self.final_layer_norm = nn.LayerNorm(
516
+ config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
517
+ )
518
+ else:
519
+ self.final_layer_norm = None
520
+
521
+ self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
522
+
523
+ self.gradient_checkpointing = False
524
+ # Initialize weights and apply final processing
525
+ self.post_init()
526
+
527
+ def get_input_embeddings(self):
528
+ return self.embed_tokens
529
+
530
+ def set_input_embeddings(self, value):
531
+ self.embed_tokens = value
532
+
533
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
534
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
535
+ # create causal mask
536
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
537
+ combined_attention_mask = None
538
+ if input_shape[-1] > 1:
539
+ combined_attention_mask = _make_causal_mask(
540
+ input_shape,
541
+ inputs_embeds.dtype,
542
+ device=inputs_embeds.device,
543
+ past_key_values_length=past_key_values_length,
544
+ )
545
+
546
+ if attention_mask is not None:
547
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
548
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
549
+ inputs_embeds.device
550
+ )
551
+ combined_attention_mask = (
552
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
553
+ )
554
+
555
+ return combined_attention_mask
556
+
557
+ def forward(
558
+ self,
559
+ input_ids: torch.LongTensor = None,
560
+ attention_mask: Optional[torch.Tensor] = None,
561
+ head_mask: Optional[torch.Tensor] = None,
562
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
563
+ inputs_embeds: Optional[torch.FloatTensor] = None,
564
+ use_cache: Optional[bool] = None,
565
+ output_attentions: Optional[bool] = None,
566
+ output_hidden_states: Optional[bool] = None,
567
+ return_dict: Optional[bool] = None,
568
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
569
+ r"""
570
+ Args:
571
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
572
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
573
+ provide it.
574
+
575
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
576
+ [`PreTrainedTokenizer.__call__`] for details.
577
+
578
+ [What are input IDs?](../glossary#input-ids)
579
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
580
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
581
+
582
+ - 1 for tokens that are **not masked**,
583
+ - 0 for tokens that are **masked**.
584
+
585
+ [What are attention masks?](../glossary#attention-mask)
586
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
587
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
588
+
589
+ - 1 indicates the head is **not masked**,
590
+ - 0 indicates the head is **masked**.
591
+
592
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
593
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
594
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
595
+
596
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
597
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
598
+
599
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
600
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
601
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
602
+
603
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
604
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
605
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
606
+ than the model's internal embedding lookup matrix.
607
+ output_attentions (`bool`, *optional*):
608
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
609
+ returned tensors for more detail.
610
+ output_hidden_states (`bool`, *optional*):
611
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
612
+ for more detail.
613
+ return_dict (`bool`, *optional*):
614
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
615
+ """
616
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
617
+ output_hidden_states = (
618
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
619
+ )
620
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
621
+
622
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
623
+
624
+ # retrieve input_ids and inputs_embeds
625
+ if input_ids is not None and inputs_embeds is not None:
626
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
627
+ elif input_ids is not None:
628
+ input_shape = input_ids.size()
629
+ input_ids = input_ids.view(-1, input_shape[-1])
630
+ elif inputs_embeds is not None:
631
+ input_shape = inputs_embeds.size()[:-1]
632
+ else:
633
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
634
+
635
+ if inputs_embeds is None:
636
+ inputs_embeds = self.embed_tokens(input_ids)
637
+
638
+ batch_size, seq_length = input_shape
639
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
640
+ # required mask seq length can be calculated via length of past
641
+ mask_seq_length = past_key_values_length + seq_length
642
+
643
+ # embed positions
644
+ if attention_mask is None:
645
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
646
+ causal_attention_mask = self._prepare_decoder_attention_mask(
647
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
648
+ )
649
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
650
+
651
+ if self.project_in is not None:
652
+ inputs_embeds = self.project_in(inputs_embeds)
653
+
654
+ hidden_states = inputs_embeds + pos_embeds
655
+
656
+ if self.gradient_checkpointing and self.training:
657
+ if use_cache:
658
+ logger.warning_once(
659
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
660
+ )
661
+ use_cache = False
662
+
663
+ # decoder layers
664
+ all_hidden_states = () if output_hidden_states else None
665
+ all_self_attns = () if output_attentions else None
666
+ next_decoder_cache = () if use_cache else None
667
+
668
+ # check if head_mask has a correct number of layers specified if desired
669
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
670
+ if attn_mask is not None:
671
+ if attn_mask.size()[0] != (len(self.layers)):
672
+ raise ValueError(
673
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
674
+ f" {head_mask.size()[0]}."
675
+ )
676
+
677
+ for idx, decoder_layer in enumerate(self.layers):
678
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
679
+ if output_hidden_states:
680
+ all_hidden_states += (hidden_states,)
681
+
682
+ dropout_probability = random.uniform(0, 1)
683
+ if self.training and (dropout_probability < self.layerdrop):
684
+ continue
685
+
686
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
687
+
688
+ if self.gradient_checkpointing and self.training:
689
+
690
+ def create_custom_forward(module):
691
+ def custom_forward(*inputs):
692
+ # None for past_key_value
693
+ return module(*inputs, output_attentions, None)
694
+
695
+ return custom_forward
696
+
697
+ layer_outputs = torch.utils.checkpoint.checkpoint(
698
+ create_custom_forward(decoder_layer),
699
+ hidden_states,
700
+ causal_attention_mask,
701
+ head_mask[idx] if head_mask is not None else None,
702
+ None,
703
+ )
704
+ else:
705
+ layer_outputs = decoder_layer(
706
+ hidden_states,
707
+ attention_mask=causal_attention_mask,
708
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
709
+ past_key_value=past_key_value,
710
+ output_attentions=output_attentions,
711
+ use_cache=use_cache,
712
+ )
713
+
714
+ hidden_states = layer_outputs[0]
715
+
716
+ if use_cache:
717
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
718
+
719
+ if output_attentions:
720
+ all_self_attns += (layer_outputs[1],)
721
+
722
+ if self.final_layer_norm is not None:
723
+ hidden_states = self.final_layer_norm(hidden_states)
724
+
725
+ if self.project_out is not None:
726
+ hidden_states = self.project_out(hidden_states)
727
+
728
+ # add hidden states from the last decoder layer
729
+ if output_hidden_states:
730
+ all_hidden_states += (hidden_states,)
731
+
732
+ next_cache = next_decoder_cache if use_cache else None
733
+ if not return_dict:
734
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
735
+ return BaseModelOutputWithPast(
736
+ last_hidden_state=hidden_states,
737
+ past_key_values=next_cache,
738
+ hidden_states=all_hidden_states,
739
+ attentions=all_self_attns,
740
+ )
741
+
742
+
743
+ @add_start_docstrings(
744
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
745
+ OPT_START_DOCSTRING,
746
+ )
747
+ class OPTModel(OPTPreTrainedModel):
748
+ def __init__(self, config: OPTConfig):
749
+ super().__init__(config)
750
+ self.decoder = OPTDecoder(config)
751
+ # Initialize weights and apply final processing
752
+ self.post_init()
753
+
754
+ def get_input_embeddings(self):
755
+ return self.decoder.embed_tokens
756
+
757
+ def set_input_embeddings(self, value):
758
+ self.decoder.embed_tokens = value
759
+
760
+ def get_decoder(self):
761
+ return self.decoder
762
+
763
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
764
+ @add_code_sample_docstrings(
765
+ checkpoint=_CHECKPOINT_FOR_DOC,
766
+ output_type=BaseModelOutputWithPast,
767
+ config_class=_CONFIG_FOR_DOC,
768
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
769
+ )
770
+ def forward(
771
+ self,
772
+ input_ids: torch.LongTensor = None,
773
+ attention_mask: Optional[torch.Tensor] = None,
774
+ head_mask: Optional[torch.Tensor] = None,
775
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
776
+ inputs_embeds: Optional[torch.FloatTensor] = None,
777
+ use_cache: Optional[bool] = None,
778
+ output_attentions: Optional[bool] = None,
779
+ output_hidden_states: Optional[bool] = None,
780
+ return_dict: Optional[bool] = None,
781
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
782
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
783
+ output_hidden_states = (
784
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
785
+ )
786
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
787
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
788
+
789
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
790
+ decoder_outputs = self.decoder(
791
+ input_ids=input_ids,
792
+ attention_mask=attention_mask,
793
+ head_mask=head_mask,
794
+ past_key_values=past_key_values,
795
+ inputs_embeds=inputs_embeds,
796
+ use_cache=use_cache,
797
+ output_attentions=output_attentions,
798
+ output_hidden_states=output_hidden_states,
799
+ return_dict=return_dict,
800
+ )
801
+
802
+ if not return_dict:
803
+ return decoder_outputs
804
+
805
+ return BaseModelOutputWithPast(
806
+ last_hidden_state=decoder_outputs.last_hidden_state,
807
+ past_key_values=decoder_outputs.past_key_values,
808
+ hidden_states=decoder_outputs.hidden_states,
809
+ attentions=decoder_outputs.attentions,
810
+ )
811
+
812
+
813
+ class OPTForCausalLM(OPTPreTrainedModel):
814
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
815
+
816
+ def __init__(self, config):
817
+ super().__init__(config)
818
+ self.model = OPTModel(config)
819
+
820
+ # the lm_head weight is automatically tied to the embed tokens weight
821
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
822
+
823
+ # Initialize weights and apply final processing
824
+ self.post_init()
825
+
826
+ def get_input_embeddings(self):
827
+ return self.model.decoder.embed_tokens
828
+
829
+ def set_input_embeddings(self, value):
830
+ self.model.decoder.embed_tokens = value
831
+
832
+ def get_output_embeddings(self):
833
+ return self.lm_head
834
+
835
+ def set_output_embeddings(self, new_embeddings):
836
+ self.lm_head = new_embeddings
837
+
838
+ def set_decoder(self, decoder):
839
+ self.model.decoder = decoder
840
+
841
+ def get_decoder(self):
842
+ return self.model.decoder
843
+
844
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
845
+ def forward(
846
+ self,
847
+ input_ids: torch.LongTensor = None,
848
+ attention_mask: Optional[torch.Tensor] = None,
849
+ head_mask: Optional[torch.Tensor] = None,
850
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
851
+ inputs_embeds: Optional[torch.FloatTensor] = None,
852
+ labels: Optional[torch.LongTensor] = None,
853
+ use_cache: Optional[bool] = None,
854
+ output_attentions: Optional[bool] = None,
855
+ output_hidden_states: Optional[bool] = None,
856
+ return_dict: Optional[bool] = None,
857
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
858
+ r"""
859
+ Args:
860
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
861
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
862
+ provide it.
863
+
864
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
865
+ [`PreTrainedTokenizer.__call__`] for details.
866
+
867
+ [What are input IDs?](../glossary#input-ids)
868
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
869
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
870
+
871
+ - 1 for tokens that are **not masked**,
872
+ - 0 for tokens that are **masked**.
873
+
874
+ [What are attention masks?](../glossary#attention-mask)
875
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
876
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
877
+
878
+ - 1 indicates the head is **not masked**,
879
+ - 0 indicates the head is **masked**.
880
+
881
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
882
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
883
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
884
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
885
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
886
+
887
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
888
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
889
+
890
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
891
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
892
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
893
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
894
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
895
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
896
+ than the model's internal embedding lookup matrix.
897
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
898
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
899
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
900
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
901
+ use_cache (`bool`, *optional*):
902
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
903
+ (see `past_key_values`).
904
+ output_attentions (`bool`, *optional*):
905
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
906
+ returned tensors for more detail.
907
+ output_hidden_states (`bool`, *optional*):
908
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
909
+ for more detail.
910
+ return_dict (`bool`, *optional*):
911
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
912
+
913
+ Returns:
914
+
915
+ Example:
916
+
917
+ ```python
918
+ >>> from transformers import AutoTokenizer, OPTForCausalLM
919
+
920
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
921
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
922
+
923
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
924
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
925
+
926
+ >>> # Generate
927
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
928
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
929
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
930
+ ```"""
931
+
932
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
933
+ output_hidden_states = (
934
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
935
+ )
936
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
937
+
938
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
939
+ outputs = self.model.decoder(
940
+ input_ids=input_ids,
941
+ attention_mask=attention_mask,
942
+ head_mask=head_mask,
943
+ past_key_values=past_key_values,
944
+ inputs_embeds=inputs_embeds,
945
+ use_cache=use_cache,
946
+ output_attentions=output_attentions,
947
+ output_hidden_states=output_hidden_states,
948
+ return_dict=return_dict,
949
+ )
950
+
951
+ logits = self.lm_head(outputs[0]).contiguous()
952
+
953
+ loss = None
954
+ if labels is not None:
955
+ # Shift so that tokens < n predict n
956
+ shift_logits = logits[..., :-1, :].contiguous()
957
+ shift_labels = labels[..., 1:].contiguous()
958
+ # Flatten the tokens
959
+ loss_fct = CrossEntropyLoss()
960
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
961
+
962
+ if not return_dict:
963
+ output = (logits,) + outputs[1:]
964
+ return (loss,) + output if loss is not None else output
965
+
966
+ return CausalLMOutputWithPast(
967
+ loss=loss,
968
+ logits=logits,
969
+ past_key_values=outputs.past_key_values,
970
+ hidden_states=outputs.hidden_states,
971
+ attentions=outputs.attentions,
972
+ )
973
+
974
+ def prepare_inputs_for_generation(
975
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
976
+ ):
977
+ if past_key_values:
978
+ input_ids = input_ids[:, -1:]
979
+
980
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
981
+ if inputs_embeds is not None and past_key_values is None:
982
+ model_inputs = {"inputs_embeds": inputs_embeds}
983
+ else:
984
+ model_inputs = {"input_ids": input_ids}
985
+
986
+ model_inputs.update(
987
+ {
988
+ "past_key_values": past_key_values,
989
+ "use_cache": kwargs.get("use_cache"),
990
+ "attention_mask": attention_mask,
991
+ }
992
+ )
993
+ return model_inputs
994
+
995
+ @staticmethod
996
+ def _reorder_cache(past_key_values, beam_idx):
997
+ reordered_past = ()
998
+ for layer_past in past_key_values:
999
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1000
+ return reordered_past
1001
+
1002
+
1003
+ @add_start_docstrings(
1004
+ """
1005
+ The OPT Model transformer with a sequence classification head on top (linear layer).
1006
+
1007
+ [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1008
+ (e.g. GPT-2) do.
1009
+
1010
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1011
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1012
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1013
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1014
+ each row of the batch).
1015
+ """,
1016
+ OPT_START_DOCSTRING,
1017
+ )
1018
+ class OPTForSequenceClassification(OPTPreTrainedModel):
1019
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
1020
+
1021
+ def __init__(self, config: OPTConfig):
1022
+ super().__init__(config)
1023
+ self.num_labels = config.num_labels
1024
+ self.model = OPTModel(config)
1025
+ self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
1026
+
1027
+ # Initialize weights and apply final processing
1028
+ self.post_init()
1029
+
1030
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1031
+ @add_code_sample_docstrings(
1032
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1033
+ output_type=SequenceClassifierOutputWithPast,
1034
+ config_class=_CONFIG_FOR_DOC,
1035
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1036
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1037
+ )
1038
+ def forward(
1039
+ self,
1040
+ input_ids: Optional[torch.LongTensor] = None,
1041
+ attention_mask: Optional[torch.FloatTensor] = None,
1042
+ head_mask: Optional[torch.FloatTensor] = None,
1043
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1044
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1045
+ labels: Optional[torch.LongTensor] = None,
1046
+ use_cache: Optional[bool] = None,
1047
+ output_attentions: Optional[bool] = None,
1048
+ output_hidden_states: Optional[bool] = None,
1049
+ return_dict: Optional[bool] = None,
1050
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1051
+ r"""
1052
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1053
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1054
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1055
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1056
+ """
1057
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1058
+
1059
+ transformer_outputs = self.model(
1060
+ input_ids,
1061
+ past_key_values=past_key_values,
1062
+ attention_mask=attention_mask, # shape = [B, max_len]
1063
+ head_mask=head_mask,
1064
+ inputs_embeds=inputs_embeds,
1065
+ use_cache=use_cache,
1066
+ output_attentions=output_attentions,
1067
+ output_hidden_states=output_hidden_states,
1068
+ return_dict=return_dict,
1069
+ )
1070
+ hidden_states = transformer_outputs[0]
1071
+ logits = self.score(hidden_states) # shape = [B, max_len, D]
1072
+
1073
+ denom = torch.sum(attention_mask, -1, keepdim=True) # shape = [B, 1]
1074
+ pooled_logits = torch.sum(logits * attention_mask.unsqueeze(-1), dim=1) # shape = [B, D]
1075
+ pooled_logits = pooled_logits / denom
1076
+
1077
+ loss = None
1078
+ return SequenceClassifierOutputWithPast(
1079
+ loss=loss,
1080
+ logits=pooled_logits,
1081
+ past_key_values=transformer_outputs.past_key_values,
1082
+ hidden_states=transformer_outputs.hidden_states,
1083
+ attentions=transformer_outputs.attentions,
1084
+ )
1085
+
1086
+ def get_input_embeddings(self):
1087
+ return self.model.decoder.embed_tokens
1088
+
1089
+ def set_input_embeddings(self, value):
1090
+ self.model.decoder.embed_tokens = value
1091
+
1092
+
1093
+ @add_start_docstrings(
1094
+ """
1095
+ The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD
1096
+ (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1097
+ """,
1098
+ OPT_START_DOCSTRING,
1099
+ )
1100
+ class OPTForQuestionAnswering(OPTPreTrainedModel):
1101
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
1102
+
1103
+ def __init__(self, config: OPTConfig):
1104
+ super().__init__(config)
1105
+ self.model = OPTModel(config)
1106
+ self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2)
1107
+
1108
+ # Initialize weights and apply final processing
1109
+ self.post_init()
1110
+
1111
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1112
+ @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
1113
+ def forward(
1114
+ self,
1115
+ input_ids: Optional[torch.LongTensor] = None,
1116
+ attention_mask: Optional[torch.FloatTensor] = None,
1117
+ head_mask: Optional[torch.FloatTensor] = None,
1118
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1119
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1120
+ start_positions: Optional[torch.LongTensor] = None,
1121
+ end_positions: Optional[torch.LongTensor] = None,
1122
+ use_cache: Optional[bool] = None,
1123
+ output_attentions: Optional[bool] = None,
1124
+ output_hidden_states: Optional[bool] = None,
1125
+ return_dict: Optional[bool] = None,
1126
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1127
+ r"""
1128
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1129
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1130
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1131
+ are not taken into account for computing the loss.
1132
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1133
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1134
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1135
+ are not taken into account for computing the loss.
1136
+
1137
+ Returns:
1138
+
1139
+ Example:
1140
+
1141
+ ```python
1142
+ >>> from transformers import AutoTokenizer, OPTForQuestionAnswering
1143
+ >>> import torch
1144
+
1145
+ >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
1146
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
1147
+
1148
+ >>> # note: we are loading a OPTForQuestionAnswering from the hub here,
1149
+ >>> # so the head will be randomly initialized, hence the predictions will be random
1150
+ >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
1151
+
1152
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
1153
+
1154
+ >>> inputs = tokenizer(question, text, return_tensors="pt")
1155
+ >>> with torch.no_grad():
1156
+ ... outputs = model(**inputs)
1157
+
1158
+ >>> answer_start_index = outputs.start_logits.argmax()
1159
+ >>> answer_end_index = outputs.end_logits.argmax()
1160
+
1161
+ >>> answer_offset = len(tokenizer(question)[0])
1162
+
1163
+ >>> predict_answer_tokens = inputs.input_ids[
1164
+ ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1
1165
+ ... ]
1166
+ >>> predicted = tokenizer.decode(predict_answer_tokens)
1167
+ >>> predicted
1168
+ ' a nice puppet'
1169
+ ```"""
1170
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1171
+
1172
+ transformer_outputs = self.model(
1173
+ input_ids,
1174
+ past_key_values=past_key_values,
1175
+ attention_mask=attention_mask,
1176
+ head_mask=head_mask,
1177
+ inputs_embeds=inputs_embeds,
1178
+ use_cache=use_cache,
1179
+ output_attentions=output_attentions,
1180
+ output_hidden_states=output_hidden_states,
1181
+ return_dict=return_dict,
1182
+ )
1183
+ hidden_states = transformer_outputs[0]
1184
+
1185
+ logits = self.qa_outputs(hidden_states)
1186
+ start_logits, end_logits = logits.split(1, dim=-1)
1187
+ start_logits = start_logits.squeeze(-1).contiguous()
1188
+ end_logits = end_logits.squeeze(-1).contiguous()
1189
+
1190
+ total_loss = None
1191
+ if start_positions is not None and end_positions is not None:
1192
+ # If we are on multi-GPU, split add a dimension
1193
+ if len(start_positions.size()) > 1:
1194
+ start_positions = start_positions.squeeze(-1)
1195
+ if len(end_positions.size()) > 1:
1196
+ end_positions = end_positions.squeeze(-1)
1197
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1198
+ ignored_index = start_logits.size(1)
1199
+ start_positions = start_positions.clamp(0, ignored_index)
1200
+ end_positions = end_positions.clamp(0, ignored_index)
1201
+
1202
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1203
+ start_loss = loss_fct(start_logits, start_positions)
1204
+ end_loss = loss_fct(end_logits, end_positions)
1205
+ total_loss = (start_loss + end_loss) / 2
1206
+
1207
+ if not return_dict:
1208
+ output = (start_logits, end_logits) + transformer_outputs[2:]
1209
+ return ((total_loss,) + output) if total_loss is not None else output
1210
+
1211
+ return QuestionAnsweringModelOutput(
1212
+ loss=total_loss,
1213
+ start_logits=start_logits,
1214
+ end_logits=end_logits,
1215
+ hidden_states=transformer_outputs.hidden_states,
1216
+ attentions=transformer_outputs.attentions,
1217
+ )
1218
+
1219
+ def get_input_embeddings(self):
1220
+ return self.model.decoder.embed_tokens
1221
+
1222
+ def set_input_embeddings(self, value):
1223
+ self.model.decoder.embed_tokens = value
model/opt_flash_attention.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+ import logging
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ import transformers
8
+ from einops import rearrange
9
+
10
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
11
+ from flash_attn.bert_padding import unpad_input, pad_input
12
+ from transformers.models.opt.modeling_opt import _make_causal_mask, _expand_mask
13
+
14
+
15
+ def _prepare_decoder_attention_mask_original(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
16
+ # create causal mask
17
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
18
+ combined_attention_mask = None
19
+ if input_shape[-1] > 1:
20
+ combined_attention_mask = _make_causal_mask(
21
+ input_shape,
22
+ inputs_embeds.dtype,
23
+ device=inputs_embeds.device,
24
+ past_key_values_length=past_key_values_length,
25
+ )
26
+
27
+ if attention_mask is not None:
28
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
29
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
30
+ inputs_embeds.device
31
+ )
32
+ combined_attention_mask = (
33
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
34
+ )
35
+
36
+ return combined_attention_mask
37
+
38
+ def forward_original(
39
+ self,
40
+ hidden_states: torch.Tensor,
41
+ key_value_states: Optional[torch.Tensor] = None,
42
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ layer_head_mask: Optional[torch.Tensor] = None,
45
+ output_attentions: bool = False,
46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
47
+ """Input shape: Batch x Time x Channel"""
48
+ # if key_value_states are provided this layer is used as a cross-attention layer
49
+ # for the decoder
50
+ is_cross_attention = key_value_states is not None
51
+
52
+ bsz, tgt_len, _ = hidden_states.size()
53
+
54
+ # get query proj
55
+ query_states = self.q_proj(hidden_states) * self.scaling
56
+ # get key, value proj
57
+ if is_cross_attention and past_key_value is not None:
58
+ # reuse k,v, cross_attentions
59
+ key_states = past_key_value[0]
60
+ value_states = past_key_value[1]
61
+ elif is_cross_attention:
62
+ # cross_attentions
63
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
64
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
65
+ elif past_key_value is not None:
66
+ # reuse k, v, self_attention
67
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
68
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
69
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
70
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
71
+ else:
72
+ # self_attention
73
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
74
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
75
+
76
+ if self.is_decoder:
77
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
78
+ # Further calls to cross_attention layer can then reuse all cross-attention
79
+ # key/value_states (first "if" case)
80
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
81
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
82
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
83
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
84
+ past_key_value = (key_states, value_states)
85
+
86
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
87
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
88
+ key_states = key_states.view(*proj_shape)
89
+ value_states = value_states.view(*proj_shape)
90
+
91
+ src_len = key_states.size(1)
92
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
93
+
94
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
95
+ raise ValueError(
96
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
97
+ f" {attn_weights.size()}"
98
+ )
99
+
100
+ if attention_mask is not None:
101
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
102
+ raise ValueError(
103
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
104
+ )
105
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
106
+ attn_weights = torch.max(
107
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
108
+ )
109
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
110
+
111
+ # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
112
+ if attn_weights.dtype == torch.float16:
113
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
114
+ else:
115
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
116
+
117
+ if layer_head_mask is not None:
118
+ if layer_head_mask.size() != (self.num_heads,):
119
+ raise ValueError(
120
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
121
+ f" {layer_head_mask.size()}"
122
+ )
123
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
124
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
125
+
126
+ if output_attentions:
127
+ # this operation is a bit awkward, but it's required to
128
+ # make sure that attn_weights keeps its gradient.
129
+ # In order to do so, attn_weights have to be reshaped
130
+ # twice and have to be reused in the following
131
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
132
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
133
+ else:
134
+ attn_weights_reshaped = None
135
+
136
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
137
+
138
+ attn_output = torch.bmm(attn_probs, value_states)
139
+
140
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
141
+ raise ValueError(
142
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
143
+ f" {attn_output.size()}"
144
+ )
145
+
146
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
147
+ attn_output = attn_output.transpose(1, 2)
148
+
149
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
150
+ # partitioned aross GPUs when using tensor-parallelism.
151
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
152
+
153
+ attn_output = self.out_proj(attn_output)
154
+
155
+ return attn_output, attn_weights_reshaped, past_key_value
156
+
157
+
158
+ def forward(
159
+ self,
160
+ hidden_states: torch.Tensor,
161
+ key_value_states: Optional[torch.Tensor] = None,
162
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
163
+ attention_mask: Optional[torch.Tensor] = None,
164
+ layer_head_mask: Optional[torch.Tensor] = None,
165
+ output_attentions: bool = False,
166
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
167
+ """Input shape: Batch x Time x Channel"""
168
+
169
+ # if key_value_states are provided this layer is used as a cross-attention layer
170
+ # for the decoder
171
+ is_cross_attention = key_value_states is not None
172
+ assert not is_cross_attention, "Cross attention is not supported for flash attention"
173
+ assert past_key_value is None, "past_key_value is not None is not supported for flash attention"
174
+ assert not output_attentions, "output_attentions is not supported for flash attention"
175
+
176
+ bsz, tgt_len, _ = hidden_states.size()
177
+
178
+ # get query proj
179
+ query_states = self.q_proj(hidden_states) * self.scaling
180
+ # get key, value proj
181
+
182
+ if past_key_value is not None:
183
+ # reuse k, v, self_attention
184
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
185
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
186
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
187
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
188
+ else:
189
+ # self_attention
190
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
191
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
192
+
193
+ if self.is_decoder:
194
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
195
+ # Further calls to cross_attention layer can then reuse all cross-attention
196
+ # key/value_states (first "if" case)
197
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
198
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
199
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
200
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
201
+ past_key_value = (key_states, value_states)
202
+
203
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
204
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
205
+ key_states = key_states.view(*proj_shape)
206
+ value_states = value_states.view(*proj_shape)
207
+
208
+ ## for flash attention
209
+ flash_shape = (bsz, self.num_heads, tgt_len, self.head_dim)
210
+ query_states = query_states.view(*flash_shape)
211
+ key_states = key_states.view(*flash_shape)
212
+ value_states = value_states.view(*flash_shape)
213
+ qkv = torch.stack([query_states, key_states, value_states], dim=2) # shape = [bsz, num_heads, 3, tgt_len, head_dim]
214
+ qkv = qkv.transpose(1, 3) # [bsz, tgt_len, 3, num_heads, head_dim]
215
+
216
+ key_padding_mask = attention_mask
217
+
218
+
219
+ assert key_padding_mask is not None
220
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
221
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
222
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=self.num_heads)
223
+ output_unpad = flash_attn_varlen_qkvpacked_func(
224
+ x_unpad, cu_seqlens, max_s, self.dropout if self.training else 0.0,
225
+ softmax_scale=1, causal=True, return_attn_probs=False
226
+ )
227
+
228
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
229
+ indices, bsz, tgt_len),
230
+ 'b s (h d) -> b s h d', h=self.num_heads)
231
+
232
+ attn_output = self.out_proj(rearrange(output, "b s h d -> b s (h d)"))
233
+ return attn_output, None, past_key_value
234
+
235
+
236
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
237
+ # requires the attention mask to be the same as the key_padding_mask
238
+ def _prepare_decoder_attention_mask(
239
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
240
+ ):
241
+ # [bsz, seq_len]
242
+ return attention_mask
243
+
244
+
245
+ def replace_opt_attn_with_flash_attn():
246
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
247
+ if cuda_major < 8:
248
+ logging.warning(
249
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
250
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
251
+ )
252
+ transformers.models.opt.modeling_opt.OPTDecoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
253
+ transformers.models.opt.modeling_opt.OPTAttention.forward = forward
254
+
255
+ def replace_opt_attn_with_original_attn():
256
+ transformers.models.opt.modeling_opt.OPTDecoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask_original
257
+ transformers.models.opt.modeling_opt.OPTAttention.forward = forward_original
258
+
259
+ if __name__ == '__main__':
260
+ ## generate tests to verify the equivalence between forward_original and forward
261
+ import torch.nn as nn
262
+ import math
263
+ class FakeNN(nn.Module):
264
+ def __init__(self, ):
265
+ super().__init__()
266
+ self.scaling = 1 / math.sqrt(2048)
267
+ if False:
268
+ self.q_proj = nn.Linear(2048, 2048)
269
+ self.k_proj = nn.Linear(2048, 2048)
270
+ self.v_proj = nn.Linear(2048, 2048)
271
+ self.out_proj = nn.Linear(2048, 2048)
272
+ else:
273
+ self.q_proj = nn.Identity()
274
+ self.k_proj = nn.Identity()
275
+ self.v_proj = nn.Identity()
276
+ self.out_proj = nn.Identity()
277
+
278
+ self.is_decoder = True
279
+ self.num_heads = 2
280
+ self.head_dim = 128
281
+ self.embed_dim = 256
282
+ self.dropout = 0
283
+
284
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
285
+ # create causal mask
286
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
287
+ combined_attention_mask = None
288
+ if input_shape[-1] > 1:
289
+ combined_attention_mask = _make_causal_mask(
290
+ input_shape,
291
+ inputs_embeds.dtype,
292
+ device=inputs_embeds.device,
293
+ past_key_values_length=past_key_values_length,
294
+ )
295
+
296
+ if attention_mask is not None:
297
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
298
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
299
+ inputs_embeds.device
300
+ )
301
+ combined_attention_mask = (
302
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
303
+ )
304
+
305
+ return combined_attention_mask
306
+
307
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
308
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
309
+
310
+ fakenn = FakeNN().to(torch.bfloat16).to('cuda:0')
311
+
312
+ t_len = 3
313
+ fake_input = torch.randn(2, t_len, fakenn.embed_dim).to(torch.bfloat16).to('cuda:0')
314
+ if False:
315
+ fake_lens = torch.randint(0, t_len, (2,)).to('cuda:0')
316
+ fake_lens = torch.LongTensor([3, 2]).to('cuda:0')
317
+ # fake_lens = torch.ones((2,)).to('cuda:0') * 3
318
+ fake_mask = torch.arange(t_len).unsqueeze(0).to('cuda:0') < fake_lens.unsqueeze(1)
319
+ else:
320
+ fake_mask = torch.randint(0, t_len, (2, t_len)).bool().to('cuda:0')
321
+
322
+ fake_mask2 = fakenn._prepare_decoder_attention_mask(fake_mask, (2,t_len), fake_input, 0)
323
+ attn_output0, _, _ = forward_original(fakenn, fake_input, None, None, fake_mask2, None, False)
324
+ attn_output1, _, _ = forward(fakenn, fake_input, None, None, fake_mask, None, False) # shape = [2, 3, 256]
325
+ attn_output0 = attn_output0 * fake_mask.unsqueeze(-1)
326
+
327
+ print(torch.isclose(attn_output0, attn_output1).all())
328
+ print(attn_output0.shape, attn_output1.shape)
329
+ difference = (attn_output0- attn_output1).abs()
330
+ print(difference)
331
+ print(difference.sum())
read_results/baselines.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ import torch
3
+ from rxnfp.transformer_fingerprints import (
4
+ RXNBERTFingerprintGenerator, get_default_model_and_tokenizer, generate_fingerprints
5
+ )
6
+
7
+ class Reaction_model:
8
+ def __init__(self, train_list, test_list):
9
+ self.train_list = train_list
10
+ self.test_list = test_list
11
+
12
+ model, tokenizer = get_default_model_and_tokenizer()
13
+ self.rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer)
14
+
15
+ @time_it
16
+ def generate_random(self):
17
+ pred = random.sample(self.train_list, k=len(self.test_list))
18
+ pred = [i['actions'] for i in pred]
19
+ return pred
20
+
21
+ @time_it
22
+ def generate_random_compatible_old(self):
23
+ pred_list = []
24
+ len_id_map = defaultdict(list)
25
+ for train_rxn in self.train_list:
26
+ len_id_map[len(train_rxn['extracted_molecules'])-1].append(train_rxn['index'])
27
+
28
+ keys = sorted(k for k in len_id_map.keys())
29
+ accumulated_counts = {}
30
+ count = 0
31
+ for key in keys:
32
+ count += len(len_id_map[key])
33
+ accumulated_counts[key] = count
34
+
35
+ for rxn in self.test_list:
36
+ test_token_num = len(rxn['extracted_molecules'])-1
37
+ idx = random.randint(0, accumulated_counts[test_token_num] - 1)
38
+ for key in keys:
39
+ if idx < len(len_id_map[key]):
40
+ pred_list.append(self.train_list[len_id_map[key][idx]]['actions'])
41
+ break
42
+ else:
43
+ idx -= len(len_id_map[key])
44
+ return pred_list
45
+
46
+ @time_it
47
+ def generate_random_compatible(self):
48
+ pred_list = []
49
+ len_id_map = defaultdict(list)
50
+ for train_rxn in self.train_list:
51
+ len_id_map[len(train_rxn['extracted_molecules'])-1].append(train_rxn['index'])
52
+
53
+ for rxn in self.test_list:
54
+ mole_num = len(rxn['extracted_molecules'])-1
55
+ pred_list.append(self.train_list[random.choice(len_id_map[mole_num])]['actions'])
56
+ return pred_list
57
+
58
+ @time_it
59
+ def generate_nn(self, batch_size=2048):
60
+ train_rxns = [f"{'.'.join(rxn['REACTANT'])}>>{rxn['PRODUCT'][0]}" for rxn in self.train_list]
61
+ test_rxns = [f"{'.'.join(rxn['REACTANT'])}>>{rxn['PRODUCT'][0]}" for rxn in self.test_list]
62
+
63
+ train_rxns_batches = [train_rxns[i:i+batch_size] for i in range(0, len(train_rxns), batch_size)]
64
+ test_rxns_batches = [test_rxns[i:i+batch_size] for i in range(0, len(test_rxns), batch_size)]
65
+
66
+ device = torch.device("cuda")
67
+ train_fps = []
68
+ for batch in tqdm(train_rxns_batches, desc='Generating fingerprints for training reactions'):
69
+ batch_fps = self.rxnfp_generator.convert_batch(batch)
70
+ train_fps.extend(batch_fps)
71
+ train_fps = torch.tensor(train_fps, device=device) # N x 256
72
+
73
+ most_similar_indices = []
74
+ for batch in tqdm(test_rxns_batches, desc='Generating fingerprints for test reactions'):
75
+ batch_fps = self.rxnfp_generator.convert_batch(batch)
76
+ batch_fps = torch.tensor(batch_fps, device=device) # BS x 256
77
+ batch_fps = batch_fps / torch.norm(batch_fps, dim=1, keepdim=True)
78
+
79
+ similarity_matrix = torch.matmul(train_fps, batch_fps.T) # N x BS
80
+ most_similar_indices.extend(torch.argmax(similarity_matrix, dim=0).tolist())
81
+
82
+ return [self.train_list[i]['actions'] for i in most_similar_indices]
83
+
84
+ def save_results(self, gt_list, pred_list, target_file):
85
+ text_dict_list = [{
86
+ "targets": gt,
87
+ "indices": i,
88
+ "predictions": pred,
89
+ } for i, (gt, pred) in enumerate(zip(gt_list, pred_list))]
90
+
91
+ with open(target_file, 'w') as f:
92
+ json.dump(text_dict_list, f, indent=4)
93
+
94
+ def parse_args():
95
+ parser = argparse.ArgumentParser(description="A simple argument parser")
96
+
97
+ parser.add_argument('--name', default='none', type=str)
98
+ parser.add_argument('--train_file', default=None, type=str)
99
+ parser.add_argument('--test_file', default=None, type=str)
100
+ parser.add_argument('--use_tok', default=False, action='store_true')
101
+ args = parser.parse_args()
102
+ return args
103
+
104
+ def read_dataset(args):
105
+ print(f'Reading {args.train_file}...')
106
+ with open(args.train_file, 'r', encoding='utf-8') as f:
107
+ train_ds = json.load(f)
108
+ print(f'{len(train_ds)} samples read.')
109
+ print(f'Reading {args.test_file}...')
110
+ with open(args.test_file, 'r', encoding='utf-8') as f:
111
+ test_ds = json.load(f)
112
+ print(f'{len(test_ds)} samples read.')
113
+ return train_ds, test_ds
114
+
115
+ def run_baselines(args):
116
+ set_random_seed(0)
117
+
118
+ train_ds, test_ds = read_dataset(args)
119
+ model = Reaction_model(train_ds, test_ds)
120
+ calculator = Metric_calculator()
121
+ gt_list = [i['actions'] for i in test_ds]
122
+
123
+ print('Random:')
124
+ pred_list = model.generate_random()
125
+ calculator(gt_list, pred_list, args.use_tok)
126
+ model.save_results(gt_list, pred_list, f'results/{args.name}/random.json')
127
+
128
+ print('Random (compatible pattern):')
129
+ pred_list = model.generate_random_compatible()
130
+ calculator(gt_list, pred_list, args.use_tok)
131
+ model.save_results(gt_list, pred_list, f'results/{args.name}/random_compatible.json')
132
+
133
+ print('Nearest neighbor:')
134
+ pred_list = model.generate_nn()
135
+ calculator(gt_list, pred_list, args.use_tok)
136
+ model.save_results(gt_list, pred_list, f'results/{args.name}/nn.json')
137
+ # assert 0
138
+
139
+ if __name__ == "__main__":
140
+ args=parse_args()
141
+ run_baselines(args)
read_results/read_results.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+
3
+ def parse_args():
4
+ parser = argparse.ArgumentParser(description="A simple argument parser")
5
+
6
+ parser.add_argument('--name', default='none', type=str)
7
+ parser.add_argument('--path', default=None, type=str)
8
+ parser.add_argument('--use_tok', default=False, action='store_true')
9
+ args = parser.parse_args()
10
+ return args
11
+
12
+ def read_dataset(args):
13
+ print(f'Reading {args.path}...')
14
+ with open(args.path, 'r', encoding='utf-8') as f:
15
+ test_tgt = [json.loads(line) for line in f.readlines()]
16
+ print(f'{len(test_tgt)} samples read.')
17
+ gt_list = [i['targets'] for i in test_tgt]
18
+ pred_list = [i['predictions'] for i in test_tgt]
19
+ return gt_list, pred_list
20
+
21
+ def read_result(args):
22
+ gt_list, pred_list = read_dataset(args)
23
+ calculator = Metric_calculator()
24
+ calculator(gt_list, pred_list, args.use_tok)
25
+
26
+ if __name__ == "__main__":
27
+ args=parse_args()
28
+ read_result(args)
read_results/score.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ import os
3
+ import argparse
4
+ from tqdm import tqdm
5
+ import multiprocessing
6
+ import pandas as pd
7
+ from rdkit import RDLogger
8
+ import re
9
+ from utils import *
10
+
11
+ lg = RDLogger.logger()
12
+ lg.setLevel(RDLogger.CRITICAL)
13
+
14
+
15
+ def extract_smiles(s):
16
+ start_token = "[START_SMILES]"
17
+ end_token = "[END_SMILES]"
18
+ start_index = s.find(start_token) + len(start_token)
19
+ end_index = s.find(end_token)
20
+ if start_index > -1 and end_index > -1:
21
+ return s[start_index:end_index].strip()
22
+ return s
23
+
24
+ def canonicalize_smiles_clear_map(smiles,return_max_frag=True):
25
+ mol = Chem.MolFromSmiles(smiles,sanitize=not opt.synthon)
26
+ if mol is not None:
27
+ [atom.ClearProp('molAtomMapNumber') for atom in mol.GetAtoms() if atom.HasProp('molAtomMapNumber')]
28
+ try:
29
+ smi = Chem.MolToSmiles(mol, isomericSmiles=False)
30
+ except:
31
+ if return_max_frag:
32
+ return '',''
33
+ else:
34
+ return ''
35
+ if return_max_frag:
36
+ sub_smi = smi.split(".")
37
+ sub_mol = [Chem.MolFromSmiles(smiles,sanitize=not opt.synthon) for smiles in sub_smi]
38
+ sub_mol_size = [(sub_smi[i], len(m.GetAtoms())) for i, m in enumerate(sub_mol) if m is not None]
39
+ if len(sub_mol_size) > 0:
40
+ return smi, canonicalize_smiles_clear_map(sorted(sub_mol_size,key=lambda x:x[1],reverse=True)[0][0],return_max_frag=False)
41
+ else:
42
+ return smi, ''
43
+ else:
44
+ return smi
45
+ else:
46
+ if return_max_frag:
47
+ return '',''
48
+ else:
49
+ return ''
50
+
51
+
52
+ def compute_rank(input_smiles, prediction,raw=False,alpha=1.0):
53
+ valid_score = [[k for k in range(len(prediction[j]))] for j in range(len(prediction))]
54
+ invalid_rates = [0 for k in range(len(prediction[0]))]
55
+ rank = {}
56
+ max_frag_rank = {}
57
+ highest = {}
58
+ if raw:
59
+ # no test augmentation
60
+ assert len(prediction) == 1
61
+ for j in range(len(prediction)):
62
+ for k in range(len(prediction[j])):
63
+ if prediction[j][k][0] == "":
64
+ invalid_rates[k] += 1
65
+ # error detection
66
+ de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0][0] != ""]
67
+ prediction[j] = list(set(de_error))
68
+ prediction[j].sort(key=de_error.index)
69
+ for k, data in enumerate(prediction[j]):
70
+ rank[data] = 1 / (alpha * k + 1)
71
+ else:
72
+ for j in range(len(prediction)): # aug_num, beam_size, 2
73
+ for k in range(len(prediction[j])):
74
+ # predictions[i][j][k] = canonicalize_smiles_clear_map(predictions[i][j][k])
75
+ if prediction[j][k][0] == "":
76
+ valid_score[j][k] = opt.beam_size + 1
77
+ invalid_rates[k] += 1
78
+ # error detection and deduplication
79
+ de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0][0] != ""]
80
+ prediction[j] = list(set(de_error))
81
+ prediction[j].sort(key=de_error.index)
82
+ for k, data in enumerate(prediction[j]):
83
+ if data in rank:
84
+ rank[data] += 1 / (alpha * k + 1)
85
+ else:
86
+ rank[data] = 1 / (alpha * k + 1)
87
+ if data in highest:
88
+ highest[data] = min(k,highest[data])
89
+ else:
90
+ highest[data] = k
91
+ for key in rank.keys():
92
+ rank[key] += highest[key] * -1
93
+ rank[key] += abs(len(key[0])-len(input_smiles)) * -0.2
94
+ rank[key] += len(key[0]) * -0.2
95
+ return rank,invalid_rates
96
+
97
+ def read_dataset(opt):
98
+ print(f'Reading {opt.path}...')
99
+ with open(opt.path, 'r', encoding='utf-8') as f:
100
+ test_tgt = [json.loads(line) for line in f.readlines()]
101
+ if opt.raw:
102
+ test_tgt = test_tgt[::opt.augmentation]
103
+ filtered_tgt = {}
104
+ idx_key = 'ds_idx' if 'ds_idx' in test_tgt[0] else 'index'
105
+ for dic in test_tgt:
106
+ if dic[idx_key] not in filtered_tgt:
107
+ filtered_tgt[dic[idx_key]] = dic
108
+ test_tgt = list(filtered_tgt.values())
109
+ test_tgt.sort(key=lambda x: x[idx_key])
110
+ print(f'{len(test_tgt)} samples read.')
111
+ input_list = [extract_smiles(i['input']) for i in test_tgt]
112
+ gt_list = [i['targets'].replace('[START_SMILES]', '').replace('[END_SMILES]', '').replace('SPL1T-TH1S-Pl3A5E','').strip().replace(' ','.') for i in test_tgt]
113
+ pred_list = [[smi.strip().replace(' ','.') for smi in i['predictions']] for i in test_tgt]
114
+ return input_list, gt_list, pred_list
115
+
116
+ def main(opt):
117
+ input_list, gt_list, pred_list = read_dataset(opt)
118
+ if opt.raw:
119
+ opt.augmentation=1
120
+ print('Reading predictions from file ...')
121
+
122
+ # inputs
123
+ print("Input Length", len(gt_list))
124
+ ras_src_smiles = input_list[::opt.augmentation]
125
+ with multiprocessing.Pool(processes=opt.process_number) as pool:
126
+ ras_src_smiles = pool.map(func=canonicalize_smiles_clear_map,iterable=ras_src_smiles)
127
+ ras_src_smiles = [i[0] for i in ras_src_smiles]
128
+
129
+ # predictions
130
+ print("Prediction Length", len(pred_list))
131
+ pred_lines = [i.split('>')[0] for d in pred_list for i in d]
132
+ data_size = len(pred_lines) // (opt.augmentation * opt.beam_size) if opt.length == -1 else opt.length
133
+ pred_lines = pred_lines[:data_size * (opt.augmentation * opt.beam_size)]
134
+ print("Canonicalizing predictions using Process Number ",opt.process_number)
135
+ with multiprocessing.Pool(processes=opt.process_number) as pool:
136
+ raw_predictions = pool.map(func=canonicalize_smiles_clear_map,iterable=pred_lines)
137
+
138
+ predictions = [[[] for j in range(opt.augmentation)] for i in range(data_size)] # data_len x augmentation x beam_size
139
+ for i, line in enumerate(raw_predictions):
140
+ predictions[i // (opt.beam_size * opt.augmentation)][i % (opt.beam_size * opt.augmentation) // opt.beam_size].append(line)
141
+
142
+ # ground truth
143
+ print("Origin Length", len(gt_list))
144
+ targets = [''.join(gt_list[i].strip().split(' ')) for i in tqdm(range(0,data_size * opt.augmentation,opt.augmentation))]
145
+ with multiprocessing.Pool(processes=opt.process_number) as pool:
146
+ targets = pool.map(func=canonicalize_smiles_clear_map, iterable=targets)
147
+
148
+ print("predictions Length", len(predictions), len(predictions[0]), len(predictions[0][0]))
149
+ print("Target Length", len(targets))
150
+
151
+ ground_truth = targets
152
+ print("Origin Target Lentgh, ", len(ground_truth))
153
+ print("Cutted Length, ",data_size)
154
+ print('\n')
155
+ accuracy = [0 for j in range(opt.n_best)]
156
+ topn_accuracy_chirality = [0 for _ in range(opt.n_best)]
157
+ topn_accuracy_wochirality = [0 for _ in range(opt.n_best)]
158
+ topn_accuracy_ringopening = [0 for _ in range(opt.n_best)]
159
+ topn_accuracy_ringformation = [0 for _ in range(opt.n_best)]
160
+ topn_accuracy_woring = [0 for _ in range(opt.n_best)]
161
+ total_chirality = 0
162
+ total_ringopening = 0
163
+ total_ringformation = 0
164
+ atomsize_topk = []
165
+ accurate_indices = [[] for j in range(opt.n_best)]
166
+ max_frag_accuracy = [0 for j in range(opt.n_best)]
167
+ invalid_rates = [0 for j in range(opt.beam_size)]
168
+ sorted_invalid_rates = [0 for j in range(opt.beam_size)]
169
+ unique_rates = 0
170
+ ranked_results = []
171
+
172
+ for i in tqdm(range(len(predictions))):
173
+ accurate_flag = False
174
+ if opt.detailed:
175
+ chirality_flag = False
176
+ ringopening_flag = False
177
+ ringformation_flag = False
178
+ pro_mol = Chem.MolFromSmiles(ras_src_smiles[i])
179
+ rea_mol = Chem.MolFromSmiles(ground_truth[i][0])
180
+ try:
181
+ pro_ringinfo = pro_mol.GetRingInfo()
182
+ rea_ringinfo = rea_mol.GetRingInfo()
183
+ pro_ringnum = pro_ringinfo.NumRings()
184
+ rea_ringnum = rea_ringinfo.NumRings()
185
+ size = len(rea_mol.GetAtoms()) - len(pro_mol.GetAtoms())
186
+ # if (int(ras_src_smiles[i].count("@") > 0) + int(ground_truth[i][0].count("@") > 0)) == 1:
187
+ if ras_src_smiles[i].count("@") > 0 or ground_truth[i][0].count("@") > 0:
188
+ total_chirality += 1
189
+ chirality_flag = True
190
+ if pro_ringnum < rea_ringnum:
191
+ total_ringopening += 1
192
+ ringopening_flag = True
193
+ if pro_ringnum > rea_ringnum:
194
+ total_ringformation += 1
195
+ ringformation_flag = True
196
+ except:
197
+ pass
198
+ # continue
199
+
200
+ inputs = input_list[i*opt.augmentation:(i+1)*opt.augmentation]
201
+ gt = gt_list[i*opt.augmentation:(i+1)*opt.augmentation]
202
+ rank, invalid_rate = compute_rank(ras_src_smiles[i], predictions[i], raw=opt.raw,alpha=opt.score_alpha)
203
+
204
+ rank_ = {k[0]: v for k, v in sorted(rank.items(), key=lambda item: item[1], reverse=True)}
205
+ if opt.detailed:
206
+ print('Index', i)
207
+ print('inputs', json.dumps(inputs, indent=4))
208
+ print('targets', json.dumps(gt, indent=4))
209
+ print('input', ras_src_smiles[i])
210
+ print('target', targets[i][0])
211
+ print('rank', json.dumps(rank_,indent=4))
212
+ print('invalid_rate', json.dumps(invalid_rate,indent=4))
213
+ print('\n')
214
+ for j in range(opt.beam_size):
215
+ invalid_rates[j] += invalid_rate[j]
216
+ rank = list(zip(rank.keys(),rank.values()))
217
+ rank.sort(key=lambda x:x[1],reverse=True)
218
+ rank = rank[:opt.n_best]
219
+ ranked_results.append([item[0][0] for item in rank])
220
+
221
+ for j, item in enumerate(rank):
222
+ if item[0][0] == ground_truth[i][0]:
223
+ if not accurate_flag:
224
+ accurate_flag = True
225
+ accurate_indices[j].append(i)
226
+ for k in range(j, opt.n_best):
227
+ accuracy[k] += 1
228
+ if opt.detailed:
229
+ atomsize_topk.append((size,j))
230
+ if chirality_flag:
231
+ for k in range(j,opt.n_best):
232
+ topn_accuracy_chirality[k] += 1
233
+ else:
234
+ for k in range(j,opt.n_best):
235
+ topn_accuracy_wochirality[k] += 1
236
+ if ringopening_flag:
237
+ for k in range(j,opt.n_best):
238
+ topn_accuracy_ringopening[k] += 1
239
+ if ringformation_flag:
240
+ for k in range(j,opt.n_best):
241
+ topn_accuracy_ringformation[k] += 1
242
+ if not ringopening_flag and not ringformation_flag:
243
+ for k in range(j,opt.n_best):
244
+ topn_accuracy_woring[k] += 1
245
+
246
+ if opt.detailed and not accurate_flag:
247
+ atomsize_topk.append((size,opt.n_best))
248
+ for j, item in enumerate(rank):
249
+ if item[0][1] == ground_truth[i][1]:
250
+ for k in range(j,opt.n_best):
251
+ max_frag_accuracy[k] += 1
252
+ break
253
+ for j in range(len(rank),opt.beam_size):
254
+ sorted_invalid_rates[j] += 1
255
+ unique_rates += len(rank)
256
+
257
+ for i in range(opt.n_best):
258
+ if i in [0,1,2,3,4,5,6,7,8,9,19,49]:
259
+ # if i in range(10):
260
+ print("Top-{} Acc:{:.3f}%, MaxFrag {:.3f}%,".format(i+1,accuracy[i] / data_size * 100,max_frag_accuracy[i] / data_size * 100),
261
+ " Invalid SMILES:{:.3f}% Sorted Invalid SMILES:{:.3f}%".format(invalid_rates[i] / data_size / opt.augmentation * 100,sorted_invalid_rates[i] / data_size / opt.augmentation * 100))
262
+ print(' '.join([f'{accuracy[i] / data_size * 100:.3f}' for i in [0,2,4,9]]))
263
+ print("Unique Rates:{:.3f}%".format(unique_rates / data_size / opt.beam_size * 100))
264
+
265
+ if opt.detailed:
266
+ print_topk = [1,3,5,10]
267
+ save_dict = {}
268
+ atomsize_topk.sort(key=lambda x:x[0])
269
+ differ_now = atomsize_topk[0][0]
270
+ topn_accuracy_bydiffer = [0 for _ in range(opt.n_best)]
271
+ total_bydiffer = 0
272
+ for i,item in enumerate(atomsize_topk):
273
+ if differ_now < 11 and differ_now != item[0]:
274
+ for j in range(opt.n_best):
275
+ if (j+1) in print_topk:
276
+ save_dict[f'top-{j+1}_size_{differ_now}'] = topn_accuracy_bydiffer[j] / total_bydiffer * 100
277
+ print("Top-{} Atom differ size {} Acc:{:.3f}%, Number:{:.3f}%".format(j+1,
278
+ differ_now,
279
+ topn_accuracy_bydiffer[j] / total_bydiffer * 100,
280
+ total_bydiffer/data_size * 100))
281
+ total_bydiffer = 0
282
+ topn_accuracy_bydiffer = [0 for _ in range(opt.n_best)]
283
+ differ_now = item[0]
284
+ for k in range(item[1],opt.n_best):
285
+ topn_accuracy_bydiffer[k] += 1
286
+ total_bydiffer += 1
287
+ for j in range(opt.n_best):
288
+ if (j + 1) in print_topk:
289
+ print("Top-{} Atom differ size {} Acc:{:.3f}%, Number:{:.3f}%".format(j + 1,
290
+ differ_now,
291
+ topn_accuracy_bydiffer[j] / total_bydiffer * 100,
292
+ total_bydiffer / data_size * 100))
293
+ save_dict[f'top-{j+1}_size_{differ_now}'] = topn_accuracy_bydiffer[j] / total_bydiffer * 100
294
+
295
+ for i in range(opt.n_best):
296
+ if (i+1) in print_topk:
297
+ if total_chirality > 0:
298
+ print("Top-{} Accuracy with chirality:{:.3f}%".format(i + 1, topn_accuracy_chirality[i] / total_chirality * 100))
299
+ save_dict[f'top-{i+1}_chilarity'] = topn_accuracy_chirality[i] / total_chirality * 100
300
+ print("Top-{} Accuracy without chirality:{:.3f}%".format(i + 1, topn_accuracy_wochirality[i] / (data_size - total_chirality) * 100))
301
+ save_dict[f'top-{i+1}_wochilarity'] = topn_accuracy_wochirality[i] / (data_size - total_chirality) * 100
302
+ if total_ringopening > 0:
303
+ print("Top-{} Accuracy ring Opening:{:.3f}%".format(i + 1, topn_accuracy_ringopening[i] / total_ringopening * 100))
304
+ save_dict[f'top-{i+1}_ringopening'] = topn_accuracy_ringopening[i] / total_ringopening * 100
305
+ if total_ringformation > 0:
306
+ print("Top-{} Accuracy ring Formation:{:.3f}%".format(i + 1, topn_accuracy_ringformation[i] / total_ringformation * 100))
307
+ save_dict[f'top-{i+1}_ringformation'] = topn_accuracy_ringformation[i] / total_ringformation * 100
308
+ print("Top-{} Accuracy without ring:{:.3f}%".format(i + 1, topn_accuracy_woring[i] / (data_size - total_ringopening - total_ringformation) * 100))
309
+ save_dict[f'top-{i+1}_wocring'] = topn_accuracy_woring[i] / (data_size - total_ringopening - total_ringformation)* 100
310
+ print(total_chirality)
311
+ print(total_ringformation)
312
+ print(total_ringopening)
313
+ # df = pd.DataFrame(list(save_dict.items()))
314
+ df = pd.DataFrame(save_dict,index=[0])
315
+ df.to_csv("detailed_results.csv")
316
+ if opt.save_accurate_indices != "":
317
+ with open(opt.save_accurate_indices, "w") as f:
318
+ total_accurate_indices = []
319
+ for indices in accurate_indices:
320
+ total_accurate_indices.extend(indices)
321
+ total_accurate_indices.sort()
322
+
323
+ # for index in total_accurate_indices:
324
+ for index in accurate_indices[0]:
325
+ f.write(str(index))
326
+ f.write("\n")
327
+
328
+ if opt.save_file != "":
329
+ with open(opt.save_file,"w") as f:
330
+ for res in ranked_results:
331
+ for smi in res:
332
+ f.write(smi)
333
+ f.write("\n")
334
+ for i in range(len(res),opt.n_best):
335
+ f.write("")
336
+ f.write("\n")
337
+
338
+
339
+ if __name__ == "__main__":
340
+ parser = argparse.ArgumentParser(
341
+ description='score.py',
342
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
343
+ parser.add_argument('--beam_size', type=int, default=10,help='Beam size')
344
+ parser.add_argument('--n_best', type=int, default=10,help='n best')
345
+ parser.add_argument('--path', type=str, required=True, help="Path to file containing the predictions and ground truth.")
346
+ parser.add_argument('--augmentation', type=int, default=20)
347
+ parser.add_argument('--score_alpha', type=float, default=1.0)
348
+ parser.add_argument('--length', type=int, default=-1)
349
+ parser.add_argument('--process_number', type=int, default=multiprocessing.cpu_count())
350
+ parser.add_argument('--synthon', action="store_true", default=False)
351
+ parser.add_argument('--detailed', action="store_true", default=False)
352
+ parser.add_argument('--raw', action="store_true", default=False)
353
+ parser.add_argument('--save_file', type=str,default="")
354
+ parser.add_argument('--save_accurate_indices', type=str,default="")
355
+
356
+ opt = parser.parse_args()
357
+ print(opt)
358
+ main(opt)
read_results/t_test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ import scipy.stats as stats
3
+
4
+ def parse_args():
5
+ parser = argparse.ArgumentParser(description="A simple argument parser")
6
+
7
+ parser.add_argument('--name', default='none', type=str)
8
+ parser.add_argument('--path_exp', default=None, type=str)
9
+ parser.add_argument('--path_ref', default=None, type=str)
10
+ parser.add_argument('--use_tok', default=False, action='store_true')
11
+ args = parser.parse_args()
12
+ return args
13
+
14
+ def read_dataset(data_path):
15
+ print(f'Reading {data_path}...')
16
+ with open(data_path, 'r', encoding='utf-8') as f:
17
+ test_tgt = [json.loads(line) for line in f.readlines()]
18
+ print(f'{len(test_tgt)} samples read.')
19
+ gt_list = [i['targets'] for i in test_tgt]
20
+ pred_list = [i['predictions'] for i in test_tgt]
21
+ return gt_list, pred_list
22
+
23
+ def t_test(mean_exp, std_exp, mean_ref, std_ref, n):
24
+ numerator = mean_exp - mean_ref
25
+ denominator = np.sqrt((std_exp**2 / n) + (std_ref**2 / n))
26
+ t_statistic = numerator / denominator
27
+ df = (((std_exp**2 / n) + (std_ref**2 / n))**2) / (((std_exp**2 / n)**2 / (n-1)) + ((std_ref**2 / n)**2 / (n-1)))
28
+
29
+ p_value = 2 * stats.t.sf(np.abs(t_statistic), df)
30
+ return t_statistic, p_value
31
+
32
+ def read_result(args):
33
+ gt_list_exp, pred_list_exp = read_dataset(args.path_exp)
34
+ gt_list_ref, pred_list_ref = read_dataset(args.path_ref)
35
+ calculator = Metric_calculator()
36
+ result_exp = calculator.get_result_list(gt_list_exp, pred_list_exp, args.use_tok)
37
+ result_ref = calculator.get_result_list(gt_list_ref, pred_list_ref, args.use_tok)
38
+
39
+ for key in ['bleu2', 'bleu4', 'rouge_1', 'rouge_2', 'rouge_l', 'lev_score', 'meteor_score']:
40
+ if not isinstance(result_exp[key], list):
41
+ continue
42
+ levene_s, levene_p = stats.levene(result_exp[key], result_ref[key])
43
+ t_stat, p_val = stats.ttest_ind(result_exp[key], result_ref[key], equal_var=(levene_p > 0.05))
44
+ print(f'{key} (mean={float(np.mean(result_exp[key])):.4f}, levene p={levene_p:.3f}):\t{t_stat:.6f}\t{p_val}')
45
+
46
+ if __name__ == "__main__":
47
+ args=parse_args()
48
+ read_result(args)
read_results/utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Levenshtein import distance as lev_distance
2
+ import random
3
+ import json
4
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction, corpus_bleu
5
+ from nltk.translate.meteor_score import meteor_score
6
+ from rouge_score import rouge_scorer
7
+ from tqdm import tqdm
8
+ import random
9
+ import numpy as np
10
+ import argparse
11
+ from paragraph2actions.readable_converter import ReadableConverter
12
+ import re
13
+ from transformers import AutoTokenizer
14
+ from collections import defaultdict
15
+ import time
16
+ from functools import wraps
17
+ import os
18
+ import torch
19
+ import textdistance
20
+ from typing import List
21
+
22
+ def levenshtein_similarity(truth: List[str], pred: List[str]) -> List[float]:
23
+ assert len(truth) == len(pred)
24
+ scores: List[float] = [
25
+ textdistance.levenshtein.normalized_similarity(t, p)
26
+ for t, p in zip(truth, pred)
27
+ ]
28
+ return scores
29
+
30
+ def modified_bleu(truth: List[str], pred: List[str], bleu_n=4) -> float:
31
+ """
32
+ Calculates the BLEU score of a translation, with a small modification in order not to penalize sentences
33
+ with less than 4 words.
34
+
35
+ Returns:
36
+ value between 0 and 1.
37
+ """
38
+ references = [sentence.split() for sentence in truth]
39
+ candidates = [sentence.split() for sentence in pred]
40
+
41
+ # BLEU penalizes sentences with only one word. Even correct translations get a score of zero.
42
+ references = [r + max(0, bleu_n - len(r)) * [""] for r in references]
43
+ candidates = [c + max(0, bleu_n - len(c)) * [""] for c in candidates]
44
+
45
+ # references must have a larger depth because it supports multiple choices
46
+ refs = [[r] for r in references]
47
+ weights = {
48
+ 2: (0.5, 0.5),
49
+ 4: (0.25, 0.25, 0.25, 0.25),
50
+ }
51
+ return 100*corpus_bleu(refs, candidates, weights=weights[bleu_n]) # type: ignore[no-any-return]
52
+
53
+ def set_random_seed(seed):
54
+ random.seed(seed)
55
+ os.environ['PYTHONHASHSEED'] = str(seed)
56
+ np.random.seed(seed)
57
+ torch.manual_seed(seed)
58
+ torch.cuda.manual_seed(seed)
59
+ torch.cuda.manual_seed_all(seed) # If using multi-GPU.
60
+ torch.backends.cudnn.deterministic = True
61
+ torch.backends.cudnn.benchmark = False
62
+
63
+ def time_it(func):
64
+ @wraps(func)
65
+ def wrapper(*args, **kwargs):
66
+ start_time = time.time()
67
+ result = func(*args, **kwargs)
68
+ end_time = time.time()
69
+ print(f"Function {func.__name__} finished in {end_time - start_time:.5f} seconds.\n")
70
+ return result
71
+ return wrapper
72
+
73
+ def accuracy_score(score_list, threshold):
74
+ matches = sum(score>=threshold for score in score_list)
75
+ acc = matches / len(score_list)
76
+ return acc
77
+
78
+ def extract_tokenized_entities(text):
79
+ pattern = r'\$[^\$]+\$|#[^#]+#|@[^\@]+@'
80
+ return re.findall(pattern, text)
81
+
82
+ def extract_reactant_cnt(text):
83
+ max_id = None
84
+ for token in text.split():
85
+ if token.startswith('$') and token.endswith('$'):
86
+ try:
87
+ current_id = int(token.strip('$'))
88
+ if max_id is None or current_id > max_id:
89
+ max_id = current_id
90
+ except ValueError:
91
+ pass # Ignore tokens that do not represent an integer
92
+ if not max_id:
93
+ return 0
94
+ return max_id
95
+
96
+ class Metric_calculator:
97
+ def __init__(self, text_trunc_length=1024):
98
+ self.converter = ReadableConverter(separator=' ; ')
99
+ self.tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', use_fast=False, padding_side='right')
100
+ self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
101
+ self.text_trunc_length = text_trunc_length
102
+ self.scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'])
103
+
104
+ def tokenize(self, gt_list, pred_list):
105
+ references = []
106
+ hypotheses = []
107
+
108
+ for gt, out in tqdm(zip(gt_list, pred_list)):
109
+ gt_tokens = self.tokenizer.tokenize(gt)
110
+ ## added for galactica
111
+ gt_tokens = list(filter(('<pad>').__ne__, gt_tokens))
112
+ gt_tokens = list(filter(('[PAD]').__ne__, gt_tokens))
113
+ gt_tokens = list(filter(('[CLS]').__ne__, gt_tokens))
114
+ gt_tokens = list(filter(('[SEP]').__ne__, gt_tokens))
115
+
116
+ out_tokens = self.tokenizer.tokenize(out)
117
+ out_tokens = list(filter(('<pad>').__ne__, out_tokens))
118
+ out_tokens = list(filter(('[PAD]').__ne__, out_tokens))
119
+ out_tokens = list(filter(('[CLS]').__ne__, out_tokens))
120
+ out_tokens = list(filter(('[SEP]').__ne__, out_tokens))
121
+
122
+ references.append([gt_tokens])
123
+ hypotheses.append(out_tokens)
124
+ return references, hypotheses
125
+
126
+ @time_it
127
+ def __call__(self, gt_list, pred_list, use_tokenizer=False):
128
+ gt_list = [gt.strip() for gt in gt_list]
129
+ pred_list = [pred.strip() for pred in pred_list]
130
+
131
+ if use_tokenizer:
132
+ references, hypotheses = self.tokenize(gt_list, pred_list)
133
+ bleu2, bleu4 = self.bleu(references, hypotheses)
134
+ _meteor_score = self.meteor(references, hypotheses)
135
+ else:
136
+ bleu2 = modified_bleu(gt_list, pred_list, bleu_n=2)
137
+ bleu4 = modified_bleu(gt_list, pred_list, bleu_n=4)
138
+ _meteor_score = 0
139
+ rouge_1, rouge_2, rouge_l = self.rouge(gt_list, pred_list)
140
+
141
+ validity = self.validity(gt_list, pred_list)
142
+ acc_100, acc_90, acc_75, acc_50 = self.accuracy(gt_list, pred_list)
143
+
144
+ print('BLEU-2 score:', bleu2)
145
+ print('BLEU-4 score:', bleu4)
146
+ print('Average Meteor score:', _meteor_score)
147
+ print('rouge1:', rouge_1)
148
+ print('rouge2:', rouge_2)
149
+ print('rougeL:', rouge_l)
150
+
151
+ print(f'Validity: {validity:.6f}')
152
+ print(f'Accuracy (100): {acc_100:.6f}')
153
+ print(f'Accuracy (90): {acc_90:.6f}')
154
+ print(f'Accuracy (75): {acc_75:.6f}')
155
+ print(f'Accuracy (50): {acc_50:.6f}')
156
+
157
+ line = ''
158
+ for score in [validity, bleu2, bleu4, acc_100, acc_90, acc_75, acc_50, rouge_1, rouge_2, rouge_l, _meteor_score]:
159
+ line += f'{score:.6f} '
160
+ print(line)
161
+
162
+ return {
163
+ 'bleu2': bleu2,
164
+ 'bleu4': bleu4,
165
+ 'rouge_1': rouge_1,
166
+ 'rouge_2': rouge_2,
167
+ 'rouge_l': rouge_l,
168
+ 'meteor_score': _meteor_score,
169
+ 'validity': validity,
170
+ 'acc_100': acc_100,
171
+ 'acc_90': acc_90,
172
+ 'acc_75': acc_75,
173
+ 'acc_50': acc_50,
174
+ }
175
+
176
+ def get_result_list(self, gt_list, pred_list, use_tokenizer=False):
177
+ gt_list = [gt.strip() for gt in gt_list]
178
+ pred_list = [pred.strip() for pred in pred_list]
179
+
180
+ if use_tokenizer:
181
+ references, hypotheses = self.tokenize(gt_list, pred_list)
182
+ bleu2 = [corpus_bleu([gt], [pred], weights=(.5,.5)) for gt, pred in zip(references, hypotheses)]
183
+ bleu4 = [corpus_bleu([gt], [pred], weights=(.25,.25,.25,.25)) for gt, pred in zip(references, hypotheses)]
184
+ _meteor_score = [meteor_score(gt, out) for gt, out in zip(references, hypotheses)]
185
+ else:
186
+ bleu2 = [modified_bleu([gt], [pred], bleu_n=2) for gt, pred in zip(gt_list, pred_list)]
187
+ bleu4 = [modified_bleu([gt], [pred], bleu_n=4) for gt, pred in zip(gt_list, pred_list)]
188
+ _meteor_score = 0
189
+ rouge_1, rouge_2, rouge_l = self.rouge(gt_list, pred_list, return_list=True)
190
+
191
+ lev_score = levenshtein_similarity(gt_list, pred_list)
192
+
193
+ return {
194
+ 'bleu2': bleu2,
195
+ 'bleu4': bleu4,
196
+ 'rouge_1': rouge_1,
197
+ 'rouge_2': rouge_2,
198
+ 'rouge_l': rouge_l,
199
+ 'meteor_score': _meteor_score,
200
+ 'lev_score': lev_score,
201
+ }
202
+
203
+ def bleu(self, references, hypotheses):
204
+ bleu2 = corpus_bleu(references, hypotheses, weights=(.5,.5))
205
+ bleu4 = corpus_bleu(references, hypotheses, weights=(.25,.25,.25,.25))
206
+ bleu2 *= 100
207
+ bleu4 *= 100
208
+ return bleu2, bleu4
209
+
210
+ def meteor(self, references, hypotheses):
211
+ meteor_scores = []
212
+ for gt, out in zip(references, hypotheses):
213
+ mscore = meteor_score(gt, out)
214
+ meteor_scores.append(mscore)
215
+ _meteor_score = np.mean(meteor_scores)
216
+ _meteor_score *= 100
217
+ return _meteor_score
218
+
219
+ def rouge(self, targets, predictions, return_list=False):
220
+ rouge_scores = []
221
+ for gt, out in zip(targets, predictions):
222
+ rs = self.scorer.score(out, gt)
223
+ rouge_scores.append(rs)
224
+
225
+ rouge_1 = [rs['rouge1'].fmeasure for rs in rouge_scores]
226
+ rouge_2 = [rs['rouge2'].fmeasure for rs in rouge_scores]
227
+ rouge_l = [rs['rougeL'].fmeasure for rs in rouge_scores]
228
+ if return_list:
229
+ return rouge_1, rouge_2, rouge_l
230
+
231
+ rouge_1 = np.mean(rouge_1) * 100
232
+ rouge_2 = np.mean(rouge_2) * 100
233
+ rouge_l = np.mean(rouge_l) * 100
234
+ return rouge_1, rouge_2, rouge_l
235
+
236
+
237
+ def validity(self, gt_list, pred_list):
238
+ num_valid, n = 0, len(pred_list)
239
+ for pred, gt in zip(pred_list, gt_list):
240
+ try:
241
+ actions = self.converter.string_to_actions(pred)
242
+ max_token_pred = extract_reactant_cnt(pred)
243
+ max_token_gt = extract_reactant_cnt(gt)
244
+ assert max_token_gt >= max_token_pred
245
+ num_valid += 1
246
+ except:
247
+ pass
248
+ return 100*(num_valid / n)
249
+
250
+ def accuracy(self, gt_list, pred_list):
251
+ score_list = levenshtein_similarity(gt_list, pred_list)
252
+ acc_100 = 100*accuracy_score(score_list, 1.0)
253
+ acc_90 = 100*accuracy_score(score_list, 0.90)
254
+ acc_75 = 100*accuracy_score(score_list, 0.75)
255
+ acc_50 = 100*accuracy_score(score_list, 0.50)
256
+ return acc_100, acc_90, acc_75, acc_50
visualize_context_gen.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_provider.context_gen import *
2
+
3
+ def parse_args():
4
+ parser = argparse.ArgumentParser(description="A simple argument parser")
5
+
6
+ # Script arguments
7
+ parser.add_argument('--name', default='none', type=str)
8
+ parser.add_argument('--seed', default=0, type=int)
9
+ parser.add_argument('--epochs', default=100, type=int)
10
+ parser.add_argument('--chunk_size', default=100, type=int)
11
+ parser.add_argument('--rxn_num', default=50000, type=int)
12
+ parser.add_argument('--k', default=4, type=int)
13
+ parser.add_argument('--root', default='data/pretrain_data', type=str)
14
+
15
+ args = parser.parse_args()
16
+ return args
17
+
18
+ def pad_shorter_array(arr1, arr2):
19
+ len1 = arr1.shape[0]
20
+ len2 = arr2.shape[0]
21
+ if len1 > len2:
22
+ arr2 = np.pad(arr2, (0, len1 - len2), 'constant')
23
+ elif len2 > len1:
24
+ arr1 = np.pad(arr1, (0, len2 - len1), 'constant')
25
+ return arr1, arr2
26
+
27
+ def plot_distribution(values, target_path, x_lim=None, y_lim=None, chunk_size=100, color='blue'):
28
+ num_full_chunks = len(values) // chunk_size
29
+ values = np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1)
30
+ values = np.sort(values)[::-1]
31
+ plt.figure(figsize=(10, 4), dpi=100)
32
+ x = np.arange(len(values))
33
+ plt.bar(x, values, color=color)
34
+ current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int)
35
+ plt.xticks((current_values/chunk_size).astype(int), current_values)
36
+ plt.ylabel('Molecule Frequency', fontsize=20)
37
+ if x_lim:
38
+ plt.xlim(*x_lim)
39
+ if y_lim:
40
+ plt.ylim(*y_lim)
41
+ plt.tick_params(axis='both', which='major', labelsize=12)
42
+ plt.tight_layout(pad=0.5)
43
+ plt.savefig(target_path)
44
+ print(f'Figure saved to {target_path}')
45
+ plt.clf()
46
+
47
+ def plot_compare_distribution(list1, list2, target_path, x_lim=None, y_lim=None, labels=['Random', 'Ours'], colors=['blue', 'orange'], chunk_size=100):
48
+ num_full_chunks = len(list1) // chunk_size
49
+ list1, list2 = pad_shorter_array(list1, list2)
50
+ values1, values2 = [
51
+ np.sort(np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1))[::-1]
52
+ for values in (list1, list2)]
53
+
54
+ plt.figure(figsize=(10, 6), dpi=100)
55
+ x = np.arange(len(values1))
56
+ plt.bar(x, values1, color=colors[0], label=labels[0], alpha=0.6)
57
+ plt.bar(x, values2, color=colors[1], label=labels[1], alpha=0.5)
58
+ current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int)
59
+ plt.xticks((current_values/chunk_size).astype(int), current_values)
60
+ plt.ylabel('Molecule Frequency', fontsize=20)
61
+ if x_lim:
62
+ plt.xlim(*x_lim)
63
+ if y_lim:
64
+ plt.ylim(*y_lim)
65
+ plt.tick_params(axis='both', which='major', labelsize=18)
66
+ plt.tight_layout(pad=0.5)
67
+ plt.legend(fontsize=24, loc='upper right')
68
+ plt.savefig(target_path)
69
+ print(f'Figure saved to {target_path}')
70
+ plt.clf()
71
+
72
+ def statistics(args):
73
+ if args.seed:
74
+ set_random_seed(args.seed)
75
+ # 1141864 rxns from ord
76
+ # 1120773 rxns from uspto
77
+ cluster = Reaction_Cluster(args.root)
78
+
79
+ rxn_num = len(cluster.reaction_data)
80
+ abstract_num = 0
81
+ property_num = 0
82
+ calculated_property_num = 0
83
+ experimental_property_num = 0
84
+ avg_calculated_property_len = 0
85
+ avg_experimental_property_len = 0
86
+ mol_set = set()
87
+ for rxn_dict in cluster.reaction_data:
88
+ for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']:
89
+ for mol in rxn_dict[key]:
90
+ mol_set.add(mol)
91
+ mol_num = len(mol_set)
92
+
93
+ for mol_dict in cluster.property_data:
94
+ if 'abstract' in mol_dict:
95
+ abstract_num += 1
96
+ if 'property' in mol_dict:
97
+ property_num += 1
98
+ if 'Experimental Properties' in mol_dict['property']:
99
+ experimental_property_num += 1
100
+ avg_experimental_property_len += len(mol_dict['property']['Experimental Properties'])
101
+ if 'Computed Properties' in mol_dict['property']:
102
+ calculated_property_num += 1
103
+ avg_calculated_property_len += len(mol_dict['property']['Computed Properties'])
104
+
105
+ print(f'Reaction Number: {rxn_num}')
106
+ print(f'Molecule Number: {mol_num}')
107
+ print(f'Abstract Number: {abstract_num}/{mol_num}({abstract_num/mol_num*100:.2f}%)')
108
+ print(f'Property Number: {property_num}/{mol_num}({property_num/mol_num*100:.2f}%)')
109
+ print(f'- Experimental Properties Number: {experimental_property_num}/{property_num}({experimental_property_num/property_num*100:.2f}%), {avg_experimental_property_len/mol_num:.2f} items per molecule')
110
+ print(f'- Computed Properties: {calculated_property_num}/{property_num}({calculated_property_num/property_num*100:.2f}%), {avg_calculated_property_len/mol_num:.2f} items per molecule')
111
+
112
+ def visualize(args):
113
+ if args.seed:
114
+ set_random_seed(args.seed)
115
+ cluster = Reaction_Cluster(args.root)
116
+ prob_values, rxn_weights = cluster.visualize_mol_distribution()
117
+ rand_prob_values, rand_rxn_weights = cluster._randomly(
118
+ cluster.visualize_mol_distribution
119
+ )
120
+ fig_root = f'results/{args.name}/'
121
+
122
+ plot_distribution(prob_values, fig_root+'mol_distribution.pdf')
123
+ plot_distribution(rxn_weights, fig_root+'rxns_distribution.pdf')
124
+ plot_distribution(rand_prob_values, fig_root+'mol_distribution_random.pdf')
125
+ plot_distribution(rand_rxn_weights, fig_root+'rxns_distribution_random.pdf')
126
+
127
+ plot_compare_distribution(prob_values, rand_prob_values, fig_root+'Compare_mol.pdf', y_lim=(-0.5,15.5))
128
+ plot_compare_distribution(rxn_weights, rand_rxn_weights, fig_root+'Compare_rxns.pdf')
129
+
130
+
131
+ def visualize_frequency(args):
132
+ if args.seed:
133
+ set_random_seed(args.seed)
134
+ fig_root = f'results/{args.name}/'
135
+ name_suffix = f'E{args.epochs}_Rxn{args.rxn_num}_K{args.k}'
136
+ cache_path = f'{fig_root}/freq_{name_suffix}.npy'
137
+ if os.path.exists(cache_path):
138
+ mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq = np.load(cache_path, allow_pickle=True)
139
+ else:
140
+ cluster = Reaction_Cluster(args.root)
141
+ mol_freq, rxn_freq = cluster.visualize_mol_frequency(rxn_num=args.rxn_num, k=args.k, epochs=args.epochs)
142
+ rand_mol_freq, rand_rxn_freq = cluster._randomly(
143
+ cluster.visualize_mol_frequency,
144
+ rxn_num=args.rxn_num, k=args.k, epochs=args.epochs
145
+ )
146
+ np.save(cache_path, np.array([mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq], dtype=object), allow_pickle=True)
147
+
148
+ color1 = '#FA7F6F'
149
+ color2 = '#80AFBF'
150
+ color3 = '#FFBE7A'
151
+ plot_distribution(mol_freq, fig_root+f'mol_frequency_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2)
152
+ # plot_distribution(rxn_freq, fig_root+f'rxns_frequency_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1)
153
+ plot_distribution(rand_mol_freq, fig_root+f'mol_frequency_random_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2)
154
+ # plot_distribution(rand_rxn_freq, fig_root+f'rxns_frequency_random_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1)
155
+
156
+ plot_compare_distribution(rand_mol_freq, mol_freq, fig_root+f'Compare_mol_{name_suffix}.pdf', y_lim=(-2, 62), labels=['Before Adjustment', 'After Adjustment'], colors=[color1, color2], chunk_size=args.chunk_size)
157
+ # plot_compare_distribution(rxn_freq, rand_rxn_freq, fig_root+f'Compare_rxns_{name_suffix}.pdf', chunk_size=args.chunk_size)
158
+
159
+ if __name__=='__main__':
160
+ args = parse_args()
161
+ print(args, flush=True)
162
+ # statistics(args)
163
+ # visualize(args)
164
+ visualize_frequency(args)