LIVE / main.py
Xu Ma
update
3338dc9
raw
history blame
25.9 kB
"""
Here are some use cases:
python main.py --config config/all.yaml --experiment experiment_8x1 --signature demo1 --target data/demo1.png
"""
import pydiffvg
import torch
import cv2
import matplotlib.pyplot as plt
import random
import argparse
import math
import errno
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.nn.functional import adaptive_avg_pool2d
import warnings
warnings.filterwarnings("ignore")
import PIL
import PIL.Image
import os
import os.path as osp
import numpy as np
import numpy.random as npr
import shutil
import copy
# import skfmm
from xing_loss import xing_loss
import yaml
from easydict import EasyDict as edict
pydiffvg.set_print_timing(False)
gamma = 1.0
##########
# helper #
##########
from utils import \
get_experiment_id, \
get_path_schedule, \
edict_2_dict, \
check_and_create_dir
def get_bezier_circle(radius=1, segments=4, bias=None):
points = []
if bias is None:
bias = (random.random(), random.random())
avg_degree = 360 / (segments*3)
for i in range(0, segments*3):
point = (np.cos(np.deg2rad(i * avg_degree)),
np.sin(np.deg2rad(i * avg_degree)))
points.append(point)
points = torch.tensor(points)
points = (points)*radius + torch.tensor(bias).unsqueeze(dim=0)
points = points.type(torch.FloatTensor)
return points
def get_sdf(phi, method='skfmm', **kwargs):
if method == 'skfmm':
import skfmm
phi = (phi-0.5)*2
if (phi.max() <= 0) or (phi.min() >= 0):
return np.zeros(phi.shape).astype(np.float32)
sd = skfmm.distance(phi, dx=1)
flip_negative = kwargs.get('flip_negative', True)
if flip_negative:
sd = np.abs(sd)
truncate = kwargs.get('truncate', 10)
sd = np.clip(sd, -truncate, truncate)
# print(f"max sd value is: {sd.max()}")
zero2max = kwargs.get('zero2max', True)
if zero2max and flip_negative:
sd = sd.max() - sd
elif zero2max:
raise ValueError
normalize = kwargs.get('normalize', 'sum')
if normalize == 'sum':
sd /= sd.sum()
elif normalize == 'to1':
sd /= sd.max()
return sd
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true', default=False)
parser.add_argument("--config", type=str)
parser.add_argument("--experiment", type=str)
parser.add_argument("--seed", type=int)
parser.add_argument("--target", type=str, help="target image path")
parser.add_argument('--log_dir', metavar='DIR', default="log/debug")
parser.add_argument('--initial', type=str, default="random", choices=['random', 'circle'])
parser.add_argument('--signature', nargs='+', type=str)
parser.add_argument('--seginit', nargs='+', type=str)
parser.add_argument("--num_segments", type=int, default=4)
# parser.add_argument("--num_paths", type=str, default="1,1,1")
# parser.add_argument("--num_iter", type=int, default=500)
# parser.add_argument('--free', action='store_true')
# Please ensure that image resolution is divisible by pool_size; otherwise the performance would drop a lot.
# parser.add_argument('--pool_size', type=int, default=40, help="the pooled image size for next path initialization")
# parser.add_argument('--save_loss', action='store_true')
# parser.add_argument('--save_init', action='store_true')
# parser.add_argument('--save_image', action='store_true')
# parser.add_argument('--save_video', action='store_true')
# parser.add_argument('--print_weight', action='store_true')
# parser.add_argument('--circle_init_radius', type=float)
cfg = edict()
args = parser.parse_args()
cfg.debug = args.debug
cfg.config = args.config
cfg.experiment = args.experiment
cfg.seed = args.seed
cfg.target = args.target
cfg.log_dir = args.log_dir
cfg.initial = args.initial
cfg.signature = args.signature
# set cfg num_segments in command
cfg.num_segments = args.num_segments
if args.seginit is not None:
cfg.seginit = edict()
cfg.seginit.type = args.seginit[0]
if cfg.seginit.type == 'circle':
cfg.seginit.radius = float(args.seginit[1])
return cfg
def ycrcb_conversion(im, format='[bs x 3 x 2D]', reverse=False):
mat = torch.FloatTensor([
[ 65.481/255, 128.553/255, 24.966/255], # ranged_from [0, 219/255]
[-37.797/255, -74.203/255, 112.000/255], # ranged_from [-112/255, 112/255]
[112.000/255, -93.786/255, -18.214/255], # ranged_from [-112/255, 112/255]
]).to(im.device)
if reverse:
mat = mat.inverse()
if format == '[bs x 3 x 2D]':
im = im.permute(0, 2, 3, 1)
im = torch.matmul(im, mat.T)
im = im.permute(0, 3, 1, 2).contiguous()
return im
elif format == '[2D x 3]':
im = torch.matmul(im, mat.T)
return im
else:
raise ValueError
class random_coord_init():
def __init__(self, canvas_size):
self.canvas_size = canvas_size
def __call__(self):
h, w = self.canvas_size
return [npr.uniform(0, 1)*w, npr.uniform(0, 1)*h]
class naive_coord_init():
def __init__(self, pred, gt, format='[bs x c x 2D]', replace_sampling=True):
if isinstance(pred, torch.Tensor):
pred = pred.detach().cpu().numpy()
if isinstance(gt, torch.Tensor):
gt = gt.detach().cpu().numpy()
if format == '[bs x c x 2D]':
self.map = ((pred[0] - gt[0])**2).sum(0)
elif format == ['[2D x c]']:
self.map = ((pred - gt)**2).sum(-1)
else:
raise ValueError
self.replace_sampling = replace_sampling
def __call__(self):
coord = np.where(self.map == self.map.max())
coord_h, coord_w = coord[0][0], coord[1][0]
if self.replace_sampling:
self.map[coord_h, coord_w] = -1
return [coord_w, coord_h]
class sparse_coord_init():
def __init__(self, pred, gt, format='[bs x c x 2D]', quantile_interval=200, nodiff_thres=0.1):
if isinstance(pred, torch.Tensor):
pred = pred.detach().cpu().numpy()
if isinstance(gt, torch.Tensor):
gt = gt.detach().cpu().numpy()
if format == '[bs x c x 2D]':
self.map = ((pred[0] - gt[0])**2).sum(0)
self.reference_gt = copy.deepcopy(
np.transpose(gt[0], (1, 2, 0)))
elif format == ['[2D x c]']:
self.map = (np.abs(pred - gt)).sum(-1)
self.reference_gt = copy.deepcopy(gt[0])
else:
raise ValueError
# OptionA: Zero too small errors to avoid the error too small deadloop
self.map[self.map < nodiff_thres] = 0
quantile_interval = np.linspace(0., 1., quantile_interval)
quantized_interval = np.quantile(self.map, quantile_interval)
# remove redundant
quantized_interval = np.unique(quantized_interval)
quantized_interval = sorted(quantized_interval[1:-1])
self.map = np.digitize(self.map, quantized_interval, right=False)
self.map = np.clip(self.map, 0, 255).astype(np.uint8)
self.idcnt = {}
for idi in sorted(np.unique(self.map)):
self.idcnt[idi] = (self.map==idi).sum()
self.idcnt.pop(min(self.idcnt.keys()))
# remove smallest one to remove the correct region
def __call__(self):
if len(self.idcnt) == 0:
h, w = self.map.shape
return [npr.uniform(0, 1)*w, npr.uniform(0, 1)*h]
target_id = max(self.idcnt, key=self.idcnt.get)
_, component, cstats, ccenter = cv2.connectedComponentsWithStats(
(self.map==target_id).astype(np.uint8), connectivity=4)
# remove cid = 0, it is the invalid area
csize = [ci[-1] for ci in cstats[1:]]
target_cid = csize.index(max(csize))+1
center = ccenter[target_cid][::-1]
coord = np.stack(np.where(component == target_cid)).T
dist = np.linalg.norm(coord-center, axis=1)
target_coord_id = np.argmin(dist)
coord_h, coord_w = coord[target_coord_id]
# replace_sampling
self.idcnt[target_id] -= max(csize)
if self.idcnt[target_id] == 0:
self.idcnt.pop(target_id)
self.map[component == target_cid] = 0
return [coord_w, coord_h]
def init_shapes(num_paths,
num_segments,
canvas_size,
seginit_cfg,
shape_cnt,
pos_init_method=None,
trainable_stroke=False,
**kwargs):
shapes = []
shape_groups = []
h, w = canvas_size
# change path init location
if pos_init_method is None:
pos_init_method = random_coord_init(canvas_size=canvas_size)
for i in range(num_paths):
num_control_points = [2] * num_segments
if seginit_cfg.type=="random":
points = []
p0 = pos_init_method()
color_ref = copy.deepcopy(p0)
points.append(p0)
for j in range(num_segments):
radius = seginit_cfg.radius
p1 = (p0[0] + radius * npr.uniform(-0.5, 0.5),
p0[1] + radius * npr.uniform(-0.5, 0.5))
p2 = (p1[0] + radius * npr.uniform(-0.5, 0.5),
p1[1] + radius * npr.uniform(-0.5, 0.5))
p3 = (p2[0] + radius * npr.uniform(-0.5, 0.5),
p2[1] + radius * npr.uniform(-0.5, 0.5))
points.append(p1)
points.append(p2)
if j < num_segments - 1:
points.append(p3)
p0 = p3
points = torch.FloatTensor(points)
# circle points initialization
elif seginit_cfg.type=="circle":
radius = seginit_cfg.radius
if radius is None:
radius = npr.uniform(0.5, 1)
center = pos_init_method()
color_ref = copy.deepcopy(center)
points = get_bezier_circle(
radius=radius, segments=num_segments,
bias=center)
path = pydiffvg.Path(num_control_points = torch.LongTensor(num_control_points),
points = points,
stroke_width = torch.tensor(0.0),
is_closed = True)
shapes.append(path)
# !!!!!!problem is here. the shape group shape_ids is wrong
if 'gt' in kwargs:
wref, href = color_ref
wref = max(0, min(int(wref), w-1))
href = max(0, min(int(href), h-1))
fill_color_init = list(gt[0, :, href, wref]) + [1.]
fill_color_init = torch.FloatTensor(fill_color_init)
stroke_color_init = torch.FloatTensor(npr.uniform(size=[4]))
else:
fill_color_init = torch.FloatTensor(npr.uniform(size=[4]))
stroke_color_init = torch.FloatTensor(npr.uniform(size=[4]))
path_group = pydiffvg.ShapeGroup(
shape_ids = torch.LongTensor([shape_cnt+i]),
fill_color = fill_color_init,
stroke_color = stroke_color_init,
)
shape_groups.append(path_group)
point_var = []
color_var = []
for path in shapes:
path.points.requires_grad = True
point_var.append(path.points)
for group in shape_groups:
group.fill_color.requires_grad = True
color_var.append(group.fill_color)
if trainable_stroke:
stroke_width_var = []
stroke_color_var = []
for path in shapes:
path.stroke_width.requires_grad = True
stroke_width_var.append(path.stroke_width)
for group in shape_groups:
group.stroke_color.requires_grad = True
stroke_color_var.append(group.stroke_color)
return shapes, shape_groups, point_var, color_var, stroke_width_var, stroke_color_var
else:
return shapes, shape_groups, point_var, color_var
class linear_decay_lrlambda_f(object):
def __init__(self, decay_every, decay_ratio):
self.decay_every = decay_every
self.decay_ratio = decay_ratio
def __call__(self, n):
decay_time = n//self.decay_every
decay_step = n %self.decay_every
lr_s = self.decay_ratio**decay_time
lr_e = self.decay_ratio**(decay_time+1)
r = decay_step/self.decay_every
lr = lr_s * (1-r) + lr_e * r
return lr
if __name__ == "__main__":
###############
# make config #
###############
cfg_arg = parse_args()
with open(cfg_arg.config, 'r') as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
cfg_default = edict(cfg['default'])
cfg = edict(cfg[cfg_arg.experiment])
cfg.update(cfg_default)
cfg.update(cfg_arg)
cfg.exid = get_experiment_id(cfg.debug)
cfg.experiment_dir = \
osp.join(cfg.log_dir, '{}_{}'.format(cfg.exid, '_'.join(cfg.signature)))
configfile = osp.join(cfg.experiment_dir, 'config.yaml')
check_and_create_dir(configfile)
with open(osp.join(configfile), 'w') as f:
yaml.dump(edict_2_dict(cfg), f)
# Use GPU if available
pydiffvg.set_use_gpu(torch.cuda.is_available())
device = pydiffvg.get_device()
gt = np.array(PIL.Image.open(cfg.target))
print(f"Input image shape is: {gt.shape}")
if len(gt.shape) == 2:
print("Converting the gray-scale image to RGB.")
gt = gt.unsqueeze(dim=-1).repeat(1,1,3)
if gt.shape[2] == 4:
print("Input image includes alpha channel, simply dropout alpha channel.")
gt = gt[:, :, :3]
gt = (gt/255).astype(np.float32)
gt = torch.FloatTensor(gt).permute(2, 0, 1)[None].to(device)
if cfg.use_ycrcb:
gt = ycrcb_conversion(gt)
h, w = gt.shape[2:]
path_schedule = get_path_schedule(**cfg.path_schedule)
if cfg.seed is not None:
random.seed(cfg.seed)
npr.seed(cfg.seed)
torch.manual_seed(cfg.seed)
render = pydiffvg.RenderFunction.apply
shapes_record, shape_groups_record = [], []
region_loss = None
loss_matrix = []
para_point, para_color = {}, {}
if cfg.trainable.stroke:
para_stroke_width, para_stroke_color = {}, {}
pathn_record = []
# Background
if cfg.trainable.bg:
# meancolor = gt.mean([2, 3])[0]
para_bg = torch.tensor([1., 1., 1.], requires_grad=True, device=device)
else:
if cfg.use_ycrcb:
para_bg = torch.tensor([219/255, 0, 0], requires_grad=False, device=device)
else:
para_bg = torch.tensor([1., 1., 1.], requires_grad=False, device=device)
##################
# start_training #
##################
loss_weight = None
loss_weight_keep = 0
if cfg.coord_init.type == 'naive':
pos_init_method = naive_coord_init(
para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt)
elif cfg.coord_init.type == 'sparse':
pos_init_method = sparse_coord_init(
para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt)
elif cfg.coord_init.type == 'random':
pos_init_method = random_coord_init([h, w])
else:
raise ValueError
lrlambda_f = linear_decay_lrlambda_f(cfg.num_iter, 0.4)
optim_schedular_dict = {}
for path_idx, pathn in enumerate(path_schedule):
loss_list = []
print("=> Adding [{}] paths, [{}] ...".format(pathn, cfg.seginit.type))
pathn_record.append(pathn)
pathn_record_str = '-'.join([str(i) for i in pathn_record])
# initialize new shapes related stuffs.
if cfg.trainable.stroke:
shapes, shape_groups, point_var, color_var, stroke_width_var, stroke_color_var = init_shapes(
pathn, cfg.num_segments, (h, w),
cfg.seginit, len(shapes_record),
pos_init_method,
trainable_stroke=True,
gt=gt, )
para_stroke_width[path_idx] = stroke_width_var
para_stroke_color[path_idx] = stroke_color_var
else:
shapes, shape_groups, point_var, color_var = init_shapes(
pathn, cfg.num_segments, (h, w),
cfg.seginit, len(shapes_record),
pos_init_method,
trainable_stroke=False,
gt=gt, )
shapes_record += shapes
shape_groups_record += shape_groups
if cfg.save.init:
filename = os.path.join(
cfg.experiment_dir, "svg-init",
"{}-init.svg".format(pathn_record_str))
check_and_create_dir(filename)
pydiffvg.save_svg(
filename, w, h,
shapes_record, shape_groups_record)
para = {}
if (cfg.trainable.bg) and (path_idx == 0):
para['bg'] = [para_bg]
para['point'] = point_var
para['color'] = color_var
if cfg.trainable.stroke:
para['stroke_width'] = stroke_width_var
para['stroke_color'] = stroke_color_var
pg = [{'params' : para[ki], 'lr' : cfg.lr_base[ki]} for ki in sorted(para.keys())]
optim = torch.optim.Adam(pg)
if cfg.trainable.record:
scheduler = LambdaLR(
optim, lr_lambda=lrlambda_f, last_epoch=-1)
else:
scheduler = LambdaLR(
optim, lr_lambda=lrlambda_f, last_epoch=cfg.num_iter)
optim_schedular_dict[path_idx] = (optim, scheduler)
# Inner loop training
t_range = tqdm(range(cfg.num_iter))
for t in t_range:
for _, (optim, _) in optim_schedular_dict.items():
optim.zero_grad()
# Forward pass: render the image.
scene_args = pydiffvg.RenderFunction.serialize_scene(
w, h, shapes_record, shape_groups_record)
img = render(w, h, 2, 2, t, None, *scene_args)
# Compose img with white background
img = img[:, :, 3:4] * img[:, :, :3] + \
para_bg * (1 - img[:, :, 3:4])
if cfg.save.video:
filename = os.path.join(
cfg.experiment_dir, "video-png",
"{}-iter{}.png".format(pathn_record_str, t))
check_and_create_dir(filename)
if cfg.use_ycrcb:
imshow = ycrcb_conversion(
img, format='[2D x 3]', reverse=True).detach().cpu()
else:
imshow = img.detach().cpu()
pydiffvg.imwrite(imshow, filename, gamma=gamma)
x = img.unsqueeze(0).permute(0, 3, 1, 2) # HWC -> NCHW
if cfg.use_ycrcb:
color_reweight = torch.FloatTensor([255/219, 255/224, 255/255]).to(device)
loss = ((x-gt)*(color_reweight.view(1, -1, 1, 1)))**2
else:
loss = ((x-gt)**2)
if cfg.loss.use_l1_loss:
loss = abs(x-gt)
if cfg.loss.use_distance_weighted_loss:
if cfg.use_ycrcb:
raise ValueError
shapes_forsdf = copy.deepcopy(shapes)
shape_groups_forsdf = copy.deepcopy(shape_groups)
for si in shapes_forsdf:
si.stroke_width = torch.FloatTensor([0]).to(device)
for sg_idx, sgi in enumerate(shape_groups_forsdf):
sgi.fill_color = torch.FloatTensor([1, 1, 1, 1]).to(device)
sgi.shape_ids = torch.LongTensor([sg_idx]).to(device)
sargs_forsdf = pydiffvg.RenderFunction.serialize_scene(
w, h, shapes_forsdf, shape_groups_forsdf)
with torch.no_grad():
im_forsdf = render(w, h, 2, 2, 0, None, *sargs_forsdf)
# use alpha channel is a trick to get 0-1 image
im_forsdf = (im_forsdf[:, :, 3]).detach().cpu().numpy()
loss_weight = get_sdf(im_forsdf, normalize='to1')
loss_weight += loss_weight_keep
loss_weight = np.clip(loss_weight, 0, 1)
loss_weight = torch.FloatTensor(loss_weight).to(device)
if cfg.save.loss:
save_loss = loss.squeeze(dim=0).mean(dim=0,keepdim=False).cpu().detach().numpy()
save_weight = loss_weight.cpu().detach().numpy()
save_weighted_loss = save_loss*save_weight
# normalize to [0,1]
save_loss = (save_loss - np.min(save_loss))/np.ptp(save_loss)
save_weight = (save_weight - np.min(save_weight))/np.ptp(save_weight)
save_weighted_loss = (save_weighted_loss - np.min(save_weighted_loss))/np.ptp(save_weighted_loss)
# save
plt.imshow(save_loss, cmap='Reds')
plt.axis('off')
# plt.colorbar()
filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-mseloss.png".format(pathn_record_str, t))
check_and_create_dir(filename)
plt.savefig(filename, dpi=800)
plt.close()
plt.imshow(save_weight, cmap='Greys')
plt.axis('off')
# plt.colorbar()
filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-sdfweight.png".format(pathn_record_str, t))
plt.savefig(filename, dpi=800)
plt.close()
plt.imshow(save_weighted_loss, cmap='Reds')
plt.axis('off')
# plt.colorbar()
filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-weightedloss.png".format(pathn_record_str, t))
plt.savefig(filename, dpi=800)
plt.close()
if loss_weight is None:
loss = loss.sum(1).mean()
else:
loss = (loss.sum(1)*loss_weight).mean()
# if (cfg.loss.bis_loss_weight is not None) and (cfg.loss.bis_loss_weight > 0):
# loss_bis = bezier_intersection_loss(point_var[0]) * cfg.loss.bis_loss_weight
# loss = loss + loss_bis
if (cfg.loss.xing_loss_weight is not None) \
and (cfg.loss.xing_loss_weight > 0):
loss_xing = xing_loss(point_var) * cfg.loss.xing_loss_weight
loss = loss + loss_xing
loss_list.append(loss.item())
t_range.set_postfix({'loss': loss.item()})
loss.backward()
# step
for _, (optim, scheduler) in optim_schedular_dict.items():
optim.step()
scheduler.step()
for group in shape_groups_record:
group.fill_color.data.clamp_(0.0, 1.0)
if cfg.loss.use_distance_weighted_loss:
loss_weight_keep = loss_weight.detach().cpu().numpy() * 1
if not cfg.trainable.record:
for _, pi in pg.items():
for ppi in pi:
pi.require_grad = False
optim_schedular_dict = {}
if cfg.save.image:
filename = os.path.join(
cfg.experiment_dir, "demo-png", "{}.png".format(pathn_record_str))
check_and_create_dir(filename)
if cfg.use_ycrcb:
imshow = ycrcb_conversion(
img, format='[2D x 3]', reverse=True).detach().cpu()
else:
imshow = img.detach().cpu()
pydiffvg.imwrite(imshow, filename, gamma=gamma)
if cfg.save.output:
filename = os.path.join(
cfg.experiment_dir, "output-svg", "{}.svg".format(pathn_record_str))
check_and_create_dir(filename)
pydiffvg.save_svg(filename, w, h, shapes_record, shape_groups_record)
loss_matrix.append(loss_list)
# calculate the pixel loss
# pixel_loss = ((x-gt)**2).sum(dim=1, keepdim=True).sqrt_() # [N,1,H, W]
# region_loss = adaptive_avg_pool2d(pixel_loss, cfg.region_loss_pool_size)
# loss_weight = torch.softmax(region_loss.reshape(1, 1, -1), dim=-1)\
# .reshape_as(region_loss)
pos_init_method = naive_coord_init(x, gt)
if cfg.coord_init.type == 'naive':
pos_init_method = naive_coord_init(x, gt)
elif cfg.coord_init.type == 'sparse':
pos_init_method = sparse_coord_init(x, gt)
elif cfg.coord_init.type == 'random':
pos_init_method = random_coord_init([h, w])
else:
raise ValueError
if cfg.save.video:
print("saving iteration video...")
img_array = []
for ii in range(0, cfg.num_iter):
filename = os.path.join(
cfg.experiment_dir, "video-png",
"{}-iter{}.png".format(pathn_record_str, ii))
img = cv2.imread(filename)
# cv2.putText(
# img, "Path:{} \nIteration:{}".format(pathn_record_str, ii),
# (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
img_array.append(img)
videoname = os.path.join(
cfg.experiment_dir, "video-avi",
"{}.avi".format(pathn_record_str))
check_and_create_dir(videoname)
out = cv2.VideoWriter(
videoname,
# cv2.VideoWriter_fourcc(*'mp4v'),
cv2.VideoWriter_fourcc(*'FFV1'),
20.0, (w, h))
for iii in range(len(img_array)):
out.write(img_array[iii])
out.release()
# shutil.rmtree(os.path.join(cfg.experiment_dir, "video-png"))
print("The last loss is: {}".format(loss.item()))