ECON / apps /IFGeo.py
Yuliang's picture
Support TEXTure
487ee6d
raw
history blame
4.96 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
import numpy as np
import pytorch_lightning as pl
import torch
from lib.common.seg3d_lossless import Seg3dLossless
from lib.common.train_util import *
torch.backends.cudnn.benchmark = True
class IFGeo(pl.LightningModule):
def __init__(self, cfg):
super(IFGeo, self).__init__()
self.cfg = cfg
self.batch_size = self.cfg.batch_size
self.lr_G = self.cfg.lr_G
self.use_sdf = cfg.sdf
self.mcube_res = cfg.mcube_res
self.clean_mesh_flag = cfg.clean_mesh
self.overfit = cfg.overfit
if cfg.dataset.prior_type == "SMPL":
from lib.net.IFGeoNet import IFGeoNet
self.netG = IFGeoNet(cfg)
else:
from lib.net.IFGeoNet_nobody import IFGeoNet
self.netG = IFGeoNet(cfg)
self.resolutions = (
np.logspace(
start=5,
stop=np.log2(self.mcube_res),
base=2,
num=int(np.log2(self.mcube_res) - 4),
endpoint=True,
) + 1.0
)
self.resolutions = self.resolutions.astype(np.int16).tolist()
self.reconEngine = Seg3dLossless(
query_func=query_func_IF,
b_min=[[-1.0, 1.0, -1.0]],
b_max=[[1.0, -1.0, 1.0]],
resolutions=self.resolutions,
align_corners=True,
balance_value=0.50,
visualize=False,
debug=False,
use_cuda_impl=False,
faster=True,
)
self.export_dir = None
self.result_eval = {}
# Training related
def configure_optimizers(self):
# set optimizer
weight_decay = self.cfg.weight_decay
momentum = self.cfg.momentum
optim_params_G = [{"params": self.netG.parameters(), "lr": self.lr_G}]
if self.cfg.optim == "Adadelta":
optimizer_G = torch.optim.Adadelta(
optim_params_G, lr=self.lr_G, weight_decay=weight_decay
)
elif self.cfg.optim == "Adam":
optimizer_G = torch.optim.Adam(optim_params_G, lr=self.lr_G, weight_decay=weight_decay)
elif self.cfg.optim == "RMSprop":
optimizer_G = torch.optim.RMSprop(
optim_params_G,
lr=self.lr_G,
weight_decay=weight_decay,
momentum=momentum,
)
else:
raise NotImplementedError
# set scheduler
scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma
)
return [optimizer_G], [scheduler_G]
def training_step(self, batch, batch_idx):
self.netG.train()
preds_G = self.netG(batch)
error_G = self.netG.compute_loss(preds_G, batch["labels_geo"])
# metrics processing
metrics_log = {
"loss": error_G,
}
self.log_dict(
metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True
)
return metrics_log
def training_epoch_end(self, outputs):
# metrics processing
metrics_log = {
"train/avgloss": batch_mean(outputs, "loss"),
}
self.log_dict(
metrics_log,
prog_bar=False,
logger=True,
on_step=False,
on_epoch=True,
rank_zero_only=True
)
def validation_step(self, batch, batch_idx):
self.netG.eval()
self.netG.training = False
preds_G = self.netG(batch)
error_G = self.netG.compute_loss(preds_G, batch["labels_geo"])
metrics_log = {
"val/loss": error_G,
}
self.log_dict(
metrics_log, prog_bar=True, logger=False, on_step=True, on_epoch=False, sync_dist=True
)
return metrics_log
def validation_epoch_end(self, outputs):
# metrics processing
metrics_log = {
"val/avgloss": batch_mean(outputs, "val/loss"),
}
self.log_dict(
metrics_log,
prog_bar=False,
logger=True,
on_step=False,
on_epoch=True,
rank_zero_only=True
)