Spaces:
Running
Running
import time | |
from glob import glob | |
from pathlib import Path | |
from typing import List | |
from fastapi import FastAPI, File, Form, UploadFile | |
from fastapi.responses import JSONResponse, Response | |
from tqdm import tqdm | |
from utils import * | |
from concrete.ml.deployment import FHEModelClient, FHEModelServer | |
# Load the FHE server | |
# Initialize an instance of FastAPI | |
app = FastAPI() | |
# Define the default route | |
def root(): | |
""" | |
Root endpoint of the health prediction API. | |
Returns: | |
dict: The welcome message. | |
""" | |
return {"message": "Welcome to your encrypted DNA testing use-case with FHE!"} | |
def send_input( | |
user_id: str = Form(...), root_dir: str = Form(...), files: List[UploadFile] = File(...) | |
): | |
"""Send the inputs to the server.""" | |
print("------------ Step 3.2: Send the data to the server") | |
print(f"{user_id=}, {root_dir=}, {len(files)=}") | |
SERVER_DIR = Path(root_dir) / f"{user_id}/server" | |
SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR | |
SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR | |
SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR | |
# Save the files using the above paths | |
with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("wb") as eval_key_1: | |
eval_key_1.write(files[0].file.read()) | |
with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("wb") as eval_key_2: | |
eval_key_2.write(files[1].file.read()) | |
print(f"{len(files)=}") | |
for i in tqdm(range(2, len(files))): | |
with (SERVER_ENCRYPTED_INPUT_DIR / f"encrypted_window_{i}").open("wb") as eval_key_2: | |
eval_key_2.write(files[i].file.read()) | |
def run_fhe( | |
user_id: str = Form(), | |
root_dir: str = Form(...), | |
): | |
"""Inference in FHE.""" | |
print("------------ Step 4.2: Run in FHE on the Server Side") | |
print(f"{user_id=}, {root_dir=}") | |
SERVER_DIR = Path(root_dir) / f"{user_id}/server" | |
SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR | |
SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR | |
SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR | |
SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR | |
with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1: | |
eval_key_base_module = eval_key_1.read() | |
assert isinstance(eval_key_base_module, bytes) | |
with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2: | |
eval_key_smoother_module = eval_key_2.read() | |
assert isinstance(eval_key_smoother_module, bytes) | |
shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*") | |
shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number) | |
print(f"{len(shared_base_modules_path)=}") | |
assert len(shared_base_modules_path) == META["NW"] | |
client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*") | |
client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number) | |
print(f"{len(client_encrypted_input_path)=}") | |
assert len(shared_base_modules_path) == META["NW"] | |
nb_total_iterations = META["NW"] * 2 | |
start_time = time.time() | |
y_proba = [] | |
for i, (model_path, encrypted_window_path) in tqdm( | |
enumerate(zip(shared_base_modules_path, client_encrypted_input_path)) | |
): | |
server = FHEModelServer(model_path) | |
with open(encrypted_window_path, "rb") as f: | |
encrypted_window = f.read() | |
encrypted_output = server.run( | |
encrypted_window, serialized_evaluation_keys=eval_key_base_module | |
) | |
assert isinstance(encrypted_output, bytes) | |
client = FHEModelClient(model_path, key_dir=model_path) | |
decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output) | |
with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f: | |
f.write(encrypted_window) | |
y_proba.append(decrypted_output) | |
with open(FHE_COMPUTATION_TIMELINE, "w", encoding="utf-8") as f: | |
f.write(f"{time.time() - start_time:.2f} seconds ({(i + 1)/nb_total_iterations:.0%})") | |
client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR) | |
server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR) | |
y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2)) | |
y_proba = y_proba.astype(numpy.int8) | |
print(f"{y_proba.shape=}, {type(y_proba)}") | |
X_slide, _ = slide_window(y_proba, META["SS"]) | |
yhat_encrypted = [] | |
for i in tqdm(range(len(X_slide))): | |
input = X_slide[i].reshape(1, -1) | |
encrypted_input = client.quantize_encrypt_serialize(input) | |
encrypted_output = server.run( | |
encrypted_input, serialized_evaluation_keys=eval_key_smoother_module | |
) | |
# output = client.deserialize_decrypt_dequantize(encrypted_output) | |
# y_pred = numpy.argmax(output, axis=-1)[0] | |
yhat_encrypted.append(encrypted_output) | |
with open(FHE_COMPUTATION_TIMELINE, "w", encoding="utf-8") as f: | |
f.write(f"{time.time() - start_time:.2f} seconds ({(i + 1)/nb_total_iterations:.0%})") | |
write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) | |
fhe_execution_time = round(time.time() - start_time, 2) | |
return JSONResponse(content=fhe_execution_time) | |
def run_fhe_stage1( | |
user_id: str = Form(), | |
root_dir: str = Form(...), | |
): | |
"""Inference in FHE.""" | |
print("------------ Step 4.2: Run in FHE on the Server Side") | |
print(f"{user_id=}, {root_dir=}") | |
SERVER_DIR = Path(root_dir) / f"{user_id}/server" | |
SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR | |
SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR | |
SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR | |
SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR | |
with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1: | |
eval_key_base_module = eval_key_1.read() | |
assert isinstance(eval_key_base_module, bytes) | |
with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2: | |
eval_key_smoother_module = eval_key_2.read() | |
assert isinstance(eval_key_smoother_module, bytes) | |
shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*") | |
shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number) | |
print(f"{len(shared_base_modules_path)=}") | |
assert len(shared_base_modules_path) == META["NW"] | |
client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*") | |
client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number) | |
print(f"{len(client_encrypted_input_path)=}") | |
assert len(shared_base_modules_path) == META["NW"] | |
start = time.time() | |
y_proba = [] | |
for i, (model_path, encrypted_window_path) in tqdm( | |
enumerate(zip(shared_base_modules_path, client_encrypted_input_path)) | |
): | |
server = FHEModelServer(model_path) | |
with open(encrypted_window_path, "rb") as f: | |
encrypted_window = f.read() | |
encrypted_output = server.run( | |
encrypted_window, serialized_evaluation_keys=eval_key_base_module | |
) | |
assert isinstance(encrypted_output, bytes) | |
client = FHEModelClient(model_path, key_dir=model_path) | |
decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output) | |
with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f: | |
f.write(encrypted_window) | |
y_proba.append(decrypted_output) | |
client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR) | |
server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR) | |
y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2)) | |
y_proba = y_proba.astype(numpy.int8) | |
print(f"{y_proba.shape=}, {type(y_proba)}") | |
X_slide, _ = slide_window(y_proba, META["SS"]) | |
yhat_encrypted = [] | |
for i in tqdm(range(len(X_slide))): | |
input = X_slide[i].reshape(1, -1) | |
encrypted_input = client.quantize_encrypt_serialize(input) | |
encrypted_output = server.run( | |
encrypted_input, serialized_evaluation_keys=eval_key_smoother_module | |
) | |
# output = client.deserialize_decrypt_dequantize(encrypted_output) | |
# y_pred = numpy.argmax(output, axis=-1)[0] | |
yhat_encrypted.append(encrypted_output) | |
write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) | |
fhe_execution_time = round(time.time() - start, 2) | |
return JSONResponse(content=fhe_execution_time) | |
def run_fhe_stage2( | |
user_id: str = Form(), | |
root_dir: str = Form(...), | |
): | |
"""Inference in FHE.""" | |
print("------------ Step 4.2: Run in FHE on the Server Side") | |
print(f"{user_id=}, {root_dir=}") | |
SERVER_DIR = Path(root_dir) / f"{user_id}/server" | |
SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR | |
SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR | |
SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR | |
SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR | |
with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1: | |
eval_key_base_module = eval_key_1.read() | |
assert isinstance(eval_key_base_module, bytes) | |
with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2: | |
eval_key_smoother_module = eval_key_2.read() | |
assert isinstance(eval_key_smoother_module, bytes) | |
shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*") | |
shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number) | |
print(f"{len(shared_base_modules_path)=}") | |
assert len(shared_base_modules_path) == META["NW"] | |
client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*") | |
client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number) | |
print(f"{len(client_encrypted_input_path)=}") | |
assert len(shared_base_modules_path) == META["NW"] | |
start = time.time() | |
y_proba = [] | |
for i, (model_path, encrypted_window_path) in tqdm( | |
enumerate(zip(shared_base_modules_path, client_encrypted_input_path)) | |
): | |
server = FHEModelServer(model_path) | |
with open(encrypted_window_path, "rb") as f: | |
encrypted_window = f.read() | |
encrypted_output = server.run( | |
encrypted_window, serialized_evaluation_keys=eval_key_base_module | |
) | |
assert isinstance(encrypted_output, bytes) | |
client = FHEModelClient(model_path, key_dir=model_path) | |
decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output) | |
with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f: | |
f.write(encrypted_window) | |
y_proba.append(decrypted_output) | |
client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR) | |
server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR) | |
y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2)) | |
y_proba = y_proba.astype(numpy.int8) | |
print(f"{y_proba.shape=}, {type(y_proba)}") | |
X_slide, _ = slide_window(y_proba, META["SS"]) | |
yhat_encrypted = [] | |
for i in tqdm(range(len(X_slide))): | |
input = X_slide[i].reshape(1, -1) | |
encrypted_input = client.quantize_encrypt_serialize(input) | |
encrypted_output = server.run( | |
encrypted_input, serialized_evaluation_keys=eval_key_smoother_module | |
) | |
# output = client.deserialize_decrypt_dequantize(encrypted_output) | |
# y_pred = numpy.argmax(output, axis=-1)[0] | |
yhat_encrypted.append(encrypted_output) | |
write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) | |
fhe_execution_time = round(time.time() - start, 2) | |
return JSONResponse(content=fhe_execution_time) | |
def get_output(user_id: str = Form(), root_dir: str = Form()): | |
"""Retrieve the encrypted output from the server.""" | |
print("\nStep 5.2: Get the output from the server ............\n") | |
SERVER_DIR = Path(root_dir) / f"{user_id}/server" | |
SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR | |
yhat_encrypted = load_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl") | |
CLIENT_DIR = Path(root_dir) / f"{user_id}/client" | |
CLIENT_ENCRYPTED_OUTPUT_DIR = CLIENT_DIR / ENCRYPTED_OUTPUT_DIR | |
write_pickle(CLIENT_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) | |
assert len(yhat_encrypted) == META["NW"] | |
time.sleep(1) | |
# Send the encrypted output | |
return Response("OK") | |