""" Implementation of YOLOv3 architecture """ import random from typing import Any, Optional import pytorch_lightning as pl import torch import torch.nn as nn import torch.optim as optim from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.optim.lr_scheduler import OneCycleLR from . import config from .loss import YoloLoss """ Information about architecture config: Tuple is structured by (filters, kernel_size, stride) Every conv is a same convolution. List is structured by "B" indicating a residual block followed by the number of repeats "S" is for scale prediction block and computing the yolo loss "U" is for upsampling the feature map and concatenating with a previous layer """ model_config = [ (32, 3, 1), (64, 3, 2), ["B", 1], (128, 3, 2), ["B", 2], (256, 3, 2), ["B", 8], (512, 3, 2), ["B", 8], (1024, 3, 2), ["B", 4], # To this point is Darknet-53 (512, 1, 1), (1024, 3, 1), "S", (256, 1, 1), "U", (256, 1, 1), (512, 3, 1), "S", (128, 1, 1), "U", (128, 1, 1), (256, 3, 1), "S", ] class CNNBlock(pl.LightningModule): def __init__(self, in_channels, out_channels, bn_act=True, **kwargs): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs) self.bn = nn.BatchNorm2d(out_channels) self.leaky = nn.LeakyReLU(0.1) self.use_bn_act = bn_act def forward(self, x): if self.use_bn_act: return self.leaky(self.bn(self.conv(x))) else: return self.conv(x) class ResidualBlock(pl.LightningModule): def __init__(self, channels, use_residual=True, num_repeats=1): super().__init__() self.layers = nn.ModuleList() for repeat in range(num_repeats): self.layers += [ nn.Sequential( CNNBlock(channels, channels // 2, kernel_size=1), CNNBlock(channels // 2, channels, kernel_size=3, padding=1), ) ] self.use_residual = use_residual self.num_repeats = num_repeats def forward(self, x): for layer in self.layers: if self.use_residual: x = x + layer(x) else: x = layer(x) return x class ScalePrediction(pl.LightningModule): def __init__(self, in_channels, num_classes): super().__init__() self.pred = nn.Sequential( CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1), CNNBlock( 2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1 ), ) self.num_classes = num_classes def forward(self, x): return ( self.pred(x) .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]) .permute(0, 1, 3, 4, 2) ) class YOLOv3(pl.LightningModule): def __init__(self, in_channels=3, num_classes=20): super().__init__() self.num_classes = num_classes self.in_channels = in_channels self.layers = self._create_conv_layers() self.scaled_anchors = ( torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) ).to(config.DEVICE) self.learning_rate = config.LEARNING_RATE self.weight_decay = config.WEIGHT_DECAY self.best_lr = 1e-3 def forward(self, x): outputs = [] # for each scale route_connections = [] for layer in self.layers: if isinstance(layer, ScalePrediction): outputs.append(layer(x)) continue x = layer(x) if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: route_connections.append(x) elif isinstance(layer, nn.Upsample): x = torch.cat([x, route_connections[-1]], dim=1) route_connections.pop() return outputs def _create_conv_layers(self): layers = nn.ModuleList() in_channels = self.in_channels for module in model_config: if isinstance(module, tuple): out_channels, kernel_size, stride = module layers.append( CNNBlock( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1 if kernel_size == 3 else 0, ) ) in_channels = out_channels elif isinstance(module, list): num_repeats = module[1] layers.append( ResidualBlock( in_channels, num_repeats=num_repeats, ) ) elif isinstance(module, str): if module == "S": layers += [ ResidualBlock(in_channels, use_residual=False, num_repeats=1), CNNBlock(in_channels, in_channels // 2, kernel_size=1), ScalePrediction(in_channels // 2, num_classes=self.num_classes), ] in_channels = in_channels // 2 elif module == "U": layers.append( nn.Upsample(scale_factor=2), ) in_channels = in_channels * 3 return layers def yololoss(self): return YoloLoss() def training_step(self, batch, batch_idx): x, y = batch y0, y1, y2 = y[0], y[1], y[2] out = self(x) # print(out[0].shape, y0.shape) loss = ( self.yololoss()(out[0], y0, self.scaled_anchors[0]) + self.yololoss()(out[1], y1, self.scaled_anchors[1]) + self.yololoss()(out[2], y2, self.scaled_anchors[2]) ) self.log( "train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True ) # config.IMAGE_SIZE = 416 if random.random() < 0.5 else 544 # config.S = [ # config.IMAGE_SIZE // 32, # config.IMAGE_SIZE // 16, # config.IMAGE_SIZE // 8, # ] # print(f"{self.trainer.datamodule.train_dataset.S=}") # self.trainer.datamodule.train_dataset.S = [ # config.IMAGE_SIZE // 32, # config.IMAGE_SIZE // 16, # config.IMAGE_SIZE // 8, # ] # self.trainer.datamodule.train_dataset.image_size = config.IMAGE_SIZE # self.scaled_anchors = ( # torch.tensor(config.ANCHORS) # * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) # ).to(config.DEVICE) return loss def on_train_epoch_end(self) -> None: print( f"EPOCH: {self.current_epoch}, Loss: {self.trainer.callback_metrics['train_loss_epoch']}" ) def configure_optimizers(self) -> Any: optimizer = optim.Adam( self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) scheduler = OneCycleLR( optimizer, max_lr=self.best_lr, steps_per_epoch=len(self.trainer.datamodule.train_dataloader()), epochs=config.NUM_EPOCHS, pct_start=8 / config.NUM_EPOCHS, div_factor=100, three_phase=False, final_div_factor=100, anneal_strategy="linear", ) return [optimizer], [ {"scheduler": scheduler, "interval": "step", "frequency": 1} ] def on_train_end(self) -> None: torch.save(self.state_dict(), config.MODEL_STATE_DICT_PATH) if __name__ == "__main__": num_classes = 20 IMAGE_SIZE = 416 model = YOLOv3(num_classes=num_classes) x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE)) out = model(x) assert model(x)[0].shape == ( 2, 3, IMAGE_SIZE // 32, IMAGE_SIZE // 32, num_classes + 5, ) assert model(x)[1].shape == ( 2, 3, IMAGE_SIZE // 16, IMAGE_SIZE // 16, num_classes + 5, ) assert model(x)[2].shape == ( 2, 3, IMAGE_SIZE // 8, IMAGE_SIZE // 8, num_classes + 5, ) print("Success!")