encrypted_dna / server.py
kcelia's picture
chore: add relevent files
dbdd71f unverified
raw
history blame
12.9 kB
import time
from glob import glob
from pathlib import Path
from typing import List
from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import JSONResponse, Response
from tqdm import tqdm
from utils import *
from concrete.ml.deployment import FHEModelClient, FHEModelServer
# Load the FHE server
# Initialize an instance of FastAPI
app = FastAPI()
# Define the default route
@app.get("/")
def root():
"""
Root endpoint of the health prediction API.
Returns:
dict: The welcome message.
"""
return {"message": "Welcome to your encrypted DNA testing use-case with FHE!"}
@app.post("/send_input")
def send_input(
user_id: str = Form(...), root_dir: str = Form(...), files: List[UploadFile] = File(...)
):
"""Send the inputs to the server."""
print("------------ Step 3.2: Send the data to the server")
print(f"{user_id=}, {root_dir=}, {len(files)=}")
SERVER_DIR = Path(root_dir) / f"{user_id}/server"
SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR
SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR
SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR
# Save the files using the above paths
with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("wb") as eval_key_1:
eval_key_1.write(files[0].file.read())
with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("wb") as eval_key_2:
eval_key_2.write(files[1].file.read())
print(f"{len(files)=}")
for i in tqdm(range(2, len(files))):
with (SERVER_ENCRYPTED_INPUT_DIR / f"encrypted_window_{i}").open("wb") as eval_key_2:
eval_key_2.write(files[i].file.read())
@app.post("/run_fhe")
def run_fhe(
user_id: str = Form(),
root_dir: str = Form(...),
):
"""Inference in FHE."""
print("------------ Step 4.2: Run in FHE on the Server Side")
print(f"{user_id=}, {root_dir=}")
SERVER_DIR = Path(root_dir) / f"{user_id}/server"
SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR
SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR
SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR
SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR
with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1:
eval_key_base_module = eval_key_1.read()
assert isinstance(eval_key_base_module, bytes)
with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2:
eval_key_smoother_module = eval_key_2.read()
assert isinstance(eval_key_smoother_module, bytes)
shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*")
shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number)
print(f"{len(shared_base_modules_path)=}")
assert len(shared_base_modules_path) == META["NW"]
client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*")
client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number)
print(f"{len(client_encrypted_input_path)=}")
assert len(shared_base_modules_path) == META["NW"]
nb_total_iterations = META["NW"] * 2
start_time = time.time()
y_proba = []
for i, (model_path, encrypted_window_path) in tqdm(
enumerate(zip(shared_base_modules_path, client_encrypted_input_path))
):
server = FHEModelServer(model_path)
with open(encrypted_window_path, "rb") as f:
encrypted_window = f.read()
encrypted_output = server.run(
encrypted_window, serialized_evaluation_keys=eval_key_base_module
)
assert isinstance(encrypted_output, bytes)
client = FHEModelClient(model_path, key_dir=model_path)
decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output)
with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f:
f.write(encrypted_window)
y_proba.append(decrypted_output)
with open(FHE_COMPUTATION_TIMELINE, "w", encoding="utf-8") as f:
f.write(f"{time.time() - start_time:.2f} seconds ({(i + 1)/nb_total_iterations:.0%})")
client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR)
server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR)
y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2))
y_proba = y_proba.astype(numpy.int8)
print(f"{y_proba.shape=}, {type(y_proba)}")
X_slide, _ = slide_window(y_proba, META["SS"])
yhat_encrypted = []
for i in tqdm(range(len(X_slide))):
input = X_slide[i].reshape(1, -1)
encrypted_input = client.quantize_encrypt_serialize(input)
encrypted_output = server.run(
encrypted_input, serialized_evaluation_keys=eval_key_smoother_module
)
# output = client.deserialize_decrypt_dequantize(encrypted_output)
# y_pred = numpy.argmax(output, axis=-1)[0]
yhat_encrypted.append(encrypted_output)
with open(FHE_COMPUTATION_TIMELINE, "w", encoding="utf-8") as f:
f.write(f"{time.time() - start_time:.2f} seconds ({(i + 1)/nb_total_iterations:.0%})")
write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted)
fhe_execution_time = round(time.time() - start_time, 2)
return JSONResponse(content=fhe_execution_time)
@app.post("/run_fhe_stage1")
def run_fhe_stage1(
user_id: str = Form(),
root_dir: str = Form(...),
):
"""Inference in FHE."""
print("------------ Step 4.2: Run in FHE on the Server Side")
print(f"{user_id=}, {root_dir=}")
SERVER_DIR = Path(root_dir) / f"{user_id}/server"
SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR
SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR
SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR
SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR
with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1:
eval_key_base_module = eval_key_1.read()
assert isinstance(eval_key_base_module, bytes)
with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2:
eval_key_smoother_module = eval_key_2.read()
assert isinstance(eval_key_smoother_module, bytes)
shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*")
shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number)
print(f"{len(shared_base_modules_path)=}")
assert len(shared_base_modules_path) == META["NW"]
client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*")
client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number)
print(f"{len(client_encrypted_input_path)=}")
assert len(shared_base_modules_path) == META["NW"]
start = time.time()
y_proba = []
for i, (model_path, encrypted_window_path) in tqdm(
enumerate(zip(shared_base_modules_path, client_encrypted_input_path))
):
server = FHEModelServer(model_path)
with open(encrypted_window_path, "rb") as f:
encrypted_window = f.read()
encrypted_output = server.run(
encrypted_window, serialized_evaluation_keys=eval_key_base_module
)
assert isinstance(encrypted_output, bytes)
client = FHEModelClient(model_path, key_dir=model_path)
decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output)
with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f:
f.write(encrypted_window)
y_proba.append(decrypted_output)
client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR)
server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR)
y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2))
y_proba = y_proba.astype(numpy.int8)
print(f"{y_proba.shape=}, {type(y_proba)}")
X_slide, _ = slide_window(y_proba, META["SS"])
yhat_encrypted = []
for i in tqdm(range(len(X_slide))):
input = X_slide[i].reshape(1, -1)
encrypted_input = client.quantize_encrypt_serialize(input)
encrypted_output = server.run(
encrypted_input, serialized_evaluation_keys=eval_key_smoother_module
)
# output = client.deserialize_decrypt_dequantize(encrypted_output)
# y_pred = numpy.argmax(output, axis=-1)[0]
yhat_encrypted.append(encrypted_output)
write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted)
fhe_execution_time = round(time.time() - start, 2)
return JSONResponse(content=fhe_execution_time)
@app.post("/run_fhe_stage2")
def run_fhe_stage2(
user_id: str = Form(),
root_dir: str = Form(...),
):
"""Inference in FHE."""
print("------------ Step 4.2: Run in FHE on the Server Side")
print(f"{user_id=}, {root_dir=}")
SERVER_DIR = Path(root_dir) / f"{user_id}/server"
SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR
SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR
SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR
SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR
with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1:
eval_key_base_module = eval_key_1.read()
assert isinstance(eval_key_base_module, bytes)
with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2:
eval_key_smoother_module = eval_key_2.read()
assert isinstance(eval_key_smoother_module, bytes)
shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*")
shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number)
print(f"{len(shared_base_modules_path)=}")
assert len(shared_base_modules_path) == META["NW"]
client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*")
client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number)
print(f"{len(client_encrypted_input_path)=}")
assert len(shared_base_modules_path) == META["NW"]
start = time.time()
y_proba = []
for i, (model_path, encrypted_window_path) in tqdm(
enumerate(zip(shared_base_modules_path, client_encrypted_input_path))
):
server = FHEModelServer(model_path)
with open(encrypted_window_path, "rb") as f:
encrypted_window = f.read()
encrypted_output = server.run(
encrypted_window, serialized_evaluation_keys=eval_key_base_module
)
assert isinstance(encrypted_output, bytes)
client = FHEModelClient(model_path, key_dir=model_path)
decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output)
with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f:
f.write(encrypted_window)
y_proba.append(decrypted_output)
client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR)
server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR)
y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2))
y_proba = y_proba.astype(numpy.int8)
print(f"{y_proba.shape=}, {type(y_proba)}")
X_slide, _ = slide_window(y_proba, META["SS"])
yhat_encrypted = []
for i in tqdm(range(len(X_slide))):
input = X_slide[i].reshape(1, -1)
encrypted_input = client.quantize_encrypt_serialize(input)
encrypted_output = server.run(
encrypted_input, serialized_evaluation_keys=eval_key_smoother_module
)
# output = client.deserialize_decrypt_dequantize(encrypted_output)
# y_pred = numpy.argmax(output, axis=-1)[0]
yhat_encrypted.append(encrypted_output)
write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted)
fhe_execution_time = round(time.time() - start, 2)
return JSONResponse(content=fhe_execution_time)
@app.post("/get_output")
def get_output(user_id: str = Form(), root_dir: str = Form()):
"""Retrieve the encrypted output from the server."""
print("\nStep 5.2: Get the output from the server ............\n")
SERVER_DIR = Path(root_dir) / f"{user_id}/server"
SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR
yhat_encrypted = load_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl")
CLIENT_DIR = Path(root_dir) / f"{user_id}/client"
CLIENT_ENCRYPTED_OUTPUT_DIR = CLIENT_DIR / ENCRYPTED_OUTPUT_DIR
write_pickle(CLIENT_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted)
assert len(yhat_encrypted) == META["NW"]
time.sleep(1)
# Send the encrypted output
return Response("OK")