drscotthawley commited on
Commit
3bf82fd
·
1 Parent(s): 78b535c

fixing up py files for run

Browse files
Files changed (2) hide show
  1. pom/chords.py +1 -1
  2. pom/utils.py +46 -0
pom/chords.py CHANGED
@@ -5,7 +5,7 @@ import torch.nn as nn
5
  import torch
6
  from PIL import Image
7
  import numpy as np
8
- from control_toys.utils import rect_to_square, square_to_rect
9
 
10
  CHORD_BORDER = 8 # chord border size in pixels
11
 
 
5
  import torch
6
  from PIL import Image
7
  import numpy as np
8
+ from . import rect_to_square, square_to_rect
9
 
10
  CHORD_BORDER = 8 # chord border size in pixels
11
 
pom/utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ simple routines used throughout the project, placed here to avoid circular imports
5
+ """
6
+
7
+ from PIL import Image
8
+ import torch
9
+
10
+ def flip_bottom_half_and_attach(sub_img):
11
+ "takes one 256x256 and returns on 512x128 image with the bottom half reversed and attached on the right"
12
+ h, w = sub_img.size
13
+ new_img = Image.new(sub_img.mode, (w*2, h//2))
14
+ new_img.paste(sub_img.crop((0, 0, w, h//2)), (0, 0))
15
+ new_img.paste(sub_img.crop((0, h//2, w, h)).transpose(Image.FLIP_LEFT_RIGHT), (w, 0))
16
+ return new_img
17
+
18
+ def square_to_rect_tensor(img_tensor:torch.Tensor):
19
+ if len(img_tensor.shape) <= 4:
20
+ img_tensor = img_tensor.unsqueeze(0)
21
+ channels_last = img_tensor.shape[-1] == 3
22
+ if channels_last:
23
+ img_tensor = img_tensor.permute(0, 3, 1, 2)
24
+ channels_first = img_tensor.shape[1] == 3
25
+ b,c,w,h = img_tensor.shape
26
+ new_img = torch.zeros((b,c,w*2,h//2), dtype=img_tensor.dtype).to(img_tensor.device)
27
+ new_img[:,:,:w, :] = img_tensor[:,:,:w,:]
28
+ new_img[:,:,w:, :] = torch.flip(img_tensor[:,:,w:, :], dims=[-2])
29
+ return new_img
30
+
31
+
32
+ def square_to_rect(img):
33
+ #"""just an alias for flip_bottom_half_and_attach"""
34
+ if isinstance(img, torch.Tensor):
35
+ return square_to_rect_tensor(img)
36
+ return flip_bottom_half_and_attach(img)
37
+
38
+ def rect_to_square(img):
39
+ "takes a 512x128 image and returns a 256x256 image with the bottom half reversed"
40
+ w, h = img.size
41
+ new_img = Image.new(img.mode, (w//2, h*2))
42
+ new_img.paste(img.crop((0, 0, w//2, h)), (0, 0))
43
+ new_img.paste(img.crop((w//2, 0, w, h)).transpose(Image.FLIP_LEFT_RIGHT), (0, h))
44
+ return new_img
45
+
46
+