from typing import Dict, List, Any from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, Swin2SRModel import torch import base64 import logging import numpy as np from PIL import Image from io import BytesIO import subprocess logger = logging.getLogger() logger.setLevel(logging.DEBUG) # check for GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class EndpointHandler: def __init__(self, path=""): # load the model self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64") Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"] Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"] model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto") logger.info(model.hf_device_map) model.hf_device_map["swin2sr.conv_after_body"] = model.hf_device_map["swin2sr.embeddings"] model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"] self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map) print(subprocess.run(["nvidia-smi"])) def __call__(self, data: Any): """ Args: data (:obj:): binary image data to be labeled Return: A :obj:`string`:. Base64 encoded image string """ image = data["inputs"] inputs = self.processor(image, return_tensors="pt") with torch.no_grad(): outputs = self.model(**inputs) print(subprocess.run(["nvidia-smi"])) output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = np.moveaxis(output, source=0, destination=-1) output = (output * 255.0).round().astype(np.uint8) img = Image.fromarray(output) buffered = BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) return img_str.decode()