Nekshay commited on
Commit
7700523
·
verified ·
1 Parent(s): bc4944b

Update dataloder_pytorch.py

Browse files
Files changed (1) hide show
  1. dataloder_pytorch.py +60 -0
dataloder_pytorch.py CHANGED
@@ -46,3 +46,63 @@ val_data = DataLoader(CarShadowDataset(root_dir='dataset/val', transform=your_tr
46
  for car_image, shadow_image in train_data:
47
  # Access your data for training
48
  # ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  for car_image, shadow_image in train_data:
47
  # Access your data for training
48
  # ...
49
+
50
+
51
+ # ... (Previous code for model definition and DataLoader)
52
+
53
+ # Discriminator Training
54
+ def train_discriminator(d_optimizer, real_images, fake_images, real_labels, fake_labels):
55
+ # Clear gradients
56
+ d_optimizer.zero_grad()
57
+
58
+ # Forward pass through discriminator
59
+ d_real_output = discriminator(real_images, real_images) # Real images with real shadows
60
+ d_fake_output = discriminator(real_images, fake_images) # Real images with generated shadows
61
+
62
+ # Calculate loss
63
+ d_real_loss = criterion(d_real_output, torch.ones_like(d_real_output))
64
+ d_fake_loss = criterion(d_fake_output, torch.zeros_like(d_fake_output))
65
+ d_loss = (d_real_loss + d_fake_loss) / 2
66
+
67
+ # Backpropagate and update weights
68
+ d_loss.backward()
69
+ d_optimizer.step()
70
+
71
+ # Return the discriminator loss
72
+ return d_loss.item()
73
+
74
+ # Generator Training
75
+ def train_generator(g_optimizer, real_images, fake_images):
76
+ # Clear gradients
77
+ g_optimizer.zero_grad()
78
+
79
+ # Forward pass through discriminator (using generated shadows)
80
+ g_fake_output = discriminator(real_images, fake_images)
81
+
82
+ # Calculate loss (try to fool the discriminator)
83
+ g_loss = criterion(g_fake_output, torch.ones_like(g_fake_output))
84
+
85
+ # Backpropagate and update weights
86
+ g_loss.backward()
87
+ g_optimizer.step()
88
+
89
+ # Return the generator loss
90
+ return g_loss.item()
91
+
92
+ # Training loop
93
+ for epoch in range(epochs):
94
+ for i, (real_images, real_shadows) in enumerate(train_data):
95
+ # Generate fake shadows
96
+ fake_shadows = generator(real_images)
97
+
98
+ # Train discriminator
99
+ d_loss = train_discriminator(d_optimizer, real_images, fake_shadows, torch.ones(real_images.size(0)), torch.zeros(real_images.size(0)))
100
+
101
+ # Train generator
102
+ g_loss = train_generator(g_optimizer, real_images, fake_shadows)
103
+
104
+ # Print training progress
105
+ if i % 100 == 0:
106
+ print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_data)}], D_loss: {d_loss:.4f}, G_loss: {g_loss:.4f}')
107
+
108
+