Update dataloder_pytorch.py
Browse files- dataloder_pytorch.py +48 -0
dataloder_pytorch.py
CHANGED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
|
4 |
+
# Define a custom Dataset class
|
5 |
+
class CarShadowDataset(Dataset):
|
6 |
+
def __init__(self, root_dir, transform=None):
|
7 |
+
self.root_dir = root_dir
|
8 |
+
self.transform = transform
|
9 |
+
self.image_paths = [] # List to store image paths
|
10 |
+
|
11 |
+
# Loop through car and shadow folders to collect image paths
|
12 |
+
for phase in ['train', 'val', 'test']: # Adjust based on your data structure
|
13 |
+
car_folder = os.path.join(root_dir, phase, 'car')
|
14 |
+
shadow_folder = os.path.join(root_dir, phase, 'shadow')
|
15 |
+
|
16 |
+
for filename in os.listdir(car_folder):
|
17 |
+
car_path = os.path.join(car_folder, filename)
|
18 |
+
shadow_path = os.path.join(shadow_folder, filename.split('.')[0] + '_shadow.jpg') # Assuming consistent naming
|
19 |
+
self.image_paths.append((car_path, shadow_path))
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.image_paths)
|
23 |
+
|
24 |
+
def __getitem__(self, idx):
|
25 |
+
car_path, shadow_path = self.image_paths[idx]
|
26 |
+
car_image = load_image(car_path) # Replace with your image loading function
|
27 |
+
shadow_image = load_image(shadow_path) # Replace with your image loading function
|
28 |
+
|
29 |
+
if self.transform:
|
30 |
+
car_image = self.transform(car_image)
|
31 |
+
shadow_image = self.transform(shadow_image)
|
32 |
+
|
33 |
+
return car_image, shadow_image
|
34 |
+
|
35 |
+
# Function to load image (replace with your preferred method)
|
36 |
+
def load_image(path):
|
37 |
+
# Implement image loading using libraries like OpenCV or PIL
|
38 |
+
# Ensure images are converted to tensors and normalized if needed
|
39 |
+
# ...
|
40 |
+
|
41 |
+
# Prepare data loaders
|
42 |
+
train_data = DataLoader(CarShadowDataset(root_dir='dataset/train', transform=your_transform), batch_size=32, shuffle=True)
|
43 |
+
val_data = DataLoader(CarShadowDataset(root_dir='dataset/val', transform=your_transform), batch_size=32) # Optional for validation
|
44 |
+
|
45 |
+
# Example usage
|
46 |
+
for car_image, shadow_image in train_data:
|
47 |
+
# Access your data for training
|
48 |
+
# ...
|