fofr-sdxl-emoji / cog_sdxl /tests /test_predict.py
julien-c's picture
julien-c HF staff
vendor cog_sdxl
3ab8901 verified
raw
history blame
7.05 kB
import base64
import os
import pickle
import subprocess
import sys
import time
from functools import partial
from io import BytesIO
import numpy as np
import pytest
import replicate
import requests
from PIL import Image, ImageChops
ENV = os.getenv('TEST_ENV', 'local')
LOCAL_ENDPOINT = "http://localhost:5000/predictions"
MODEL = os.getenv('STAGING_MODEL', 'no model configured')
def local_run(model_endpoint: str, model_input: dict):
response = requests.post(model_endpoint, json={"input": model_input})
data = response.json()
try:
# TODO: this will break if we test batching
datauri = data["output"][0]
base64_encoded_data = datauri.split(",")[1]
data = base64.b64decode(base64_encoded_data)
return Image.open(BytesIO(data))
except Exception as e:
print("Error!")
print("input:", model_input)
print(data["logs"])
raise e
def replicate_run(model: str, version: str, model_input: dict):
output = replicate.run(
f"{model}:{version}",
input=model_input)
url = output[0]
response = requests.get(url)
return Image.open(BytesIO(response.content))
def wait_for_server_to_be_ready(url, timeout=300):
"""
Waits for the server to be ready.
Args:
- url: The health check URL to poll.
- timeout: Maximum time (in seconds) to wait for the server to be ready.
"""
start_time = time.time()
while True:
try:
response = requests.get(url)
data = response.json()
if data["status"] == "READY":
return
elif data["status"] == "SETUP_FAILED":
raise RuntimeError(
"Server initialization failed with status: SETUP_FAILED"
)
except requests.RequestException:
pass
if time.time() - start_time > timeout:
raise TimeoutError("Server did not become ready in the expected time.")
time.sleep(5) # Poll every 5 seconds
@pytest.fixture(scope="session")
def inference_func():
"""
local inference uses http API to hit local server; staging inference uses python API b/c it's cleaner.
"""
if ENV == 'local':
return partial(local_run, LOCAL_ENDPOINT)
elif ENV == 'staging':
model = replicate.models.get(MODEL)
print(f"model,", model)
version = model.versions.list()[0]
return partial(replicate_run, MODEL, version.id)
else:
raise Exception(f"env should be local or staging but was {ENV}")
@pytest.fixture(scope="session", autouse=True)
def service():
"""
Spins up local cog server to hit for tests if running locally, no-op otherwise
"""
if ENV == 'local':
print("building model")
# starts local server if we're running things locally
build_command = 'cog build -t test-model'.split()
subprocess.run(build_command, check=True)
container_name = 'cog-test'
try:
subprocess.check_output(['docker', 'inspect', '--format="{{.State.Running}}"', container_name])
print(f"Container '{container_name}' is running. Stopping and removing...")
subprocess.check_call(['docker', 'stop', container_name])
subprocess.check_call(['docker', 'rm', container_name])
print(f"Container '{container_name}' stopped and removed.")
except subprocess.CalledProcessError:
# Container not found
print(f"Container '{container_name}' not found or not running.")
run_command = f'docker run -d -p 5000:5000 --gpus all --name {container_name} test-model '.split()
process = subprocess.Popen(run_command, stdout=sys.stdout, stderr=sys.stderr)
wait_for_server_to_be_ready("http://localhost:5000/health-check")
yield
process.terminate()
process.wait()
stop_command = "docker stop cog-test".split()
subprocess.run(stop_command)
else:
yield
def image_equal_fuzzy(img_expected, img_actual, test_name='default', tol=20):
"""
Assert that average pixel values differ by less than tol across image
Tol determined empirically - holding everything else equal but varying seed
generates images that vary by at least 50
"""
img1 = np.array(img_expected, dtype=np.int32)
img2 = np.array(img_actual, dtype=np.int32)
mean_delta = np.mean(np.abs(img1 - img2))
imgs_equal = (mean_delta < tol)
if not imgs_equal:
# save failures for quick inspection
save_dir = f"tmp/{test_name}"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
img_expected.save(os.path.join(save_dir, 'expected.png'))
img_actual.save(os.path.join(save_dir, 'actual.png'))
difference = ImageChops.difference(img_expected, img_actual)
difference.save(os.path.join(save_dir, 'delta.png'))
return imgs_equal
def test_seeded_prediction(inference_func, request):
"""
SDXL w/seed should be deterministic. may need to adjust tolerance for optimized SDXLs
"""
data = {
"prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic",
"num_inference_steps": 50,
"width": 1024,
"height": 1024,
"scheduler": "DDIM",
"refine": "expert_ensemble_refiner",
"seed": 12103,
}
actual_image = inference_func(data)
expected_image = Image.open("tests/assets/out.png")
assert image_equal_fuzzy(actual_image, expected_image, test_name=request.node.name)
def test_lora_load_unload(inference_func, request):
"""
Tests generation with & without loras.
This is checking for some gnarly state issues (can SDXL load / unload LoRAs), so predictions need to run in series.
"""
SEED = 1234
base_data = {
"prompt": "A photo of a dog on the beach",
"num_inference_steps": 50,
# Add other parameters here
"seed": SEED,
}
base_img_1 = inference_func(base_data)
lora_a_data = {
"prompt": "A photo of a TOK on the beach",
"num_inference_steps": 50,
# Add other parameters here
"replicate_weights": "https://storage.googleapis.com/dan-scratch-public/sdxl/other_model.tar",
"seed": SEED
}
lora_a_img_1 = inference_func(lora_a_data)
assert not image_equal_fuzzy(lora_a_img_1, base_img_1, test_name=request.node.name)
lora_a_img_2 = inference_func(lora_a_data)
assert image_equal_fuzzy(lora_a_img_1, lora_a_img_2, test_name=request.node.name)
lora_b_data = {
"prompt": "A photo of a TOK on the beach",
"num_inference_steps": 50,
"replicate_weights": "https://storage.googleapis.com/dan-scratch-public/sdxl/monstertoy_model.tar",
"seed": SEED,
}
lora_b_img = inference_func(lora_b_data)
assert not image_equal_fuzzy(lora_a_img_1, lora_b_img, test_name=request.node.name)
assert not image_equal_fuzzy(base_img_1, lora_b_img, test_name=request.node.name)