Nekshay commited on
Commit
bc4944b
·
verified ·
1 Parent(s): 288af86

Update dataloder_pytorch.py

Browse files
Files changed (1) hide show
  1. dataloder_pytorch.py +48 -0
dataloder_pytorch.py CHANGED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+
4
+ # Define a custom Dataset class
5
+ class CarShadowDataset(Dataset):
6
+ def __init__(self, root_dir, transform=None):
7
+ self.root_dir = root_dir
8
+ self.transform = transform
9
+ self.image_paths = [] # List to store image paths
10
+
11
+ # Loop through car and shadow folders to collect image paths
12
+ for phase in ['train', 'val', 'test']: # Adjust based on your data structure
13
+ car_folder = os.path.join(root_dir, phase, 'car')
14
+ shadow_folder = os.path.join(root_dir, phase, 'shadow')
15
+
16
+ for filename in os.listdir(car_folder):
17
+ car_path = os.path.join(car_folder, filename)
18
+ shadow_path = os.path.join(shadow_folder, filename.split('.')[0] + '_shadow.jpg') # Assuming consistent naming
19
+ self.image_paths.append((car_path, shadow_path))
20
+
21
+ def __len__(self):
22
+ return len(self.image_paths)
23
+
24
+ def __getitem__(self, idx):
25
+ car_path, shadow_path = self.image_paths[idx]
26
+ car_image = load_image(car_path) # Replace with your image loading function
27
+ shadow_image = load_image(shadow_path) # Replace with your image loading function
28
+
29
+ if self.transform:
30
+ car_image = self.transform(car_image)
31
+ shadow_image = self.transform(shadow_image)
32
+
33
+ return car_image, shadow_image
34
+
35
+ # Function to load image (replace with your preferred method)
36
+ def load_image(path):
37
+ # Implement image loading using libraries like OpenCV or PIL
38
+ # Ensure images are converted to tensors and normalized if needed
39
+ # ...
40
+
41
+ # Prepare data loaders
42
+ train_data = DataLoader(CarShadowDataset(root_dir='dataset/train', transform=your_transform), batch_size=32, shuffle=True)
43
+ val_data = DataLoader(CarShadowDataset(root_dir='dataset/val', transform=your_transform), batch_size=32) # Optional for validation
44
+
45
+ # Example usage
46
+ for car_image, shadow_image in train_data:
47
+ # Access your data for training
48
+ # ...