File size: 2,802 Bytes
a721ac0
a8cfa26
a721ac0
 
 
 
64ab5a0
a721ac0
 
7651656
 
a721ac0
 
 
 
 
 
f6c11fb
a721ac0
 
 
97505a5
a721ac0
f6c11fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a721ac0
 
 
 
 
 
 
 
 
 
f6c11fb
 
 
 
 
0dec772
1ef36b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a721ac0
1ef36b1
 
 
 
 
a721ac0
1ef36b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import Dict, List, Any
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, Swin2SRModel
import torch
import base64
import logging
import numpy as np
import gc
from PIL import Image
from io import BytesIO
import subprocess


logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_count = torch.cuda.device_count()

class EndpointHandler:
    def __init__(self, path=""):
        # load the model        
        self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")

        if(gpu_count > 1):
            Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
            Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
            model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto")
            logger.info(model.hf_device_map)
            model.hf_device_map["swin2sr.conv_after_body"] = model.hf_device_map["swin2sr.embeddings"]
            model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"]
            self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map)

            print(subprocess.run(["nvidia-smi"]))

        else:
            self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
            # move model to device
            self.model.to(device)


    def __call__(self, data: Any):
        """
        Args:
            data (:obj:):
                binary image data to be labeled
        Return:
            A :obj:`string`:. Base64 encoded image string
        """

        image = data["inputs"]

        if(gpu_count > 1):
            inputs = self.processor(image, return_tensors="pt")
        else:
            inputs = self.processor(image, return_tensors="pt").to(device)

        try:
            with torch.no_grad():
                outputs = self.model(**inputs)

            print(subprocess.run(["nvidia-smi"]))

            output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
            output = np.moveaxis(output, source=0, destination=-1)
            output = (output * 255.0).round().astype(np.uint8)

            img = Image.fromarray(output)
            buffered = BytesIO()
            img.save(buffered, format="JPEG")
            img_str = base64.b64encode(buffered.getvalue())

            return img_str.decode()

        except Exception as e:
            logger.error(str(e))
            del inputs
            gc.collect()
            torch.cuda.empty_cache()

            return {"error": str(e)}