File size: 2,015 Bytes
a721ac0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97505a5
a721ac0
97505a5
 
 
 
 
 
 
a721ac0
 
 
 
 
 
 
 
 
 
 
97505a5
a721ac0
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import Dict, List, Any
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import torch
import base64
import logging
import numpy as np
from PIL import Image
from io import BytesIO

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class EndpointHandler:
    def __init__(self, path=""):
        # load the model        
        self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
        Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
        Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
        model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto")
        logger.info(model.hf_device_map)
        model.hf_device_map["swin2sr.conv_after_body"] = model.hf_device_map["swin2sr.embeddings"]
        model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"]
        self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map)

    def __call__(self, data: Any):
        """
        Args:
            data (:obj:):
                binary image data to be labeled
        Return:
            A :obj:`string`:. Base64 encoded image string
        """

        image = data["inputs"]
        inputs = self.processor(image, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.moveaxis(output, source=0, destination=-1)
        output = (output * 255.0).round().astype(np.uint8)

        img = Image.fromarray(output)
        buffered = BytesIO()
        img.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue())

        return img_str.decode()