Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: [email protected] | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
from lib.common.seg3d_lossless import Seg3dLossless | |
from lib.common.train_util import * | |
torch.backends.cudnn.benchmark = True | |
class IFGeo(pl.LightningModule): | |
def __init__(self, cfg): | |
super(IFGeo, self).__init__() | |
self.cfg = cfg | |
self.batch_size = self.cfg.batch_size | |
self.lr_G = self.cfg.lr_G | |
self.use_sdf = cfg.sdf | |
self.mcube_res = cfg.mcube_res | |
self.clean_mesh_flag = cfg.clean_mesh | |
self.overfit = cfg.overfit | |
if cfg.dataset.prior_type == "SMPL": | |
from lib.net.IFGeoNet import IFGeoNet | |
self.netG = IFGeoNet(cfg) | |
else: | |
from lib.net.IFGeoNet_nobody import IFGeoNet | |
self.netG = IFGeoNet(cfg) | |
self.resolutions = ( | |
np.logspace( | |
start=5, | |
stop=np.log2(self.mcube_res), | |
base=2, | |
num=int(np.log2(self.mcube_res) - 4), | |
endpoint=True, | |
) + 1.0 | |
) | |
self.resolutions = self.resolutions.astype(np.int16).tolist() | |
self.reconEngine = Seg3dLossless( | |
query_func=query_func_IF, | |
b_min=[[-1.0, 1.0, -1.0]], | |
b_max=[[1.0, -1.0, 1.0]], | |
resolutions=self.resolutions, | |
align_corners=True, | |
balance_value=0.50, | |
visualize=False, | |
debug=False, | |
use_cuda_impl=False, | |
faster=True, | |
) | |
self.export_dir = None | |
self.result_eval = {} | |
# Training related | |
def configure_optimizers(self): | |
# set optimizer | |
weight_decay = self.cfg.weight_decay | |
momentum = self.cfg.momentum | |
optim_params_G = [{"params": self.netG.parameters(), "lr": self.lr_G}] | |
if self.cfg.optim == "Adadelta": | |
optimizer_G = torch.optim.Adadelta( | |
optim_params_G, lr=self.lr_G, weight_decay=weight_decay | |
) | |
elif self.cfg.optim == "Adam": | |
optimizer_G = torch.optim.Adam(optim_params_G, lr=self.lr_G, weight_decay=weight_decay) | |
elif self.cfg.optim == "RMSprop": | |
optimizer_G = torch.optim.RMSprop( | |
optim_params_G, | |
lr=self.lr_G, | |
weight_decay=weight_decay, | |
momentum=momentum, | |
) | |
else: | |
raise NotImplementedError | |
# set scheduler | |
scheduler_G = torch.optim.lr_scheduler.MultiStepLR( | |
optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma | |
) | |
return [optimizer_G], [scheduler_G] | |
def training_step(self, batch, batch_idx): | |
self.netG.train() | |
preds_G = self.netG(batch) | |
error_G = self.netG.compute_loss(preds_G, batch["labels_geo"]) | |
# metrics processing | |
metrics_log = { | |
"loss": error_G, | |
} | |
self.log_dict( | |
metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True | |
) | |
return metrics_log | |
def training_epoch_end(self, outputs): | |
# metrics processing | |
metrics_log = { | |
"train/avgloss": batch_mean(outputs, "loss"), | |
} | |
self.log_dict( | |
metrics_log, | |
prog_bar=False, | |
logger=True, | |
on_step=False, | |
on_epoch=True, | |
rank_zero_only=True | |
) | |
def validation_step(self, batch, batch_idx): | |
self.netG.eval() | |
self.netG.training = False | |
preds_G = self.netG(batch) | |
error_G = self.netG.compute_loss(preds_G, batch["labels_geo"]) | |
metrics_log = { | |
"val/loss": error_G, | |
} | |
self.log_dict( | |
metrics_log, prog_bar=True, logger=False, on_step=True, on_epoch=False, sync_dist=True | |
) | |
return metrics_log | |
def validation_epoch_end(self, outputs): | |
# metrics processing | |
metrics_log = { | |
"val/avgloss": batch_mean(outputs, "val/loss"), | |
} | |
self.log_dict( | |
metrics_log, | |
prog_bar=False, | |
logger=True, | |
on_step=False, | |
on_epoch=True, | |
rank_zero_only=True | |
) | |