Vincentqyw commited on
Commit
f517bbf
·
1 Parent(s): 8af5ecd

fix: roma cpu

Browse files
app.py CHANGED
@@ -86,8 +86,7 @@ def ui_reset_state(
86
 
87
 
88
  def run(config):
89
- with gr.Blocks(css="footer {visibility: hidden}"
90
- ) as app:
91
  gr.Markdown(
92
  """
93
  <p align="center">
 
86
 
87
 
88
  def run(config):
89
+ with gr.Blocks(css="footer {visibility: hidden}") as app:
 
90
  gr.Markdown(
91
  """
92
  <p align="center">
third_party/Roma/roma/models/encoders.py CHANGED
@@ -6,6 +6,8 @@ import torch.nn.functional as F
6
  import torchvision.models as tvm
7
  import gc
8
 
 
 
9
 
10
  class ResNet50(nn.Module):
11
  def __init__(
@@ -47,7 +49,7 @@ class ResNet50(nn.Module):
47
  self.amp_dtype = torch.float32
48
 
49
  def forward(self, x, **kwargs):
50
- with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
51
  net = self.net
52
  feats = {1: x}
53
  x = net.conv1(x)
@@ -90,7 +92,7 @@ class VGG19(nn.Module):
90
  self.amp_dtype = torch.float32
91
 
92
  def forward(self, x, **kwargs):
93
- with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
94
  feats = {}
95
  scale = 1
96
  for layer in self.layers:
 
6
  import torchvision.models as tvm
7
  import gc
8
 
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
 
12
  class ResNet50(nn.Module):
13
  def __init__(
 
49
  self.amp_dtype = torch.float32
50
 
51
  def forward(self, x, **kwargs):
52
+ with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
53
  net = self.net
54
  feats = {1: x}
55
  x = net.conv1(x)
 
92
  self.amp_dtype = torch.float32
93
 
94
  def forward(self, x, **kwargs):
95
+ with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
96
  feats = {}
97
  scale = 1
98
  for layer in self.layers:
third_party/Roma/roma/models/matcher.py CHANGED
@@ -14,6 +14,8 @@ from roma.utils.local_correlation import local_correlation
14
  from roma.utils.utils import cls_to_flow_refine
15
  from roma.utils.kde import kde
16
 
 
 
17
 
18
  class ConvRefiner(nn.Module):
19
  def __init__(
@@ -118,7 +120,7 @@ class ConvRefiner(nn.Module):
118
 
119
  def forward(self, x, y, flow, scale_factor=1, logits=None):
120
  b, c, hs, ws = x.shape
121
- with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
122
  with torch.no_grad():
123
  x_hat = F.grid_sample(
124
  y,
@@ -129,8 +131,8 @@ class ConvRefiner(nn.Module):
129
  if self.has_displacement_emb:
130
  im_A_coords = torch.meshgrid(
131
  (
132
- torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
133
- torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
134
  )
135
  )
136
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
@@ -423,7 +425,7 @@ class Decoder(nn.Module):
423
  corresps[ins] = {}
424
  f1_s, f2_s = f1[ins], f2[ins]
425
  if new_scale in self.proj:
426
- with torch.autocast("cuda", self.amp_dtype):
427
  f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
428
 
429
  if ins in coarse_scales:
@@ -643,7 +645,7 @@ class RegressionMatcher(nn.Module):
643
  device=None,
644
  ):
645
  if device is None:
646
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
647
  from PIL import Image
648
 
649
  if isinstance(im_A_path, (str, os.PathLike)):
@@ -739,8 +741,8 @@ class RegressionMatcher(nn.Module):
739
  # Create im_A meshgrid
740
  im_A_coords = torch.meshgrid(
741
  (
742
- torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
743
- torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
744
  )
745
  )
746
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
 
14
  from roma.utils.utils import cls_to_flow_refine
15
  from roma.utils.kde import kde
16
 
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
 
20
  class ConvRefiner(nn.Module):
21
  def __init__(
 
120
 
121
  def forward(self, x, y, flow, scale_factor=1, logits=None):
122
  b, c, hs, ws = x.shape
123
+ with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
124
  with torch.no_grad():
125
  x_hat = F.grid_sample(
126
  y,
 
131
  if self.has_displacement_emb:
132
  im_A_coords = torch.meshgrid(
133
  (
134
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
135
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
136
  )
137
  )
138
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
 
425
  corresps[ins] = {}
426
  f1_s, f2_s = f1[ins], f2[ins]
427
  if new_scale in self.proj:
428
+ with torch.autocast(device, self.amp_dtype):
429
  f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
430
 
431
  if ins in coarse_scales:
 
645
  device=None,
646
  ):
647
  if device is None:
648
+ device = torch.device(device if torch.cuda.is_available() else "cpu")
649
  from PIL import Image
650
 
651
  if isinstance(im_A_path, (str, os.PathLike)):
 
741
  # Create im_A meshgrid
742
  im_A_coords = torch.meshgrid(
743
  (
744
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
745
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
746
  )
747
  )
748
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
third_party/Roma/roma/models/transformer/__init__.py CHANGED
@@ -7,6 +7,8 @@ from .layers.block import Block
7
  from .layers.attention import MemEffAttention
8
  from .dinov2 import vit_large
9
 
 
 
10
 
11
  class TransformerDecoder(nn.Module):
12
  def __init__(
@@ -51,7 +53,7 @@ class TransformerDecoder(nn.Module):
51
  return self._scales.copy()
52
 
53
  def forward(self, gp_posterior, features, old_stuff, new_scale):
54
- with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp):
55
  B, C, H, W = gp_posterior.shape
56
  x = torch.cat((gp_posterior, features), dim=1)
57
  B, C, H, W = x.shape
 
7
  from .layers.attention import MemEffAttention
8
  from .dinov2 import vit_large
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
 
13
  class TransformerDecoder(nn.Module):
14
  def __init__(
 
53
  return self._scales.copy()
54
 
55
  def forward(self, gp_posterior, features, old_stuff, new_scale):
56
+ with torch.autocast(device, dtype=self.amp_dtype, enabled=self.amp):
57
  B, C, H, W = gp_posterior.shape
58
  x = torch.cat((gp_posterior, features), dim=1)
59
  B, C, H, W = x.shape
third_party/Roma/roma/utils/local_correlation.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
  import torch.nn.functional as F
3
 
 
 
4
 
5
  def local_correlation(
6
  feature0,
@@ -20,8 +22,8 @@ def local_correlation(
20
  # If flow is None, assume feature0 and feature1 are aligned
21
  coords = torch.meshgrid(
22
  (
23
- torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"),
24
- torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"),
25
  )
26
  )
27
  coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(B, h, w, 2)
@@ -30,10 +32,10 @@ def local_correlation(
30
  local_window = torch.meshgrid(
31
  (
32
  torch.linspace(
33
- -2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device="cuda"
34
  ),
35
  torch.linspace(
36
- -2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device="cuda"
37
  ),
38
  )
39
  )
 
1
  import torch
2
  import torch.nn.functional as F
3
 
4
+ device = "cuda" if torch.cuda.is_available() else "cpu"
5
+
6
 
7
  def local_correlation(
8
  feature0,
 
22
  # If flow is None, assume feature0 and feature1 are aligned
23
  coords = torch.meshgrid(
24
  (
25
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
26
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
27
  )
28
  )
29
  coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(B, h, w, 2)
 
32
  local_window = torch.meshgrid(
33
  (
34
  torch.linspace(
35
+ -2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device=device
36
  ),
37
  torch.linspace(
38
+ -2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device=device
39
  ),
40
  )
41
  )