File size: 2,802 Bytes
a721ac0 a8cfa26 a721ac0 64ab5a0 a721ac0 7651656 a721ac0 f6c11fb a721ac0 97505a5 a721ac0 f6c11fb a721ac0 f6c11fb 0dec772 1ef36b1 a721ac0 1ef36b1 a721ac0 1ef36b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from typing import Dict, List, Any
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, Swin2SRModel
import torch
import base64
import logging
import numpy as np
import gc
from PIL import Image
from io import BytesIO
import subprocess
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_count = torch.cuda.device_count()
class EndpointHandler:
def __init__(self, path=""):
# load the model
self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
if(gpu_count > 1):
Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto")
logger.info(model.hf_device_map)
model.hf_device_map["swin2sr.conv_after_body"] = model.hf_device_map["swin2sr.embeddings"]
model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"]
self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map)
print(subprocess.run(["nvidia-smi"]))
else:
self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
# move model to device
self.model.to(device)
def __call__(self, data: Any):
"""
Args:
data (:obj:):
binary image data to be labeled
Return:
A :obj:`string`:. Base64 encoded image string
"""
image = data["inputs"]
if(gpu_count > 1):
inputs = self.processor(image, return_tensors="pt")
else:
inputs = self.processor(image, return_tensors="pt").to(device)
try:
with torch.no_grad():
outputs = self.model(**inputs)
print(subprocess.run(["nvidia-smi"]))
output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.moveaxis(output, source=0, destination=-1)
output = (output * 255.0).round().astype(np.uint8)
img = Image.fromarray(output)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode()
except Exception as e:
logger.error(str(e))
del inputs
gc.collect()
torch.cuda.empty_cache()
return {"error": str(e)} |