fashxp commited on
Commit
f6c11fb
·
1 Parent(s): 64ab5a0

make it compatible to 1 and multiple GPUs

Browse files
Files changed (1) hide show
  1. handler.py +24 -12
handler.py CHANGED
@@ -15,22 +15,30 @@ logger.setLevel(logging.DEBUG)
15
 
16
  # check for GPU
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
-
19
 
20
  class EndpointHandler:
21
  def __init__(self, path=""):
22
  # load the model
23
  self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
24
- Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
25
- Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
26
- model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto")
27
- logger.info(model.hf_device_map)
28
- model.hf_device_map["swin2sr.conv_after_body"] = model.hf_device_map["swin2sr.embeddings"]
29
- model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"]
30
- self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map)
31
-
32
- print(subprocess.run(["nvidia-smi"]))
33
-
 
 
 
 
 
 
 
 
34
  def __call__(self, data: Any):
35
  """
36
  Args:
@@ -41,7 +49,11 @@ class EndpointHandler:
41
  """
42
 
43
  image = data["inputs"]
44
- inputs = self.processor(image, return_tensors="pt")
 
 
 
 
45
 
46
  try:
47
  with torch.no_grad():
 
15
 
16
  # check for GPU
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ gpu_count = torch.cuda.device_count()
19
 
20
  class EndpointHandler:
21
  def __init__(self, path=""):
22
  # load the model
23
  self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
24
+
25
+ if(gpu_count > 1):
26
+ Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
27
+ Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
28
+ model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto")
29
+ logger.info(model.hf_device_map)
30
+ model.hf_device_map["swin2sr.conv_after_body"] = model.hf_device_map["swin2sr.embeddings"]
31
+ model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"]
32
+ self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map)
33
+
34
+ print(subprocess.run(["nvidia-smi"]))
35
+
36
+ else:
37
+ self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
38
+ # move model to device
39
+ self.model.to(device)
40
+
41
+
42
  def __call__(self, data: Any):
43
  """
44
  Args:
 
49
  """
50
 
51
  image = data["inputs"]
52
+
53
+ if(gpu_count > 1):
54
+ inputs = self.processor(image, return_tensors="pt")
55
+ else:
56
+ inputs = self.processor(image, return_tensors="pt").to(device)
57
 
58
  try:
59
  with torch.no_grad():