Clemspace commited on
Commit
32fe622
·
1 Parent(s): cb9e677

added inference + api wrapper

Browse files
chemistral_api.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.middleware.cors import CORSMiddleware
2
+ from fastapi import FastAPI, HTTPException, File, UploadFile, Form
3
+ from fastapi.responses import JSONResponse, FileResponse
4
+ from pydantic import BaseModel
5
+ from typing import Optional
6
+ import subprocess
7
+ import os
8
+ import logging
9
+ from inference_transform import process_smiles, process_pdb, process_sdf, extract_and_convert_to_sdf, is_valid_smiles
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ app = FastAPI()
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=['*'],
19
+ allow_credentials=True,
20
+ allow_methods=['*'],
21
+ allow_headers=['*']
22
+ )
23
+
24
+ sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
25
+
26
+ class InferenceRequest(BaseModel):
27
+ prompt: str
28
+ max_tokens: int = 256
29
+ temperature: float = 1.0
30
+
31
+ @app.post("/predict_base")
32
+ async def predict_base(
33
+ prompt: str = Form(...),
34
+ max_tokens: int = Form(256),
35
+ temperature: float = Form(1.0),
36
+ file: Optional[UploadFile] = File(None)
37
+ ):
38
+ try:
39
+ if file:
40
+ file_path = f"/tmp/{file.filename}"
41
+ with open(file_path, "wb") as f:
42
+ f.write(file.file.read())
43
+ if file.filename.endswith(".pdb"):
44
+ prompt += f" {process_pdb(file_path)}"
45
+ elif file.filename.endswith(".sdf"):
46
+ prompt += f" {process_sdf(file_path)}"
47
+ else:
48
+ try:
49
+ sdf_file = extract_and_convert_to_sdf(prompt)
50
+ if sdf_file:
51
+ prompt += f" {sdf_file}"
52
+ except ValueError as e:
53
+ logger.info(str(e))
54
+
55
+ command = [
56
+ "python",
57
+ "/root/CHEMISTral7Bv0.3/mistral_chat_script.py",
58
+ "/root/mistral_models/7B-v0.3/",
59
+ prompt,
60
+ f"--max_tokens={max_tokens}",
61
+ f"--temperature={temperature}",
62
+ "--instruct"
63
+ ]
64
+
65
+ logger.info(f"Running command: {' '.join(command)}")
66
+ result = subprocess.run(command, capture_output=True, text=True)
67
+
68
+ if result.returncode != 0:
69
+ logger.error(f"Command failed with return code {result.returncode}")
70
+ logger.error(f"stderr: {result.stderr}")
71
+ raise HTTPException(status_code=500, detail=result.stderr)
72
+
73
+ response = result.stdout.strip()
74
+ sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
75
+
76
+ return {
77
+ "response": response,
78
+ "sdf_file_path": sdf_file_path
79
+ }
80
+ except Exception as e:
81
+ logger.exception("Exception occurred during inference.")
82
+ raise HTTPException(status_code=500, detail=str(e))
83
+
84
+ @app.post("/predict")
85
+ async def predict_alternative(
86
+ prompt: str = Form(...),
87
+ max_tokens: int = Form(256),
88
+ temperature: float = Form(1.0),
89
+ file: Optional[UploadFile] = File(None)
90
+ ):
91
+ try:
92
+ if file:
93
+ file_path = f"/tmp/{file.filename}"
94
+ with open(file_path, "wb") as f:
95
+ f.write(await file.read())
96
+ if file.filename.endswith(".pdb"):
97
+ prompt += f" {process_pdb(file_path)}"
98
+ elif file.filename.endswith(".sdf"):
99
+ prompt += f" {process_sdf(file_path)}"
100
+ else:
101
+ try:
102
+ sdf_file = extract_and_convert_to_sdf(prompt)
103
+ if sdf_file:
104
+ prompt += f" {sdf_file}"
105
+ except ValueError as e:
106
+ logger.info(str(e))
107
+
108
+ command = [
109
+ "python",
110
+ "/root/CHEMISTral7Bv0.3/mistral_chat_script.py",
111
+ "/root/mistral_models/7B-v0.3/",
112
+ prompt,
113
+ f"--max_tokens={max_tokens}",
114
+ f"--temperature={temperature}",
115
+ "--instruct",
116
+ "--lora_path=/root/CHEMISTral7Bv0.3/runs/checkpoints/checkpoint_000300/consolidated/lora.safetensors"
117
+ ]
118
+ logger.info(f"Running command: {' '.join(command)}")
119
+ result = subprocess.run(command, capture_output=True, text=True)
120
+ if result.returncode != 0:
121
+ logger.error(f"Command failed with return code {result.returncode}")
122
+ logger.error(f"stderr: {result.stderr}")
123
+ raise HTTPException(status_code=500, detail=result.stderr)
124
+
125
+ response = result.stdout.strip()
126
+ sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
127
+
128
+ # Return the file as a direct download
129
+ return FileResponse(sdf_file_path, media_type='chemical/x-mdl-sdfile', filename="Conformer3D_COMPOUND_CID_240.sdf")
130
+
131
+ except Exception as e:
132
+ logger.exception("Exception occurred during inference.")
133
+ raise HTTPException(status_code=500, detail=str(e))
134
+
135
+ # @app.post("/predict")
136
+ # async def predict_alternative(
137
+ # prompt: str = Form(...),
138
+ # max_tokens: int = Form(256),
139
+ # temperature: float = Form(1.0),
140
+ # file: Optional[UploadFile] = File(None)
141
+ # ):
142
+ # try:
143
+ # global sdf_file_path
144
+ # if file:
145
+ # file_path = f"/tmp/{file.filename}"
146
+ # with open(file_path, "wb") as f:
147
+ # f.write(file.file.read())
148
+ # if file.filename.endswith(".pdb"):
149
+ # prompt += f" {process_pdb(file_path)}"
150
+ # elif file.filename.endswith(".sdf"):
151
+ # prompt += f" {process_sdf(file_path)}"
152
+ # else:
153
+ # try:
154
+ # sdf_file = extract_and_convert_to_sdf(prompt)
155
+ # if sdf_file:
156
+ # prompt += f" {sdf_file}"
157
+ # except ValueError as e:
158
+ # logger.info(str(e))
159
+
160
+ # command = [
161
+ # "python",
162
+ # "/root/CHEMISTral7Bv0.3/mistral_chat_script.py",
163
+ # "/root/mistral_models/7B-v0.3/",
164
+ # prompt,
165
+ # f"--max_tokens={max_tokens}",
166
+ # f"--temperature={temperature}",
167
+ # "--instruct",
168
+ # "--lora_path=/root/CHEMISTral7Bv0.3/runs/checkpoints/checkpoint_000300/consolidated/lora.safetensors"
169
+ # ]
170
+
171
+ # logger.info(f"Running command: {' '.join(command)}")
172
+ # result = subprocess.run(command, capture_output=True, text=True)
173
+
174
+ # if result.returncode != 0:
175
+ # logger.error(f"Command failed with return code {result.returncode}")
176
+ # logger.error(f"stderr: {result.stderr}")
177
+ # raise HTTPException(status_code=500, detail=result.stderr)
178
+
179
+ # response = result.stdout.strip()
180
+ # sdf_file_path = "/root/CHEMISTral7Bv0.3/example/Conformer3D_COMPOUND_CID_240.sdf"
181
+
182
+ # return {
183
+ # "response": response,
184
+ # "sdf_file_path": sdf_file_path
185
+ # }
186
+ # except Exception as e:
187
+ # logger.exception("Exception occurred during inference.")
188
+ # raise HTTPException(status_code=500, detail=str(e))
189
+
190
+ @app.get("/download_sdf")
191
+ async def download_sdf():
192
+ try:
193
+ return FileResponse(path=sdf_file_path, filename="Conformer3D_COMPOUND_CID_240.sdf")
194
+ except Exception as e:
195
+ logger.exception("Exception occurred while sending SDF file.")
196
+ raise HTTPException(status_code=500, detail=str(e))
197
+
198
+ if __name__ == "__main__":
199
+ import uvicorn
200
+ uvicorn.run(app, host="0.0.0.0", port=8000)
example/Conformer3D_COMPOUND_CID_240.sdf ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 240
2
+ -OEChem-03012409263D
3
+
4
+ 14 14 0 0 0 0 0 0 0999 V2000
5
+ 2.8466 -0.3870 0.0002 O 0 0 0 0 0 0 0 0 0 0 0 0
6
+ 0.5644 0.2371 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
7
+ -0.3437 1.2960 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
8
+ 0.1013 -1.0787 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
9
+ -1.7147 1.0393 0.0001 C 0 0 0 0 0 0 0 0 0 0 0 0
10
+ -1.2698 -1.3354 -0.0001 C 0 0 0 0 0 0 0 0 0 0 0 0
11
+ -2.1777 -0.2764 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
12
+ 1.9937 0.5050 -0.0003 C 0 0 0 0 0 0 0 0 0 0 0 0
13
+ 0.0016 2.3267 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
14
+ 0.7902 -1.9194 -0.0001 H 0 0 0 0 0 0 0 0 0 0 0 0
15
+ -2.4218 1.8637 0.0001 H 0 0 0 0 0 0 0 0 0 0 0 0
16
+ -1.6308 -2.3599 -0.0001 H 0 0 0 0 0 0 0 0 0 0 0 0
17
+ -3.2452 -0.4764 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
18
+ 2.2986 1.5653 -0.0006 H 0 0 0 0 0 0 0 0 0 0 0 0
19
+ 1 8 2 0 0 0 0
20
+ 2 3 2 0 0 0 0
21
+ 2 4 1 0 0 0 0
22
+ 2 8 1 0 0 0 0
23
+ 3 5 1 0 0 0 0
24
+ 3 9 1 0 0 0 0
25
+ 4 6 2 0 0 0 0
26
+ 4 10 1 0 0 0 0
27
+ 5 7 2 0 0 0 0
28
+ 5 11 1 0 0 0 0
29
+ 6 7 1 0 0 0 0
30
+ 6 12 1 0 0 0 0
31
+ 7 13 1 0 0 0 0
32
+ 8 14 1 0 0 0 0
33
+ M END
34
+ > <PUBCHEM_COMPOUND_CID>
35
+ 240
36
+
37
+ > <PUBCHEM_CONFORMER_RMSD>
38
+ 0.4
39
+
40
+ > <PUBCHEM_CONFORMER_DIVERSEORDER>
41
+ 1
42
+
43
+ > <PUBCHEM_MMFF94_PARTIAL_CHARGES>
44
+ 14
45
+ 1 -0.57
46
+ 10 0.15
47
+ 11 0.15
48
+ 12 0.15
49
+ 13 0.15
50
+ 14 0.06
51
+ 2 0.09
52
+ 3 -0.15
53
+ 4 -0.15
54
+ 5 -0.15
55
+ 6 -0.15
56
+ 7 -0.15
57
+ 8 0.42
58
+ 9 0.15
59
+
60
+ > <PUBCHEM_EFFECTIVE_ROTOR_COUNT>
61
+ 1
62
+
63
+ > <PUBCHEM_PHARMACOPHORE_FEATURES>
64
+ 2
65
+ 1 1 acceptor
66
+ 6 2 3 4 5 6 7 rings
67
+
68
+ > <PUBCHEM_HEAVY_ATOM_COUNT>
69
+ 8
70
+
71
+ > <PUBCHEM_ATOM_DEF_STEREO_COUNT>
72
+ 0
73
+
74
+ > <PUBCHEM_ATOM_UDEF_STEREO_COUNT>
75
+ 0
76
+
77
+ > <PUBCHEM_BOND_DEF_STEREO_COUNT>
78
+ 0
79
+
80
+ > <PUBCHEM_BOND_UDEF_STEREO_COUNT>
81
+ 0
82
+
83
+ > <PUBCHEM_ISOTOPIC_ATOM_COUNT>
84
+ 0
85
+
86
+ > <PUBCHEM_COMPONENT_COUNT>
87
+ 1
88
+
89
+ > <PUBCHEM_CACTVS_TAUTO_COUNT>
90
+ 1
91
+
92
+ > <PUBCHEM_CONFORMER_ID>
93
+ 000000F000000001
94
+
95
+ > <PUBCHEM_MMFF94_ENERGY>
96
+ 18.0728
97
+
98
+ > <PUBCHEM_FEATURE_SELFOVERLAP>
99
+ 10.148
100
+
101
+ > <PUBCHEM_SHAPE_FINGERPRINT>
102
+ 16714656 1 18409731763581766061
103
+ 18185500 45 18263078975662380655
104
+ 21040471 1 18338797793165636005
105
+ 23552423 10 18187929525178951646
106
+ 29004967 10 16200157598908099156
107
+ 369184 2 15195566792725449510
108
+ 5084963 1 18412544318583771616
109
+
110
+ > <PUBCHEM_SHAPE_MULTIPOLES>
111
+ 158.77
112
+ 3.12
113
+ 1.41
114
+ 0.6
115
+ 1.61
116
+ 0.02
117
+ 0
118
+ 0.14
119
+ 0
120
+ -0.45
121
+ 0
122
+ -0.03
123
+ -0.01
124
+ 0
125
+
126
+ > <PUBCHEM_SHAPE_SELFOVERLAP>
127
+ 329.455
128
+
129
+ > <PUBCHEM_SHAPE_VOLUME>
130
+ 90.5
131
+
132
+ > <PUBCHEM_COORDINATE_TYPE>
133
+ 2
134
+ 5
135
+ 10
136
+
137
+ $$$$
inference_transform.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from rdkit import Chem
3
+ from rdkit.Chem import MolFromSmiles, SDWriter
4
+ import logging
5
+ from Bio import SeqIO
6
+
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ def process_smiles(smiles: str) -> str:
12
+ mol = MolFromSmiles(smiles)
13
+ if not mol:
14
+ raise ValueError(f"Invalid SMILES string: {smiles}")
15
+
16
+ sdf_file = "/tmp/output.sdf"
17
+ writer = SDWriter(sdf_file)
18
+ writer.write(mol)
19
+ writer.close()
20
+
21
+ return sdf_file
22
+
23
+ def process_pdb(file_path: str) -> str:
24
+ sequences = []
25
+ with open(file_path, "r") as handle:
26
+ for record in SeqIO.parse(handle, "pdb-seqres"):
27
+ sequences.append(str(record.seq))
28
+ return " ".join(sequences)
29
+
30
+ def process_sdf(file_path: str) -> str:
31
+ return file_path
32
+
33
+ def extract_smiles(text: str) -> str:
34
+ smiles_pattern = r"([^J][0-9BCOHNSOPrIFla@+\-\[\]\(\)\\\/%=#$]{6,})"
35
+ matches = re.findall(smiles_pattern, text)
36
+ if matches:
37
+ return matches[0]
38
+ return ""
39
+
40
+ def is_valid_smiles(smiles: str) -> bool:
41
+ mol = MolFromSmiles(smiles)
42
+ return mol is not None
43
+
44
+ def extract_and_convert_to_sdf(text: str) -> str:
45
+ smiles = extract_smiles(text)
46
+ if smiles and is_valid_smiles(smiles):
47
+ return process_smiles(smiles)
48
+ raise ValueError("No valid SMILES string found in the text.")
mistral_chat_script.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ from mistral_inference.generate import generate
4
+ from mistral_inference.model import Transformer
5
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
6
+
7
+ def run_chat(model_path: str, prompt: str, max_tokens: int = 256, temperature: float = 1.0, instruct: bool = True, lora_path: str = None):
8
+ # Find the correct tokenizer file
9
+ model_path = Path(model_path)
10
+ tokenizer_file = model_path / "tokenizer.model.v3"
11
+
12
+ if not tokenizer_file.is_file():
13
+ raise FileNotFoundError(f"Tokenizer model file not found at {tokenizer_file}")
14
+
15
+ mistral_tokenizer = MistralTokenizer.from_file(str(tokenizer_file))
16
+ tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
17
+
18
+ transformer = Transformer.from_folder(
19
+ model_path, max_batch_size=3, num_pipeline_ranks=1
20
+ )
21
+
22
+ if lora_path is not None:
23
+ transformer.load_lora(Path(lora_path))
24
+
25
+ tokens = tokenizer.encode(prompt, bos=True, eos=False)
26
+ generated_tokens, _ = generate(
27
+ [tokens],
28
+ transformer,
29
+ max_tokens=max_tokens,
30
+ temperature=temperature,
31
+ eos_id=tokenizer.eos_id,
32
+ )
33
+ answer = tokenizer.decode(generated_tokens[0])
34
+ print(answer)
35
+
36
+ if __name__ == "__main__":
37
+ import fire
38
+ fire.Fire(run_chat)