diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..0a239b92663a62fce9bceb9f22d300dc75a43b26 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +core/data/deg_kair_utils/test.png filter=lfs diff=lfs merge=lfs -text +figures/California_000490.jpg filter=lfs diff=lfs merge=lfs -text +figures/example_dataset/000008.jpg filter=lfs diff=lfs merge=lfs -text +figures/example_dataset/000012.jpg filter=lfs diff=lfs merge=lfs -text +figures/teaser.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed382f1907ddc86c7e9a9618c21441755a6221a9 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,372 @@ +import os +import yaml +import torch +from torch import nn +import wandb +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass +from torch.utils.data import Dataset, DataLoader + +from torch.distributed import init_process_group, destroy_process_group, barrier +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + FullStateDictConfig, + MixedPrecision, + ShardingStrategy, + StateDictType +) + +from .utils import Base, EXPECTED, EXPECTED_TRAIN +from .utils import create_folder_if_necessary, safe_save, load_or_fail + +# pylint: disable=unused-argument +class WarpCore(ABC): + @dataclass(frozen=True) + class Config(Base): + experiment_id: str = EXPECTED_TRAIN + checkpoint_path: str = EXPECTED_TRAIN + output_path: str = EXPECTED_TRAIN + checkpoint_extension: str = "safetensors" + dist_file_subfolder: str = "" + allow_tf32: bool = True + + wandb_project: str = None + wandb_entity: str = None + + @dataclass() # not frozen, means that fields are mutable + class Info(): # not inheriting from Base, because we don't want to enforce the default fields + wandb_run_id: str = None + total_steps: int = 0 + iter: int = 0 + + @dataclass(frozen=True) + class Data(Base): + dataset: Dataset = EXPECTED + dataloader: DataLoader = EXPECTED + iterator: any = EXPECTED + + @dataclass(frozen=True) + class Models(Base): + pass + + @dataclass(frozen=True) + class Optimizers(Base): + pass + + @dataclass(frozen=True) + class Schedulers(Base): + pass + + @dataclass(frozen=True) + class Extras(Base): + pass + # --------------------------------------- + info: Info + config: Config + + # FSDP stuff + fsdp_defaults = { + "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP, + "cpu_offload": None, + "mixed_precision": MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + "limit_all_gathers": True, + } + fsdp_fullstate_save_policy = FullStateDictConfig( + offload_to_cpu=True, rank0_only=True + ) + # ------------ + + # OVERRIDEABLE METHODS + + # [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup + def setup_extras_pre(self) -> Extras: + return self.Extras() + + # setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator + @abstractmethod + def setup_data(self, extras: Extras) -> Data: + raise NotImplementedError("This method needs to be overriden") + + # return a dict with all models that are going to be used in the training + @abstractmethod + def setup_models(self, extras: Extras) -> Models: + raise NotImplementedError("This method needs to be overriden") + + # return a dict with all optimizers that are going to be used in the training + @abstractmethod + def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: + raise NotImplementedError("This method needs to be overriden") + + # [optionally] return a dict with all schedulers that are going to be used in the training + def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers: + return self.Schedulers() + + # [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup + def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras: + return self.Extras.from_dict(extras.to_dict()) + + # perform the training here + @abstractmethod + def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): + raise NotImplementedError("This method needs to be overriden") + # ------------ + + def setup_info(self, full_path=None) -> Info: + if full_path is None: + full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json") + info_dict = load_or_fail(full_path, wandb_run_id=None) or {} + info_dto = self.Info(**info_dict) + if info_dto.total_steps > 0 and self.is_main_node: + print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps) + return info_dto + + def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config: + if config_file_path is not None: + if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"): + with open(config_file_path, "r", encoding="utf-8") as file: + loaded_config = yaml.safe_load(file) + elif config_file_path.endswith(".json"): + with open(config_file_path, "r", encoding="utf-8") as file: + loaded_config = json.load(file) + else: + raise ValueError("Config file must be either a .yml|.yaml or .json file") + return self.Config.from_dict({**loaded_config, 'training': training}) + if config_dict is not None: + return self.Config.from_dict({**config_dict, 'training': training}) + return self.Config(training=training) + + def setup_ddp(self, experiment_id, single_gpu=False): + if not single_gpu: + local_rank = int(os.environ.get("SLURM_LOCALID")) + process_id = int(os.environ.get("SLURM_PROCID")) + world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count() + + self.process_id = process_id + self.is_main_node = process_id == 0 + self.device = torch.device(local_rank) + self.world_size = world_size + + dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}" + # if os.path.exists(dist_file_path) and self.is_main_node: + # os.remove(dist_file_path) + + torch.cuda.set_device(local_rank) + init_process_group( + backend="nccl", + rank=process_id, + world_size=world_size, + init_method=f"file://{dist_file_path}", + ) + print(f"[GPU {process_id}] READY") + else: + print("Running in single thread, DDP not enabled.") + + def setup_wandb(self): + if self.is_main_node and self.config.wandb_project is not None: + self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id() + wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict()) + + if self.info.total_steps > 0: + wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}") + else: + wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started") + + # LOAD UTILITIES ---------- + def load_model(self, model, model_id=None, full_path=None, strict=True): + print('in line 181 load model', type(model), model_id, full_path, strict) + if model_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}" + elif full_path is None and model_id is None: + raise ValueError( + "This method expects either 'model_id' or 'full_path' to be defined" + ) + + checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None) + if checkpoint is not None: + model.load_state_dict(checkpoint, strict=strict) + del checkpoint + + return model + + def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None): + if optim_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt" + elif full_path is None and optim_id is None: + raise ValueError( + "This method expects either 'optim_id' or 'full_path' to be defined" + ) + + checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None) + if checkpoint is not None: + try: + if fsdp_model is not None: + sharded_optimizer_state_dict = ( + FSDP.scatter_full_optim_state_dict( # <---- FSDP + checkpoint + if ( + self.is_main_node + or self.fsdp_defaults["sharding_strategy"] + == ShardingStrategy.NO_SHARD + ) + else None, + fsdp_model, + ) + ) + optim.load_state_dict(sharded_optimizer_state_dict) + del checkpoint, sharded_optimizer_state_dict + else: + optim.load_state_dict(checkpoint) + # pylint: disable=broad-except + except Exception as e: + print("!!! Failed loading optimizer, skipping... Exception:", e) + + return optim + + # SAVE UTILITIES ---------- + def save_info(self, info, suffix=""): + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json" + create_folder_if_necessary(full_path) + if self.is_main_node: + safe_save(vars(self.info), full_path) + + def save_model(self, model, model_id=None, full_path=None, is_fsdp=False): + if model_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}" + elif full_path is None and model_id is None: + raise ValueError( + "This method expects either 'model_id' or 'full_path' to be defined" + ) + create_folder_if_necessary(full_path) + if is_fsdp: + with FSDP.summon_full_params(model): + pass + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy + ): + checkpoint = model.state_dict() + if self.is_main_node: + safe_save(checkpoint, full_path) + del checkpoint + else: + if self.is_main_node: + checkpoint = model.state_dict() + safe_save(checkpoint, full_path) + del checkpoint + + def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None): + if optim_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt" + elif full_path is None and optim_id is None: + raise ValueError( + "This method expects either 'optim_id' or 'full_path' to be defined" + ) + create_folder_if_necessary(full_path) + if fsdp_model is not None: + optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim) + if self.is_main_node: + safe_save(optim_statedict, full_path) + del optim_statedict + else: + if self.is_main_node: + checkpoint = optim.state_dict() + safe_save(checkpoint, full_path) + del checkpoint + # ----- + + def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True): + # Temporary setup, will be overriden by setup_ddp if required + self.device = device + self.process_id = 0 + self.is_main_node = True + self.world_size = 1 + # ---- + + self.config: self.Config = self.setup_config(config_file_path, config_dict, training) + self.info: self.Info = self.setup_info() + + def __call__(self, single_gpu=False): + self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank + self.setup_wandb() + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if self.is_main_node: + print() + print("**STARTIG JOB WITH CONFIG:**") + print(yaml.dump(self.config.to_dict(), default_flow_style=False)) + print("------------------------------------") + print() + print("**INFO:**") + print(yaml.dump(vars(self.info), default_flow_style=False)) + print("------------------------------------") + print() + + # SETUP STUFF + extras = self.setup_extras_pre() + assert extras is not None, "setup_extras_pre() must return a DTO" + + data = self.setup_data(extras) + assert data is not None, "setup_data() must return a DTO" + if self.is_main_node: + print("**DATA:**") + print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + models = self.setup_models(extras) + assert models is not None, "setup_models() must return a DTO" + if self.is_main_node: + print("**MODELS:**") + print(yaml.dump({ + k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() + }, default_flow_style=False)) + print("------------------------------------") + print() + + optimizers = self.setup_optimizers(extras, models) + assert optimizers is not None, "setup_optimizers() must return a DTO" + if self.is_main_node: + print("**OPTIMIZERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + schedulers = self.setup_schedulers(extras, models, optimizers) + assert schedulers is not None, "setup_schedulers() must return a DTO" + if self.is_main_node: + print("**SCHEDULERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) + assert post_extras is not None, "setup_extras_post() must return a DTO" + extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) + if self.is_main_node: + print("**EXTRAS:**") + print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + # ------- + + # TRAIN + if self.is_main_node: + print("**TRAINING STARTING...**") + self.train(data, extras, models, optimizers, schedulers) + + if single_gpu is False: + barrier() + destroy_process_group() + if self.is_main_node: + print() + print("------------------------------------") + print() + print("**TRAINING COMPLETE**") + if self.config.wandb_project is not None: + wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished") diff --git a/core/data/__init__.py b/core/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b687719914b2e303909f7c280347e4bdee607d13 --- /dev/null +++ b/core/data/__init__.py @@ -0,0 +1,69 @@ +import json +import subprocess +import yaml +import os +from .bucketeer import Bucketeer + +class MultiFilter(): + def __init__(self, rules, default=False): + self.rules = rules + self.default = default + + def __call__(self, x): + try: + x_json = x['json'] + if isinstance(x_json, bytes): + x_json = json.loads(x_json) + validations = [] + for k, r in self.rules.items(): + if isinstance(k, tuple): + v = r(*[x_json[kv] for kv in k]) + else: + v = r(x_json[k]) + validations.append(v) + return all(validations) + except Exception: + return False + +class MultiGetter(): + def __init__(self, rules): + self.rules = rules + + def __call__(self, x_json): + if isinstance(x_json, bytes): + x_json = json.loads(x_json) + outputs = [] + for k, r in self.rules.items(): + if isinstance(k, tuple): + v = r(*[x_json[kv] for kv in k]) + else: + v = r(x_json[k]) + outputs.append(v) + if len(outputs) == 1: + outputs = outputs[0] + return outputs + +def setup_webdataset_path(paths, cache_path=None): + if cache_path is None or not os.path.exists(cache_path): + tar_paths = [] + if isinstance(paths, str): + paths = [paths] + for path in paths: + if path.strip().endswith(".tar"): + # Avoid looking up s3 if we already have a tar file + tar_paths.append(path) + continue + bucket = "/".join(path.split("/")[:3]) + result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True) + files = result.stdout.decode('utf-8').split() + files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")] + tar_paths += files + + with open(cache_path, 'w', encoding='utf-8') as outfile: + yaml.dump(tar_paths, outfile, default_flow_style=False) + else: + with open(cache_path, 'r', encoding='utf-8') as file: + tar_paths = yaml.safe_load(file) + + tar_paths_str = ",".join([f"{p}" for p in tar_paths]) + return f"pipe:aws s3 cp {{ {tar_paths_str} }} -" diff --git a/core/data/bucketeer.py b/core/data/bucketeer.py new file mode 100644 index 0000000000000000000000000000000000000000..131e6ba4293bd7c00399f08609aba184b712d5e8 --- /dev/null +++ b/core/data/bucketeer.py @@ -0,0 +1,88 @@ +import torch +import torchvision +import numpy as np +from torchtools.transforms import SmartCrop +import math + +class Bucketeer(): + def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): + assert crop_mode in ['center', 'random', 'smart'] + self.crop_mode = crop_mode + self.ratios = ratios + if reverse_list: + for r in list(ratios): + if 1/r not in self.ratios: + self.ratios.append(1/r) + self.sizes = {} + for dd in density: + self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios] + + self.batch_size = dataloader.batch_size + self.iterator = iter(dataloader) + all_sizes = [] + for k, vs in self.sizes.items(): + all_sizes += vs + self.buckets = {s: [] for s in all_sizes} + self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None + self.p_random_ratio = p_random_ratio + self.interpolate_nearest = interpolate_nearest + + def get_available_batch(self): + for b in self.buckets: + if len(self.buckets[b]) >= self.batch_size: + batch = self.buckets[b][:self.batch_size] + self.buckets[b] = self.buckets[b][self.batch_size:] + return batch + return None + + def get_closest_size(self, x): + w, h = x.size(-1), x.size(-2) + + + best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) + find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()} + min_ = find_dict[list(find_dict.keys())[0]] + find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx] + for dd, val in find_dict.items(): + if val < min_: + min_ = val + find_size = self.sizes[dd][best_size_idx] + + return find_size + + def get_resize_size(self, orig_size, tgt_size): + if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: + alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) + resize_size = max(alt_min, min(tgt_size)) + else: + alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) + resize_size = max(alt_max, max(tgt_size)) + + return resize_size + + def __next__(self): + batch = self.get_available_batch() + while batch is None: + elements = next(self.iterator) + for dct in elements: + img = dct['images'] + size = self.get_closest_size(img) + resize_size = self.get_resize_size(img.shape[-2:], size) + + if self.interpolate_nearest: + img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) + else: + img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) + if self.crop_mode == 'center': + img = torchvision.transforms.functional.center_crop(img, size) + elif self.crop_mode == 'random': + img = torchvision.transforms.RandomCrop(size)(img) + elif self.crop_mode == 'smart': + self.smartcrop.output_size = size + img = self.smartcrop(img) + + self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) + batch = self.get_available_batch() + + out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} + return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} diff --git a/core/data/bucketeer_deg.py b/core/data/bucketeer_deg.py new file mode 100644 index 0000000000000000000000000000000000000000..6deb4bcd18392183b71b1f9a4360e21d6383d1bc --- /dev/null +++ b/core/data/bucketeer_deg.py @@ -0,0 +1,91 @@ +import torch +import torchvision +import numpy as np +from torchtools.transforms import SmartCrop +import math + +class Bucketeer(): + def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): + assert crop_mode in ['center', 'random', 'smart'] + self.crop_mode = crop_mode + self.ratios = ratios + if reverse_list: + for r in list(ratios): + if 1/r not in self.ratios: + self.ratios.append(1/r) + self.sizes = {} + for dd in density: + self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios] + print('in line 17 buckteer', self.sizes) + self.batch_size = dataloader.batch_size + self.iterator = iter(dataloader) + all_sizes = [] + for k, vs in self.sizes.items(): + all_sizes += vs + self.buckets = {s: [] for s in all_sizes} + self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None + self.p_random_ratio = p_random_ratio + self.interpolate_nearest = interpolate_nearest + + def get_available_batch(self): + for b in self.buckets: + if len(self.buckets[b]) >= self.batch_size: + batch = self.buckets[b][:self.batch_size] + self.buckets[b] = self.buckets[b][self.batch_size:] + return batch + return None + + def get_closest_size(self, x): + w, h = x.size(-1), x.size(-2) + #if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio: + # best_size_idx = np.random.randint(len(self.ratios)) + #print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio) + #else: + + best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) + find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()} + min_ = find_dict[list(find_dict.keys())[0]] + find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx] + for dd, val in find_dict.items(): + if val < min_: + min_ = val + find_size = self.sizes[dd][best_size_idx] + + return find_size + + def get_resize_size(self, orig_size, tgt_size): + if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: + alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) + resize_size = max(alt_min, min(tgt_size)) + else: + alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) + resize_size = max(alt_max, max(tgt_size)) + #print('in line 50', orig_size, tgt_size, resize_size) + return resize_size + + def __next__(self): + batch = self.get_available_batch() + while batch is None: + elements = next(self.iterator) + for dct in elements: + img = dct['images'] + size = self.get_closest_size(img) + resize_size = self.get_resize_size(img.shape[-2:], size) + #print('in line 74', img.size(), resize_size) + if self.interpolate_nearest: + img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) + else: + img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) + if self.crop_mode == 'center': + img = torchvision.transforms.functional.center_crop(img, size) + elif self.crop_mode == 'random': + img = torchvision.transforms.RandomCrop(size)(img) + elif self.crop_mode == 'smart': + self.smartcrop.output_size = size + img = self.smartcrop(img) + print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img)) + self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) + batch = self.get_available_batch() + + out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} + return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} diff --git a/core/data/deg_kair_utils/test.bmp b/core/data/deg_kair_utils/test.bmp new file mode 100644 index 0000000000000000000000000000000000000000..6c40c310074fedb7907c7755b5fceda8b9d65b76 Binary files /dev/null and b/core/data/deg_kair_utils/test.bmp differ diff --git a/core/data/deg_kair_utils/test.png b/core/data/deg_kair_utils/test.png new file mode 100644 index 0000000000000000000000000000000000000000..6bc29ad47ae18a1202eaddf850b6ec833226eb0f --- /dev/null +++ b/core/data/deg_kair_utils/test.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f36b49d23d9bb206f8c3ec537b57bd44ed008f3f632706ac02eba73e175fa3d0 +size 1457992 diff --git a/core/data/deg_kair_utils/utils_alignfaces.py b/core/data/deg_kair_utils/utils_alignfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..fa74e8a2e8984f5075d0cbd06afd494c9661a015 --- /dev/null +++ b/core/data/deg_kair_utils/utils_alignfaces.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Apr 24 15:43:29 2017 +@author: zhaoy +""" +import cv2 +import numpy as np +from skimage import transform as trans + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [ + [30.29459953, 51.69630051], + [65.53179932, 51.50139999], + [48.02519989, 71.73660278], + [33.54930115, 92.3655014], + [62.72990036, 92.20410156] +] + +DEFAULT_CROP_SIZE = (96, 112) + + +def _umeyama(src, dst, estimate_scale=True, scale=1.0): + """Estimate N-D similarity transformation with or without scaling. + Parameters + ---------- + src : (M, N) array + Source coordinates. + dst : (M, N) array + Destination coordinates. + estimate_scale : bool + Whether to estimate scaling factor. + Returns + ------- + T : (N + 1, N + 1) + The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem is not well-conditioned. + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` + """ + + num = src.shape[0] + dim = src.shape[1] + + # Compute mean of src and dst. + src_mean = src.mean(axis=0) + dst_mean = dst.mean(axis=0) + + # Subtract mean from src and dst. + src_demean = src - src_mean + dst_demean = dst - dst_mean + + # Eq. (38). + A = dst_demean.T @ src_demean / num + + # Eq. (39). + d = np.ones((dim,), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + T = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * T + elif rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + T[:dim, :dim] = U @ V + else: + s = d[dim - 1] + d[dim - 1] = -1 + T[:dim, :dim] = U @ np.diag(d) @ V + d[dim - 1] = s + else: + T[:dim, :dim] = U @ np.diag(d) @ V + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) + else: + scale = scale + + T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) + T[:dim, :dim] *= scale + + return T, scale + + +class FaceWarpException(Exception): + def __str__(self): + return 'In File {}:{}'.format( + __file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, + inner_padding_factor=0.0, + outer_padding=(0, 0), + default_square=False): + tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) + tmp_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(tmp_crop_size) - tmp_crop_size + tmp_5pts += size_diff / 2 + tmp_crop_size += size_diff + + if (output_size and + output_size[0] == tmp_crop_size[0] and + output_size[1] == tmp_crop_size[1]): + print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size)) + return tmp_5pts + + if (inner_padding_factor == 0 and + outer_padding == (0, 0)): + if output_size is None: + print('No paddings to do: return default reference points') + return tmp_5pts + else: + raise FaceWarpException( + 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') + + if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) + and output_size is None): + output_size = tmp_crop_size * \ + (1 + inner_padding_factor * 2).astype(np.int32) + output_size += np.array(outer_padding) + print(' deduced from paddings, output_size = ', output_size) + + if not (outer_padding[0] < output_size[0] + and outer_padding[1] < output_size[1]): + raise FaceWarpException('Not (outer_padding[0] < output_size[0]' + 'and outer_padding[1] < output_size[1])') + + # 1) pad the inner region according inner_padding_factor + # print('---> STEP1: pad the inner region according inner_padding_factor') + if inner_padding_factor > 0: + size_diff = tmp_crop_size * inner_padding_factor * 2 + tmp_5pts += size_diff / 2 + tmp_crop_size += np.round(size_diff).astype(np.int32) + + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + + # 2) resize the padded inner region + # print('---> STEP2: resize the padded inner region') + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + # print(' crop_size = ', tmp_crop_size) + # print(' size_bf_outer_pad = ', size_bf_outer_pad) + + if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: + raise FaceWarpException('Must have (output_size - outer_padding)' + '= some_scale * (crop_size * (1.0 + inner_padding_factor)') + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] + # print(' resize scale_factor = ', scale_factor) + tmp_5pts = tmp_5pts * scale_factor + # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) + # tmp_5pts = tmp_5pts + size_diff / 2 + tmp_crop_size = size_bf_outer_pad + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + + # 3) add outer_padding to make output_size + reference_5point = tmp_5pts + np.array(outer_padding) + tmp_crop_size = output_size + # print('---> STEP3: add outer_padding to make output_size') + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + # + # print('===> end get_reference_facial_points\n') + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([ + [A[0, 0], A[1, 0], A[2, 0]], + [A[0, 1], A[1, 1], A[2, 1]] + ]) + elif rank == 2: + tfm = np.float32([ + [A[0, 0], A[1, 0], 0], + [A[0, 1], A[1, 1], 0] + ]) + + return tfm + + +def warp_and_crop_face(src_img, + facial_pts, + reference_pts=None, + crop_size=(96, 112), + align_type='smilarity'): #smilarity cv2_affine affine + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points(output_size, + inner_padding_factor, + outer_padding, + default_square) + + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException( + 'reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException( + 'facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException( + 'facial_pts and reference_pts must have the same shape') + + if align_type is 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) + elif align_type is 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) + else: + params, scale = _umeyama(src_pts, ref_pts) + tfm = params[:2, :] + + params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale) + tfm_inv = params[:2, :] + + face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3) + + return face_img, tfm_inv diff --git a/core/data/deg_kair_utils/utils_blindsr.py b/core/data/deg_kair_utils/utils_blindsr.py new file mode 100644 index 0000000000000000000000000000000000000000..b76d1e163b5beb96ebefed575dae3af96ca176e0 --- /dev/null +++ b/core/data/deg_kair_utils/utils_blindsr.py @@ -0,0 +1,631 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from core.data.deg_kair_utils import utils_image as util + +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth + + + + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf-1)*0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w-1) + y1 = np.clip(y1, 0, h-1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1,c,1,1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z-MU + ZZ_t = ZZ.transpose(0,1,3,2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + #kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1)) + arg = -(x*x + y*y)/(2*std*std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h/sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha,1])]) + h1 = alpha/(alpha+1) + h2 = (1-alpha)/(alpha+1) + h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1/sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + + Return: + downsampled LR image + + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + + ''' bicubic downsampling + blur + + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + + Return: + downsampled LR image + + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2*sf + if random.random() < 0.5: + l1 = wd2*random.random() + l2 = wd2*random.random() + k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5/sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2/255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3,3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2/255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3,3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10**(2*random.random()+2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h-lq_patchsize) + rnd_w = random.randint(0, w-lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize*sf or w < lq_patchsize*sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3])) + else: + img = util.imresize_np(img, 1/2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1,2*sf) + img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + + + +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize*sf or w < lq_patchsize*sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + + +if __name__ == '__main__': + img = util.imread_uint('utils/test.png', 3) + img = util.uint2single(img) + sf = 4 + + for i in range(20): + img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72) + print(i) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0) + img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i)+'.png') + +# for i in range(10): +# img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64) +# print(i) +# lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0) +# img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1) +# util.imsave(img_concat, str(i)+'.png') + +# run utils/utils_blindsr.py diff --git a/core/data/deg_kair_utils/utils_bnorm.py b/core/data/deg_kair_utils/utils_bnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd346e05b66efd074f81f1961068e2de45ac5da --- /dev/null +++ b/core/data/deg_kair_utils/utils_bnorm.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + + +""" +# -------------------------------------------- +# Batch Normalization +# -------------------------------------------- + +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# 01/Jan/2019 +# -------------------------------------------- +""" + + +# -------------------------------------------- +# remove/delete specified layer +# -------------------------------------------- +def deleteLayer(model, layer_type=nn.BatchNorm2d): + ''' Kai Zhang, 11/Jan/2019. + ''' + for k, m in list(model.named_children()): + if isinstance(m, layer_type): + del model._modules[k] + deleteLayer(m, layer_type) + + +# -------------------------------------------- +# merge bn, "conv+bn" --> "conv" +# -------------------------------------------- +def merge_bn(model): + ''' Kai Zhang, 11/Jan/2019. + merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv') + based on https://github.com/pytorch/pytorch/pull/901 + ''' + prev_m = None + for k, m in list(model.named_children()): + if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)): + + w = prev_m.weight.data + + if prev_m.bias is None: + zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type()) + prev_m.bias = nn.Parameter(zeros) + b = prev_m.bias.data + + invstd = m.running_var.clone().add_(m.eps).pow_(-0.5) + if isinstance(prev_m, nn.ConvTranspose2d): + w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w)) + else: + w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w)) + b.add_(-m.running_mean).mul_(invstd) + if m.affine: + if isinstance(prev_m, nn.ConvTranspose2d): + w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w)) + else: + w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w)) + b.mul_(m.weight.data).add_(m.bias.data) + + del model._modules[k] + prev_m = m + merge_bn(m) + + +# -------------------------------------------- +# add bn, "conv" --> "conv+bn" +# -------------------------------------------- +def add_bn(model): + ''' Kai Zhang, 11/Jan/2019. + ''' + for k, m in list(model.named_children()): + if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)): + b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True) + b.weight.data.fill_(1) + new_m = nn.Sequential(model._modules[k], b) + model._modules[k] = new_m + add_bn(m) + + +# -------------------------------------------- +# tidy model after removing bn +# -------------------------------------------- +def tidy_sequential(model): + ''' Kai Zhang, 11/Jan/2019. + ''' + for k, m in list(model.named_children()): + if isinstance(m, nn.Sequential): + if m.__len__() == 1: + model._modules[k] = m.__getitem__(0) + tidy_sequential(m) diff --git a/core/data/deg_kair_utils/utils_deblur.py b/core/data/deg_kair_utils/utils_deblur.py new file mode 100644 index 0000000000000000000000000000000000000000..c5457b9c1df3bd7bbe8758cf8be5824273b8db29 --- /dev/null +++ b/core/data/deg_kair_utils/utils_deblur.py @@ -0,0 +1,655 @@ +# -*- coding: utf-8 -*- +import numpy as np +import scipy +from scipy import fftpack +import torch + +from math import cos, sin +from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round +from numpy.random import randn, rand +from scipy.signal import convolve2d +import cv2 +import random +# import utils_image as util + +''' +modified by Kai Zhang (github: https://github.com/cszn) +03/03/2019 +''' + + +def get_uperleft_denominator(img, kernel): + ''' + img: HxWxC + kernel: hxw + denominator: HxWx1 + upperleft: HxWxC + ''' + V = psf2otf(kernel, img.shape[:2]) + denominator = np.expand_dims(np.abs(V)**2, axis=2) + upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1]) + return upperleft, denominator + + +def get_uperleft_denominator_pytorch(img, kernel): + ''' + img: NxCxHxW + kernel: Nx1xhxw + denominator: Nx1xHxW + upperleft: NxCxHxWx2 + ''' + V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2 + denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW + upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2 + return upperleft, denominator + + +def c2c(x): + return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1)) + + +def r2c(x): + return torch.stack([x, torch.zeros_like(x)], -1) + + +def cdiv(x, y): + a, b = x[..., 0], x[..., 1] + c, d = y[..., 0], y[..., 1] + cd2 = c**2 + d**2 + return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1) + + +def cabs(x): + return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5) + + +def cmul(t1, t2): + ''' + complex multiplication + t1: NxCxHxWx2 + output: NxCxHxWx2 + ''' + real1, imag1 = t1[..., 0], t1[..., 1] + real2, imag2 = t2[..., 0], t2[..., 1] + return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1) + + +def cconj(t, inplace=False): + ''' + # complex's conjugation + t: NxCxHxWx2 + output: NxCxHxWx2 + ''' + c = t.clone() if not inplace else t + c[..., 1] *= -1 + return c + + +def rfft(t): + return torch.rfft(t, 2, onesided=False) + + +def irfft(t): + return torch.irfft(t, 2, onesided=False) + + +def fft(t): + return torch.fft(t, 2) + + +def ifft(t): + return torch.ifft(t, 2) + + +def p2o(psf, shape): + ''' + # psf: NxCxhxw + # shape: [H,W] + # otf: NxCxHxWx2 + ''' + otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) + otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) + for axis, axis_size in enumerate(psf.shape[2:]): + otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) + otf = torch.rfft(otf, 2, onesided=False) + n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) + otf[...,1][torch.abs(otf[...,1])= abs(y)] = abs(x)[abs(x) >= abs(y)] + maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)] + minxy = np.zeros(x.shape) + minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)] + minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)] + m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\ + (rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\ + np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2) + m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\ + (rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\ + np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2) + h = None + return h + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1)) + arg = -(x*x + y*y)/(2*std*std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h/sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha,1])]) + h1 = alpha/(alpha+1) + h2 = (1-alpha)/(alpha+1) + h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial_log(hsize, sigma): + raise(NotImplemented) + + +def fspecial_motion(motion_len, theta): + raise(NotImplemented) + + +def fspecial_prewitt(): + return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]]) + + +def fspecial_sobel(): + return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'average': + return fspecial_average(*args, **kwargs) + if filter_type == 'disk': + return fspecial_disk(*args, **kwargs) + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + if filter_type == 'log': + return fspecial_log(*args, **kwargs) + if filter_type == 'motion': + return fspecial_motion(*args, **kwargs) + if filter_type == 'prewitt': + return fspecial_prewitt(*args, **kwargs) + if filter_type == 'sobel': + return fspecial_sobel(*args, **kwargs) + + +def fspecial_gauss(size, sigma): + x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1] + g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2))) + return g / g.sum() + + +def blurkernel_synthesis(h=37, w=None): + # https://github.com/tkkcc/prior/blob/879a0b6c117c810776d8cc6b63720bf29f7d0cc4/util/gen_kernel.py + w = h if w is None else w + kdims = [h, w] + x = randomTrajectory(250) + k = None + while k is None: + k = kernelFromTrajectory(x) + + # center pad to kdims + pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2) + pad_width = [(pad_width[0],), (pad_width[1],)] + + if pad_width[0][0]<0 or pad_width[1][0]<0: + k = k[0:h, 0:h] + else: + k = pad(k, pad_width, "constant") + x1,x2 = k.shape + if np.random.randint(0, 4) == 1: + k = cv2.resize(k, (random.randint(x1, 5*x1), random.randint(x2, 5*x2)), interpolation=cv2.INTER_LINEAR) + y1, y2 = k.shape + k = k[(y1-x1)//2: (y1-x1)//2+x1, (y2-x2)//2: (y2-x2)//2+x2] + + if sum(k)<0.1: + k = fspecial_gaussian(h, 0.1+6*np.random.rand(1)) + k = k / sum(k) + # import matplotlib.pyplot as plt + # plt.imshow(k, interpolation="nearest", cmap="gray") + # plt.show() + return k + + +def kernelFromTrajectory(x): + h = 5 - log(rand()) / 0.15 + h = round(min([h, 27])).astype(int) + h = h + 1 - h % 2 + w = h + k = zeros((h, w)) + + xmin = min(x[0]) + xmax = max(x[0]) + ymin = min(x[1]) + ymax = max(x[1]) + xthr = arange(xmin, xmax, (xmax - xmin) / w) + ythr = arange(ymin, ymax, (ymax - ymin) / h) + + for i in range(1, xthr.size): + for j in range(1, ythr.size): + idx = ( + (x[0, :] >= xthr[i - 1]) + & (x[0, :] < xthr[i]) + & (x[1, :] >= ythr[j - 1]) + & (x[1, :] < ythr[j]) + ) + k[i - 1, j - 1] = sum(idx) + if sum(k) == 0: + return + k = k / sum(k) + k = convolve2d(k, fspecial_gauss(3, 1), "same") + k = k / sum(k) + return k + + +def randomTrajectory(T): + x = zeros((3, T)) + v = randn(3, T) + r = zeros((3, T)) + trv = 1 / 1 + trr = 2 * pi / T + for t in range(1, T): + F_rot = randn(3) / (t + 1) + r[:, t - 1] + F_trans = randn(3) / (t + 1) + r[:, t] = r[:, t - 1] + trr * F_rot + v[:, t] = v[:, t - 1] + trv * F_trans + st = v[:, t] + st = rot3D(st, r[:, t]) + x[:, t] = x[:, t - 1] + st + return x + + +def rot3D(x, r): + Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]]) + Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]]) + Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]]) + R = Rz @ Ry @ Rx + x = R @ x + return x + + +if __name__ == '__main__': + a = opt_fft_size([111]) + print(a) + + print(fspecial('gaussian', 5, 1)) + + print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape) + + k = blurkernel_synthesis(11) + import matplotlib.pyplot as plt + plt.imshow(k, interpolation="nearest", cmap="gray") + plt.show() diff --git a/core/data/deg_kair_utils/utils_dist.py b/core/data/deg_kair_utils/utils_dist.py new file mode 100644 index 0000000000000000000000000000000000000000..7729e3af0b8fc3f48bb050b5eb31eaf971488d1e --- /dev/null +++ b/core/data/deg_kair_utils/utils_dist.py @@ -0,0 +1,201 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +# ---------------------------------- +# init +# ---------------------------------- +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + + +# ---------------------------------- +# get rank and world_size +# ---------------------------------- +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def get_rank(): + if not dist.is_available(): + return 0 + + if not dist.is_initialized(): + return 0 + + return dist.get_rank() + + +def get_world_size(): + if not dist.is_available(): + return 1 + + if not dist.is_initialized(): + return 1 + + return dist.get_world_size() + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + + + + + +# ---------------------------------- +# operation across ranks +# ---------------------------------- +def reduce_sum(tensor): + if not dist.is_available(): + return tensor + + if not dist.is_initialized(): + return tensor + + tensor = tensor.clone() + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + return tensor + + +def gather_grad(params): + world_size = get_world_size() + + if world_size == 1: + return + + for param in params: + if param.grad is not None: + dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) + param.grad.data.div_(world_size) + + +def all_gather(data): + world_size = get_world_size() + + if world_size == 1: + return [data] + + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to('cuda') + + local_size = torch.IntTensor([tensor.numel()]).to('cuda') + size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) + + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') + tensor = torch.cat((tensor, padding), 0) + + dist.all_gather(tensor_list, tensor) + + data_list = [] + + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_loss_dict(loss_dict): + world_size = get_world_size() + + if world_size < 2: + return loss_dict + + with torch.no_grad(): + keys = [] + losses = [] + + for k in sorted(loss_dict.keys()): + keys.append(k) + losses.append(loss_dict[k]) + + losses = torch.stack(losses, 0) + dist.reduce(losses, dst=0) + + if dist.get_rank() == 0: + losses /= world_size + + reduced_losses = {k: v for k, v in zip(keys, losses)} + + return reduced_losses + diff --git a/core/data/deg_kair_utils/utils_googledownload.py b/core/data/deg_kair_utils/utils_googledownload.py new file mode 100644 index 0000000000000000000000000000000000000000..f4acaf78d7cc60bec569cae2f02f2ec049407615 --- /dev/null +++ b/core/data/deg_kair_utils/utils_googledownload.py @@ -0,0 +1,93 @@ +import math +import requests +from tqdm import tqdm + + +''' +borrowed from +https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py +''' + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get( + URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + if 'Content-Range' in response_file_size.headers: + file_size = int( + response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, + destination, + file_size=None, + chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' + f'/ {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +if __name__ == "__main__": + file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv' + save_path = 'BSRGAN.pth' + download_file_from_google_drive(file_id, save_path) diff --git a/core/data/deg_kair_utils/utils_image.py b/core/data/deg_kair_utils/utils_image.py new file mode 100644 index 0000000000000000000000000000000000000000..0e513a8bc1594c9ce2ba47ce3fe3b497269b7f16 --- /dev/null +++ b/core/data/deg_kair_utils/utils_image.py @@ -0,0 +1,1016 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +# import torchvision.transforms as transforms +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if isinstance(dataroot, str): + paths = sorted(_get_paths_from_images(dataroot)) + elif isinstance(dataroot, list): + paths = [] + for i in dataroot: + paths += sorted(_get_paths_from_images(i)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) + # print(w1) + # print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(uint) +# numpy(single) <---> tensor +# numpy(uint) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(uint) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(uint) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.uint8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + rlt = np.clip(rlt, 0, 255) + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR, SSIM and PSNRB +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def _blocking_effect_factor(im): + block_size = 8 + + block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8) + block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8) + + horizontal_block_difference = ( + (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum( + 3).sum(2).sum(1) + vertical_block_difference = ( + (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum( + 2).sum(1) + + nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions) + nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions) + + horizontal_nonblock_difference = ( + (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum( + 3).sum(2).sum(1) + vertical_nonblock_difference = ( + (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum( + 3).sum(2).sum(1) + + n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1) + n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1) + boundary_difference = (horizontal_block_difference + vertical_block_difference) / ( + n_boundary_horiz + n_boundary_vert) + + n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz + n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert + nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / ( + n_nonboundary_horiz + n_nonboundary_vert) + + scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]])) + bef = scaler * (boundary_difference - nonboundary_difference) + + bef[boundary_difference <= nonboundary_difference] = 0 + return bef + + +def calculate_psnrb(img1, img2, border=0): + """Calculate PSNR-B (Peak Signal-to-Noise Ratio). + Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation + # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + Returns: + float: psnr result. + """ + + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + + if img1.ndim == 2: + img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2) + + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py + img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255. + img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255. + + total = 0 + for c in range(img1.shape[1]): + mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none') + bef = _blocking_effect_factor(img1[:, c:c + 1, :, :]) + + mse = mse.view(mse.shape[0], -1).mean(1) + total += 10 * torch.log10(1 / (mse + bef)) + + return float(total) / img1.shape[1] + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) +# imshow(single2uint(img_bicubic)) +# +# img_tensor = single2tensor4(img) +# for i in range(8): +# imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1)) + +# patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200) +# imssave(patches,'a.png') + + + + + + + diff --git a/core/data/deg_kair_utils/utils_lmdb.py b/core/data/deg_kair_utils/utils_lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..75192c346bb9c0b96f8b09635ed548bd6e797d89 --- /dev/null +++ b/core/data/deg_kair_utils/utils_lmdb.py @@ -0,0 +1,205 @@ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Totoal images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + # deal with `libpng error: Read Error` + if img is None: + print(f'To deal with `libpng error: Read Error`, use PIL to load {path}') + from PIL import Image + import numpy as np + img = Image.open(path) + img = np.asanyarray(img) + img = img[:, :, [2, 1, 0]] + + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/core/data/deg_kair_utils/utils_logger.py b/core/data/deg_kair_utils/utils_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..3067190e1b09b244814e0ccc4496b18f06e22b54 --- /dev/null +++ b/core/data/deg_kair_utils/utils_logger.py @@ -0,0 +1,66 @@ +import sys +import datetime +import logging + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +def log(*args, **kwargs): + print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) + + +''' +# -------------------------------------------- +# logger +# -------------------------------------------- +''' + + +def logger_info(logger_name, log_path='default_logger.log'): + ''' set up logger + modified by Kai Zhang (github: https://github.com/cszn) + ''' + log = logging.getLogger(logger_name) + if log.hasHandlers(): + print('LogHandlers exist!') + else: + print('LogHandlers setup!') + level = logging.INFO + formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') + fh = logging.FileHandler(log_path, mode='a') + fh.setFormatter(formatter) + log.setLevel(level) + log.addHandler(fh) + # print(len(log.handlers)) + + sh = logging.StreamHandler() + sh.setFormatter(formatter) + log.addHandler(sh) + + +''' +# -------------------------------------------- +# print to file and std_out simultaneously +# -------------------------------------------- +''' + + +class logger_print(object): + def __init__(self, log_path="default.log"): + self.terminal = sys.stdout + self.log = open(log_path, 'a') + + def write(self, message): + self.terminal.write(message) + self.log.write(message) # write the message + + def flush(self): + pass diff --git a/core/data/deg_kair_utils/utils_mat.py b/core/data/deg_kair_utils/utils_mat.py new file mode 100644 index 0000000000000000000000000000000000000000..cd25d500c0eae77a3b815b8e956205b737ee43d4 --- /dev/null +++ b/core/data/deg_kair_utils/utils_mat.py @@ -0,0 +1,88 @@ +import os +import json +import scipy.io as spio +import pandas as pd + + +def loadmat(filename): + ''' + this function should be called instead of direct spio.loadmat + as it cures the problem of not properly recovering python dictionaries + from mat files. It calls the function check keys to cure all entries + which are still mat-objects + ''' + data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True) + return dict_to_nonedict(_check_keys(data)) + +def _check_keys(dict): + ''' + checks if entries in dictionary are mat-objects. If yes + todict is called to change them to nested dictionaries + ''' + for key in dict: + if isinstance(dict[key], spio.matlab.mio5_params.mat_struct): + dict[key] = _todict(dict[key]) + return dict + +def _todict(matobj): + ''' + A recursive function which constructs from matobjects nested dictionaries + ''' + dict = {} + for strg in matobj._fieldnames: + elem = matobj.__dict__[strg] + if isinstance(elem, spio.matlab.mio5_params.mat_struct): + dict[strg] = _todict(elem) + else: + dict[strg] = elem + return dict + + +def dict_to_nonedict(opt): + if isinstance(opt, dict): + new_opt = dict() + for key, sub_opt in opt.items(): + new_opt[key] = dict_to_nonedict(sub_opt) + return NoneDict(**new_opt) + elif isinstance(opt, list): + return [dict_to_nonedict(sub_opt) for sub_opt in opt] + else: + return opt + + +class NoneDict(dict): + def __missing__(self, key): + return None + + +def mat2json(mat_path=None, filepath = None): + """ + Converts .mat file to .json and writes new file + Parameters + ---------- + mat_path: Str + path/filename .mat存放路径 + filepath: Str + 如果需要保存成json, 添加这一路径. 否则不保存 + Returns + 返回转化的字典 + ------- + None + Examples + -------- + >>> mat2json(blah blah) + """ + + matlabFile = loadmat(mat_path) + #pop all those dumb fields that don't let you jsonize file + matlabFile.pop('__header__') + matlabFile.pop('__version__') + matlabFile.pop('__globals__') + #jsonize the file - orientation is 'index' + matlabFile = pd.Series(matlabFile).to_json() + + if filepath: + json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json' + with open(json_path, 'w') as f: + f.write(matlabFile) + return matlabFile \ No newline at end of file diff --git a/core/data/deg_kair_utils/utils_matconvnet.py b/core/data/deg_kair_utils/utils_matconvnet.py new file mode 100644 index 0000000000000000000000000000000000000000..37d5929692e8eadf5ec57d1616626a0611492ee2 --- /dev/null +++ b/core/data/deg_kair_utils/utils_matconvnet.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +import numpy as np +import torch +from collections import OrderedDict + +# import scipy.io as io +import hdf5storage + +""" +# -------------------------------------------- +# Convert matconvnet SimpleNN model into pytorch model +# -------------------------------------------- +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# 28/Nov/2019 +# -------------------------------------------- +""" + + +def weights2tensor(x, squeeze=False, in_features=None, out_features=None): + """Modified version of https://github.com/albanie/pytorch-mcn + Adjust memory layout and load weights as torch tensor + Args: + x (ndaray): a numpy array, corresponding to a set of network weights + stored in column major order + squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove + singletons from the trailing dimensions. So after converting to + pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1) + it will be reshaped to a matrix with shape (A,B). + in_features (int :: None): used to reshape weights for a linear block. + out_features (int :: None): used to reshape weights for a linear block. + Returns: + torch.tensor: a permuted sets of weights, matching the pytorch layout + convention + """ + if x.ndim == 4: + x = x.transpose((3, 2, 0, 1)) +# for FFDNet, pixel-shuffle layer +# if x.shape[1]==13: +# x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:] +# if x.shape[0]==12: +# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] +# if x.shape[1]==5: +# x=x[:,[0,2,1,3, 4],:,:] +# if x.shape[0]==4: +# x=x[[0,2,1,3],:,:,:] +## for SRMD, pixel-shuffle layer +# if x.shape[0]==12: +# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] +# if x.shape[0]==27: +# x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:] +# if x.shape[0]==48: +# x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:] + + elif x.ndim == 3: # add by Kai + x = x[:,:,:,None] + x = x.transpose((3, 2, 0, 1)) + elif x.ndim == 2: + if x.shape[1] == 1: + x = x.flatten() + if squeeze: + if in_features and out_features: + x = x.reshape((out_features, in_features)) + x = np.squeeze(x) + return torch.from_numpy(np.ascontiguousarray(x)) + + +def save_model(network, save_path): + state_dict = network.state_dict() + for key, param in state_dict.items(): + state_dict[key] = param.cpu() + torch.save(state_dict, save_path) + + +if __name__ == '__main__': + + +# from utils import utils_logger +# import logging +# utils_logger.logger_info('a', 'a.log') +# logger = logging.getLogger('a') +# + # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat') + mcn = hdf5storage.loadmat('models/modelcolor.mat') + + + #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0]) + + mat_net = OrderedDict() + for idx in range(25): + mat_net[str(idx)] = OrderedDict() + count = -1 + + print(idx) + for i in range(13): + + if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv': + + count += 1 + w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0] + # print(w.shape) + w = weights2tensor(w) + # print(w.shape) + + b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1] + b = weights2tensor(b) + print(b.shape) + + mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w + mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b + + torch.save(mat_net, 'model_zoo/modelcolor.pth') + + + +# from models.network_dncnn import IRCNN as net +# network = net(in_nc=3, out_nc=3, nc=64) +# state_dict = network.state_dict() +# +# #show_kv(state_dict) +# +# for i in range(len(mcn['net'][0][0][0])): +# print(mcn['net'][0][0][0][i][0][0][0][0]) +# +# count = -1 +# mat_net = OrderedDict() +# for i in range(len(mcn['net'][0][0][0])): +# if mcn['net'][0][0][0][i][0][0][0][0] == 'conv': +# +# count += 1 +# w = mcn['net'][0][0][0][i][0][1][0][0] +# print(w.shape) +# w = weights2tensor(w) +# print(w.shape) +# +# b = mcn['net'][0][0][0][i][0][1][0][1] +# b = weights2tensor(b) +# print(b.shape) +# +# mat_net['model.{:d}.weight'.format(count*2)] = w +# mat_net['model.{:d}.bias'.format(count*2)] = b +# +# torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth') +# +# +# +# crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth') +# def show_kv(net): +# for k, v in net.items(): +# print(k) +# +# show_kv(crt_net) + + +# from models.network_dncnn import DnCNN as net +# network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R') + +# from models.network_srmd import SRMD as net +# #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R') +# network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle') +# +# from models.network_rrdb import RRDB as net +# network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv') +# +# state_dict = network.state_dict() +# for key, param in state_dict.items(): +# print(key) +# from models.network_imdn import IMDN as net +# network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle') +# state_dict = network.state_dict() +# mat_net = OrderedDict() +# for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()): +# mat_net[key] = param2 +# torch.save(mat_net, 'model_zoo/imdn_x4_1.pth') +# + +# net_old = torch.load('net_old.pth') +# def show_kv(net): +# for k, v in net.items(): +# print(k) +# +# show_kv(net_old) +# from models.network_dpsr import MSRResNet_prior as net +# model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle') +# state_dict = network.state_dict() +# net_new = OrderedDict() +# for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()): +# net_new[key] = param_old +# torch.save(net_new, 'net_new.pth') + + + # print(key) + # print(param.size()) + + + + # run utils/utils_matconvnet.py diff --git a/core/data/deg_kair_utils/utils_model.py b/core/data/deg_kair_utils/utils_model.py new file mode 100644 index 0000000000000000000000000000000000000000..94ced53c0e34bd0938e5e55ed22b1cf214885477 --- /dev/null +++ b/core/data/deg_kair_utils/utils_model.py @@ -0,0 +1,330 @@ +# -*- coding: utf-8 -*- +import numpy as np +import torch +from utils import utils_image as util +import re +import glob +import os + + +''' +# -------------------------------------------- +# Model +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +''' + + +def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): + """ + # --------------------------------------- + # Kai Zhang (github: https://github.com/cszn) + # 03/Mar/2019 + # --------------------------------------- + Args: + save_dir: model folder + net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' + pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path + + Return: + init_iter: iteration number + init_path: model path + # --------------------------------------- + """ + + file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) + if file_list: + iter_exist = [] + for file_ in file_list: + iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) + iter_exist.append(int(iter_current[0])) + init_iter = max(iter_exist) + init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) + else: + init_iter = 0 + init_path = pretrained_path + return init_iter, init_path + + +def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1): + ''' + # --------------------------------------- + # Kai Zhang (github: https://github.com/cszn) + # 03/Mar/2019 + # --------------------------------------- + Args: + model: trained model + L: input Low-quality image + mode: + (0) normal: test(model, L) + (1) pad: test_pad(model, L, modulo=16) + (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1) + (3) x8: test_x8(model, L, modulo=1) ^_^ + (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1) + refield: effective receptive filed of the network, 32 is enough + useful when split, i.e., mode=2, 4 + min_size: min_sizeXmin_size image, e.g., 256X256 image + useful when split, i.e., mode=2, 4 + sf: scale factor for super-resolution, otherwise 1 + modulo: 1 if split + useful when pad, i.e., mode=1 + + Returns: + E: estimated image + # --------------------------------------- + ''' + if mode == 0: + E = test(model, L) + elif mode == 1: + E = test_pad(model, L, modulo, sf) + elif mode == 2: + E = test_split(model, L, refield, min_size, sf, modulo) + elif mode == 3: + E = test_x8(model, L, modulo, sf) + elif mode == 4: + E = test_split_x8(model, L, refield, min_size, sf, modulo) + return E + + +''' +# -------------------------------------------- +# normal (0) +# -------------------------------------------- +''' + + +def test(model, L): + E = model(L) + return E + + +''' +# -------------------------------------------- +# pad (1) +# -------------------------------------------- +''' + + +def test_pad(model, L, modulo=16, sf=1): + h, w = L.size()[-2:] + paddingBottom = int(np.ceil(h/modulo)*modulo-h) + paddingRight = int(np.ceil(w/modulo)*modulo-w) + L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L) + E = model(L) + E = E[..., :h*sf, :w*sf] + return E + + +''' +# -------------------------------------------- +# split (function) +# -------------------------------------------- +''' + + +def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1): + """ + Args: + model: trained model + L: input Low-quality image + refield: effective receptive filed of the network, 32 is enough + min_size: min_sizeXmin_size image, e.g., 256X256 image + sf: scale factor for super-resolution, otherwise 1 + modulo: 1 if split + + Returns: + E: estimated result + """ + h, w = L.size()[-2:] + if h*w <= min_size**2: + L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L) + E = model(L) + E = E[..., :h*sf, :w*sf] + else: + top = slice(0, (h//2//refield+1)*refield) + bottom = slice(h - (h//2//refield+1)*refield, h) + left = slice(0, (w//2//refield+1)*refield) + right = slice(w - (w//2//refield+1)*refield, w) + Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]] + + if h * w <= 4*(min_size**2): + Es = [model(Ls[i]) for i in range(4)] + else: + Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)] + + b, c = Es[0].size()[:2] + E = torch.zeros(b, c, sf * h, sf * w).type_as(L) + + E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf] + E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:] + E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf] + E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:] + return E + + +''' +# -------------------------------------------- +# split (2) +# -------------------------------------------- +''' + + +def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1): + E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo) + return E + + +''' +# -------------------------------------------- +# x8 (3) +# -------------------------------------------- +''' + + +def test_x8(model, L, modulo=1, sf=1): + E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)] + for i in range(len(E_list)): + if i == 3 or i == 5: + E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i) + else: + E_list[i] = util.augment_img_tensor4(E_list[i], mode=i) + output_cat = torch.stack(E_list, dim=0) + E = output_cat.mean(dim=0, keepdim=False) + return E + + +''' +# -------------------------------------------- +# split and x8 (4) +# -------------------------------------------- +''' + + +def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1): + E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)] + for k, i in enumerate(range(len(E_list))): + if i==3 or i==5: + E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i) + else: + E_list[k] = util.augment_img_tensor4(E_list[k], mode=i) + output_cat = torch.stack(E_list, dim=0) + E = output_cat.mean(dim=0, keepdim=False) + return E + + +''' +# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- +# _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^ +# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- +''' + + +''' +# -------------------------------------------- +# print +# -------------------------------------------- +''' + + +# -------------------------------------------- +# print model +# -------------------------------------------- +def print_model(model): + msg = describe_model(model) + print(msg) + + +# -------------------------------------------- +# print params +# -------------------------------------------- +def print_params(model): + msg = describe_params(model) + print(msg) + + +''' +# -------------------------------------------- +# information +# -------------------------------------------- +''' + + +# -------------------------------------------- +# model inforation +# -------------------------------------------- +def info_model(model): + msg = describe_model(model) + return msg + + +# -------------------------------------------- +# params inforation +# -------------------------------------------- +def info_params(model): + msg = describe_params(model) + return msg + + +''' +# -------------------------------------------- +# description +# -------------------------------------------- +''' + + +# -------------------------------------------- +# model name and total number of parameters +# -------------------------------------------- +def describe_model(model): + if isinstance(model, torch.nn.DataParallel): + model = model.module + msg = '\n' + msg += 'models name: {}'.format(model.__class__.__name__) + '\n' + msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n' + msg += 'Net structure:\n{}'.format(str(model)) + '\n' + return msg + + +# -------------------------------------------- +# parameters description +# -------------------------------------------- +def describe_params(model): + if isinstance(model, torch.nn.DataParallel): + model = model.module + msg = '\n' + msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n' + for name, param in model.state_dict().items(): + if not 'num_batches_tracked' in name: + v = param.data.clone().float() + msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n' + return msg + + +if __name__ == '__main__': + + class Net(torch.nn.Module): + def __init__(self, in_channels=3, out_channels=3): + super(Net, self).__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) + + def forward(self, x): + x = self.conv(x) + return x + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + model = Net() + model = model.eval() + print_model(model) + print_params(model) + x = torch.randn((2,3,401,401)) + torch.cuda.empty_cache() + with torch.no_grad(): + for mode in range(5): + y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1) + print(y.shape) + + # run utils/utils_model.py diff --git a/core/data/deg_kair_utils/utils_modelsummary.py b/core/data/deg_kair_utils/utils_modelsummary.py new file mode 100644 index 0000000000000000000000000000000000000000..5e040e31d8ddffbb8b7b2e2dc4ddf0b9cdca6a23 --- /dev/null +++ b/core/data/deg_kair_utils/utils_modelsummary.py @@ -0,0 +1,485 @@ +import torch.nn as nn +import torch +import numpy as np + +''' +---- 1) FLOPs: floating point operations +---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs +---- 3) #Conv2d: the number of ‘Conv2d’ layers +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 21/July/2020 +# -------------------------------------------- +# Reference +https://github.com/sovrasov/flops-counter.pytorch.git + +# If you use this code, please consider the following citation: + +@inproceedings{zhang2020aim, % + title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results}, + author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others}, + booktitle={European Conference on Computer Vision Workshops}, + year={2020} +} +# -------------------------------------------- +''' + +def get_model_flops(model, input_res, print_per_layer_stat=True, + input_constructor=None): + assert type(input_res) is tuple, 'Please provide the size of the input image.' + assert len(input_res) >= 3, 'Input image should have 3 dimensions.' + flops_model = add_flops_counting_methods(model) + flops_model.eval().start_flops_count() + if input_constructor: + input = input_constructor(input_res) + _ = flops_model(**input) + else: + device = list(flops_model.parameters())[-1].device + batch = torch.FloatTensor(1, *input_res).to(device) + _ = flops_model(batch) + + if print_per_layer_stat: + print_model_with_flops(flops_model) + flops_count = flops_model.compute_average_flops_cost() + flops_model.stop_flops_count() + + return flops_count + +def get_model_activation(model, input_res, input_constructor=None): + assert type(input_res) is tuple, 'Please provide the size of the input image.' + assert len(input_res) >= 3, 'Input image should have 3 dimensions.' + activation_model = add_activation_counting_methods(model) + activation_model.eval().start_activation_count() + if input_constructor: + input = input_constructor(input_res) + _ = activation_model(**input) + else: + device = list(activation_model.parameters())[-1].device + batch = torch.FloatTensor(1, *input_res).to(device) + _ = activation_model(batch) + + activation_count, num_conv = activation_model.compute_average_activation_cost() + activation_model.stop_activation_count() + + return activation_count, num_conv + + +def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True, + input_constructor=None): + assert type(input_res) is tuple + assert len(input_res) >= 3 + flops_model = add_flops_counting_methods(model) + flops_model.eval().start_flops_count() + if input_constructor: + input = input_constructor(input_res) + _ = flops_model(**input) + else: + batch = torch.FloatTensor(1, *input_res) + _ = flops_model(batch) + + if print_per_layer_stat: + print_model_with_flops(flops_model) + flops_count = flops_model.compute_average_flops_cost() + params_count = get_model_parameters_number(flops_model) + flops_model.stop_flops_count() + + if as_strings: + return flops_to_string(flops_count), params_to_string(params_count) + + return flops_count, params_count + + +def flops_to_string(flops, units='GMac', precision=2): + if units is None: + if flops // 10**9 > 0: + return str(round(flops / 10.**9, precision)) + ' GMac' + elif flops // 10**6 > 0: + return str(round(flops / 10.**6, precision)) + ' MMac' + elif flops // 10**3 > 0: + return str(round(flops / 10.**3, precision)) + ' KMac' + else: + return str(flops) + ' Mac' + else: + if units == 'GMac': + return str(round(flops / 10.**9, precision)) + ' ' + units + elif units == 'MMac': + return str(round(flops / 10.**6, precision)) + ' ' + units + elif units == 'KMac': + return str(round(flops / 10.**3, precision)) + ' ' + units + else: + return str(flops) + ' Mac' + + +def params_to_string(params_num): + if params_num // 10 ** 6 > 0: + return str(round(params_num / 10 ** 6, 2)) + ' M' + elif params_num // 10 ** 3: + return str(round(params_num / 10 ** 3, 2)) + ' k' + else: + return str(params_num) + + +def print_model_with_flops(model, units='GMac', precision=3): + total_flops = model.compute_average_flops_cost() + + def accumulate_flops(self): + if is_supported_instance(self): + return self.__flops__ / model.__batch_counter__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_flops() + return sum + + def flops_repr(self): + accumulated_flops_cost = self.accumulate_flops() + return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision), + '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), + self.original_extra_repr()]) + + def add_extra_repr(m): + m.accumulate_flops = accumulate_flops.__get__(m) + flops_extra_repr = flops_repr.__get__(m) + if m.extra_repr != flops_extra_repr: + m.original_extra_repr = m.extra_repr + m.extra_repr = flops_extra_repr + assert m.extra_repr != m.original_extra_repr + + def del_extra_repr(m): + if hasattr(m, 'original_extra_repr'): + m.extra_repr = m.original_extra_repr + del m.original_extra_repr + if hasattr(m, 'accumulate_flops'): + del m.accumulate_flops + + model.apply(add_extra_repr) + print(model) + model.apply(del_extra_repr) + + +def get_model_parameters_number(model): + params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return params_num + + +def add_flops_counting_methods(net_main_module): + # adding additional methods to the existing module object, + # this is done this way so that each function has access to self object + # embed() + net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) + net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) + net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) + net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) + + net_main_module.reset_flops_count() + return net_main_module + + +def compute_average_flops_cost(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Returns current mean flops consumption per image. + + """ + + flops_sum = 0 + for module in self.modules(): + if is_supported_instance(module): + flops_sum += module.__flops__ + + return flops_sum + + +def start_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Activates the computation of mean flops consumption per image. + Call it before you run the network. + + """ + self.apply(add_flops_counter_hook_function) + + +def stop_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Stops computing the mean flops consumption per image. + Call whenever you want to pause the computation. + + """ + self.apply(remove_flops_counter_hook_function) + + +def reset_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Resets statistics computed so far. + + """ + self.apply(add_flops_counter_variable_or_reset) + + +def add_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + return + + if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)): + handle = module.register_forward_hook(conv_flops_counter_hook) + elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)): + handle = module.register_forward_hook(relu_flops_counter_hook) + elif isinstance(module, nn.Linear): + handle = module.register_forward_hook(linear_flops_counter_hook) + elif isinstance(module, (nn.BatchNorm2d)): + handle = module.register_forward_hook(bn_flops_counter_hook) + else: + handle = module.register_forward_hook(empty_flops_counter_hook) + module.__flops_handle__ = handle + + +def remove_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + module.__flops_handle__.remove() + del module.__flops_handle__ + + +def add_flops_counter_variable_or_reset(module): + if is_supported_instance(module): + module.__flops__ = 0 + + +# ---- Internal functions +def is_supported_instance(module): + if isinstance(module, + ( + nn.Conv2d, nn.ConvTranspose2d, + nn.BatchNorm2d, + nn.Linear, + nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, + )): + return True + + return False + + +def conv_flops_counter_hook(conv_module, input, output): + # Can have multiple inputs, getting the first one + # input = input[0] + + batch_size = output.shape[0] + output_dims = list(output.shape[2:]) + + kernel_dims = list(conv_module.kernel_size) + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel + + active_elements_count = batch_size * np.prod(output_dims) + overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count) + + # overall_flops = overall_conv_flops + + conv_module.__flops__ += int(overall_conv_flops) + # conv_module.__output_dims__ = output_dims + + +def relu_flops_counter_hook(module, input, output): + active_elements_count = output.numel() + module.__flops__ += int(active_elements_count) + # print(module.__flops__, id(module)) + # print(module) + + +def linear_flops_counter_hook(module, input, output): + input = input[0] + if len(input.shape) == 1: + batch_size = 1 + module.__flops__ += int(batch_size * input.shape[0] * output.shape[0]) + else: + batch_size = input.shape[0] + module.__flops__ += int(batch_size * input.shape[1] * output.shape[1]) + + +def bn_flops_counter_hook(module, input, output): + # input = input[0] + # TODO: need to check here + # batch_flops = np.prod(input.shape) + # if module.affine: + # batch_flops *= 2 + # module.__flops__ += int(batch_flops) + batch = output.shape[0] + output_dims = output.shape[2:] + channels = module.num_features + batch_flops = batch * channels * np.prod(output_dims) + if module.affine: + batch_flops *= 2 + module.__flops__ += int(batch_flops) + + +# ---- Count the number of convolutional layers and the activation +def add_activation_counting_methods(net_main_module): + # adding additional methods to the existing module object, + # this is done this way so that each function has access to self object + # embed() + net_main_module.start_activation_count = start_activation_count.__get__(net_main_module) + net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module) + net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module) + net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module) + + net_main_module.reset_activation_count() + return net_main_module + + +def compute_average_activation_cost(self): + """ + A method that will be available after add_activation_counting_methods() is called + on a desired net object. + + Returns current mean activation consumption per image. + + """ + + activation_sum = 0 + num_conv = 0 + for module in self.modules(): + if is_supported_instance_for_activation(module): + activation_sum += module.__activation__ + num_conv += module.__num_conv__ + return activation_sum, num_conv + + +def start_activation_count(self): + """ + A method that will be available after add_activation_counting_methods() is called + on a desired net object. + + Activates the computation of mean activation consumption per image. + Call it before you run the network. + + """ + self.apply(add_activation_counter_hook_function) + + +def stop_activation_count(self): + """ + A method that will be available after add_activation_counting_methods() is called + on a desired net object. + + Stops computing the mean activation consumption per image. + Call whenever you want to pause the computation. + + """ + self.apply(remove_activation_counter_hook_function) + + +def reset_activation_count(self): + """ + A method that will be available after add_activation_counting_methods() is called + on a desired net object. + + Resets statistics computed so far. + + """ + self.apply(add_activation_counter_variable_or_reset) + + +def add_activation_counter_hook_function(module): + if is_supported_instance_for_activation(module): + if hasattr(module, '__activation_handle__'): + return + + if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): + handle = module.register_forward_hook(conv_activation_counter_hook) + module.__activation_handle__ = handle + + +def remove_activation_counter_hook_function(module): + if is_supported_instance_for_activation(module): + if hasattr(module, '__activation_handle__'): + module.__activation_handle__.remove() + del module.__activation_handle__ + + +def add_activation_counter_variable_or_reset(module): + if is_supported_instance_for_activation(module): + module.__activation__ = 0 + module.__num_conv__ = 0 + + +def is_supported_instance_for_activation(module): + if isinstance(module, + ( + nn.Conv2d, nn.ConvTranspose2d, + )): + return True + + return False + +def conv_activation_counter_hook(module, input, output): + """ + Calculate the activations in the convolutional operation. + Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces. + :param module: + :param input: + :param output: + :return: + """ + module.__activation__ += output.numel() + module.__num_conv__ += 1 + + +def empty_flops_counter_hook(module, input, output): + module.__flops__ += 0 + + +def upsample_flops_counter_hook(module, input, output): + output_size = output[0] + batch_size = output_size.shape[0] + output_elements_count = batch_size + for val in output_size.shape[1:]: + output_elements_count *= val + module.__flops__ += int(output_elements_count) + + +def pool_flops_counter_hook(module, input, output): + input = input[0] + module.__flops__ += int(np.prod(input.shape)) + + +def dconv_flops_counter_hook(dconv_module, input, output): + input = input[0] + + batch_size = input.shape[0] + output_dims = list(output.shape[2:]) + + m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape + out_channels, _, kernel_dim2, _, = dconv_module.projection.shape + # groups = dconv_module.groups + + # filters_per_channel = out_channels // groups + conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels + conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels + active_elements_count = batch_size * np.prod(output_dims) + + overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count + overall_flops = overall_conv_flops + + dconv_module.__flops__ += int(overall_flops) + # dconv_module.__output_dims__ = output_dims + + + + + diff --git a/core/data/deg_kair_utils/utils_option.py b/core/data/deg_kair_utils/utils_option.py new file mode 100644 index 0000000000000000000000000000000000000000..cf096210e2d8ea553b06a91ac5cdaa21127d837c --- /dev/null +++ b/core/data/deg_kair_utils/utils_option.py @@ -0,0 +1,255 @@ +import os +from collections import OrderedDict +from datetime import datetime +import json +import re +import glob + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +def get_timestamp(): + return datetime.now().strftime('_%y%m%d_%H%M%S') + + +def parse(opt_path, is_train=True): + + # ---------------------------------------- + # remove comments starting with '//' + # ---------------------------------------- + json_str = '' + with open(opt_path, 'r') as f: + for line in f: + line = line.split('//')[0] + '\n' + json_str += line + + # ---------------------------------------- + # initialize opt + # ---------------------------------------- + opt = json.loads(json_str, object_pairs_hook=OrderedDict) + + opt['opt_path'] = opt_path + opt['is_train'] = is_train + + # ---------------------------------------- + # set default + # ---------------------------------------- + if 'merge_bn' not in opt: + opt['merge_bn'] = False + opt['merge_bn_startpoint'] = -1 + + if 'scale' not in opt: + opt['scale'] = 1 + + # ---------------------------------------- + # datasets + # ---------------------------------------- + for phase, dataset in opt['datasets'].items(): + phase = phase.split('_')[0] + dataset['phase'] = phase + dataset['scale'] = opt['scale'] # broadcast + dataset['n_channels'] = opt['n_channels'] # broadcast + if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None: + dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H']) + if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None: + dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L']) + + # ---------------------------------------- + # path + # ---------------------------------------- + for key, path in opt['path'].items(): + if path and key in opt['path']: + opt['path'][key] = os.path.expanduser(path) + + path_task = os.path.join(opt['path']['root'], opt['task']) + opt['path']['task'] = path_task + opt['path']['log'] = path_task + opt['path']['options'] = os.path.join(path_task, 'options') + + if is_train: + opt['path']['models'] = os.path.join(path_task, 'models') + opt['path']['images'] = os.path.join(path_task, 'images') + else: # test + opt['path']['images'] = os.path.join(path_task, 'test_images') + + # ---------------------------------------- + # network + # ---------------------------------------- + opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1 + + # ---------------------------------------- + # GPU devices + # ---------------------------------------- + gpu_list = ','.join(str(x) for x in opt['gpu_ids']) + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list + print('export CUDA_VISIBLE_DEVICES=' + gpu_list) + + # ---------------------------------------- + # default setting for distributeddataparallel + # ---------------------------------------- + if 'find_unused_parameters' not in opt: + opt['find_unused_parameters'] = True + if 'use_static_graph' not in opt: + opt['use_static_graph'] = False + if 'dist' not in opt: + opt['dist'] = False + opt['num_gpu'] = len(opt['gpu_ids']) + print('number of GPUs is: ' + str(opt['num_gpu'])) + + # ---------------------------------------- + # default setting for perceptual loss + # ---------------------------------------- + if 'F_feature_layer' not in opt['train']: + opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34] + if 'F_weights' not in opt['train']: + opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0] + if 'F_lossfn_type' not in opt['train']: + opt['train']['F_lossfn_type'] = 'l1' + if 'F_use_input_norm' not in opt['train']: + opt['train']['F_use_input_norm'] = True + if 'F_use_range_norm' not in opt['train']: + opt['train']['F_use_range_norm'] = False + + # ---------------------------------------- + # default setting for optimizer + # ---------------------------------------- + if 'G_optimizer_type' not in opt['train']: + opt['train']['G_optimizer_type'] = "adam" + if 'G_optimizer_betas' not in opt['train']: + opt['train']['G_optimizer_betas'] = [0.9,0.999] + if 'G_scheduler_restart_weights' not in opt['train']: + opt['train']['G_scheduler_restart_weights'] = 1 + if 'G_optimizer_wd' not in opt['train']: + opt['train']['G_optimizer_wd'] = 0 + if 'G_optimizer_reuse' not in opt['train']: + opt['train']['G_optimizer_reuse'] = False + if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']: + opt['train']['D_optimizer_reuse'] = False + + # ---------------------------------------- + # default setting of strict for model loading + # ---------------------------------------- + if 'G_param_strict' not in opt['train']: + opt['train']['G_param_strict'] = True + if 'netD' in opt and 'D_param_strict' not in opt['path']: + opt['train']['D_param_strict'] = True + if 'E_param_strict' not in opt['path']: + opt['train']['E_param_strict'] = True + + # ---------------------------------------- + # Exponential Moving Average + # ---------------------------------------- + if 'E_decay' not in opt['train']: + opt['train']['E_decay'] = 0 + + # ---------------------------------------- + # default setting for discriminator + # ---------------------------------------- + if 'netD' in opt: + if 'net_type' not in opt['netD']: + opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet + if 'in_nc' not in opt['netD']: + opt['netD']['in_nc'] = 3 + if 'base_nc' not in opt['netD']: + opt['netD']['base_nc'] = 64 + if 'n_layers' not in opt['netD']: + opt['netD']['n_layers'] = 3 + if 'norm_type' not in opt['netD']: + opt['netD']['norm_type'] = 'spectral' + + + return opt + + +def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): + """ + Args: + save_dir: model folder + net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' + pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path + + Return: + init_iter: iteration number + init_path: model path + """ + file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) + if file_list: + iter_exist = [] + for file_ in file_list: + iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) + iter_exist.append(int(iter_current[0])) + init_iter = max(iter_exist) + init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) + else: + init_iter = 0 + init_path = pretrained_path + return init_iter, init_path + + +''' +# -------------------------------------------- +# convert the opt into json file +# -------------------------------------------- +''' + + +def save(opt): + opt_path = opt['opt_path'] + opt_path_copy = opt['path']['options'] + dirname, filename_ext = os.path.split(opt_path) + filename, ext = os.path.splitext(filename_ext) + dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext) + with open(dump_path, 'w') as dump_file: + json.dump(opt, dump_file, indent=2) + + +''' +# -------------------------------------------- +# dict to string for logger +# -------------------------------------------- +''' + + +def dict2str(opt, indent_l=1): + msg = '' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_l * 2) + k + ':[\n' + msg += dict2str(v, indent_l + 1) + msg += ' ' * (indent_l * 2) + ']\n' + else: + msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' + return msg + + +''' +# -------------------------------------------- +# convert OrderedDict to NoneDict, +# return None for missing key +# -------------------------------------------- +''' + + +def dict_to_nonedict(opt): + if isinstance(opt, dict): + new_opt = dict() + for key, sub_opt in opt.items(): + new_opt[key] = dict_to_nonedict(sub_opt) + return NoneDict(**new_opt) + elif isinstance(opt, list): + return [dict_to_nonedict(sub_opt) for sub_opt in opt] + else: + return opt + + +class NoneDict(dict): + def __missing__(self, key): + return None diff --git a/core/data/deg_kair_utils/utils_params.py b/core/data/deg_kair_utils/utils_params.py new file mode 100644 index 0000000000000000000000000000000000000000..def1cb79e11472b9b8ebbaae4bd83e7216af2ccb --- /dev/null +++ b/core/data/deg_kair_utils/utils_params.py @@ -0,0 +1,135 @@ +import torch + +import torchvision + +from models import basicblock as B + +def show_kv(net): + for k, v in net.items(): + print(k) + +# should run train debug mode first to get an initial model +#crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth') +# +#for k, v in crt_net.items(): +# print(k) +#for k, v in crt_net.items(): +# if k in pretrained_net: +# crt_net[k] = pretrained_net[k] +# print('replace ... ', k) + +# x2 -> x4 +#crt_net['model.5.weight'] = pretrained_net['model.2.weight'] +#crt_net['model.5.bias'] = pretrained_net['model.2.bias'] +#crt_net['model.8.weight'] = pretrained_net['model.5.weight'] +#crt_net['model.8.bias'] = pretrained_net['model.5.bias'] +#crt_net['model.10.weight'] = pretrained_net['model.7.weight'] +#crt_net['model.10.bias'] = pretrained_net['model.7.bias'] +#torch.save(crt_net, '../pretrained_tmp.pth') + +# x2 -> x3 +''' +in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3 +new_filter = torch.Tensor(576, 64, 3, 3) +new_filter[0:256, :, :, :] = in_filter +new_filter[256:512, :, :, :] = in_filter +new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :] +crt_net['model.2.weight'] = new_filter + +in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3 +new_bias = torch.Tensor(576) +new_bias[0:256] = in_bias +new_bias[256:512] = in_bias +new_bias[512:] = in_bias[0:576 - 512] +crt_net['model.2.bias'] = new_bias + +torch.save(crt_net, '../pretrained_tmp.pth') +''' + +# x2 -> x8 +''' +crt_net['model.5.weight'] = pretrained_net['model.2.weight'] +crt_net['model.5.bias'] = pretrained_net['model.2.bias'] +crt_net['model.8.weight'] = pretrained_net['model.2.weight'] +crt_net['model.8.bias'] = pretrained_net['model.2.bias'] +crt_net['model.11.weight'] = pretrained_net['model.5.weight'] +crt_net['model.11.bias'] = pretrained_net['model.5.bias'] +crt_net['model.13.weight'] = pretrained_net['model.7.weight'] +crt_net['model.13.bias'] = pretrained_net['model.7.bias'] +torch.save(crt_net, '../pretrained_tmp.pth') +''' + +# x3/4/8 RGB -> Y + +def rgb2gray_net(net, only_input=True): + + if only_input: + in_filter = net['0.weight'] + in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114 + in_new_filter.unsqueeze_(1) + net['0.weight'] = in_new_filter + +# out_filter = pretrained_net['model.13.weight'] +# out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \ +# out_filter[2, :, :, :] * 0.114 +# out_new_filter.unsqueeze_(0) +# crt_net['model.13.weight'] = out_new_filter +# out_bias = pretrained_net['model.13.bias'] +# out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114 +# out_new_bias = torch.Tensor(1).fill_(out_new_bias) +# crt_net['model.13.bias'] = out_new_bias + +# torch.save(crt_net, '../pretrained_tmp.pth') + + return net + + + +if __name__ == '__main__': + + net = torchvision.models.vgg19(pretrained=True) + for k,v in net.features.named_parameters(): + if k=='0.weight': + in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114 + in_new_filter.unsqueeze_(1) + v = in_new_filter + print(v.shape) + print(v[0,0,0,0]) + if k=='0.bias': + in_new_bias = v + print(v[0]) + + print(net.features[0]) + + net.features[0] = B.conv(1, 64, mode='C') + + print(net.features[0]) + net.features[0].weight.data=in_new_filter + net.features[0].bias.data=in_new_bias + + for k,v in net.features.named_parameters(): + if k=='0.weight': + print(v[0,0,0,0]) + if k=='0.bias': + print(v[0]) + + # transfer parameters of old model to new one + model_old = torch.load(model_path) + state_dict = model.state_dict() + for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()): + state_dict[key2] = param + print([key, key2]) + # print([param.size(), param2.size()]) + torch.save(state_dict, 'model_new.pth') + + + # rgb2gray_net(net) + + + + + + + + + diff --git a/core/data/deg_kair_utils/utils_receptivefield.py b/core/data/deg_kair_utils/utils_receptivefield.py new file mode 100644 index 0000000000000000000000000000000000000000..394456390644ba9edc406b810f67d09b0e2ff114 --- /dev/null +++ b/core/data/deg_kair_utils/utils_receptivefield.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +# online calculation: https://fomoro.com/research/article/receptive-field-calculator# + +# [filter size, stride, padding] +#Assume the two dimensions are the same +#Each kernel requires the following parameters: +# - k_i: kernel size +# - s_i: stride +# - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow) +# +#Each layer i requires the following parameters to be fully represented: +# - n_i: number of feature (data layer has n_1 = imagesize ) +# - j_i: distance (projected to image pixel distance) between center of two adjacent features +# - r_i: receptive field of a feature in layer i +# - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding) + +import math + +def outFromIn(conv, layerIn): + n_in = layerIn[0] + j_in = layerIn[1] + r_in = layerIn[2] + start_in = layerIn[3] + k = conv[0] + s = conv[1] + p = conv[2] + + n_out = math.floor((n_in - k + 2*p)/s) + 1 + actualP = (n_out-1)*s - n_in + k + pR = math.ceil(actualP/2) + pL = math.floor(actualP/2) + + j_out = j_in * s + r_out = r_in + (k - 1)*j_in + start_out = start_in + ((k-1)/2 - pL)*j_in + return n_out, j_out, r_out, start_out + +def printLayer(layer, layer_name): + print(layer_name + ":") + print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3])) + + + +layerInfos = [] +if __name__ == '__main__': + + convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]] + layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12'] + imsize = 128 + + print ("-------Net summary------") + currentLayer = [imsize, 1, 1, 0.5] + printLayer(currentLayer, "input image") + for i in range(len(convnet)): + currentLayer = outFromIn(convnet[i], currentLayer) + layerInfos.append(currentLayer) + printLayer(currentLayer, layer_names[i]) + + +# run utils/utils_receptivefield.py + \ No newline at end of file diff --git a/core/data/deg_kair_utils/utils_regularizers.py b/core/data/deg_kair_utils/utils_regularizers.py new file mode 100644 index 0000000000000000000000000000000000000000..17e7c8524b716f36e10b41d72fee2e375af69454 --- /dev/null +++ b/core/data/deg_kair_utils/utils_regularizers.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +''' + + +# -------------------------------------------- +# SVD Orthogonal Regularization +# -------------------------------------------- +def regularizer_orth(m): + """ + # ---------------------------------------- + # SVD Orthogonal Regularization + # ---------------------------------------- + # Applies regularization to the training by performing the + # orthogonalization technique described in the paper + # This function is to be called by the torch.nn.Module.apply() method, + # which applies svd_orthogonalization() to every layer of the model. + # usage: net.apply(regularizer_orth) + # ---------------------------------------- + """ + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + w = m.weight.data.clone() + c_out, c_in, f1, f2 = w.size() + # dtype = m.weight.data.type() + w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) + # self.netG.apply(svd_orthogonalization) + u, s, v = torch.svd(w) + s[s > 1.5] = s[s > 1.5] - 1e-4 + s[s < 0.5] = s[s < 0.5] + 1e-4 + w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) + m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) + else: + pass + + +# -------------------------------------------- +# SVD Orthogonal Regularization +# -------------------------------------------- +def regularizer_orth2(m): + """ + # ---------------------------------------- + # Applies regularization to the training by performing the + # orthogonalization technique described in the paper + # This function is to be called by the torch.nn.Module.apply() method, + # which applies svd_orthogonalization() to every layer of the model. + # usage: net.apply(regularizer_orth2) + # ---------------------------------------- + """ + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + w = m.weight.data.clone() + c_out, c_in, f1, f2 = w.size() + # dtype = m.weight.data.type() + w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) + u, s, v = torch.svd(w) + s_mean = s.mean() + s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4 + s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4 + w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) + m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) + else: + pass + + + +def regularizer_clip(m): + """ + # ---------------------------------------- + # usage: net.apply(regularizer_clip) + # ---------------------------------------- + """ + eps = 1e-4 + c_min = -1.5 + c_max = 1.5 + + classname = m.__class__.__name__ + if classname.find('Conv') != -1 or classname.find('Linear') != -1: + w = m.weight.data.clone() + w[w > c_max] -= eps + w[w < c_min] += eps + m.weight.data = w + + if m.bias is not None: + b = m.bias.data.clone() + b[b > c_max] -= eps + b[b < c_min] += eps + m.bias.data = b + +# elif classname.find('BatchNorm2d') != -1: +# +# rv = m.running_var.data.clone() +# rm = m.running_mean.data.clone() +# +# if m.affine: +# m.weight.data +# m.bias.data diff --git a/core/data/deg_kair_utils/utils_sisr.py b/core/data/deg_kair_utils/utils_sisr.py new file mode 100644 index 0000000000000000000000000000000000000000..fde7881526c5544ed09657872b044af5fa99b3a9 --- /dev/null +++ b/core/data/deg_kair_utils/utils_sisr.py @@ -0,0 +1,848 @@ +# -*- coding: utf-8 -*- +from utils import utils_image as util +import random + +import scipy +import scipy.stats as ss +import scipy.io as io +from scipy import ndimage +from scipy.interpolate import interp2d + +import numpy as np +import torch + + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# modified by Kai Zhang (github: https://github.com/cszn) +# 03/03/2020 +# -------------------------------------------- +""" + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +""" +# -------------------------------------------- +# calculate PCA projection matrix +# -------------------------------------------- +""" + + +def get_pca_matrix(x, dim_pca=15): + """ + Args: + x: 225x10000 matrix + dim_pca: 15 + Returns: + pca_matrix: 15x225 + """ + C = np.dot(x, x.T) + w, v = scipy.linalg.eigh(C) + pca_matrix = v[:, -dim_pca:].T + + return pca_matrix + + +def show_pca(x): + """ + x: PCA projection matrix, e.g., 15x225 + """ + for i in range(x.shape[0]): + xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F") + util.surf(xc) + + +def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500): + kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32) + for i in range(num_samples): + + theta = np.pi*np.random.rand(1) + l1 = 0.1+l_max*np.random.rand(1) + l2 = 0.1+(l1-0.1)*np.random.rand(1) + + k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0]) + + # util.imshow(k) + + kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F') + + # io.savemat('k.mat', {'k': kernels}) + + pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca) + + io.savemat(path, {'p': pca_matrix}) + + return pca_matrix + + +""" +# -------------------------------------------- +# shifted anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z-MU + ZZ_t = ZZ.transpose(0,1,3,2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + #kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + sf = random.choice([1, 2, 3, 4]) + scale_factor = np.array([sf, sf]) + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z-MU + ZZ_t = ZZ.transpose(0,1,3,2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + #kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1/sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +''' +# ================= +# Numpy +# ================= +''' + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH, image or kernel + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf-1)*0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w-1) + y1 = np.clip(y1, 0, h-1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +''' +# ================= +# pytorch +# ================= +''' + + +def splits(a, sf): + ''' + a: tensor NxCxWxHx2 + sf: scale factor + out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2) + ''' + b = torch.stack(torch.chunk(a, sf, dim=2), dim=5) + b = torch.cat(torch.chunk(b, sf, dim=3), dim=5) + return b + + +def c2c(x): + return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1)) + + +def r2c(x): + return torch.stack([x, torch.zeros_like(x)], -1) + + +def cdiv(x, y): + a, b = x[..., 0], x[..., 1] + c, d = y[..., 0], y[..., 1] + cd2 = c**2 + d**2 + return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1) + + +def csum(x, y): + return torch.stack([x[..., 0] + y, x[..., 1]], -1) + + +def cabs(x): + return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5) + + +def cmul(t1, t2): + ''' + complex multiplication + t1: NxCxHxWx2 + output: NxCxHxWx2 + ''' + real1, imag1 = t1[..., 0], t1[..., 1] + real2, imag2 = t2[..., 0], t2[..., 1] + return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1) + + +def cconj(t, inplace=False): + ''' + # complex's conjugation + t: NxCxHxWx2 + output: NxCxHxWx2 + ''' + c = t.clone() if not inplace else t + c[..., 1] *= -1 + return c + + +def rfft(t): + return torch.rfft(t, 2, onesided=False) + + +def irfft(t): + return torch.irfft(t, 2, onesided=False) + + +def fft(t): + return torch.fft(t, 2) + + +def ifft(t): + return torch.ifft(t, 2) + + +def p2o(psf, shape): + ''' + Args: + psf: NxCxhxw + shape: [H,W] + + Returns: + otf: NxCxHxWx2 + ''' + otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) + otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) + for axis, axis_size in enumerate(psf.shape[2:]): + otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) + otf = torch.rfft(otf, 2, onesided=False) + n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) + otf[...,1][torch.abs(otf[...,1]) x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding) + ''' + x = torch.cat([x, x[:, :, 0:pad, :]], dim=2) + x = torch.cat([x, x[:, :, :, 0:pad]], dim=3) + x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2) + x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3) + return x + + +def pad_circular(input, padding): + # type: (Tensor, List[int]) -> Tensor + """ + Arguments + :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))` + :param padding: (tuple): m-elem tuple where m is the degree of convolution + Returns + :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0], + H + 2 * padding[1]], W + 2 * padding[2]))` + """ + offset = 3 + for dimension in range(input.dim() - offset + 1): + input = dim_pad_circular(input, padding[dimension], dimension + offset) + return input + + +def dim_pad_circular(input, padding, dimension): + # type: (Tensor, int, int) -> Tensor + input = torch.cat([input, input[[slice(None)] * (dimension - 1) + + [slice(0, padding)]]], dim=dimension - 1) + input = torch.cat([input[[slice(None)] * (dimension - 1) + + [slice(-2 * padding, -padding)]], input], dim=dimension - 1) + return input + + +def imfilter(x, k): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + ''' + x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2)) + x = torch.nn.functional.conv2d(x, k, groups=x.shape[1]) + return x + + +def G(x, k, sf=3, center=False): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + sf: scale factor + center: the first one or the moddle one + + Matlab function: + tmp = imfilter(x,h,'circular'); + y = downsample2(tmp,K); + ''' + x = downsample(imfilter(x, k), sf=sf, center=center) + return x + + +def Gt(x, k, sf=3, center=False): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + sf: scale factor + center: the first one or the moddle one + + Matlab function: + tmp = upsample2(x,K); + y = imfilter(tmp,h,'circular'); + ''' + x = imfilter(upsample(x, sf=sf, center=center), k) + return x + + +def interpolation_down(x, sf, center=False): + mask = torch.zeros_like(x) + if center: + start = torch.tensor((sf-1)//2) + mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x) + LR = x[..., start::sf, start::sf] + else: + mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x) + LR = x[..., ::sf, ::sf] + y = x.mul(mask) + + return LR, y, mask + + +''' +# ================= +Numpy +# ================= +''' + + +def blockproc(im, blocksize, fun): + xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0) + xblocks_proc = [] + for xb in xblocks: + yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1) + yblocks_proc = [] + for yb in yblocks: + yb_proc = fun(yb) + yblocks_proc.append(yb_proc) + xblocks_proc.append(np.concatenate(yblocks_proc, axis=1)) + + proc = np.concatenate(xblocks_proc, axis=0) + + return proc + + +def fun_reshape(a): + return np.reshape(a, (-1,1,a.shape[-1]), order='F') + + +def fun_mul(a, b): + return a*b + + +def BlockMM(nr, nc, Nb, m, x1): + ''' + myfun = @(block_struct) reshape(block_struct.data,m,1); + x1 = blockproc(x1,[nr nc],myfun); + x1 = reshape(x1,m,Nb); + x1 = sum(x1,2); + x = reshape(x1,nr,nc); + ''' + fun = fun_reshape + x1 = blockproc(x1, blocksize=(nr, nc), fun=fun) + x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F') + x1 = np.sum(x1, 1) + x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F') + return x + + +def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m): + ''' + x1 = FB.*FR; + FBR = BlockMM(nr,nc,Nb,m,x1); + invW = BlockMM(nr,nc,Nb,m,F2B); + invWBR = FBR./(invW + tau*Nb); + fun = @(block_struct) block_struct.data.*invWBR; + FCBinvWBR = blockproc(FBC,[nr,nc],fun); + FX = (FR-FCBinvWBR)/tau; + Xest = real(ifft2(FX)); + ''' + x1 = FB*FR + FBR = BlockMM(nr, nc, Nb, m, x1) + invW = BlockMM(nr, nc, Nb, m, F2B) + invWBR = FBR/(invW + tau*Nb) + FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR)) + FX = (FR-FCBinvWBR)/tau + Xest = np.real(np.fft.ifft2(FX, axes=(0, 1))) + return Xest + + +def psf2otf(psf, shape=None): + """ + Convert point-spread function to optical transfer function. + Compute the Fast Fourier Transform (FFT) of the point-spread + function (PSF) array and creates the optical transfer function (OTF) + array that is not influenced by the PSF off-centering. + By default, the OTF array is the same size as the PSF array. + To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF + post-pads the PSF array (down or to the right) with zeros to match + dimensions specified in OUTSIZE, then circularly shifts the values of + the PSF array up (or to the left) until the central pixel reaches (1,1) + position. + Parameters + ---------- + psf : `numpy.ndarray` + PSF array + shape : int + Output shape of the OTF array + Returns + ------- + otf : `numpy.ndarray` + OTF array + Notes + ----- + Adapted from MATLAB psf2otf function + """ + if type(shape) == type(None): + shape = psf.shape + shape = np.array(shape) + if np.all(psf == 0): + # return np.zeros_like(psf) + return np.zeros(shape) + if len(psf.shape) == 1: + psf = psf.reshape((1, psf.shape[0])) + inshape = psf.shape + psf = zero_pad(psf, shape, position='corner') + for axis, axis_size in enumerate(inshape): + psf = np.roll(psf, -int(axis_size / 2), axis=axis) + # Compute the OTF + otf = np.fft.fft2(psf, axes=(0, 1)) + # Estimate the rough number of operations involved in the FFT + # and discard the PSF imaginary part if within roundoff error + # roundoff error = machine epsilon = sys.float_info.epsilon + # or np.finfo().eps + n_ops = np.sum(psf.size * np.log2(psf.shape)) + otf = np.real_if_close(otf, tol=n_ops) + return otf + + +def zero_pad(image, shape, position='corner'): + """ + Extends image to a certain size with zeros + Parameters + ---------- + image: real 2d `numpy.ndarray` + Input image + shape: tuple of int + Desired output shape of the image + position : str, optional + The position of the input image in the output one: + * 'corner' + top-left corner (default) + * 'center' + centered + Returns + ------- + padded_img: real `numpy.ndarray` + The zero-padded image + """ + shape = np.asarray(shape, dtype=int) + imshape = np.asarray(image.shape, dtype=int) + if np.alltrue(imshape == shape): + return image + if np.any(shape <= 0): + raise ValueError("ZERO_PAD: null or negative shape given") + dshape = shape - imshape + if np.any(dshape < 0): + raise ValueError("ZERO_PAD: target size smaller than source one") + pad_img = np.zeros(shape, dtype=image.dtype) + idx, idy = np.indices(imshape) + if position == 'center': + if np.any(dshape % 2 != 0): + raise ValueError("ZERO_PAD: source and target shapes " + "have different parity.") + offx, offy = dshape // 2 + else: + offx, offy = (0, 0) + pad_img[idx + offx, idy + offy] = image + return pad_img + + +def upsample_np(x, sf=3, center=False): + st = (sf-1)//2 if center else 0 + z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2])) + z[st::sf, st::sf, ...] = x + return z + + +def downsample_np(x, sf=3, center=False): + st = (sf-1)//2 if center else 0 + return x[st::sf, st::sf, ...] + + +def imfilter_np(x, k): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def G_np(x, k, sf=3, center=False): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + + Matlab function: + tmp = imfilter(x,h,'circular'); + y = downsample2(tmp,K); + ''' + x = downsample_np(imfilter_np(x, k), sf=sf, center=center) + return x + + +def Gt_np(x, k, sf=3, center=False): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + + Matlab function: + tmp = upsample2(x,K); + y = imfilter(tmp,h,'circular'); + ''' + x = imfilter_np(upsample_np(x, sf=sf, center=center), k) + return x + + +if __name__ == '__main__': + img = util.imread_uint('test.bmp', 3) + + img = util.uint2single(img) + k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6) + util.imshow(k*10) + + + for sf in [2, 3, 4]: + + # modcrop + img = modcrop_np(img, sf=sf) + + # 1) bicubic degradation + img_b = bicubic_degradation(img, sf=sf) + print(img_b.shape) + + # 2) srmd degradation + img_s = srmd_degradation(img, k, sf=sf) + print(img_s.shape) + + # 3) dpsr degradation + img_d = dpsr_degradation(img, k, sf=sf) + print(img_d.shape) + + # 4) classical degradation + img_d = classical_degradation(img, k, sf=sf) + print(img_d.shape) + + k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01) + #print(k) +# util.imshow(k*10) + + k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0) +# util.imshow(k*10) + + + # PCA +# pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500) +# print(pca_matrix.shape) +# show_pca(pca_matrix) + # run utils/utils_sisr.py + # run utils_sisr.py + + + + + + + diff --git a/core/data/deg_kair_utils/utils_video.py b/core/data/deg_kair_utils/utils_video.py new file mode 100644 index 0000000000000000000000000000000000000000..596dd4203098cf7b36f3d8499ccbf299623381ae --- /dev/null +++ b/core/data/deg_kair_utils/utils_video.py @@ -0,0 +1,493 @@ +import os +import cv2 +import numpy as np +import torch +import random +from os import path as osp +from torch.nn import functional as F +from abc import ABCMeta, abstractmethod + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False): + """Read a sequence of images from a given folder path. + + Args: + path (list[str] | str): List of image paths or image folder path. + require_mod_crop (bool): Require mod crop for each image. + Default: False. + scale (int): Scale factor for mod_crop. Default: 1. + return_imgname(bool): Whether return image names. Default False. + + Returns: + Tensor: size (t, c, h, w), RGB, [0, 1]. + list[str]: Returned image name list. + """ + if isinstance(path, list): + img_paths = path + else: + img_paths = sorted(list(scandir(path, full_path=True))) + imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] + + if require_mod_crop: + imgs = [mod_crop(img, scale) for img in imgs] + imgs = img2tensor(imgs, bgr2rgb=True, float32=True) + imgs = torch.stack(imgs, dim=0) + + if return_imgname: + imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths] + return imgs, imgnames + else: + return imgs + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): + """Paired random crop. Support Numpy array and Tensor inputs. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. Default: None. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + # determine input type: Numpy array or Tensor + input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' + + if input_type == 'Tensor': + h_lq, w_lq = img_lqs[0].size()[-2:] + h_gt, w_gt = img_gts[0].size()[-2:] + else: + h_lq, w_lq = img_lqs[0].shape[0:2] + h_gt, w_gt = img_gts[0].shape[0:2] + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + if input_type == 'Tensor': + img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] + else: + img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + if input_type == 'Tensor': + img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] + else: + img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing different lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + diff --git a/core/data/deg_kair_utils/utils_videoio.py b/core/data/deg_kair_utils/utils_videoio.py new file mode 100644 index 0000000000000000000000000000000000000000..5be8c7f06802d5aaa7155a1cdcb27d2838a0882c --- /dev/null +++ b/core/data/deg_kair_utils/utils_videoio.py @@ -0,0 +1,555 @@ +import os +import cv2 +import numpy as np +import torch +import random +from os import path as osp +from torchvision.utils import make_grid +import sys +from pathlib import Path +import six +from collections import OrderedDict +import math +import glob +import av +import io +from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT, + CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH, + CAP_PROP_POS_FRAMES, VideoWriter_fourcc) + +if sys.version_info <= (3, 3): + FileNotFoundError = IOError +else: + FileNotFoundError = FileNotFoundError + + +def is_str(x): + """Whether the input is an string instance.""" + return isinstance(x, six.string_types) + + +def is_filepath(x): + return is_str(x) or isinstance(x, Path) + + +def fopen(filepath, *args, **kwargs): + if is_str(filepath): + return open(filepath, *args, **kwargs) + elif isinstance(filepath, Path): + return filepath.open(*args, **kwargs) + raise ValueError('`filepath` should be a string or a Path') + + +def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): + if not osp.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == '': + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def symlink(src, dst, overwrite=True, **kwargs): + if os.path.lexists(dst) and overwrite: + os.remove(dst) + os.symlink(src, dst, **kwargs) + + +def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): + """Scan a directory to find the interested files. + Args: + dir_path (str | :obj:`Path`): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + case_sensitive (bool, optional) : If set to False, ignore the case of + suffix. Default: True. + Returns: + A generator for all the interested files with relative paths. + """ + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + if suffix is not None and not case_sensitive: + suffix = suffix.lower() if isinstance(suffix, str) else tuple( + item.lower() for item in suffix) + + root = dir_path + + def _scandir(dir_path, suffix, recursive, case_sensitive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + _rel_path = rel_path if case_sensitive else rel_path.lower() + if suffix is None or _rel_path.endswith(suffix): + yield rel_path + elif recursive and os.path.isdir(entry.path): + # scan recursively if entry.path is a directory + yield from _scandir(entry.path, suffix, recursive, + case_sensitive) + + return _scandir(dir_path, suffix, recursive, case_sensitive) + + +class Cache: + + def __init__(self, capacity): + self._cache = OrderedDict() + self._capacity = int(capacity) + if capacity <= 0: + raise ValueError('capacity must be a positive integer') + + @property + def capacity(self): + return self._capacity + + @property + def size(self): + return len(self._cache) + + def put(self, key, val): + if key in self._cache: + return + if len(self._cache) >= self.capacity: + self._cache.popitem(last=False) + self._cache[key] = val + + def get(self, key, default=None): + val = self._cache[key] if key in self._cache else default + return val + + +class VideoReader: + """Video class with similar usage to a list object. + + This video warpper class provides convenient apis to access frames. + There exists an issue of OpenCV's VideoCapture class that jumping to a + certain frame may be inaccurate. It is fixed in this class by checking + the position after jumping each time. + Cache is used when decoding videos. So if the same frame is visited for + the second time, there is no need to decode again if it is stored in the + cache. + + """ + + def __init__(self, filename, cache_capacity=10): + # Check whether the video path is a url + if not filename.startswith(('https://', 'http://')): + check_file_exist(filename, 'Video file not found: ' + filename) + self._vcap = cv2.VideoCapture(filename) + assert cache_capacity > 0 + self._cache = Cache(cache_capacity) + self._position = 0 + # get basic info + self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH)) + self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT)) + self._fps = self._vcap.get(CAP_PROP_FPS) + self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT)) + self._fourcc = self._vcap.get(CAP_PROP_FOURCC) + + @property + def vcap(self): + """:obj:`cv2.VideoCapture`: The raw VideoCapture object.""" + return self._vcap + + @property + def opened(self): + """bool: Indicate whether the video is opened.""" + return self._vcap.isOpened() + + @property + def width(self): + """int: Width of video frames.""" + return self._width + + @property + def height(self): + """int: Height of video frames.""" + return self._height + + @property + def resolution(self): + """tuple: Video resolution (width, height).""" + return (self._width, self._height) + + @property + def fps(self): + """float: FPS of the video.""" + return self._fps + + @property + def frame_cnt(self): + """int: Total frames of the video.""" + return self._frame_cnt + + @property + def fourcc(self): + """str: "Four character code" of the video.""" + return self._fourcc + + @property + def position(self): + """int: Current cursor position, indicating frame decoded.""" + return self._position + + def _get_real_position(self): + return int(round(self._vcap.get(CAP_PROP_POS_FRAMES))) + + def _set_real_position(self, frame_id): + self._vcap.set(CAP_PROP_POS_FRAMES, frame_id) + pos = self._get_real_position() + for _ in range(frame_id - pos): + self._vcap.read() + self._position = frame_id + + def read(self): + """Read the next frame. + + If the next frame have been decoded before and in the cache, then + return it directly, otherwise decode, cache and return it. + + Returns: + ndarray or None: Return the frame if successful, otherwise None. + """ + # pos = self._position + if self._cache: + img = self._cache.get(self._position) + if img is not None: + ret = True + else: + if self._position != self._get_real_position(): + self._set_real_position(self._position) + ret, img = self._vcap.read() + if ret: + self._cache.put(self._position, img) + else: + ret, img = self._vcap.read() + if ret: + self._position += 1 + return img + + def get_frame(self, frame_id): + """Get frame by index. + + Args: + frame_id (int): Index of the expected frame, 0-based. + + Returns: + ndarray or None: Return the frame if successful, otherwise None. + """ + if frame_id < 0 or frame_id >= self._frame_cnt: + raise IndexError( + f'"frame_id" must be between 0 and {self._frame_cnt - 1}') + if frame_id == self._position: + return self.read() + if self._cache: + img = self._cache.get(frame_id) + if img is not None: + self._position = frame_id + 1 + return img + self._set_real_position(frame_id) + ret, img = self._vcap.read() + if ret: + if self._cache: + self._cache.put(self._position, img) + self._position += 1 + return img + + def current_frame(self): + """Get the current frame (frame that is just visited). + + Returns: + ndarray or None: If the video is fresh, return None, otherwise + return the frame. + """ + if self._position == 0: + return None + return self._cache.get(self._position - 1) + + def cvt2frames(self, + frame_dir, + file_start=0, + filename_tmpl='{:06d}.jpg', + start=0, + max_num=0, + show_progress=False): + """Convert a video to frame images. + + Args: + frame_dir (str): Output directory to store all the frame images. + file_start (int): Filenames will start from the specified number. + filename_tmpl (str): Filename template with the index as the + placeholder. + start (int): The starting frame index. + max_num (int): Maximum number of frames to be written. + show_progress (bool): Whether to show a progress bar. + """ + mkdir_or_exist(frame_dir) + if max_num == 0: + task_num = self.frame_cnt - start + else: + task_num = min(self.frame_cnt - start, max_num) + if task_num <= 0: + raise ValueError('start must be less than total frame number') + if start > 0: + self._set_real_position(start) + + def write_frame(file_idx): + img = self.read() + if img is None: + return + filename = osp.join(frame_dir, filename_tmpl.format(file_idx)) + cv2.imwrite(filename, img) + + if show_progress: + pass + #track_progress(write_frame, range(file_start,file_start + task_num)) + else: + for i in range(task_num): + write_frame(file_start + i) + + def __len__(self): + return self.frame_cnt + + def __getitem__(self, index): + if isinstance(index, slice): + return [ + self.get_frame(i) + for i in range(*index.indices(self.frame_cnt)) + ] + # support negative indexing + if index < 0: + index += self.frame_cnt + if index < 0: + raise IndexError('index out of range') + return self.get_frame(index) + + def __iter__(self): + self._set_real_position(0) + return self + + def __next__(self): + img = self.read() + if img is not None: + return img + else: + raise StopIteration + + next = __next__ + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._vcap.release() + + +def frames2video(frame_dir, + video_file, + fps=30, + fourcc='XVID', + filename_tmpl='{:06d}.jpg', + start=0, + end=0, + show_progress=False): + """Read the frame images from a directory and join them as a video. + + Args: + frame_dir (str): The directory containing video frames. + video_file (str): Output filename. + fps (float): FPS of the output video. + fourcc (str): Fourcc of the output video, this should be compatible + with the output file type. + filename_tmpl (str): Filename template with the index as the variable. + start (int): Starting frame index. + end (int): Ending frame index. + show_progress (bool): Whether to show a progress bar. + """ + if end == 0: + ext = filename_tmpl.split('.')[-1] + end = len([name for name in scandir(frame_dir, ext)]) + first_file = osp.join(frame_dir, filename_tmpl.format(start)) + check_file_exist(first_file, 'The start frame not found: ' + first_file) + img = cv2.imread(first_file) + height, width = img.shape[:2] + resolution = (width, height) + vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps, + resolution) + + def write_frame(file_idx): + filename = osp.join(frame_dir, filename_tmpl.format(file_idx)) + img = cv2.imread(filename) + vwriter.write(img) + + if show_progress: + pass + # track_progress(write_frame, range(start, end)) + else: + for i in range(start, end): + write_frame(i) + vwriter.release() + + +def video2images(video_path, output_dir): + vidcap = cv2.VideoCapture(video_path) + in_fps = vidcap.get(cv2.CAP_PROP_FPS) + print('video fps:', in_fps) + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + loaded, frame = vidcap.read() + total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f'number of total frames is: {total_frames:06}') + for i_frame in range(total_frames): + if i_frame % 100 == 0: + print(f'{i_frame:06} / {total_frames:06}') + frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png') + cv2.imwrite(frame_name, frame) + loaded, frame = vidcap.read() + + +def images2video(image_dir, video_path, fps=24, image_ext='png'): + ''' + #codec = cv2.VideoWriter_fourcc(*'XVID') + #codec = cv2.VideoWriter_fourcc('A','V','C','1') + #codec = cv2.VideoWriter_fourcc('Y','U','V','1') + #codec = cv2.VideoWriter_fourcc('P','I','M','1') + #codec = cv2.VideoWriter_fourcc('M','J','P','G') + codec = cv2.VideoWriter_fourcc('M','P','4','2') + #codec = cv2.VideoWriter_fourcc('D','I','V','3') + #codec = cv2.VideoWriter_fourcc('D','I','V','X') + #codec = cv2.VideoWriter_fourcc('U','2','6','3') + #codec = cv2.VideoWriter_fourcc('I','2','6','3') + #codec = cv2.VideoWriter_fourcc('F','L','V','1') + #codec = cv2.VideoWriter_fourcc('H','2','6','4') + #codec = cv2.VideoWriter_fourcc('A','Y','U','V') + #codec = cv2.VideoWriter_fourcc('I','U','Y','V') + 编码器常用的几种: + cv2.VideoWriter_fourcc("I", "4", "2", "0") + 压缩的yuv颜色编码器,4:2:0色彩度子采样 兼容性好,产生很大的视频 avi + cv2.VideoWriter_fourcc("P", I", "M", "1") + 采用mpeg-1编码,文件为avi + cv2.VideoWriter_fourcc("X", "V", "T", "D") + 采用mpeg-4编码,得到视频大小平均 拓展名avi + cv2.VideoWriter_fourcc("T", "H", "E", "O") + Ogg Vorbis, 拓展名为ogv + cv2.VideoWriter_fourcc("F", "L", "V", "1") + FLASH视频,拓展名为.flv + ''' + image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext)))) + print(len(image_files)) + height, width, _ = cv2.imread(image_files[0]).shape + out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') # cv2.VideoWriter_fourcc(*'MP4V') + out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height)) + + for image_file in image_files: + img = cv2.imread(image_file) + img = cv2.resize(img, (width, height), interpolation=3) + out_video.write(img) + out_video.release() + + +def add_video_compression(imgs): + codec_type = ['libx264', 'h264', 'mpeg4'] + codec_prob = [1 / 3., 1 / 3., 1 / 3.] + codec = random.choices(codec_type, codec_prob)[0] + # codec = 'mpeg4' + bitrate = [1e4, 1e5] + bitrate = np.random.randint(bitrate[0], bitrate[1] + 1) + + buf = io.BytesIO() + with av.open(buf, 'w', 'mp4') as container: + stream = container.add_stream(codec, rate=1) + stream.height = imgs[0].shape[0] + stream.width = imgs[0].shape[1] + stream.pix_fmt = 'yuv420p' + stream.bit_rate = bitrate + + for img in imgs: + img = np.uint8((img.clip(0, 1)*255.).round()) + frame = av.VideoFrame.from_ndarray(img, format='rgb24') + frame.pict_type = 'NONE' + # pdb.set_trace() + for packet in stream.encode(frame): + container.mux(packet) + + # Flush stream + for packet in stream.encode(): + container.mux(packet) + + outputs = [] + with av.open(buf, 'r', 'mp4') as container: + if container.streams.video: + for frame in container.decode(**{'video': 0}): + outputs.append( + frame.to_rgb().to_ndarray().astype(np.float32) / 255.) + + #outputs = np.stack(outputs, axis=0) + return outputs + + +if __name__ == '__main__': + + # ----------------------------------- + # test VideoReader(filename, cache_capacity=10) + # ----------------------------------- +# video_reader = VideoReader('utils/test.mp4') +# from utils import utils_image as util +# inputs = [] +# for frame in video_reader: +# print(frame.dtype) +# util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) +# #util.imshow(np.flip(frame, axis=2)) + + # ----------------------------------- + # test video2images(video_path, output_dir) + # ----------------------------------- +# video2images('utils/test.mp4', 'frames') + + # ----------------------------------- + # test images2video(image_dir, video_path, fps=24, image_ext='png') + # ----------------------------------- +# images2video('frames', 'video_02.mp4', fps=30, image_ext='png') + + + # ----------------------------------- + # test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png') + # ----------------------------------- +# frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png') + + + # ----------------------------------- + # test add_video_compression(imgs) + # ----------------------------------- +# imgs = [] +# image_ext = 'png' +# frames = 'frames' +# from utils import utils_image as util +# image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext)))) +# for i, image_file in enumerate(image_files): +# if i < 7: +# img = util.imread_uint(image_file, 3) +# img = util.uint2single(img) +# imgs.append(img) +# +# results = add_video_compression(imgs) +# for i, img in enumerate(results): +# util.imshow(util.single2uint(img)) +# util.imsave(util.single2uint(img),f'{i:05}.png') + + # run utils/utils_video.py + + + + + + + diff --git a/core/scripts/__init__.py b/core/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/core/scripts/cli.py b/core/scripts/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..bfe3ecc330ecf9f0b3af1e7dc6b3758673712cc7 --- /dev/null +++ b/core/scripts/cli.py @@ -0,0 +1,41 @@ +import sys +import argparse +from .. import WarpCore +from .. import templates + + +def template_init(args): + return '''' + + + '''.strip() + + +def init_template(args): + parser = argparse.ArgumentParser(description='WarpCore template init tool') + parser.add_argument('-t', '--template', type=str, default='WarpCore') + args = parser.parse_args(args) + + if args.template == 'WarpCore': + template_cls = WarpCore + else: + try: + template_cls = __import__(args.template) + except ModuleNotFoundError: + template_cls = getattr(templates, args.template) + print(template_cls) + + +def main(): + if len(sys.argv) < 2: + print('Usage: core ') + sys.exit(1) + if sys.argv[1] == 'init': + init_template(sys.argv[2:]) + else: + print('Unknown command') + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/core/templates/__init__.py b/core/templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..570f16de78bcce68aa49ff0a5d0fad63284f6948 --- /dev/null +++ b/core/templates/__init__.py @@ -0,0 +1 @@ +from .diffusion import DiffusionCore \ No newline at end of file diff --git a/core/templates/diffusion.py b/core/templates/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f36dc3f5efa14669cc36cc3c0cffcc8def037289 --- /dev/null +++ b/core/templates/diffusion.py @@ -0,0 +1,236 @@ +from .. import WarpCore +from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary +from abc import abstractmethod +from dataclasses import dataclass +import torch +from torch import nn +from torch.utils.data import DataLoader +from gdf import GDF +import numpy as np +from tqdm import tqdm +import wandb + +import webdataset as wds +from webdataset.handlers import warn_and_continue +from torch.distributed import barrier +from enum import Enum + +class TargetReparametrization(Enum): + EPSILON = 'epsilon' + X0 = 'x0' + +class DiffusionCore(WarpCore): + @dataclass(frozen=True) + class Config(WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + grad_accum_steps: int = EXPECTED_TRAIN + batch_size: int = EXPECTED_TRAIN + updates: int = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + save_every: int = 500 + backup_every: int = 20000 + use_fsdp: bool = True + + # EMA UPDATE + ema_start_iters: int = None + ema_iters: int = None + ema_beta: float = None + + # GDF setting + gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0 + + @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED + class Info(WarpCore.Info): + ema_loss: float = None + + @dataclass(frozen=True) + class Models(WarpCore.Models): + generator : nn.Module = EXPECTED + generator_ema : nn.Module = None # optional + + @dataclass(frozen=True) + class Optimizers(WarpCore.Optimizers): + generator : any = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + + # -------------------------------------------- + info: Info + config: Config + + @abstractmethod + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def webdataset_path(self, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def webdataset_filters(self, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def webdataset_preprocessors(self, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + # ------------- + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.webdataset_path(extras) + preprocessors = self.webdataset_preprocessors(extras) + filters = self.webdataset_filters(extras) + + handler = warn_and_continue # None + # handler = None + dataset = wds.WebDataset( + dataset_path, resampled=True, handler=handler + ).select(filters).shuffle(690, handler=handler).decode( + "pilrgb", handler=handler + ).to_tuple( + *[p[0] for p in preprocessors], handler=handler + ).map_tuple( + *[p[1] for p in preprocessors], handler=handler + ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)}) + + # SETUP DATALOADER + real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True + ) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader)) + + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + latents = self.encode_latents(batch, models, extras) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + # FORWARD PASS + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON: + pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss + target = noise + elif self.config.gdf_target_reparametrization == TargetReparametrization.X0: + pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss + target = latents + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + return loss, loss_adjusted + + def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): + start_iter = self.info.iter+1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP + models.generator.train() + for i in pbar: + # FORWARD PASS + loss, loss_adjusted = self.forward_pass(data, extras, models) + + # BACKWARD PASS + if i % self.config.grad_accum_steps == 0 or i == max_iters: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + schedulers_dict[k].step() + models.generator.zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + with models.generator.no_sync(): + loss_adjusted.backward() + self.info.iter = i + + # UPDATE EMA + if models.generator_ema is not None and i % self.config.ema_iters == 0: + update_weights_ema( + models.generator_ema, models.generator, + beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) + ) + + # UPDATE LOSS METRICS + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): + wandb.alert( + title=f"NaN value encountered in training run {self.info.wandb_run_id}", + text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", + wait_duration=60*30 + ) + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'raw_loss': loss.mean().item(), + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'], + 'total_steps': self.info.total_steps, + } + + pbar.set_postfix(logs) + if self.config.wandb_project is not None: + wandb.log(logs) + + if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters: + # SAVE AND CHECKPOINT STUFF + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + tqdm.write("Skipping sampling & checkpoint because the loss is NaN") + wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") + else: + self.save_checkpoints(models, optimizers) + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + self.sample(models, data, extras) + + def models_to_save(self): + return ['generator', 'generator_ema'] + + def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): + barrier() + suffix = '' if suffix is None else suffix + self.save_info(self.info, suffix=suffix) + models_dict = models.to_dict() + optimizers_dict = optimizers.to_dict() + for key in self.models_to_save(): + model = models_dict[key] + if model is not None: + self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) + for key in optimizers_dict: + optimizer = optimizers_dict[key] + if optimizer is not None: + self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None) + if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: + self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k") + torch.cuda.empty_cache() diff --git a/core/utils/__init__.py b/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e71b37e8d1690a00ab1e0958320775bc822b6f5 --- /dev/null +++ b/core/utils/__init__.py @@ -0,0 +1,9 @@ +from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN +from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail + +# MOVE IT SOMERWHERE ELSE +def update_weights_ema(tgt_model, src_model, beta=0.999): + for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta) + for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta) \ No newline at end of file diff --git a/core/utils/__pycache__/__init__.cpython-310.pyc b/core/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63c0a7e0fbf358f557d6bea755a0f550b4010a48 Binary files /dev/null and b/core/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/core/utils/__pycache__/__init__.cpython-39.pyc b/core/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f18d6921da3c9d93087c1b6d8eacd7a5e46a8e5 Binary files /dev/null and b/core/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/core/utils/__pycache__/base_dto.cpython-310.pyc b/core/utils/__pycache__/base_dto.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de093eb65813d4abf69edfbb6923f2cabab21ad7 Binary files /dev/null and b/core/utils/__pycache__/base_dto.cpython-310.pyc differ diff --git a/core/utils/__pycache__/base_dto.cpython-39.pyc b/core/utils/__pycache__/base_dto.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b80d348c7959338709ec24c3ac24dfc4f6dab3dc Binary files /dev/null and b/core/utils/__pycache__/base_dto.cpython-39.pyc differ diff --git a/core/utils/__pycache__/save_and_load.cpython-310.pyc b/core/utils/__pycache__/save_and_load.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7a0f63ac8bbaf073dcd8a046ed112cec181d33a Binary files /dev/null and b/core/utils/__pycache__/save_and_load.cpython-310.pyc differ diff --git a/core/utils/__pycache__/save_and_load.cpython-39.pyc b/core/utils/__pycache__/save_and_load.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec04e9aba6f83ab76f0bbc243bb95fda07ad8d16 Binary files /dev/null and b/core/utils/__pycache__/save_and_load.cpython-39.pyc differ diff --git a/core/utils/base_dto.py b/core/utils/base_dto.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf185f00e5c6f56d23774cec8591b8d4554971e --- /dev/null +++ b/core/utils/base_dto.py @@ -0,0 +1,56 @@ +import dataclasses +from dataclasses import dataclass, _MISSING_TYPE +from munch import Munch + +EXPECTED = "___REQUIRED___" +EXPECTED_TRAIN = "___REQUIRED_TRAIN___" + +# pylint: disable=invalid-field-call +def nested_dto(x, raw=False): + return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x)) + +@dataclass(frozen=True) +class Base: + training: bool = None + def __new__(cls, **kwargs): + training = kwargs.get('training', True) + setteable_fields = cls.setteable_fields(**kwargs) + mandatory_fields = cls.mandatory_fields(**kwargs) + invalid_kwargs = [ + {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False) + ] + print(mandatory_fields) + assert ( + len(invalid_kwargs) == 0 + ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable." + missing_kwargs = [f for f in mandatory_fields if f not in kwargs] + assert ( + len(missing_kwargs) == 0 + ), f"Required fields missing initializing this DTO: {missing_kwargs}." + return object.__new__(cls) + + + @classmethod + def setteable_fields(cls, **kwargs): + return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN] + + @classmethod + def mandatory_fields(cls, **kwargs): + training = kwargs.get('training', True) + return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)] + + @classmethod + def from_dict(cls, kwargs): + for k in kwargs: + if isinstance(kwargs[k], (dict, list, tuple)): + kwargs[k] = Munch.fromDict(kwargs[k]) + return cls(**kwargs) + + def to_dict(self): + # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes + selfdict = {} + for k in dataclasses.fields(self): + selfdict[k.name] = getattr(self, k.name) + if isinstance(selfdict[k.name], Munch): + selfdict[k.name] = selfdict[k.name].toDict() + return selfdict diff --git a/core/utils/save_and_load.py b/core/utils/save_and_load.py new file mode 100644 index 0000000000000000000000000000000000000000..0215f664f5a8e738147d0828b6a7e65b9c3a8507 --- /dev/null +++ b/core/utils/save_and_load.py @@ -0,0 +1,59 @@ +import os +import torch +import json +from pathlib import Path +import safetensors +import wandb + + +def create_folder_if_necessary(path): + path = "/".join(path.split("/")[:-1]) + Path(path).mkdir(parents=True, exist_ok=True) + + +def safe_save(ckpt, path): + try: + os.remove(f"{path}.bak") + except OSError: + pass + try: + os.rename(path, f"{path}.bak") + except OSError: + pass + if path.endswith(".pt") or path.endswith(".ckpt"): + torch.save(ckpt, path) + elif path.endswith(".json"): + with open(path, "w", encoding="utf-8") as f: + json.dump(ckpt, f, indent=4) + elif path.endswith(".safetensors"): + safetensors.torch.save_file(ckpt, path) + else: + raise ValueError(f"File extension not supported: {path}") + + +def load_or_fail(path, wandb_run_id=None): + accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"] + try: + assert any( + [path.endswith(ext) for ext in accepted_extensions] + ), f"Automatic loading not supported for this extension: {path}" + if not os.path.exists(path): + checkpoint = None + elif path.endswith(".pt") or path.endswith(".ckpt"): + checkpoint = torch.load(path, map_location="cpu") + elif path.endswith(".json"): + with open(path, "r", encoding="utf-8") as f: + checkpoint = json.load(f) + elif path.endswith(".safetensors"): + checkpoint = {} + with safetensors.safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + return checkpoint + except Exception as e: + if wandb_run_id is not None: + wandb.alert( + title=f"Corrupt checkpoint for run {wandb_run_id}", + text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed", + ) + raise e diff --git a/figures/California_000490.jpg b/figures/California_000490.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d8216b079289058c8901cdaf74d91df0d26864fa --- /dev/null +++ b/figures/California_000490.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84fbae1a942a233e619fcc3ddc89b6038f909df6003bbbb1b10a15309e0ecd2e +size 6064531 diff --git a/figures/example_dataset/000008.jpg b/figures/example_dataset/000008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b5bc20400ccc11df60a621afca62cd56ca790240 --- /dev/null +++ b/figures/example_dataset/000008.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:763beb950f9794497a2fbdd62a48c8fd91f9543e9b9e6b5cf6c38bfcfdd34a02 +size 2875950 diff --git a/figures/example_dataset/000008.json b/figures/example_dataset/000008.json new file mode 100644 index 0000000000000000000000000000000000000000..51a40614db448251cc586b414cd1b47a512b342d --- /dev/null +++ b/figures/example_dataset/000008.json @@ -0,0 +1,2 @@ +{ "caption": "The image captures the iconic Shard, a modern skyscraper that stands as the tallest building in the United Kingdom. The Shard, with its glass and steel structure, pierces the sky, its pointed top reaching towards the heavens. The photograph is taken from a low angle, which emphasizes the height and grandeur of the building. The sky forms a beautiful backdrop, painted in hues of pinkish-orange, suggesting that the photo was taken at sunset. The Shard is nestled between two other buildings, their presence subtly hinted at in the foreground. The image does not contain any discernible text or countable objects, and there are no visible actions taking place. The relative positions of the objects confirm that the Shard is the central focus of the image, with the other buildings and the sky providing context to its location. The image is devoid of any aesthetic descriptions, focusing solely on the factual representation of the scene." +} \ No newline at end of file diff --git a/figures/example_dataset/000012.jpg b/figures/example_dataset/000012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..84f66e9d537be55850c1464f5ada0c5add93a398 --- /dev/null +++ b/figures/example_dataset/000012.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fea5f3176ed289556acf32c7dd3635db8944cb64be53f109b84554eb4da5bf3 +size 2674310 diff --git a/figures/example_dataset/000012.json b/figures/example_dataset/000012.json new file mode 100644 index 0000000000000000000000000000000000000000..e537d01c3f3d638ed6b49f379aff021a2e111b1f --- /dev/null +++ b/figures/example_dataset/000012.json @@ -0,0 +1 @@ +{"caption": "cars in a road during daytime"} \ No newline at end of file diff --git a/figures/teaser.jpg b/figures/teaser.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a0d576b17d5e2ae0e86ba503d23745ba1a9c7498 --- /dev/null +++ b/figures/teaser.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:def5700a069d5f754b45ec02802e258c1c1473ad82fd10d2e62cc87e75a8a5e1 +size 7951540 diff --git a/gdf/__init__.py b/gdf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..753b52e2e07e2540385594627a6faf4f6091b0a0 --- /dev/null +++ b/gdf/__init__.py @@ -0,0 +1,205 @@ +import torch +from .scalers import * +from .targets import * +from .schedulers import * +from .noise_conditions import * +from .loss_weights import * +from .samplers import * +import torch.nn.functional as F +import math +class GDF(): + def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0): + self.schedule = schedule + self.input_scaler = input_scaler + self.target = target + self.noise_cond = noise_cond + self.loss_weight = loss_weight + self.offset_noise = offset_noise + + def setup_limits(self, stretch_max=True, stretch_min=True, shift=1): + stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift) + return stretched_limits + + def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None): + if epsilon is None: + epsilon = torch.randn_like(x0) + if self.offset_noise > 0: + if offset is None: + offset = torch.randn([x0.size(0), x0.size(1)] + [1]*(len(x0.shape)-2)).to(x0.device) + epsilon = epsilon + offset * self.offset_noise + logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device) + a, b = self.input_scaler(logSNR) # B + if len(a.shape) == 1: + a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) # BxCxHxW + #print('in line 33 a b', a.shape, b.shape, x0.shape, logSNR.shape, logSNR, self.noise_cond(logSNR)) + target = self.target(x0, epsilon, logSNR, a, b) + + # noised, noise, logSNR, t_cond + #noised, noise, target, logSNR, noise_cond, loss_weight + return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift) + + def undiffuse(self, x, logSNR, pred): + a, b = self.input_scaler(logSNR) + if len(a.shape) == 1: + a, b = a.view(-1, *[1]*(len(x.shape)-1)), b.view(-1, *[1]*(len(x.shape)-1)) + return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b) + + def sample(self, model, model_inputs, shape, unconditional_inputs=None, sampler=None, schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"): + sampler_params = {} if sampler_params is None else sampler_params + if sampler is None: + sampler = DDPMSampler(self) + r_range = torch.linspace(t_start, t_end, timesteps+1) + schedule = self.schedule if schedule is None else schedule + logSNR_range = schedule(r_range, shift=shift)[:, None].expand( + -1, shape[0] if x_init is None else x_init.size(0) + ).to(device) + + x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone() + + if cfg is not None: + if unconditional_inputs is None: + unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} + model_inputs = { + k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor) + else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list) + else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict) + else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items()) + } + + for i in range(0, timesteps): + noise_cond = self.noise_cond(logSNR_range[i]) + if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start): + cfg_val = cfg + if isinstance(cfg_val, (list, tuple)): + assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" + cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item()) + + pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2) + + pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) + if cfg_rho > 0: + std_pos, std_cfg = pred.std(), pred_cfg.std() + pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) + else: + pred = pred_cfg + else: + pred = model(x, noise_cond, **model_inputs) + x0, epsilon = self.undiffuse(x, logSNR_range[i], pred) + x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params) + #print('in line 86', x0.shape, x.shape, i, ) + altered_vars = yield (x0, x, pred) + + # Update some running variables if the user wants + if altered_vars is not None: + cfg = altered_vars.get('cfg', cfg) + cfg_rho = altered_vars.get('cfg_rho', cfg_rho) + sampler = altered_vars.get('sampler', sampler) + model_inputs = altered_vars.get('model_inputs', model_inputs) + x = altered_vars.get('x', x) + x_init = altered_vars.get('x_init', x_init) + +class GDF_dual_fixlrt(GDF): + def ref_noise(self, noised, x0, logSNR): + a, b = self.input_scaler(logSNR) + if len(a.shape) == 1: + a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) + #print('in line 210', a.shape, b.shape, x0.shape, noised.shape) + return self.target.noise_givenx0_noised(x0, noised, logSNR, a, b) + + def sample(self, model, model_inputs, shape, shape_lr, unconditional_inputs=None, sampler=None, + schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, + cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"): + sampler_params = {} if sampler_params is None else sampler_params + if sampler is None: + sampler = DDPMSampler(self) + r_range = torch.linspace(t_start, t_end, timesteps+1) + schedule = self.schedule if schedule is None else schedule + logSNR_range = schedule(r_range, shift=shift)[:, None].expand( + -1, shape[0] if x_init is None else x_init.size(0) + ).to(device) + + x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone() + x_lr = sampler.init_x(shape_lr).to(device) if x_init is None else x_init.clone() + if cfg is not None: + if unconditional_inputs is None: + unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} + model_inputs = { + k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor) + else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list) + else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict) + else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items()) + } + + ###############################################lr sampling + + guide_feas = [None] * timesteps + + for i in range(0, timesteps): + noise_cond = self.noise_cond(logSNR_range[i]) + if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start): + cfg_val = cfg + if isinstance(cfg_val, (list, tuple)): + assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" + cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item()) + + + + if i == timesteps -1 : + output, guide_lr_enc, guide_lr_dec = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs) + guide_feas[i] = ([f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_enc], [f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_dec]) + else: + output, _, _ = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs) + + pred, pred_unconditional = output.chunk(2) + + + pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) + if cfg_rho > 0: + std_pos, std_cfg = pred.std(), pred_cfg.std() + pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) + else: + pred = pred_cfg + else: + pred = model(x_lr, noise_cond, **model_inputs) + x0_lr, epsilon_lr = self.undiffuse(x_lr, logSNR_range[i], pred) + x_lr = sampler(x_lr, x0_lr, epsilon_lr, logSNR_range[i], logSNR_range[i+1], **sampler_params) + + ###############################################hr HR sampling + for i in range(0, timesteps): + noise_cond = self.noise_cond(logSNR_range[i]) + if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start): + cfg_val = cfg + if isinstance(cfg_val, (list, tuple)): + assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" + cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item()) + + out_pred, t_emb = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), \ + lr_guide=guide_feas[timesteps -1] if i <=19 else None , **model_inputs, require_t=True, guide_weight=1 - i/timesteps) + pred, pred_unconditional = out_pred.chunk(2) + pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) + if cfg_rho > 0: + std_pos, std_cfg = pred.std(), pred_cfg.std() + pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) + else: + pred = pred_cfg + else: + pred = model(x, noise_cond, guide_lr=(guide_lr_enc, guide_lr_dec), **model_inputs) + x0, epsilon = self.undiffuse(x, logSNR_range[i], pred) + + x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params) + altered_vars = yield (x0, x, pred, x_lr) + + + + # Update some running variables if the user wants + if altered_vars is not None: + cfg = altered_vars.get('cfg', cfg) + cfg_rho = altered_vars.get('cfg_rho', cfg_rho) + sampler = altered_vars.get('sampler', sampler) + model_inputs = altered_vars.get('model_inputs', model_inputs) + x = altered_vars.get('x', x) + x_init = altered_vars.get('x_init', x_init) + + + + diff --git a/gdf/loss_weights.py b/gdf/loss_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..d14ddaadeeb3f8de6c68aea4c364d9b852f2f15c --- /dev/null +++ b/gdf/loss_weights.py @@ -0,0 +1,101 @@ +import torch +import numpy as np + +# --- Loss Weighting +class BaseLossWeight(): + def weight(self, logSNR): + raise NotImplementedError("this method needs to be overridden") + + def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs): + clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range + if shift != 1: + logSNR = logSNR.clone() + 2 * np.log(shift) + return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range) + +class ComposedLossWeight(BaseLossWeight): + def __init__(self, div, mul): + self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul + self.div = [div] if isinstance(div, BaseLossWeight) else div + + def weight(self, logSNR): + prod, div = 1, 1 + for m in self.mul: + prod *= m.weight(logSNR) + for d in self.div: + div *= d.weight(logSNR) + return prod/div + +class ConstantLossWeight(BaseLossWeight): + def __init__(self, v=1): + self.v = v + + def weight(self, logSNR): + return torch.ones_like(logSNR) * self.v + +class SNRLossWeight(BaseLossWeight): + def weight(self, logSNR): + return logSNR.exp() + +class P2LossWeight(BaseLossWeight): + def __init__(self, k=1.0, gamma=1.0, s=1.0): + self.k, self.gamma, self.s = k, gamma, s + + def weight(self, logSNR): + return (self.k + (logSNR * self.s).exp()) ** -self.gamma + +class SNRPlusOneLossWeight(BaseLossWeight): + def weight(self, logSNR): + return logSNR.exp() + 1 + +class MinSNRLossWeight(BaseLossWeight): + def __init__(self, max_snr=5): + self.max_snr = max_snr + + def weight(self, logSNR): + return logSNR.exp().clamp(max=self.max_snr) + +class MinSNRPlusOneLossWeight(BaseLossWeight): + def __init__(self, max_snr=5): + self.max_snr = max_snr + + def weight(self, logSNR): + return (logSNR.exp() + 1).clamp(max=self.max_snr) + +class TruncatedSNRLossWeight(BaseLossWeight): + def __init__(self, min_snr=1): + self.min_snr = min_snr + + def weight(self, logSNR): + return logSNR.exp().clamp(min=self.min_snr) + +class SechLossWeight(BaseLossWeight): + def __init__(self, div=2): + self.div = div + + def weight(self, logSNR): + return 1/(logSNR/self.div).cosh() + +class DebiasedLossWeight(BaseLossWeight): + def weight(self, logSNR): + return 1/logSNR.exp().sqrt() + +class SigmoidLossWeight(BaseLossWeight): + def __init__(self, s=1): + self.s = s + + def weight(self, logSNR): + return (logSNR * self.s).sigmoid() + +class AdaptiveLossWeight(BaseLossWeight): + def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]): + self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1) + self.bucket_losses = torch.ones(buckets) + self.weight_range = weight_range + + def weight(self, logSNR): + indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR) + return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range) + + def update_buckets(self, logSNR, loss, beta=0.99): + indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu() + self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta) diff --git a/gdf/noise_conditions.py b/gdf/noise_conditions.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2791f50a6f63eff8f9bed9b827f87517cc0be8 --- /dev/null +++ b/gdf/noise_conditions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + +class BaseNoiseCond(): + def __init__(self, *args, shift=1, clamp_range=None, **kwargs): + clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range + self.shift = shift + self.clamp_range = clamp_range + self.setup(*args, **kwargs) + + def setup(self, *args, **kwargs): + pass # this method is optional, override it if required + + def cond(self, logSNR): + raise NotImplementedError("this method needs to be overriden") + + def __call__(self, logSNR): + if self.shift != 1: + logSNR = logSNR.clone() + 2 * np.log(self.shift) + return self.cond(logSNR).clamp(*self.clamp_range) + +class CosineTNoiseCond(BaseNoiseCond): + def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999] + self.s = torch.tensor([s]) + self.clamp_range = clamp_range + self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + + def cond(self, logSNR): + var = logSNR.sigmoid() + var = var.clamp(*self.clamp_range) + s, min_var = self.s.to(var.device), self.min_var.to(var.device) + t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return t + +class EDMNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + return -logSNR/8 + +class SigmoidNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + return (-logSNR).sigmoid() + +class LogSNRNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + return logSNR + +class EDMSigmaNoiseCond(BaseNoiseCond): + def setup(self, sigma_data=1): + self.sigma_data = sigma_data + + def cond(self, logSNR): + return torch.exp(-logSNR / 2) * self.sigma_data + +class RectifiedFlowsNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + _a = logSNR.exp() - 1 + _a[_a == 0] = 1e-3 # Avoid division by zero + a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) + return a + +# Any NoiseCond that cannot be described easily as a continuous function of t +# It needs to define self.x and self.y in the setup() method +class PiecewiseLinearNoiseCond(BaseNoiseCond): + def setup(self): + self.x = None + self.y = None + + def piecewise_linear(self, y, xs, ys): + indices = (len(xs)-2) - torch.searchsorted(ys.flip(dims=(-1,))[:-2], y) + x_min, x_max = xs[indices], xs[indices+1] + y_min, y_max = ys[indices], ys[indices+1] + x = x_min + (x_max - x_min) * (y - y_min) / (y_max - y_min) + return x + + def cond(self, logSNR): + var = logSNR.sigmoid() + t = self.piecewise_linear(var, self.x.to(var.device), self.y.to(var.device)) # .mul(1000).round().clamp(min=0) + return t + +class StableDiffusionNoiseCond(PiecewiseLinearNoiseCond): + def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): + self.total_steps = total_steps + linear_range_sqrt = [r**0.5 for r in linear_range] + self.x = torch.linspace(0, 1, total_steps+1) + + alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 + self.y = alphas.cumprod(dim=-1) + + def cond(self, logSNR): + return super().cond(logSNR).clamp(0, 1) + +class DiscreteNoiseCond(BaseNoiseCond): + def setup(self, noise_cond, steps=1000, continuous_range=[0, 1]): + self.noise_cond = noise_cond + self.steps = steps + self.continuous_range = continuous_range + + def cond(self, logSNR): + cond = self.noise_cond(logSNR) + cond = (cond-self.continuous_range[0]) / (self.continuous_range[1]-self.continuous_range[0]) + return cond.mul(self.steps).long() + \ No newline at end of file diff --git a/gdf/readme.md b/gdf/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..9a63691513c9da6804fba53e36acc8e0cd7f5d7f --- /dev/null +++ b/gdf/readme.md @@ -0,0 +1,86 @@ +# Generic Diffusion Framework (GDF) + +# Basic usage +GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM +, EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different +frameworks + +Using GDF is very straighforward, first of all just define an instance of the GDF class: + +```python +from gdf import GDF +from gdf import CosineSchedule +from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight + +gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=P2LossWeight(), +) +``` + +You need to define the following components: +* **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution. +* **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule. +* **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows) +* **Target**: What the target is during training, usually: epsilon, x0 or v +* **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8` +* **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use + +All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just: +```python +class VPScaler(): + def __call__(self, logSNR): + a_squared = logSNR.sigmoid() + a = a_squared.sqrt() + b = (1-a_squared).sqrt() + return a, b + +``` + +So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc... + +### Training + +When you define your training loop you can get all you need by just doing: +```python +shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution +for inputs, extra_conditions in dataloader_iterator: + noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift) + pred = diffusion_model(noised, noise_cond, extra_conditions) + + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() + + loss_adjusted.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) +``` + +And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the +training from the GDF class. + +### Sampling + +The other important part is sampling, when you want to use this framework to sample you can just do the following: + +```python +from gdf import DDPMSampler + +shift = 1 +sampling_configs = { + "timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift, + "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999]) +} + +*_, (sampled, _, _) = gdf.sample( + diffusion_model, {"cond": extra_conditions}, latents.shape, + unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)}, + device=device, **sampling_configs +) +``` + +# Available modules + +TODO diff --git a/gdf/samplers.py b/gdf/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..b6048c86a261d53d0440a3b2c1591a03d9978c4f --- /dev/null +++ b/gdf/samplers.py @@ -0,0 +1,43 @@ +import torch + +class SimpleSampler(): + def __init__(self, gdf): + self.gdf = gdf + self.current_step = -1 + + def __call__(self, *args, **kwargs): + self.current_step += 1 + return self.step(*args, **kwargs) + + def init_x(self, shape): + return torch.randn(*shape) + + def step(self, x, x0, epsilon, logSNR, logSNR_prev): + raise NotImplementedError("You should override the 'apply' function.") + +class DDIMSampler(SimpleSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0): + a, b = self.gdf.input_scaler(logSNR) + if len(a.shape) == 1: + a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) + + a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) + if len(a_prev.shape) == 1: + a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) + + sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0 + # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) + x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) + return x + +class DDPMSampler(DDIMSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1): + return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta) + +class LCMSampler(SimpleSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev): + a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) + if len(a_prev.shape) == 1: + a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) + return x0 * a_prev + torch.randn_like(epsilon) * b_prev + \ No newline at end of file diff --git a/gdf/scalers.py b/gdf/scalers.py new file mode 100644 index 0000000000000000000000000000000000000000..b1adb8b0269667f3d006c7d7d17cbf2b7ef56ca9 --- /dev/null +++ b/gdf/scalers.py @@ -0,0 +1,42 @@ +import torch + +class BaseScaler(): + def __init__(self): + self.stretched_limits = None + + def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1): + min_logSNR = schedule(torch.ones(1), shift=shift) + max_logSNR = schedule(torch.zeros(1), shift=shift) + + min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1] + max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0] + self.stretched_limits = [min_a, max_a, min_b, max_b] + return self.stretched_limits + + def stretch_limits(self, a, b): + min_a, max_a, min_b, max_b = self.stretched_limits + return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b) + + def scalers(self, logSNR): + raise NotImplementedError("this method needs to be overridden") + + def __call__(self, logSNR): + a, b = self.scalers(logSNR) + if self.stretched_limits is not None: + a, b = self.stretch_limits(a, b) + return a, b + +class VPScaler(BaseScaler): + def scalers(self, logSNR): + a_squared = logSNR.sigmoid() + a = a_squared.sqrt() + b = (1-a_squared).sqrt() + return a, b + +class LERPScaler(BaseScaler): + def scalers(self, logSNR): + _a = logSNR.exp() - 1 + _a[_a == 0] = 1e-3 # Avoid division by zero + a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) + b = 1-a + return a, b diff --git a/gdf/schedulers.py b/gdf/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..caa6e174da1d766ea5828616bb8113865106b628 --- /dev/null +++ b/gdf/schedulers.py @@ -0,0 +1,200 @@ +import torch +import numpy as np + +class BaseSchedule(): + def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs): + self.setup(*args, **kwargs) + self.limits = None + self.discrete_steps = discrete_steps + self.shift = shift + if force_limits: + self.reset_limits() + + def reset_limits(self, shift=1, disable=False): + try: + self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max + return self.limits + except Exception: + print("WARNING: this schedule doesn't support t and will be unbounded") + return None + + def setup(self, *args, **kwargs): + raise NotImplementedError("this method needs to be overriden") + + def schedule(self, *args, **kwargs): + raise NotImplementedError("this method needs to be overriden") + + def __call__(self, t, *args, shift=1, **kwargs): + if isinstance(t, torch.Tensor): + batch_size = None + if self.discrete_steps is not None: + if t.dtype != torch.long: + t = (t * (self.discrete_steps-1)).round().long() + t = t / (self.discrete_steps-1) + t = t.clamp(0, 1) + else: + batch_size = t + t = None + logSNR = self.schedule(t, batch_size, *args, **kwargs) + if shift*self.shift != 1: + logSNR += 2 * np.log(1/(shift*self.shift)) + if self.limits is not None: + logSNR = logSNR.clamp(*self.limits) + return logSNR + +class CosineSchedule(BaseSchedule): + def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False): + self.s = torch.tensor([s]) + self.clamp_range = clamp_range + self.norm_instead = norm_instead + self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + + def schedule(self, t, batch_size): + if t is None: + t = (1-torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0) + s, min_var = self.s.to(t.device), self.min_var.to(t.device) + var = torch.cos((s + t)/(1+s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var + if self.norm_instead: + var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] + else: + var = var.clamp(*self.clamp_range) + logSNR = (var/(1-var)).log() + return logSNR + +class CosineSchedule2(BaseSchedule): + def setup(self, logsnr_range=[-15, 15]): + self.t_min = np.arctan(np.exp(-0.5 * logsnr_range[1])) + self.t_max = np.arctan(np.exp(-0.5 * logsnr_range[0])) + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + return -2 * (self.t_min + t*(self.t_max-self.t_min)).tan().log() + +class SqrtSchedule(BaseSchedule): + def setup(self, s=1e-4, clamp_range=[0.0001, 0.9999], norm_instead=False): + self.s = s + self.clamp_range = clamp_range + self.norm_instead = norm_instead + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + var = 1 - (t + self.s)**0.5 + if self.norm_instead: + var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] + else: + var = var.clamp(*self.clamp_range) + logSNR = (var/(1-var)).log() + return logSNR + +class RectifiedFlowsSchedule(BaseSchedule): + def setup(self, logsnr_range=[-15, 15]): + self.logsnr_range = logsnr_range + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + logSNR = (((1-t)**2)/(t**2)).log() + logSNR = logSNR.clamp(*self.logsnr_range) + return logSNR + +class EDMSampleSchedule(BaseSchedule): + def setup(self, sigma_range=[0.002, 80], p=7): + self.sigma_range = sigma_range + self.p = p + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + smin, smax, p = *self.sigma_range, self.p + sigma = (smax ** (1/p) + (1-t) * (smin ** (1/p) - smax ** (1/p))) ** p + logSNR = (1/sigma**2).log() + return logSNR + +class EDMTrainSchedule(BaseSchedule): + def setup(self, mu=-1.2, std=1.2): + self.mu = mu + self.std = std + + def schedule(self, t, batch_size): + if t is not None: + raise Exception("EDMTrainSchedule doesn't support passing timesteps: t") + logSNR = -2*(torch.randn(batch_size) * self.std - self.mu) + return logSNR + +class LinearSchedule(BaseSchedule): + def setup(self, logsnr_range=[-10, 10]): + self.logsnr_range = logsnr_range + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + logSNR = t * (self.logsnr_range[0]-self.logsnr_range[1]) + self.logsnr_range[1] + return logSNR + +# Any schedule that cannot be described easily as a continuous function of t +# It needs to define self.x and self.y in the setup() method +class PiecewiseLinearSchedule(BaseSchedule): + def setup(self): + self.x = None + self.y = None + + def piecewise_linear(self, x, xs, ys): + indices = torch.searchsorted(xs[:-1], x) - 1 + x_min, x_max = xs[indices], xs[indices+1] + y_min, y_max = ys[indices], ys[indices+1] + var = y_min + (y_max - y_min) * (x - x_min) / (x_max - x_min) + return var + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + var = self.piecewise_linear(t, self.x.to(t.device), self.y.to(t.device)) + logSNR = (var/(1-var)).log() + return logSNR + +class StableDiffusionSchedule(PiecewiseLinearSchedule): + def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): + linear_range_sqrt = [r**0.5 for r in linear_range] + self.x = torch.linspace(0, 1, total_steps+1) + + alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 + self.y = alphas.cumprod(dim=-1) + +class AdaptiveTrainSchedule(BaseSchedule): + def setup(self, logsnr_range=[-10, 10], buckets=100, min_probs=0.0): + th = torch.linspace(logsnr_range[0], logsnr_range[1], buckets+1) + self.bucket_ranges = torch.tensor([(th[i], th[i+1]) for i in range(buckets)]) + self.bucket_probs = torch.ones(buckets) + self.min_probs = min_probs + + def schedule(self, t, batch_size): + if t is not None: + raise Exception("AdaptiveTrainSchedule doesn't support passing timesteps: t") + norm_probs = ((self.bucket_probs+self.min_probs) / (self.bucket_probs+self.min_probs).sum()) + buckets = torch.multinomial(norm_probs, batch_size, replacement=True) + ranges = self.bucket_ranges[buckets] + logSNR = torch.rand(batch_size) * (ranges[:, 1]-ranges[:, 0]) + ranges[:, 0] + return logSNR + + def update_buckets(self, logSNR, loss, beta=0.99): + range_mtx = self.bucket_ranges.unsqueeze(0).expand(logSNR.size(0), -1, -1).to(logSNR.device) + range_mask = (range_mtx[:, :, 0] <= logSNR[:, None]) * (range_mtx[:, :, 1] > logSNR[:, None]).float() + range_idx = range_mask.argmax(-1).cpu() + self.bucket_probs[range_idx] = self.bucket_probs[range_idx] * beta + loss.detach().cpu() * (1-beta) + +class InterpolatedSchedule(BaseSchedule): + def setup(self, scheduler1, scheduler2, shifts=[1.0, 1.0]): + self.scheduler1 = scheduler1 + self.scheduler2 = scheduler2 + self.shifts = shifts + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + t = t.clamp(1e-7, 1-1e-7) # avoid infinities multiplied by 0 which cause nan + low_logSNR = self.scheduler1(t, shift=self.shifts[0]) + high_logSNR = self.scheduler2(t, shift=self.shifts[1]) + return low_logSNR * t + high_logSNR * (1-t) + diff --git a/gdf/targets.py b/gdf/targets.py new file mode 100644 index 0000000000000000000000000000000000000000..115062b6001f93082fa836e1f3742723e5972efe --- /dev/null +++ b/gdf/targets.py @@ -0,0 +1,46 @@ +class EpsilonTarget(): + def __call__(self, x0, epsilon, logSNR, a, b): + return epsilon + + def x0(self, noised, pred, logSNR, a, b): + return (noised - pred * b) / a + + def epsilon(self, noised, pred, logSNR, a, b): + return pred + def noise_givenx0_noised(self, x0, noised , logSNR, a, b): + return (noised - a * x0) / b + def xt(self, x0, noise, logSNR, a, b): + + return x0 * a + noise*b +class X0Target(): + def __call__(self, x0, epsilon, logSNR, a, b): + return x0 + + def x0(self, noised, pred, logSNR, a, b): + return pred + + def epsilon(self, noised, pred, logSNR, a, b): + return (noised - pred * a) / b + +class VTarget(): + def __call__(self, x0, epsilon, logSNR, a, b): + return a * epsilon - b * x0 + + def x0(self, noised, pred, logSNR, a, b): + squared_sum = a**2 + b**2 + return a/squared_sum * noised - b/squared_sum * pred + + def epsilon(self, noised, pred, logSNR, a, b): + squared_sum = a**2 + b**2 + return b/squared_sum * noised + a/squared_sum * pred + +class RectifiedFlowsTarget(): + def __call__(self, x0, epsilon, logSNR, a, b): + return epsilon - x0 + + def x0(self, noised, pred, logSNR, a, b): + return noised - pred * b + + def epsilon(self, noised, pred, logSNR, a, b): + return noised + pred * a + \ No newline at end of file diff --git a/inference/__init__.py b/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/test_controlnet.py b/inference/test_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..bd8ea6b462f821e0cec00c952c74c37075e3e04e --- /dev/null +++ b/inference/test_controlnet.py @@ -0,0 +1,166 @@ +import os +import yaml +import torch +import torchvision +from tqdm import tqdm +import sys +sys.path.append(os.path.abspath('./')) + +from inference.utils import * +from core.utils import load_or_fail +from train import WurstCore_control_lrguide, WurstCoreB +from PIL import Image +from core.utils import load_or_fail +import math +import argparse +import time +import random +import numpy as np +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( '--height', type=int, default=3840, help='image height') + parser.add_argument('--width', type=int, default=2160, help='image width') + parser.add_argument('--control_weight', type=float, default=0.70, help='[ 0.3, 0.8]') + parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') + parser.add_argument('--seed', type=int, default=123, help='random seed') + parser.add_argument('--config_c', type=str, + default='configs/training/cfg_control_lr.yaml' ,help='config file for stage c, latent generation') + parser.add_argument('--config_b', type=str, + default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') + parser.add_argument( '--prompt', type=str, + default='A peaceful lake surrounded by mountain, white cloud in the sky, high quality,', help='text prompt') + parser.add_argument( '--num_image', type=int, default=4, help='how many images generated') + parser.add_argument( '--output_dir', type=str, default='figures/controlnet_results/', help='output directory for generated image') + parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') + parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') + parser.add_argument( '--canny_source_url', type=str, default="figures/California_000490.jpg", help='image used to extract canny edge map') + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + + args = parse_args() + width = args.width + height = args.height + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float + + + # SETUP STAGE C + with open(args.config_c, "r", encoding="utf-8") as file: + loaded_config = yaml.safe_load(file) + core = WurstCore_control_lrguide(config_dict=loaded_config, device=device, training=False) + + # SETUP STAGE B + with open(args.config_b, "r", encoding="utf-8") as file: + config_file_b = yaml.safe_load(file) + + core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) + + extras = core.setup_extras_pre() + models = core.setup_models(extras) + models.generator.eval().requires_grad_(False) + print("CONTROLNET READY") + + extras_b = core_b.setup_extras_pre() + models_b = core_b.setup_models(extras_b, skip_clip=True) + models_b = WurstCoreB.Models( + **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} + ) + models_b.generator.eval().requires_grad_(False) + print("STAGE B READY") + + batch_size = 1 + save_dir = args.output_dir + url = args.canny_source_url + images = resize_image(Image.open(url).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1) + batch = {'images': images} + + + + + + + cnet_multiplier = args.control_weight # 0.8 0.6 0.3 control strength + caption_list = [args.prompt] * args.num_image + height_lr, width_lr = get_target_lr_size(height / width, std_size=32) + stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) + stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) + + + + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + + sdd = torch.load(args.pretrained_path, map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + models.train_norm.load_state_dict(collect_sd, strict=True) + + + + + models.controlnet.load_state_dict(load_or_fail(core.config.controlnet_checkpoint_path), strict=True) + # Stage C Parameters + extras.sampling_configs['cfg'] = 1 + extras.sampling_configs['shift'] = 2 + extras.sampling_configs['timesteps'] = 20 + extras.sampling_configs['t_start'] = 1.0 + + # Stage B Parameters + extras_b.sampling_configs['cfg'] = 1.1 + extras_b.sampling_configs['shift'] = 1 + extras_b.sampling_configs['timesteps'] = 10 + extras_b.sampling_configs['t_start'] = 1.0 + + # PREPARE CONDITIONS + + + + + for out_cnt, caption in enumerate(caption_list): + with torch.no_grad(): + + batch['captions'] = [caption + ' high quality'] * batch_size + conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + cnet, cnet_input = core.get_cnet(batch, models, extras) + cnet_uncond = cnet + conditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet] + unconditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet_uncond] + edge_images = show_images(cnet_input) + models.generator.cuda() + for idx, img in enumerate(edge_images): + img.save(os.path.join(save_dir, f"edge_{url.split('/')[-1]}")) + + + print('STAGE C GENERATION***************************') + with torch.cuda.amp.autocast(dtype=dtype): + sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions, unconditions) + models.generator.cpu() + torch.cuda.empty_cache() + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + + conditions_b['effnet'] = sampled_c + unconditions_b['effnet'] = torch.zeros_like(sampled_c) + print('STAGE B + A DECODING***************************') + with torch.cuda.amp.autocast(dtype=dtype): + sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) + + torch.cuda.empty_cache() + imgs = show_images(sampled) + + for idx, img in enumerate(imgs): + img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(out_cnt).zfill(5) + '.jpg')) + print('finished! Results at ', save_dir ) diff --git a/inference/test_personalized.py b/inference/test_personalized.py new file mode 100644 index 0000000000000000000000000000000000000000..34c14eb650e2612a6d93b0ce9051a544b9cec266 --- /dev/null +++ b/inference/test_personalized.py @@ -0,0 +1,180 @@ + +import os +import yaml +import torch +from tqdm import tqdm +import sys +sys.path.append(os.path.abspath('./')) +from inference.utils import * +from train import WurstCoreB +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from train import WurstCore_personalized as WurstCoreC +import torch.nn.functional as F +import numpy as np +import random +import math +import argparse + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( '--height', type=int, default=3072, help='image height') + parser.add_argument('--width', type=int, default=4096, help='image width') + parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') + parser.add_argument('--seed', type=int, default=23, help='random seed') + parser.add_argument('--config_c', type=str, + default="configs/training/lora_personalization.yaml" ,help='config file for stage c, latent generation') + parser.add_argument('--config_b', type=str, + default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') + parser.add_argument( '--prompt', type=str, + default='A photo of cat [roubaobao] with sunglasses, Time Square in the background, high quality, detail rich, 8k', help='text prompt') + parser.add_argument( '--num_image', type=int, default=4, help='how many images generated') + parser.add_argument( '--output_dir', type=str, default='figures/personalized/', help='output directory for generated image') + parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') + parser.add_argument( '--pretrained_path_lora', type=str, default='models/lora_cat.safetensors',help='pretrained path of personalized lora parameter') + parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_args() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float + + + # SETUP STAGE C + with open(args.config_c, "r", encoding="utf-8") as file: + loaded_config = yaml.safe_load(file) + core = WurstCoreC(config_dict=loaded_config, device=device, training=False) + + # SETUP STAGE B + with open(args.config_b, "r", encoding="utf-8") as file: + config_file_b = yaml.safe_load(file) + core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) + + extras = core.setup_extras_pre() + models = core.setup_models(extras) + models.generator.eval().requires_grad_(False) + print("STAGE C READY") + + extras_b = core_b.setup_extras_pre() + models_b = core_b.setup_models(extras_b, skip_clip=True) + models_b = WurstCoreB.Models( + **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} + ) + models_b.generator.bfloat16().eval().requires_grad_(False) + print("STAGE B READY") + + + batch_size = 1 + captions = [args.prompt] * args.num_image + height, width = args.height, args.width + save_dir = args.output_dir + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + + pretrained_pth = args.pretrained_path + sdd = torch.load(pretrained_pth, map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + + models.train_norm.load_state_dict(collect_sd) + + + pretrained_pth_lora = args.pretrained_path_lora + sdd = torch.load(pretrained_pth_lora, map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + + models.train_lora.load_state_dict(collect_sd) + + + models.generator.eval() + models.train_norm.eval() + + + height_lr, width_lr = get_target_lr_size(height / width, std_size=32) + stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) + stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) + + # Stage C Parameters + + extras.sampling_configs['cfg'] = 4 + extras.sampling_configs['shift'] = 1 + extras.sampling_configs['timesteps'] = 20 + extras.sampling_configs['t_start'] = 1.0 + extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) + + + + # Stage B Parameters + + extras_b.sampling_configs['cfg'] = 1.1 + extras_b.sampling_configs['shift'] = 1 + extras_b.sampling_configs['timesteps'] = 10 + extras_b.sampling_configs['t_start'] = 1.0 + + + for cnt, caption in enumerate(captions): + + batch = {'captions': [caption] * batch_size} + conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + + + + + for cnt, caption in enumerate(captions): + + + batch = {'captions': [caption] * batch_size} + conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + + + with torch.no_grad(): + + + models.generator.cuda() + print('STAGE C GENERATION***************************') + with torch.cuda.amp.autocast(dtype=dtype): + sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) + + + + models.generator.cpu() + torch.cuda.empty_cache() + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + conditions_b['effnet'] = sampled_c + unconditions_b['effnet'] = torch.zeros_like(sampled_c) + print('STAGE B + A DECODING***************************') + + with torch.cuda.amp.autocast(dtype=dtype): + sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) + + torch.cuda.empty_cache() + imgs = show_images(sampled) + for idx, img in enumerate(imgs): + print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx) + img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg')) + + + print('finished! Results at ', save_dir ) + + + diff --git a/inference/test_t2i.py b/inference/test_t2i.py new file mode 100644 index 0000000000000000000000000000000000000000..f16a0e62f24c387476467770cccdf146a4a1aa23 --- /dev/null +++ b/inference/test_t2i.py @@ -0,0 +1,170 @@ + +import os +import yaml +import torch +from tqdm import tqdm +import sys +sys.path.append(os.path.abspath('./')) +from inference.utils import * +from core.utils import load_or_fail +from train import WurstCoreB +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from train import WurstCore_t2i as WurstCoreC +import torch.nn.functional as F +from core.utils import load_or_fail +import numpy as np +import random +import math +import argparse +from einops import rearrange +import math +#inrfft_3b_strc_WurstCore +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( '--height', type=int, default=2560, help='image height') + parser.add_argument('--width', type=int, default=5120, help='image width') + parser.add_argument('--seed', type=int, default=123, help='random seed') + parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') + parser.add_argument('--config_c', type=str, + default='configs/training/t2i.yaml' ,help='config file for stage c, latent generation') + parser.add_argument('--config_b', type=str, + default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') + parser.add_argument( '--prompt', type=str, + default='A photo-realistic image of a west highland white terrier in the garden, high quality, detail rich, 8K', help='text prompt') + parser.add_argument( '--num_image', type=int, default=10, help='how many images generated') + parser.add_argument( '--output_dir', type=str, default='figures/output_results/', help='output directory for generated image') + parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') + parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') + args = parser.parse_args() + return args + + + +if __name__ == "__main__": + + args = parse_args() + print(args) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(device) + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float + #gdf = gdf_refine( + # schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + # input_scaler=VPScaler(), target=EpsilonTarget(), + # noise_cond=CosineTNoiseCond(), + # loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + # ) + # SETUP STAGE C + config_file = args.config_c + with open(config_file, "r", encoding="utf-8") as file: + loaded_config = yaml.safe_load(file) + + core = WurstCoreC(config_dict=loaded_config, device=device, training=False) + + # SETUP STAGE B + config_file_b = args.config_b + with open(config_file_b, "r", encoding="utf-8") as file: + config_file_b = yaml.safe_load(file) + + core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) + + extras = core.setup_extras_pre() + models = core.setup_models(extras) + models.generator.eval().requires_grad_(False) + print("STAGE C READY") + + extras_b = core_b.setup_extras_pre() + models_b = core_b.setup_models(extras_b, skip_clip=True) + models_b = WurstCoreB.Models( + **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} + ) + models_b.generator.bfloat16().eval().requires_grad_(False) + print("STAGE B READY") + + captions = [args.prompt] * args.num_image + + + height, width = args.height, args.width + save_dir = args.output_dir + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pretrained_path = args.pretrained_path + sdd = torch.load(pretrained_path, map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + + models.train_norm.load_state_dict(collect_sd) + + + models.generator.eval() + models.train_norm.eval() + + batch_size=1 + height_lr, width_lr = get_target_lr_size(height / width, std_size=32) + stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) + stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) + + # Stage C Parameters + extras.sampling_configs['cfg'] = 4 + extras.sampling_configs['shift'] = 1 + extras.sampling_configs['timesteps'] = 20 + extras.sampling_configs['t_start'] = 1.0 + extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) + + + + # Stage B Parameters + extras_b.sampling_configs['cfg'] = 1.1 + extras_b.sampling_configs['shift'] = 1 + extras_b.sampling_configs['timesteps'] = 10 + extras_b.sampling_configs['t_start'] = 1.0 + + + + + for cnt, caption in enumerate(captions): + + + batch = {'captions': [caption] * batch_size} + conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + + + with torch.no_grad(): + + + models.generator.cuda() + print('STAGE C GENERATION***************************') + with torch.cuda.amp.autocast(dtype=dtype): + sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) + + + + models.generator.cpu() + torch.cuda.empty_cache() + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + conditions_b['effnet'] = sampled_c + unconditions_b['effnet'] = torch.zeros_like(sampled_c) + print('STAGE B + A DECODING***************************') + + with torch.cuda.amp.autocast(dtype=dtype): + sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) + + torch.cuda.empty_cache() + imgs = show_images(sampled) + for idx, img in enumerate(imgs): + print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx) + img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg')) + + + print('finished! Results at ', save_dir ) diff --git a/inference/utils.py b/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5af277069ec7803d53ff8f5fa29bed41fde29b --- /dev/null +++ b/inference/utils.py @@ -0,0 +1,131 @@ +import PIL +import torch +import requests +import torchvision +from math import ceil +from io import BytesIO +import matplotlib.pyplot as plt +import torchvision.transforms.functional as F +import math +from tqdm import tqdm +def download_image(url): + return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB") + + +def resize_image(image, size=768): + tensor_image = F.to_tensor(image) + resized_image = F.resize(tensor_image, size, antialias=True) + return resized_image + + +def downscale_images(images, factor=3/4): + scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32) + scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) + return scaled_image + + + +def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0): + resolution_multiple = 42.67 + latent_height = ceil(height / compression_factor_b) + latent_width = ceil(width / compression_factor_b) + stage_c_latent_shape = (batch_size, 16, latent_height, latent_width) + + latent_height = ceil(height / compression_factor_a) + latent_width = ceil(width / compression_factor_a) + stage_b_latent_shape = (batch_size, 4, latent_height, latent_width) + + return stage_c_latent_shape, stage_b_latent_shape + + +def get_views(H, W, window_size=64, stride=16): + ''' + - H, W: height and width of the latent + ''' + num_blocks_height = (H - window_size) // stride + 1 + num_blocks_width = (W - window_size) // stride + 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) + views = [] + for i in range(total_num_blocks): + h_start = int((i // num_blocks_width) * stride) + h_end = h_start + window_size + w_start = int((i % num_blocks_width) * stride) + w_end = w_start + window_size + views.append((h_start, h_end, w_start, w_end)) + return views + + + +def show_images(images, rows=None, cols=None, **kwargs): + if images.size(1) == 1: + images = images.repeat(1, 3, 1, 1) + elif images.size(1) > 3: + images = images[:, :3] + + if rows is None: + rows = 1 + if cols is None: + cols = images.size(0) // rows + + _, _, h, w = images.shape + + imgs = [] + for i, img in enumerate(images): + imgs.append( torchvision.transforms.functional.to_pil_image(img.clamp(0, 1))) + + return imgs + + + +def decode_b(conditions_b, unconditions_b, models_b, bshape, extras_b, device, \ + stage_a_tiled=False, num_instance=4, patch_size=256, stride=24): + + + sampling_b = extras_b.gdf.sample( + models_b.generator.half(), conditions_b, bshape, + unconditions_b, device=device, + **extras_b.sampling_configs, + ) + models_b.generator.cuda() + for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']): + sampled_b = sampled_b + models_b.generator.cpu() + torch.cuda.empty_cache() + if stage_a_tiled: + with torch.cuda.amp.autocast(dtype=torch.float16): + padding = (stride*2, stride*2, stride*2, stride*2) + sampled_b = torch.nn.functional.pad(sampled_b, padding, mode='reflect') + count = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device) + sampled = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device) + views = get_views(sampled_b.shape[-2], sampled_b.shape[-1], window_size=patch_size, stride=stride) + + for view_idx, (h_start, h_end, w_start, w_end) in enumerate(tqdm(views, total=len(views))): + + sampled[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += models_b.stage_a.decode(sampled_b[:, :, h_start:h_end, w_start:w_end]).float() + count[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += 1 + sampled /= count + sampled = sampled[:, :, stride*4*2:-stride*4*2, stride*4*2:-stride*4*2] + else: + + sampled = models_b.stage_a.decode(sampled_b, tiled_decoding=stage_a_tiled) + + return sampled.float() + + +def generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions=None, unconditions=None): + if conditions is None: + conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + if unconditions is None: + unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + sampling_c = extras.gdf.sample( + models.generator, conditions, stage_c_latent_shape, stage_c_latent_shape_lr, + unconditions, device=device, **extras.sampling_configs, + ) + for idx, (sampled_c, sampled_c_curr, _, _) in enumerate(tqdm(sampling_c, total=extras.sampling_configs['timesteps'])): + sampled_c = sampled_c + return sampled_c + +def get_target_lr_size(ratio, std_size=24): + w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) + return (h * 32 , w *32 ) + diff --git a/models/models_checklist.txt b/models/models_checklist.txt new file mode 100644 index 0000000000000000000000000000000000000000..2fdec27a72db473c51893abc64826514b1d9d065 --- /dev/null +++ b/models/models_checklist.txt @@ -0,0 +1,7 @@ +https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors +https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors +https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors +https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors +https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors +https://huggingface.co/roubaofeipi/UltraPixel/blob/main/ultrapixel_t2i.safetensors +https://huggingface.co/roubaofeipi/UltraPixel/blob/main/lora_cat.safetensors (only required for personalization) \ No newline at end of file diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a6fcf5aa2a39061c3f4f82dde6ff063411223cb3 --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1,6 @@ +from .effnet import EfficientNetEncoder +from .stage_c import StageC +from .stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from .previewer import Previewer +from .controlnet import ControlNet, ControlNetDeliverer +from . import controlnet as controlnet_filters diff --git a/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc b/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c74bb92cb0db0876acda8aa3d102141526fd428 Binary files /dev/null and b/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc differ diff --git a/modules/cnet_modules/face_id/arcface.py b/modules/cnet_modules/face_id/arcface.py new file mode 100644 index 0000000000000000000000000000000000000000..64e918bb90437f6f193a7ec384bea1fcd73c7abb --- /dev/null +++ b/modules/cnet_modules/face_id/arcface.py @@ -0,0 +1,276 @@ +import numpy as np +import onnx, onnx2torch, cv2 +import torch +from insightface.utils import face_align + + +class ArcFaceRecognizer: + def __init__(self, model_file=None, device='cpu', dtype=torch.float32): + assert model_file is not None + self.model_file = model_file + + self.device = device + self.dtype = dtype + self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + self.input_mean = 127.5 + self.input_std = 127.5 + self.input_size = (112, 112) + self.input_shape = ['None', 3, 112, 112] + + def get(self, img, face): + aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0]) + face.embedding = self.get_feat(aimg).flatten() + return face.embedding + + def compute_sim(self, feat1, feat2): + from numpy.linalg import norm + feat1 = feat1.ravel() + feat2 = feat2.ravel() + sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2)) + return sim + + def get_feat(self, imgs): + if not isinstance(imgs, list): + imgs = [imgs] + input_size = self.input_size + + blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size, + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + + blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) + net_out = self.model(blob_torch) + return net_out[0].float().cpu() + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return np.stack([x1, y1, x2, y2], axis=-1) + + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i % 2] + distance[:, i] + py = points[:, i % 2 + 1] + distance[:, i + 1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return np.stack(preds, axis=-1) + + +class FaceDetector: + def __init__(self, model_file=None, dtype=torch.float32, device='cuda'): + self.model_file = model_file + self.taskname = 'detection' + self.center_cache = {} + self.nms_thresh = 0.4 + self.det_thresh = 0.5 + + self.device = device + self.dtype = dtype + self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + input_shape = (320, 320) + self.input_size = input_shape + self.input_shape = input_shape + + self.input_mean = 127.5 + self.input_std = 128.0 + self._anchor_ratio = 1.0 + self._num_anchors = 1 + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + self.use_kps = True + + self.det_thresh = 0.5 + self.nms_thresh = 0.4 + + def forward(self, img, threshold): + scores_list = [] + bboxes_list = [] + kpss_list = [] + input_size = tuple(img.shape[0:2][::-1]) + blob = cv2.dnn.blobFromImage(img, 1.0 / self.input_std, input_size, + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) + net_outs_torch = self.model(blob_torch) + # print(list(map(lambda x: x.shape, net_outs_torch))) + net_outs = list(map(lambda x: x.float().cpu().numpy(), net_outs_torch)) + + input_height = blob.shape[2] + input_width = blob.shape[3] + fmc = self.fmc + for idx, stride in enumerate(self._feat_stride_fpn): + scores = net_outs[idx] + bbox_preds = net_outs[idx + fmc] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx + fmc * 2] * stride + height = input_height // stride + width = input_width // stride + K = height * width + key = (height, width, stride) + if key in self.center_cache: + anchor_centers = self.center_cache[key] + else: + # solution-1, c style: + # anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) + # for i in range(height): + # anchor_centers[i, :, 1] = i + # for i in range(width): + # anchor_centers[:, i, 0] = i + + # solution-2: + # ax = np.arange(width, dtype=np.float32) + # ay = np.arange(height, dtype=np.float32) + # xv, yv = np.meshgrid(np.arange(width), np.arange(height)) + # anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) + + # solution-3: + anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) + # print(anchor_centers.shape) + + anchor_centers = (anchor_centers * stride).reshape((-1, 2)) + if self._num_anchors > 1: + anchor_centers = np.stack([anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2)) + if len(self.center_cache) < 100: + self.center_cache[key] = anchor_centers + + pos_inds = np.where(scores >= threshold)[0] + bboxes = distance2bbox(anchor_centers, bbox_preds) + pos_scores = scores[pos_inds] + pos_bboxes = bboxes[pos_inds] + scores_list.append(pos_scores) + bboxes_list.append(pos_bboxes) + if self.use_kps: + kpss = distance2kps(anchor_centers, kps_preds) + # kpss = kps_preds + kpss = kpss.reshape((kpss.shape[0], -1, 2)) + pos_kpss = kpss[pos_inds] + kpss_list.append(pos_kpss) + return scores_list, bboxes_list, kpss_list + + def detect(self, img, input_size=None, max_num=0, metric='default'): + assert input_size is not None or self.input_size is not None + input_size = self.input_size if input_size is None else input_size + + im_ratio = float(img.shape[0]) / img.shape[1] + model_ratio = float(input_size[1]) / input_size[0] + if im_ratio > model_ratio: + new_height = input_size[1] + new_width = int(new_height / im_ratio) + else: + new_width = input_size[0] + new_height = int(new_width * im_ratio) + det_scale = float(new_height) / img.shape[0] + resized_img = cv2.resize(img, (new_width, new_height)) + det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8) + det_img[:new_height, :new_width, :] = resized_img + + scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) + + scores = np.vstack(scores_list) + scores_ravel = scores.ravel() + order = scores_ravel.argsort()[::-1] + bboxes = np.vstack(bboxes_list) / det_scale + if self.use_kps: + kpss = np.vstack(kpss_list) / det_scale + pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) + pre_det = pre_det[order, :] + keep = self.nms(pre_det) + det = pre_det[keep, :] + if self.use_kps: + kpss = kpss[order, :, :] + kpss = kpss[keep, :, :] + else: + kpss = None + if max_num > 0 and det.shape[0] > max_num: + area = (det[:, 2] - det[:, 0]) * (det[:, 3] - + det[:, 1]) + img_center = img.shape[0] // 2, img.shape[1] // 2 + offsets = np.vstack([ + (det[:, 0] + det[:, 2]) / 2 - img_center[1], + (det[:, 1] + det[:, 3]) / 2 - img_center[0] + ]) + offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) + if metric == 'max': + values = area + else: + values = area - offset_dist_squared * 2.0 # some extra weight on the centering + bindex = np.argsort( + values)[::-1] # some extra weight on the centering + bindex = bindex[0:max_num] + det = det[bindex, :] + if kpss is not None: + kpss = kpss[bindex, :] + return det, kpss + + def nms(self, dets): + thresh = self.nms_thresh + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep diff --git a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8200104d6d66a1084685c76373c38d752ed9c3d4 Binary files /dev/null and b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc differ diff --git a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca432e5c5eed7ba17fc6cafb06a3ebe16002f67e Binary files /dev/null and b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc differ diff --git a/modules/cnet_modules/inpainting/saliency_model.pt b/modules/cnet_modules/inpainting/saliency_model.pt new file mode 100644 index 0000000000000000000000000000000000000000..e1b02cc60b2999a8f9ff90557182e3dafab63db7 --- /dev/null +++ b/modules/cnet_modules/inpainting/saliency_model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:225a602e1f2a5d159424be011a63b27d83b56343a4379a90710eca9a26bab920 +size 451123 diff --git a/modules/cnet_modules/inpainting/saliency_model.py b/modules/cnet_modules/inpainting/saliency_model.py new file mode 100644 index 0000000000000000000000000000000000000000..82355a02baead47f50fe643e57b81f8caca78f79 --- /dev/null +++ b/modules/cnet_modules/inpainting/saliency_model.py @@ -0,0 +1,81 @@ +import torch +import torchvision +from torch import nn +from PIL import Image +import numpy as np +import os + + +# MICRO RESNET +class ResBlock(nn.Module): + def __init__(self, channels): + super(ResBlock, self).__init__() + + self.resblock = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(channels, channels, kernel_size=3), + nn.InstanceNorm2d(channels, affine=True), + nn.ReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(channels, channels, kernel_size=3), + nn.InstanceNorm2d(channels, affine=True), + ) + + def forward(self, x): + out = self.resblock(x) + return out + x + + +class Upsample2d(nn.Module): + def __init__(self, scale_factor): + super(Upsample2d, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') + return x + + +class MicroResNet(nn.Module): + def __init__(self): + super(MicroResNet, self).__init__() + + self.downsampler = nn.Sequential( + nn.ReflectionPad2d(4), + nn.Conv2d(3, 8, kernel_size=9, stride=4), + nn.InstanceNorm2d(8, affine=True), + nn.ReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(8, 16, kernel_size=3, stride=2), + nn.InstanceNorm2d(16, affine=True), + nn.ReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(16, 32, kernel_size=3, stride=2), + nn.InstanceNorm2d(32, affine=True), + nn.ReLU(), + ) + + self.residual = nn.Sequential( + ResBlock(32), + nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), + ResBlock(64), + ) + + self.segmentator = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(64, 16, kernel_size=3), + nn.InstanceNorm2d(16, affine=True), + nn.ReLU(), + Upsample2d(scale_factor=2), + nn.ReflectionPad2d(4), + nn.Conv2d(16, 1, kernel_size=9), + nn.Sigmoid() + ) + + def forward(self, x): + out = self.downsampler(x) + out = self.residual(out) + out = self.segmentator(out) + return out diff --git a/modules/cnet_modules/pidinet/__init__.py b/modules/cnet_modules/pidinet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b4625bf915cc6c4053b7d7861a22ff371bc641 --- /dev/null +++ b/modules/cnet_modules/pidinet/__init__.py @@ -0,0 +1,37 @@ +# Pidinet +# https://github.com/hellozhuo/pidinet + +import os +import torch +import numpy as np +from einops import rearrange +from .model import pidinet +from .util import annotator_ckpts_path, safe_step + + +class PidiNetDetector: + def __init__(self, device): + remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth" + modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + self.netNetwork = pidinet() + self.netNetwork.load_state_dict( + {k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()}) + self.netNetwork.to(device).eval().requires_grad_(False) + + def __call__(self, input_image): # , safe=False): + return self.netNetwork(input_image)[-1] + # assert input_image.ndim == 3 + # input_image = input_image[:, :, ::-1].copy() + # with torch.no_grad(): + # image_pidi = torch.from_numpy(input_image).float().cuda() + # image_pidi = image_pidi / 255.0 + # image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') + # edge = self.netNetwork(image_pidi)[-1] + + # if safe: + # edge = safe_step(edge) + # edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + # return edge[0][0] diff --git a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07fca0abb9c90b7b40746b4044c4000ae69e00c7 Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a060aa2baa87a3670aa0bf8276e2f34bafe9451 Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2243c853d18e2a404ced3eb4ac6a95a7a9ee6874 Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc differ diff --git a/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f70342fc64759bc7459abf0f7986ee3b7fd2126 Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc differ diff --git a/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2e7ab031924860f1262f4d44bf2eaf57ca78edd Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc differ diff --git a/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4da8564d03f99caa7a45d9ccb1358cb282cd2711 Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc differ diff --git a/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth b/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth new file mode 100644 index 0000000000000000000000000000000000000000..1ceba1de87e7bb3c81961b80acbb3a106ca249c0 --- /dev/null +++ b/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80860ac267258b5f27486e0ef152a211d0b08120f62aeb185a050acc30da486c +size 2871148 diff --git a/modules/cnet_modules/pidinet/model.py b/modules/cnet_modules/pidinet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..26644c6f6174c3b5407bd10c914045758cbadefe --- /dev/null +++ b/modules/cnet_modules/pidinet/model.py @@ -0,0 +1,654 @@ +""" +Author: Zhuo Su, Wenzhe Liu +Date: Feb 18, 2021 +""" + +import math + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +nets = { + 'baseline': { + 'layer0': 'cv', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'c-v15': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'a-v15': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'r-v15': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cvvv4': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'avvv4': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'rvvv4': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cccv4': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cv', + }, + 'aaav4': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'cv', + }, + 'rrrv4': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'cv', + }, + 'c16': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cd', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cd', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cd', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cd', + }, + 'a16': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'ad', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'ad', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'ad', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'ad', + }, + 'r16': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'rd', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'rd', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'rd', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'rd', + }, + 'carv4': { + 'layer0': 'cd', + 'layer1': 'ad', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'ad', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'ad', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'ad', + 'layer14': 'rd', + 'layer15': 'cv', + }, +} + + +def createConvFunc(op_type): + assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type) + if op_type == 'cv': + return F.conv2d + + if op_type == 'cd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3' + assert padding == dilation, 'padding for cd_conv set wrong' + + weights_c = weights.sum(dim=[2, 3], keepdim=True) + yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups) + y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y - yc + + return func + elif op_type == 'ad': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3' + assert padding == dilation, 'padding for ad_conv set wrong' + + shape = weights.shape + weights = weights.view(shape[0], shape[1], -1) + weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise + y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + + return func + elif op_type == 'rd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3' + padding = 2 * dilation + + shape = weights.shape + if weights.is_cuda: + buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0) + else: + buffer = torch.zeros(shape[0], shape[1], 5 * 5) + weights = weights.view(shape[0], shape[1], -1) + buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] + buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] + buffer[:, :, 12] = 0 + buffer = buffer.view(shape[0], shape[1], 5, 5) + y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + + return func + else: + print('impossible to be here unless you force that') + return None + + +class Conv2d(nn.Module): + def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False): + super(Conv2d, self).__init__() + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + self.pdc = pdc + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + + return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class CSAM(nn.Module): + """ + Compact Spatial Attention Module + """ + + def __init__(self, channels): + super(CSAM, self).__init__() + + mid_channels = 4 + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False) + self.sigmoid = nn.Sigmoid() + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + y = self.relu1(x) + y = self.conv1(y) + y = self.conv2(y) + y = self.sigmoid(y) + + return x * y + + +class CDCM(nn.Module): + """ + Compact Dilation Convolution based Module + """ + + def __init__(self, in_channels, out_channels): + super(CDCM, self).__init__() + + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False) + self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False) + self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False) + self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False) + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + x = self.relu1(x) + x = self.conv1(x) + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x3 = self.conv2_3(x) + x4 = self.conv2_4(x) + return x1 + x2 + x3 + x4 + + +class MapReduce(nn.Module): + """ + Reduce feature maps into a single edge map + """ + + def __init__(self, channels): + super(MapReduce, self).__init__() + self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0) + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + return self.conv(x) + + +class PDCBlock(nn.Module): + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock, self).__init__() + self.stride = stride + + self.stride = stride + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + + +class PDCBlock_converted(nn.Module): + """ + CPDC, APDC can be converted to vanilla 3x3 convolution + RPDC can be converted to vanilla 5x5 convolution + """ + + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock_converted, self).__init__() + self.stride = stride + + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + if pdc == 'rd': + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False) + else: + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + + +class PiDiNet(nn.Module): + def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False): + super(PiDiNet, self).__init__() + self.sa = sa + if dil is not None: + assert isinstance(dil, int), 'dil should be an int' + self.dil = dil + + self.fuseplanes = [] + + self.inplane = inplane + if convert: + if pdcs[0] == 'rd': + init_kernel_size = 5 + init_padding = 2 + else: + init_kernel_size = 3 + init_padding = 1 + self.init_block = nn.Conv2d(3, self.inplane, + kernel_size=init_kernel_size, padding=init_padding, bias=False) + block_class = PDCBlock_converted + else: + self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1) + block_class = PDCBlock + + self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane) + self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane) + self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2) + self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane) + self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane) + self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 2C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2) + self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane) + self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane) + self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2) + self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane) + self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane) + self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.conv_reduces = nn.ModuleList() + if self.sa and self.dil is not None: + self.attentions = nn.ModuleList() + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.attentions.append(CSAM(self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + elif self.sa: + self.attentions = nn.ModuleList() + for i in range(4): + self.attentions.append(CSAM(self.fuseplanes[i])) + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + elif self.dil is not None: + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + else: + for i in range(4): + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + + self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias + nn.init.constant_(self.classifier.weight, 0.25) + nn.init.constant_(self.classifier.bias, 0) + + # print('initialization done') + + def get_weights(self): + conv_weights = [] + bn_weights = [] + relu_weights = [] + for pname, p in self.named_parameters(): + if 'bn' in pname: + bn_weights.append(p) + elif 'relu' in pname: + relu_weights.append(p) + else: + conv_weights.append(p) + + return conv_weights, bn_weights, relu_weights + + def forward(self, x): + H, W = x.size()[2:] + + x = self.init_block(x) + + x1 = self.block1_1(x) + x1 = self.block1_2(x1) + x1 = self.block1_3(x1) + + x2 = self.block2_1(x1) + x2 = self.block2_2(x2) + x2 = self.block2_3(x2) + x2 = self.block2_4(x2) + + x3 = self.block3_1(x2) + x3 = self.block3_2(x3) + x3 = self.block3_3(x3) + x3 = self.block3_4(x3) + + x4 = self.block4_1(x3) + x4 = self.block4_2(x4) + x4 = self.block4_3(x4) + x4 = self.block4_4(x4) + + x_fuses = [] + if self.sa and self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](self.dilations[i](xi))) + elif self.sa: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](xi)) + elif self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.dilations[i](xi)) + else: + x_fuses = [x1, x2, x3, x4] + + e1 = self.conv_reduces[0](x_fuses[0]) + e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False) + + e2 = self.conv_reduces[1](x_fuses[1]) + e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False) + + e3 = self.conv_reduces[2](x_fuses[2]) + e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False) + + e4 = self.conv_reduces[3](x_fuses[3]) + e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False) + + outputs = [e1, e2, e3, e4] + + output = self.classifier(torch.cat(outputs, dim=1)) + # if not self.training: + # return torch.sigmoid(output) + + outputs.append(output) + outputs = [torch.sigmoid(r) for r in outputs] + return outputs + + +def config_model(model): + model_options = list(nets.keys()) + assert model in model_options, \ + 'unrecognized model, please choose from %s' % str(model_options) + + # print(str(nets[model])) + + pdcs = [] + for i in range(16): + layer_name = 'layer%d' % i + op = nets[model][layer_name] + pdcs.append(createConvFunc(op)) + + return pdcs + + +def pidinet(): + pdcs = config_model('carv4') + dil = 24 # if args.dil else None + return PiDiNet(60, pdcs, dil=dil, sa=True) diff --git a/modules/cnet_modules/pidinet/util.py b/modules/cnet_modules/pidinet/util.py new file mode 100644 index 0000000000000000000000000000000000000000..aec00770c7706f95abf3a0b9b02dbe3232930596 --- /dev/null +++ b/modules/cnet_modules/pidinet/util.py @@ -0,0 +1,97 @@ +import random + +import numpy as np +import cv2 +import os + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z + + +def make_noise_disk(H, W, C, F): + noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) + noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) + noise = noise[F: F + H, F: F + W] + noise -= np.min(noise) + noise /= np.max(noise) + if C == 1: + noise = noise[:, :, None] + return noise + + +def min_max_norm(x): + x -= np.min(x) + x /= np.maximum(np.max(x), 1e-5) + return x + + +def safe_step(x, step=2): + y = x.astype(np.float32) * float(step + 1) + y = y.astype(np.int32).astype(np.float32) / float(step) + return y + + +def img2mask(img, H, W, low=10, high=90): + assert img.ndim == 3 or img.ndim == 2 + assert img.dtype == np.uint8 + + if img.ndim == 3: + y = img[:, :, random.randrange(0, img.shape[2])] + else: + y = img + + y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) + + if random.uniform(0, 1) < 0.5: + y = 255 - y + + return y < np.percentile(y, random.randrange(low, high)) diff --git a/modules/common.py b/modules/common.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4ad71649f60f2dd38947c9ebc23bc51db2b544 --- /dev/null +++ b/modules/common.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from einops import rearrange +import torch.fft as fft +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None + +class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None + + + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + #print('in line 23 algong self att ', kv.shape, x.shape) + kv = torch.cat([x, kv], dim=1) + #if x.shape[1] >= 72 * 72: + # x = x * math.sqrt(math.log(64*64, 24*24)) + + x = self.attn(x, kv, kv, need_weights=False)[0] + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2): + super().__init__() + self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + # self.depthwise = SAMBlock(c, num_heads, expansion) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + Linear(c * 4, c) + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + res = self.attention(self.norm(x), kv, self_attn=self.self_attn) + + #print(torch.unique(res), torch.unique(x), self.self_attn) + #scale = math.sqrt(math.log(x.shape[-2] * x.shape[-1], 24*24)) + x = x + res + + return x + +class FeedForwardBlock(nn.Module): + def __init__(self, c, dropout=0.0): + super().__init__() + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + Linear(c * 4, c) + ) + + def forward(self, x): + x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=['sca']): + super().__init__() + self.mapper = Linear(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b diff --git a/modules/common_ckpt.py b/modules/common_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..bf196ef5f95a50ac6696331207d4327d74ceef36 --- /dev/null +++ b/modules/common_ckpt.py @@ -0,0 +1,360 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from einops import rearrange +from modules.speed_util import checkpoint +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None + +class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None + +class AttnBlock_lrfuse_backup(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, use_checkpoint=True): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + self.fuse_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + self.use_checkpoint = use_checkpoint + + def forward(self, hr, lr): + return checkpoint(self._forward, (hr, lr), self.paramters(), self.use_checkpoint) + def _forward(self, hr, lr): + res = hr + hr = self.kv_mapper(rearrange(hr, 'b c h w -> b (h w ) c')) + lr_fuse = self.attention(self.norm(lr), hr, self_attn=False) + lr + + lr_fuse = self.fuse_mapper(rearrange(lr_fuse, 'b c h w -> b (h w ) c')) + hr = self.attention(self.norm(res), lr_fuse, self_attn=False) + res + return hr + + +class AttnBlock_lrfuse(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, kernel_size=3, use_checkpoint=True): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + + + self.depthwise = Conv2d(c, c , kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + + self.channelwise = nn.Sequential( + Linear(c + c, c ), + nn.GELU(), + GlobalResponseNorm(c ), + nn.Dropout(dropout), + Linear(c , c) + ) + self.use_checkpoint = use_checkpoint + + + def forward(self, hr, lr): + return checkpoint(self._forward, (hr, lr), self.parameters(), self.use_checkpoint) + + def _forward(self, hr, lr): + res = hr + hr = self.kv_mapper(rearrange(hr, 'b c h w -> b (h w ) c')) + lr_fuse = self.attention(self.norm(lr), hr, self_attn=False) + lr + + lr_fuse = torch.nn.functional.interpolate(lr_fuse.float(), res.shape[2:]) + #print('in line 65', lr_fuse.shape, res.shape) + media = torch.cat((self.depthwise(lr_fuse), res), dim=1) + out = self.channelwise(media.permute(0,2,3,1)).permute(0,3,1,2) + res + + return out + + + + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + #print('in line 23 algong self att ', kv.shape, x.shape) + + kv = torch.cat([x, kv], dim=1) + #if x.shape[1] > 48 * 48 and not self.training: + # x = x * math.sqrt(math.log(x.shape[1] , 24*24)) + + x = self.attn(x, kv, kv, need_weights=False)[0] + x = x.permute(0, 2, 1).view(*orig_shape) + return x +class Attention2D_splitpatch(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + + #x = rearrange(x, 'b c h w -> b c (nh wh) (nw ww)', wh=24, ww=24, nh=orig_shape[-2] // 24, nh=orig_shape[-1] // 24,) + x = rearrange(x, 'b c (nh wh) (nw ww) -> (b nh nw) (wh ww) c', wh=24, ww=24, nh=orig_shape[-2] // 24, nw=orig_shape[-1] // 24,) + #print('in line 168', x.shape) + #x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + #print('in line 23 algong self att ', kv.shape, x.shape) + num = (orig_shape[-2] // 24) * (orig_shape[-1] // 24) + kv = torch.cat([x, kv.repeat(num, 1, 1)], dim=1) + #if x.shape[1] > 48 * 48 and not self.training: + # x = x * math.sqrt(math.log(x.shape[1] / math.sqrt(16), 24*24)) + + x = self.attn(x, kv, kv, need_weights=False)[0] + x = rearrange(x, ' (b nh nw) (wh ww) c -> b c (nh wh) (nw ww)', b=orig_shape[0], wh=24, ww=24, nh=orig_shape[-2] // 24, nw=orig_shape[-1] // 24) + #x = x.permute(0, 2, 1).view(*orig_shape) + + return x +class Attention2D_extra(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, extra_emb=None, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + num_x = x.shape[1] + + + if extra_emb is not None: + ori_extra_shape = extra_emb.shape + extra_emb = extra_emb.view(extra_emb.size(0), extra_emb.size(1), -1).permute(0, 2, 1) + x = torch.cat((x, extra_emb), dim=1) + if self_attn: + #print('in line 23 algong self att ', kv.shape, x.shape) + kv = torch.cat([x, kv], dim=1) + x = self.attn(x, kv, kv, need_weights=False)[0] + img = x[:, :num_x, :].permute(0, 2, 1).view(*orig_shape) + if extra_emb is not None: + fix = x[:, num_x:, :].permute(0, 2, 1).view(*ori_extra_shape) + return img, fix + else: + return img +class AttnBlock_extraq(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + #self.norm2 = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D_extra(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + # norm2 initialization in generator in init extra parameter + def forward(self, x, kv, extra_emb=None): + #print('in line 84', x.shape, kv.shape, self.self_attn, extra_emb if extra_emb is None else extra_emb.shape) + #in line 84 torch.Size([1, 1536, 32, 32]) torch.Size([1, 85, 1536]) True None + #if extra_emb is not None: + + kv = self.kv_mapper(kv) + if extra_emb is not None: + res_x, res_extra = self.attention(self.norm(x), kv, extra_emb=self.norm2(extra_emb), self_attn=self.self_attn) + x = x + res_x + extra_emb = extra_emb + res_extra + return x, extra_emb + else: + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x +class AttnBlock_latent2ex(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + + def forward(self, x, kv): + #print('in line 84', x.shape, kv.shape, self.self_attn) + kv = F.interpolate(kv.float(), x.shape[2:]) + kv = kv.view(kv.size(0), kv.size(1), -1).permute(0, 2, 1) + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) +class AttnBlock_crossbranch(nn.Module): + def __init__(self, attnmodule, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.attn = AttnBlock(c, c_cond, nhead, self_attn, dropout) + #print('in line 108', attnmodule.device) + self.attn.load_state_dict(attnmodule.state_dict()) + self.norm1 = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + + self.channelwise1 = nn.Sequential( + Linear(c *2, c ), + nn.GELU(), + GlobalResponseNorm(c ), + nn.Dropout(dropout), + Linear(c, c) + ) + self.channelwise2 = nn.Sequential( + Linear(c *2, c ), + nn.GELU(), + GlobalResponseNorm(c ), + nn.Dropout(dropout), + Linear(c, c) + ) + self.c = c + def forward(self, x, kv, main_x): + #print('in line 84', x.shape, kv.shape, main_x.shape, self.c) + + x = self.channelwise1(torch.cat((x, F.interpolate(main_x.float(), x.shape[2:])), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x + x = self.attn(x, kv) + main_x = self.channelwise2(torch.cat((main_x, F.interpolate(x.float(), main_x.shape[2:])), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + main_x + return main_x, x + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, use_checkpoint =True): # , num_heads=4, expansion=2): + super().__init__() + self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + # self.depthwise = SAMBlock(c, num_heads, expansion) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + Linear(c * 4, c) + ) + self.use_checkpoint = use_checkpoint + def forward(self, x, x_skip=None): + + if x_skip is not None: + return checkpoint(self._forward_skip, (x, x_skip), self.parameters(), self.use_checkpoint) + else: + #print('in line 298', x.shape) + return checkpoint(self._forward_woskip, (x, ), self.parameters(), self.use_checkpoint) + + + + def _forward_skip(self, x, x_skip): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + def _forward_woskip(self, x): + x_res = x + x = self.norm(self.depthwise(x)) + + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, use_checkpoint=True): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + self.use_checkpoint = use_checkpoint + def forward(self, x, kv): + return checkpoint(self._forward, (x, kv), self.parameters(), self.use_checkpoint) + def _forward(self, x, kv): + kv = self.kv_mapper(kv) + res = self.attention(self.norm(x), kv, self_attn=self.self_attn) + + #print(torch.unique(res), torch.unique(x), self.self_attn) + #scale = math.sqrt(math.log(x.shape[-2] * x.shape[-1], 24*24)) + x = x + res + + return x +class AttnBlock_mytest(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + nn.Linear(c_cond, c) + ) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + +class FeedForwardBlock(nn.Module): + def __init__(self, c, dropout=0.0): + super().__init__() + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + Linear(c * 4, c) + ) + + def forward(self, x): + x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=['sca'], use_checkpoint=True): + super().__init__() + self.mapper = Linear(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) + + self.use_checkpoint = use_checkpoint + def forward(self, x, t): + return checkpoint(self._forward, (x, t), self.parameters(), self.use_checkpoint) + + def _forward(self, x, t): + #print('in line 284', x.shape, t.shape, self.conds) + #in line 284 torch.Size([4, 2048, 19, 29]) torch.Size([4, 192]) ['sca', 'crp'] + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b diff --git a/modules/controlnet.py b/modules/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c187aecb725e00e19924ae308e3aac401acfdf06 --- /dev/null +++ b/modules/controlnet.py @@ -0,0 +1,349 @@ +import torchvision +import torch +from torch import nn +import numpy as np +import kornia +import cv2 +from core.utils import load_or_fail +#from insightface.app.common import Face +from .effnet import EfficientNetEncoder +from .cnet_modules.pidinet import PidiNetDetector +from .cnet_modules.inpainting.saliency_model import MicroResNet +#from .cnet_modules.face_id.arcface import FaceDetector, ArcFaceRecognizer +from .common import LayerNorm2d + + +class CNetResBlock(nn.Module): + def __init__(self, c): + super().__init__() + self.blocks = nn.Sequential( + LayerNorm2d(c), + nn.GELU(), + nn.Conv2d(c, c, kernel_size=3, padding=1), + LayerNorm2d(c), + nn.GELU(), + nn.Conv2d(c, c, kernel_size=3, padding=1), + ) + + def forward(self, x): + return x + self.blocks(x) + + +class ControlNet(nn.Module): + def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None): + super().__init__() + if bottleneck_mode is None: + bottleneck_mode = 'effnet' + self.proj_blocks = proj_blocks + if bottleneck_mode == 'effnet': + embd_channels = 1280 + #self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval() + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + if c_in != 3: + in_weights = self.backbone[0][0].weight.data + self.backbone[0][0] = nn.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False) + if c_in > 3: + nn.init.constant_(self.backbone[0][0].weight, 0) + self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone() + else: + self.backbone[0][0].weight.data = in_weights[:, :c_in].clone() + elif bottleneck_mode == 'simple': + embd_channels = c_in + self.backbone = nn.Sequential( + nn.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1), + ) + elif bottleneck_mode == 'large': + self.backbone = nn.Sequential( + nn.Conv2d(c_in, 4096 * 4, kernel_size=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(4096 * 4, 1024, kernel_size=1), + *[CNetResBlock(1024) for _ in range(8)], + nn.Conv2d(1024, 1280, kernel_size=1), + ) + embd_channels = 1280 + else: + raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}') + self.projections = nn.ModuleList() + for _ in range(len(proj_blocks)): + self.projections.append(nn.Sequential( + nn.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False), + )) + nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection + + def forward(self, x): + x = self.backbone(x) + proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)] + for i, idx in enumerate(self.proj_blocks): + proj_outputs[idx] = self.projections[i](x) + return proj_outputs + + +class ControlNetDeliverer(): + def __init__(self, controlnet_projections): + self.controlnet_projections = controlnet_projections + self.restart() + + def restart(self): + self.idx = 0 + return self + + def __call__(self): + if self.idx < len(self.controlnet_projections): + output = self.controlnet_projections[self.idx] + else: + output = None + self.idx += 1 + return output + + +# CONTROLNET FILTERS ---------------------------------------------------- + +class BaseFilter(): + def __init__(self, device): + self.device = device + + def num_channels(self): + return 3 + + def __call__(self, x): + return x + + +class CannyFilter(BaseFilter): + def __init__(self, device, resize=224): + super().__init__(device) + self.resize = resize + + def num_channels(self): + return 1 + + def __call__(self, x): + orig_size = x.shape[-2:] + if self.resize is not None: + x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') + edges = [cv2.Canny(x[i].mul(255).permute(1, 2, 0).cpu().numpy().astype(np.uint8), 100, 200) for i in range(len(x))] + edges = torch.stack([torch.tensor(e).div(255).unsqueeze(0) for e in edges], dim=0) + if self.resize is not None: + edges = nn.functional.interpolate(edges, size=orig_size, mode='bilinear') + return edges + + +class QRFilter(BaseFilter): + def __init__(self, device, resize=224, blobify=True, dilation_kernels=[3, 5, 7], blur_kernels=[15]): + super().__init__(device) + self.resize = resize + self.blobify = blobify + self.dilation_kernels = dilation_kernels + self.blur_kernels = blur_kernels + + def num_channels(self): + return 1 + + def __call__(self, x): + x = x.to(self.device) + orig_size = x.shape[-2:] + if self.resize is not None: + x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') + + x = kornia.color.rgb_to_hsv(x)[:, -1:] + # blobify + if self.blobify: + d_kernel = np.random.choice(self.dilation_kernels) + d_blur = np.random.choice(self.blur_kernels) + if d_blur > 0: + x = torchvision.transforms.GaussianBlur(d_blur)(x) + if d_kernel > 0: + blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, + d_kernel).pow(2)[:, + None]) < 0.3).float().to(self.device) + x = kornia.morphology.dilation(x, blob_mask) + x = kornia.morphology.erosion(x, blob_mask) + # mask + vmax, vmin = x.amax(dim=[2, 3], keepdim=True)[0], x.amin(dim=[2, 3], keepdim=True)[0] + th = (vmax - vmin) * 0.33 + high_brightness, low_brightness = (x > (vmax - th)).float(), (x < (vmin + th)).float() + mask = (torch.ones_like(x) - low_brightness + high_brightness) * 0.5 + + if self.resize is not None: + mask = nn.functional.interpolate(mask, size=orig_size, mode='bilinear') + return mask.cpu() + + +class PidiFilter(BaseFilter): + def __init__(self, device, resize=224, dilation_kernels=[0, 3, 5, 7, 9], binarize=True): + super().__init__(device) + self.resize = resize + self.model = PidiNetDetector(device) + self.dilation_kernels = dilation_kernels + self.binarize = binarize + + def num_channels(self): + return 1 + + def __call__(self, x): + x = x.to(self.device) + orig_size = x.shape[-2:] + if self.resize is not None: + x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') + + x = self.model(x) + d_kernel = np.random.choice(self.dilation_kernels) + if d_kernel > 0: + blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, d_kernel).pow(2)[ + :, None]) < 0.3).float().to(self.device) + x = kornia.morphology.dilation(x, blob_mask) + if self.binarize: + th = np.random.uniform(0.05, 0.7) + x = (x > th).float() + + if self.resize is not None: + x = nn.functional.interpolate(x, size=orig_size, mode='bilinear') + return x.cpu() + + +class SRFilter(BaseFilter): + def __init__(self, device, scale_factor=1 / 4): + super().__init__(device) + self.scale_factor = scale_factor + + def num_channels(self): + return 3 + + def __call__(self, x): + x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest") + return torch.nn.functional.interpolate(x, scale_factor=1 / self.scale_factor, mode="nearest") + + +class SREffnetFilter(BaseFilter): + def __init__(self, device, scale_factor=1/2): + super().__init__(device) + self.scale_factor = scale_factor + + self.effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + self.effnet = EfficientNetEncoder().to(self.device) + effnet_checkpoint = load_or_fail("models/effnet_encoder.safetensors") + self.effnet.load_state_dict(effnet_checkpoint) + self.effnet.eval().requires_grad_(False) + + def num_channels(self): + return 16 + + def __call__(self, x): + x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest") + with torch.no_grad(): + effnet_embedding = self.effnet(self.effnet_preprocess(x.to(self.device))).cpu() + effnet_embedding = torch.nn.functional.interpolate(effnet_embedding, scale_factor=1/self.scale_factor, mode="nearest") + upscaled_image = torch.nn.functional.interpolate(x, scale_factor=1/self.scale_factor, mode="nearest") + return effnet_embedding, upscaled_image + + +class InpaintFilter(BaseFilter): + def __init__(self, device, thresold=[0.04, 0.4], p_outpaint=0.4): + super().__init__(device) + self.saliency_model = MicroResNet().eval().requires_grad_(False).to(device) + self.saliency_model.load_state_dict(load_or_fail("modules/cnet_modules/inpainting/saliency_model.pt")) + self.thresold = thresold + self.p_outpaint = p_outpaint + + def num_channels(self): + return 4 + + def __call__(self, x, mask=None, threshold=None, outpaint=None): + x = x.to(self.device) + resized_x = torchvision.transforms.functional.resize(x, 240, antialias=True) + if threshold is None: + threshold = np.random.uniform(self.thresold[0], self.thresold[1]) + if mask is None: + saliency_map = self.saliency_model(resized_x) > threshold + if outpaint is None: + if np.random.rand() < self.p_outpaint: + saliency_map = ~saliency_map + else: + if outpaint: + saliency_map = ~saliency_map + interpolated_saliency_map = torch.nn.functional.interpolate(saliency_map.float(), size=x.shape[2:], mode="nearest") + saliency_map = torchvision.transforms.functional.gaussian_blur(interpolated_saliency_map, 141) > 0.5 + inpainted_images = torch.where(saliency_map, torch.ones_like(x), x) + mask = torch.nn.functional.interpolate(saliency_map.float(), size=inpainted_images.shape[2:], mode="nearest") + else: + mask = mask.to(self.device) + inpainted_images = torch.where(mask, torch.ones_like(x), x) + c_inpaint = torch.cat([inpainted_images, mask], dim=1) + return c_inpaint.cpu() + + +# IDENTITY +''' +class IdentityFilter(BaseFilter): + def __init__(self, device, max_faces=4, p_drop=0.05, p_full=0.3): + detector_path = 'modules/cnet_modules/face_id/models/buffalo_l/det_10g.onnx' + recognizer_path = 'modules/cnet_modules/face_id/models/buffalo_l/w600k_r50.onnx' + + super().__init__(device) + self.max_faces = max_faces + self.p_drop = p_drop + self.p_full = p_full + + self.detector = FaceDetector(detector_path, device=device) + self.recognizer = ArcFaceRecognizer(recognizer_path, device=device) + + self.id_colors = torch.tensor([ + [1.0, 0.0, 0.0], # RED + [0.0, 1.0, 0.0], # GREEN + [0.0, 0.0, 1.0], # BLUE + [1.0, 0.0, 1.0], # PURPLE + [0.0, 1.0, 1.0], # CYAN + [1.0, 1.0, 0.0], # YELLOW + [0.5, 0.0, 0.0], # DARK RED + [0.0, 0.5, 0.0], # DARK GREEN + [0.0, 0.0, 0.5], # DARK BLUE + [0.5, 0.0, 0.5], # DARK PURPLE + [0.0, 0.5, 0.5], # DARK CYAN + [0.5, 0.5, 0.0], # DARK YELLOW + ]) + + def num_channels(self): + return 512 + + def get_faces(self, image): + npimg = image.permute(1, 2, 0).mul(255).to(device="cpu", dtype=torch.uint8).cpu().numpy() + bgr = cv2.cvtColor(npimg, cv2.COLOR_RGB2BGR) + bboxes, kpss = self.detector.detect(bgr, max_num=self.max_faces) + N = len(bboxes) + ids = torch.zeros((N, 512), dtype=torch.float32) + for i in range(N): + face = Face(bbox=bboxes[i, :4], kps=kpss[i], det_score=bboxes[i, 4]) + ids[i, :] = self.recognizer.get(bgr, face) + tbboxes = torch.tensor(bboxes[:, :4], dtype=torch.int) + + ids = ids / torch.linalg.norm(ids, dim=1, keepdim=True) + return tbboxes, ids # returns bounding boxes (N x 4) and ID vectors (N x 512) + + def __call__(self, x): + visual_aid = x.clone().cpu() + face_mtx = torch.zeros(x.size(0), 512, x.size(-2) // 32, x.size(-1) // 32) + + for i in range(x.size(0)): + bounding_boxes, ids = self.get_faces(x[i]) + for j in range(bounding_boxes.size(0)): + if np.random.rand() > self.p_drop: + sx, sy, ex, ey = (bounding_boxes[j] / 32).clamp(min=0).round().int().tolist() + ex, ey = max(ex, sx + 1), max(ey, sy + 1) + if bounding_boxes.size(0) == 1 and np.random.rand() < self.p_full: + sx, sy, ex, ey = 0, 0, x.size(-1) // 32, x.size(-2) // 32 + face_mtx[i, :, sy:ey, sx:ex] = ids[j:j + 1, :, None, None] + visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] += self.id_colors[j % 13, :, + None, None] + visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] *= 0.5 + + return face_mtx.to(x.device), visual_aid.to(x.device) +''' diff --git a/modules/effnet.py b/modules/effnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb2690c2547c8c7553aec8a9f9e838241f8f61c --- /dev/null +++ b/modules/effnet.py @@ -0,0 +1,17 @@ +import torchvision +from torch import nn + + +# EfficientNet +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16): + super().__init__() + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 + ) + + def forward(self, x): + return self.mapper(self.backbone(x)) + diff --git a/modules/inr_fea_res_lite.py b/modules/inr_fea_res_lite.py new file mode 100644 index 0000000000000000000000000000000000000000..f44c38ddd6c590f8c19cd14449b426e436460b3b --- /dev/null +++ b/modules/inr_fea_res_lite.py @@ -0,0 +1,435 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import einops +import numpy as np +import models +from modules.common_ckpt import Linear, Conv2d, AttnBlock, ResBlock, LayerNorm2d +#from modules.common_ckpt import AttnBlock, +from einops import rearrange +import torch.fft as fft +from modules.speed_util import checkpoint +def batched_linear_mm(x, wb): + # x: (B, N, D1); wb: (B, D1 + 1, D2) or (D1 + 1, D2) + one = torch.ones(*x.shape[:-1], 1, device=x.device) + return torch.matmul(torch.cat([x, one], dim=-1), wb) +def make_coord_grid(shape, range, device=None): + """ + Args: + shape: tuple + range: [minv, maxv] or [[minv_1, maxv_1], ..., [minv_d, maxv_d]] for each dim + Returns: + grid: shape (*shape, ) + """ + l_lst = [] + for i, s in enumerate(shape): + l = (0.5 + torch.arange(s, device=device)) / s + if isinstance(range[0], list) or isinstance(range[0], tuple): + minv, maxv = range[i] + else: + minv, maxv = range + l = minv + (maxv - minv) * l + l_lst.append(l) + grid = torch.meshgrid(*l_lst, indexing='ij') + grid = torch.stack(grid, dim=-1) + return grid +def init_wb(shape): + weight = torch.empty(shape[1], shape[0] - 1) + nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + + bias = torch.empty(shape[1], 1) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(bias, -bound, bound) + + return torch.cat([weight, bias], dim=1).t().detach() + +def init_wb_rewrite(shape): + weight = torch.empty(shape[1], shape[0] - 1) + + torch.nn.init.xavier_uniform_(weight) + + bias = torch.empty(shape[1], 1) + torch.nn.init.xavier_uniform_(bias) + + + return torch.cat([weight, bias], dim=1).t().detach() +class HypoMlp(nn.Module): + + def __init__(self, depth, in_dim, out_dim, hidden_dim, use_pe, pe_dim, out_bias=0, pe_sigma=1024): + super().__init__() + self.use_pe = use_pe + self.pe_dim = pe_dim + self.pe_sigma = pe_sigma + self.depth = depth + self.param_shapes = dict() + if use_pe: + last_dim = in_dim * pe_dim + else: + last_dim = in_dim + for i in range(depth): # for each layer the weight + cur_dim = hidden_dim if i < depth - 1 else out_dim + self.param_shapes[f'wb{i}'] = (last_dim + 1, cur_dim) + last_dim = cur_dim + self.relu = nn.ReLU() + self.params = None + self.out_bias = out_bias + + def set_params(self, params): + self.params = params + + def convert_posenc(self, x): + w = torch.exp(torch.linspace(0, np.log(self.pe_sigma), self.pe_dim // 2, device=x.device)) + x = torch.matmul(x.unsqueeze(-1), w.unsqueeze(0)).view(*x.shape[:-1], -1) + x = torch.cat([torch.cos(np.pi * x), torch.sin(np.pi * x)], dim=-1) + return x + + def forward(self, x): + B, query_shape = x.shape[0], x.shape[1: -1] + x = x.view(B, -1, x.shape[-1]) + if self.use_pe: + x = self.convert_posenc(x) + #print('in line 79 after pos embedding', x.shape) + for i in range(self.depth): + x = batched_linear_mm(x, self.params[f'wb{i}']) + if i < self.depth - 1: + x = self.relu(x) + else: + x = x + self.out_bias + x = x.view(B, *query_shape, -1) + return x + + + +class Attention(nn.Module): + + def __init__(self, dim, n_head, head_dim, dropout=0.): + super().__init__() + self.n_head = n_head + inner_dim = n_head * head_dim + self.to_q = nn.Sequential( + nn.SiLU(), + Linear(dim, inner_dim )) + self.to_kv = nn.Sequential( + nn.SiLU(), + Linear(dim, inner_dim * 2)) + self.scale = head_dim ** -0.5 + # self.to_out = nn.Sequential( + # Linear(inner_dim, dim), + # nn.Dropout(dropout), + # ) + + def forward(self, fr, to=None): + if to is None: + to = fr + q = self.to_q(fr) + k, v = self.to_kv(to).chunk(2, dim=-1) + q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), [q, k, v]) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = F.softmax(dots, dim=-1) # b h n n + out = torch.matmul(attn, v) + out = einops.rearrange(out, 'b h n d -> b n (h d)') + return out + + +class FeedForward(nn.Module): + + def __init__(self, dim, ff_dim, dropout=0.): + super().__init__() + + self.net = nn.Sequential( + Linear(dim, ff_dim), + nn.GELU(), + #GlobalResponseNorm(ff_dim), + nn.Dropout(dropout), + Linear(ff_dim, dim) + ) + + def forward(self, x): + return self.net(x) + + +class PreNorm(nn.Module): + + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + + +#TransInr(ind=2048, ch=256, n_head=16, head_dim=16, n_groups=64, f_dim=256, time_dim=self.c_r, t_conds = []) +class TransformerEncoder(nn.Module): + + def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.): + super().__init__() + self.layers = nn.ModuleList() + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, n_head, head_dim, dropout=dropout)), + PreNorm(dim, FeedForward(dim, ff_dim, dropout=dropout)), + ])) + + def forward(self, x): + for norm_attn, norm_ff in self.layers: + x = x + norm_attn(x) + x = x + norm_ff(x) + return x +class ImgrecTokenizer(nn.Module): + + def __init__(self, input_size=32*32, patch_size=1, dim=768, padding=0, img_channels=16): + super().__init__() + + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + if isinstance(padding, int): + padding = (padding, padding) + self.patch_size = patch_size + self.padding = padding + self.prefc = nn.Linear(patch_size[0] * patch_size[1] * img_channels, dim) + + self.posemb = nn.Parameter(torch.randn(input_size, dim)) + + def forward(self, x): + #print(x.shape) + p = self.patch_size + x = F.unfold(x, p, stride=p, padding=self.padding) # (B, C * p * p, L) + #print('in line 185 after unfoding', x.shape) + x = x.permute(0, 2, 1).contiguous() + ttt = self.prefc(x) + + x = self.prefc(x) + self.posemb[:x.shape[1]].unsqueeze(0) + return x + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + self.conv1 = Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.sigmoid(x) + +class TimestepBlock_res(nn.Module): + def __init__(self, c, c_timestep, conds=['sca']): + super().__init__() + + self.mapper = Linear(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) + + + + + def forward(self, x, t): + #print(x.shape, t.shape, self.conds, 'in line 269') + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + + +class ScaleNormalize_res(nn.Module): + def __init__(self, c, scale_c, conds=['sca']): + super().__init__() + self.c_r = scale_c + self.mapping = TimestepBlock_res(c, scale_c, conds=conds) + self.t_conds = conds + self.alpha = nn.Conv2d(c, c, kernel_size=1) + self.gamma = nn.Conv2d(c, c, kernel_size=1) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + def forward(self, x, std_size=24*24): + scale_val = math.sqrt(math.log(x.shape[-2] * x.shape[-1], std_size)) + scale_val = torch.ones(x.shape[0]).to(x.device)*scale_val + scale_val_f = self.gen_r_embedding(scale_val) + for c in self.t_conds: + t_cond = torch.zeros_like(scale_val) + scale_val_f = torch.cat([scale_val_f, self.gen_r_embedding(t_cond)], dim=1) + + f = self.mapping(x, scale_val_f) + + return f + x + + +class TransInr_withnorm(nn.Module): + + def __init__(self, ind=2048, ch=16, n_head=12, head_dim=64, n_groups=64, f_dim=768, time_dim=2048, t_conds=[]): + super().__init__() + self.input_layer= nn.Conv2d(ind, ch, 1) + self.tokenizer = ImgrecTokenizer(dim=ch, img_channels=ch) + #self.hyponet = HypoMlp(depth=12, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128) + #self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=12, n_head=n_head, head_dim=f_dim // n_head, ff_dim=3*f_dim, ) + + self.hyponet = HypoMlp(depth=2, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128) + self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=1, n_head=n_head, head_dim=f_dim // n_head, ff_dim=f_dim) + #self.transformer_encoder = TransInr( ch=ch, n_head=16, head_dim=16, n_groups=64, f_dim=ch, time_dim=time_dim, t_conds = []) + self.base_params = nn.ParameterDict() + n_wtokens = 0 + self.wtoken_postfc = nn.ModuleDict() + self.wtoken_rng = dict() + for name, shape in self.hyponet.param_shapes.items(): + self.base_params[name] = nn.Parameter(init_wb(shape)) + g = min(n_groups, shape[1]) + assert shape[1] % g == 0 + self.wtoken_postfc[name] = nn.Sequential( + nn.LayerNorm(f_dim), + nn.Linear(f_dim, shape[0] - 1), + ) + self.wtoken_rng[name] = (n_wtokens, n_wtokens + g) + n_wtokens += g + self.wtokens = nn.Parameter(torch.randn(n_wtokens, f_dim)) + self.output_layer= nn.Conv2d(ch, ind, 1) + + + self.mapp_t = TimestepBlock_res( ind, time_dim, conds = t_conds) + + + self.hr_norm = ScaleNormalize_res(ind, 64, conds=[]) + + self.normalize_final = nn.Sequential( + LayerNorm2d(ind, elementwise_affine=False, eps=1e-6), + ) + + self.toout = nn.Sequential( + Linear( ind*2, ind // 4), + nn.GELU(), + Linear( ind // 4, ind) + ) + self.apply(self._init_weights) + + mask = torch.zeros((1, 1, 32, 32)) + h, w = 32, 32 + center_h, center_w = h // 2, w // 2 + low_freq_h, low_freq_w = h // 4, w // 4 + mask[:, :, center_h-low_freq_h:center_h+low_freq_h, center_w-low_freq_w:center_w+low_freq_w] = 1 + self.mask = mask + + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + #nn.init.constant_(self.last.weight, 0) + def adain(self, feature_a, feature_b): + norm_mean = torch.mean(feature_a, dim=(2, 3), keepdim=True) + norm_std = torch.std(feature_a, dim=(2, 3), keepdim=True) + #feature_a = F.interpolate(feature_a, feature_b.shape[2:]) + feature_b = (feature_b - feature_b.mean(dim=(2, 3), keepdim=True)) / (1e-8 + feature_b.std(dim=(2, 3), keepdim=True)) * norm_std + norm_mean + return feature_b + def forward(self, target_shape, target, dtokens, t_emb): + #print(target.shape, dtokens.shape, 'in line 290') + hlr, wlr = dtokens.shape[2:] + original = dtokens + + dtokens = self.input_layer(dtokens) + dtokens = self.tokenizer(dtokens) + B = dtokens.shape[0] + wtokens = einops.repeat(self.wtokens, 'n d -> b n d', b=B) + #print(wtokens.shape, dtokens.shape) + trans_out = self.transformer_encoder(torch.cat([dtokens, wtokens], dim=1)) + trans_out = trans_out[:, -len(self.wtokens):, :] + + params = dict() + for name, shape in self.hyponet.param_shapes.items(): + wb = einops.repeat(self.base_params[name], 'n m -> b n m', b=B) + w, b = wb[:, :-1, :], wb[:, -1:, :] + + l, r = self.wtoken_rng[name] + x = self.wtoken_postfc[name](trans_out[:, l: r, :]) + x = x.transpose(-1, -2) # (B, shape[0] - 1, g) + w = F.normalize(w * x.repeat(1, 1, w.shape[2] // x.shape[2]), dim=1) + + wb = torch.cat([w, b], dim=1) + params[name] = wb + coord = make_coord_grid(target_shape[2:], (-1, 1), device=dtokens.device) + coord = einops.repeat(coord, 'h w d -> b h w d', b=dtokens.shape[0]) + self.hyponet.set_params(params) + ori_up = F.interpolate(original.float(), target_shape[2:]) + hr_rec = self.output_layer(rearrange(self.hyponet(coord), 'b h w c -> b c h w')) + ori_up + #print(hr_rec.shape, target.shape, torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1).shape, 'in line 537') + + output = self.toout(torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + #print(output.shape, 'in line 540') + #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)* 0.3 + output = self.mapp_t(output, t_emb) + output = self.normalize_final(output) + output = self.hr_norm(output) + #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + #output = self.mapp_t(output, t_emb) + #output = self.weight(output) * output + + return output + + + + + + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + + +if __name__ == '__main__': + #ef __init__(self, ch, n_head, head_dim, n_groups): + trans_inr = TransInr(16, 24, 32, 64).cuda() + input = torch.randn((1, 16, 24, 24)).cuda() + source = torch.randn((1, 16, 16, 16)).cuda() + t = torch.randn((1, 128)).cuda() + output, hr = trans_inr(input, t, source) + + total_up = sum([ param.nelement() for param in trans_inr.parameters()]) + print(output.shape, hr.shape, total_up /1e6 ) + diff --git a/modules/lora.py b/modules/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0a2bd797f3669a465f6c2c4255b52fe1bda7a7 --- /dev/null +++ b/modules/lora.py @@ -0,0 +1,71 @@ +import torch +from torch import nn + + +class LoRA(nn.Module): + def __init__(self, layer, name='weight', rank=16, alpha=1): + super().__init__() + weight = getattr(layer, name) + self.lora_down = nn.Parameter(torch.zeros((rank, weight.size(1)))) + self.lora_up = nn.Parameter(torch.zeros((weight.size(0), rank))) + nn.init.normal_(self.lora_up, mean=0, std=1) + + self.scale = alpha / rank + self.enabled = True + + def forward(self, original_weights): + if self.enabled: + lora_shape = list(original_weights.shape[:2]) + [1] * (len(original_weights.shape) - 2) + lora_weights = torch.matmul(self.lora_up.clone(), self.lora_down.clone()).view(*lora_shape) * self.scale + return original_weights + lora_weights + else: + return original_weights + + +def apply_lora(model, filters=None, rank=16): + def check_parameter(module, name): + return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( + getattr(module, name), nn.Parameter) + + for name, module in model.named_modules(): + if filters is None or any([f in name for f in filters]): + if check_parameter(module, "weight"): + device, dtype = module.weight.device, module.weight.dtype + torch.nn.utils.parametrize.register_parametrization(module, 'weight', LoRA(module, "weight", rank=rank).to(dtype).to(device)) + elif check_parameter(module, "in_proj_weight"): + device, dtype = module.in_proj_weight.device, module.in_proj_weight.dtype + torch.nn.utils.parametrize.register_parametrization(module, 'in_proj_weight', LoRA(module, "in_proj_weight", rank=rank).to(dtype).to(device)) + + +class ReToken(nn.Module): + def __init__(self, indices=None): + super().__init__() + assert indices is not None + self.embeddings = nn.Parameter(torch.zeros(len(indices), 1280)) + self.register_buffer('indices', torch.tensor(indices)) + self.enabled = True + + def forward(self, embeddings): + if self.enabled: + embeddings = embeddings.clone() + for i, idx in enumerate(self.indices): + embeddings[idx] += self.embeddings[i] + return embeddings + + +def apply_retoken(module, indices=None): + def check_parameter(module, name): + return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( + getattr(module, name), nn.Parameter) + + if check_parameter(module, "weight"): + device, dtype = module.weight.device, module.weight.dtype + torch.nn.utils.parametrize.register_parametrization(module, 'weight', ReToken(indices=indices).to(dtype).to(device)) + + +def remove_lora(model, leave_parametrized=True): + for module in model.modules(): + if torch.nn.utils.parametrize.is_parametrized(module, "weight"): + nn.utils.parametrize.remove_parametrizations(module, "weight", leave_parametrized=leave_parametrized) + elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): + nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight", leave_parametrized=leave_parametrized) diff --git a/modules/model_4stage_lite.py b/modules/model_4stage_lite.py new file mode 100644 index 0000000000000000000000000000000000000000..702a1f39c6719681a312f04a2402b3f4ac04f7ce --- /dev/null +++ b/modules/model_4stage_lite.py @@ -0,0 +1,458 @@ +import torch +from torch import nn +import numpy as np +import math +from modules.common_ckpt import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock +from .controlnet import ControlNetDeliverer +import torch.nn.functional as F +from modules.inr_fea_res_lite import TransInr_withnorm as TransInr +from modules.inr_fea_res_lite import ScaleNormalize_res +from einops import rearrange +import torch.fft as fft +import random +class UpDownBlock2d(nn.Module): + def __init__(self, c_in, c_out, mode, enabled=True): + super().__init__() + assert mode in ['up', 'down'] + interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', + align_corners=True) if enabled else nn.Identity() + mapping = nn.Conv2d(c_in, c_out, kernel_size=1) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x.float()) + return x +def ada_in(a, b): + mean_a = torch.mean(a, dim=(2, 3), keepdim=True) + std_a = torch.std(a, dim=(2, 3), keepdim=True) + + mean_b = torch.mean(b, dim=(2, 3), keepdim=True) + std_b = torch.std(b, dim=(2, 3), keepdim=True) + + return (b - mean_b) / (1e-8 + std_b) * std_a + mean_a +def feature_dist_loss(x1, x2): + mu1 = torch.mean(x1, dim=(2, 3)) + mu2 = torch.mean(x2, dim=(2, 3)) + + std1 = torch.std(x1, dim=(2, 3)) + std2 = torch.std(x2, dim=(2, 3)) + std_loss = torch.mean(torch.abs(torch.log(std1+ 1e-8) - torch.log(std2+ 1e-8))) + mean_loss = torch.mean(torch.abs(mu1 - mu2)) + #print('in line 36', std_loss, mean_loss) + return std_loss + mean_loss*0.1 +class StageC(nn.Module): + def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], + blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], + c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, + dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], + lr_h=24, lr_w=24): + super().__init__() + + self.lr_h, self.lr_w = lr_h, lr_w + self.block_repeat = block_repeat + self.c_in = c_in + self.c_cond = c_cond + self.patch_size = patch_size + self.c_hidden = c_hidden + self.nhead = nhead + self.blocks = blocks + self.level_config = level_config + self.kernel_size = kernel_size + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + self.self_attn = self_attn + self.dropout = dropout + self.switch_level = switch_level + # CONDITIONING + self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond) + self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq) + self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq) + self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1]) + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + + + #extra down blocks + + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1]) + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) # General init + nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + + def _init_extra_parameter(self): + + + + self.agg_net = nn.ModuleList() + for _ in range(2): + + self.agg_net.append(TransInr(ind=2048, ch=1024, n_head=32, head_dim=32, n_groups=64, f_dim=1024, time_dim=self.c_r, t_conds = [])) # + + self.agg_net_up = nn.ModuleList() + for _ in range(2): + + self.agg_net_up.append(TransInr(ind=2048, ch=1024, n_head=32, head_dim=32, n_groups=64, f_dim=1024, time_dim=self.c_r, t_conds = [])) # + + + + + + self.norm_down_blocks = nn.ModuleList() + for i in range(len(self.c_hidden)): + + up_blocks = nn.ModuleList() + for j in range(self.blocks[0][i]): + if j % 4 == 0: + up_blocks.append( + ScaleNormalize_res(self.c_hidden[0], self.c_r, conds=[])) + self.norm_down_blocks.append(up_blocks) + + + self.norm_up_blocks = nn.ModuleList() + for i in reversed(range(len(self.c_hidden))): + + up_block = nn.ModuleList() + for j in range(self.blocks[1][::-1][i]): + if j % 4 == 0: + up_block.append(ScaleNormalize_res(self.c_hidden[0], self.c_r, conds=[])) + self.norm_up_blocks.append(up_block) + + + + + self.agg_net.apply(self._init_weights) + self.agg_net_up.apply(self._init_weights) + self.norm_up_blocks.apply(self._init_weights) + self.norm_down_blocks.apply(self._init_weights) + for block in self.agg_net + self.agg_net_up: + #for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + + + + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_txt_pooled.shape) == 2: + clip_txt_pool = clip_txt_pooled.unsqueeze(1) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) + clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip, cnet=None, require_q=False, lr_guide=None, r_emb_lite=None, guide_weight=1): + level_outputs = [] + if require_q: + qs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for stage_cnt, (down_block, downscaler, repmap) in enumerate(block_group): + x = downscaler(x) + for i in range(len(repmap) + 1): + for inner_cnt, block in enumerate(down_block): + + + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + if cnet is not None and lr_guide is None: + #if cnet is not None : + next_cnet = cnet() + if next_cnet is not None: + + x = x + nn.functional.interpolate(next_cnet.float(), size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + + x = block(x, clip) + if require_q and (inner_cnt == 2 ): + qs.append(x.clone()) + if lr_guide is not None and (inner_cnt == 2 ) : + + guide = self.agg_net[stage_cnt](x.shape, x, lr_guide[stage_cnt], r_emb_lite) + x = x + guide + + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) # 0 indicate last output + if require_q: + return level_outputs, qs + return level_outputs + + + def _up_decode(self, level_outputs, r_embed, clip, cnet=None, require_ff=False, agg_f=None, r_emb_lite=None, guide_weight=1): + if require_ff: + agg_feas = [] + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + + + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', + align_corners=True) + + if cnet is not None and agg_f is None: + next_cnet = cnet() + if next_cnet is not None: + + x = x + nn.functional.interpolate(next_cnet.float(), size=x.shape[-2:], mode='bilinear', + align_corners=True) + + + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + + + x = block(x, clip) + if require_ff and (k == 2 ): + agg_feas.append(x.clone()) + if agg_f is not None and (k == 2 ) : + + guide = self.agg_net_up[i](x.shape, x, agg_f[i], r_emb_lite) # training 1 test 4k 0.8 2k 0.7 + if not self.training: + hw = x.shape[-2] * x.shape[-1] + if hw >= 96*96: + guide = 0.7*guide + + else: + + if hw >= 72*72: + guide = 0.5* guide + else: + + guide = 0.3* guide + + x = x + guide + + + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + #if require_ff: + # agg_feas.append(x.clone()) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + + + if require_ff: + return x, agg_feas + + return x + + + + + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, lr_guide=None, reuire_f=False, cnet=None, require_t=False, guide_weight=0.5, **kwargs): + + r_embed = self.gen_r_embedding(r) + + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) + clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + + # Model Blocks + + x = self.embedding(x) + + + + if cnet is not None: + cnet = ControlNetDeliverer(cnet) + + if not reuire_f: + level_outputs = self._down_encode(x, r_embed, clip, cnet, lr_guide= lr_guide[0] if lr_guide is not None else None, \ + require_q=reuire_f, r_emb_lite=self.gen_r_embedding(r), guide_weight=guide_weight) + x = self._up_decode(level_outputs, r_embed, clip, cnet, agg_f=lr_guide[1] if lr_guide is not None else None, \ + require_ff=reuire_f, r_emb_lite=self.gen_r_embedding(r), guide_weight=guide_weight) + else: + level_outputs, lr_enc = self._down_encode(x, r_embed, clip, cnet, lr_guide= lr_guide[0] if lr_guide is not None else None, require_q=True) + x, lr_dec = self._up_decode(level_outputs, r_embed, clip, cnet, agg_f=lr_guide[1] if lr_guide is not None else None, require_ff=True) + + if reuire_f and require_t: + return self.clf(x), r_embed, lr_enc, lr_dec + if reuire_f: + return self.clf(x), lr_enc, lr_dec + if require_t: + return self.clf(x), r_embed + return self.clf(x) + + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) + + + +if __name__ == '__main__': + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + total_ori = sum([ param.nelement() for param in generator.parameters()]) + generator._init_extra_parameter() + generator = generator.cuda() + total = sum([ param.nelement() for param in generator.parameters()]) + total_down = sum([ param.nelement() for param in generator.down_blocks.parameters()]) + + total_up = sum([ param.nelement() for param in generator.up_blocks.parameters()]) + total_pro = sum([ param.nelement() for param in generator.project.parameters()]) + + + print(total_ori / 1e6, total / 1e6, total_up / 1e6, total_down / 1e6, total_pro / 1e6) + + # for name, module in generator.down_blocks.named_modules(): + # print(name, module) + output, out_lr = generator( + x=torch.randn(1, 16, 24, 24).cuda(), + x_lr=torch.randn(1, 16, 16, 16).cuda(), + r=torch.tensor([0.7056]).cuda(), + clip_text=torch.randn(1, 77, 1280).cuda(), + clip_text_pooled = torch.randn(1, 1, 1280).cuda(), + clip_img = torch.randn(1, 1, 768).cuda() + ) + print(output.shape, out_lr.shape) + # cnt diff --git a/modules/previewer.py b/modules/previewer.py new file mode 100644 index 0000000000000000000000000000000000000000..51ab24292d8ac0da8d24b17d8fc0ac9e1419a3d7 --- /dev/null +++ b/modules/previewer.py @@ -0,0 +1,45 @@ +from torch import nn + + +# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 +class Previewer(nn.Module): + def __init__(self, c_in=16, c_hidden=512, c_out=3): + super().__init__() + self.blocks = nn.Sequential( + nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ) + + def forward(self, x): + return self.blocks(x) diff --git a/modules/resnet.py b/modules/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c3de556733f231815a57dc1683a1cfd1f1ab46b5 --- /dev/null +++ b/modules/resnet.py @@ -0,0 +1,415 @@ +import torch +from torch import nn +import torch.nn.functional as F +#import fvcore.nn.weight_init as weight_init + +""" +Functions for building the BottleneckBlock from Detectron2. +# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/resnet.py +""" + +def get_norm(norm, out_channels, num_norm_groups=32): + """ + Args: + norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; + or a callable that takes a channel number and returns + the normalization layer as a nn.Module. + Returns: + nn.Module or None: the normalization layer + """ + if norm is None: + return None + if isinstance(norm, str): + if len(norm) == 0: + return None + norm = { + "GN": lambda channels: nn.GroupNorm(num_norm_groups, channels), + }[norm] + return norm(out_channels) + +class Conv2d(nn.Conv2d): + """ + A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. + """ + + def __init__(self, *args, **kwargs): + """ + Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: + Args: + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + It assumes that norm layer is used before activation. + """ + norm = kwargs.pop("norm", None) + activation = kwargs.pop("activation", None) + super().__init__(*args, **kwargs) + + self.norm = norm + self.activation = activation + + def forward(self, x): + x = F.conv2d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + +class CNNBlockBase(nn.Module): + """ + A CNN block is assumed to have input channels, output channels and a stride. + The input and output of `forward()` method must be NCHW tensors. + The method can perform arbitrary computation but must match the given + channels and stride specification. + Attribute: + in_channels (int): + out_channels (int): + stride (int): + """ + + def __init__(self, in_channels, out_channels, stride): + """ + The `__init__` method of any subclass should also contain these arguments. + Args: + in_channels (int): + out_channels (int): + stride (int): + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + +class BottleneckBlock(CNNBlockBase): + """ + The standard bottleneck residual block used by ResNet-50, 101 and 152 + defined in :paper:`ResNet`. It contains 3 conv layers with kernels + 1x1, 3x3, 1x1, and a projection shortcut if needed. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="GN", + stride_in_1x1=False, + dilation=1, + num_norm_groups=32 + ): + """ + Args: + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + num_groups (int): number of groups for the 3x3 conv layer. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + stride_in_1x1 (bool): when stride>1, whether to put stride in the + first 1x1 convolution or the bottleneck 3x3 convolution. + dilation (int): the dilation rate of the 3x3 conv layer. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels, num_norm_groups), + ) + else: + self.shortcut = None + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels, num_norm_groups), + ) + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + norm=get_norm(norm, bottleneck_channels, num_norm_groups), + ) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels, num_norm_groups), + ) + + #for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + # if layer is not None: # shortcut can be None + # weight_init.c2_msra_fill(layer) + + # Zero-initialize the last normalization in each residual branch, + # so that at the beginning, the residual branch starts with zeros, + # and each residual block behaves like an identity. + # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "For BN layers, the learnable scaling coefficient �� is initialized + # to be 1, except for each residual block's last BN + # where �� is initialized to be 0." + + # nn.init.constant_(self.conv3.norm.weight, 0) + # TODO this somehow hurts performance when training GN models from scratch. + # Add it as an option when we need to use this code to train a backbone. + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + out = self.conv2(out) + out = F.relu_(out) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + +class ResNet(nn.Module): + """ + Implement :paper:`ResNet`. + """ + + def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0): + """ + Args: + stem (nn.Module): a stem module + stages (list[list[CNNBlockBase]]): several (typically 4) stages, + each contains multiple :class:`CNNBlockBase`. + num_classes (None or int): if None, will not perform classification. + Otherwise, will create a linear layer. + out_features (list[str]): name of the layers whose outputs should + be returned in forward. Can be anything in "stem", "linear", or "res2" ... + If None, will return the output of the last layer. + freeze_at (int): The number of stages at the beginning to freeze. + see :meth:`freeze` for detailed explanation. + """ + super().__init__() + self.stem = stem + self.num_classes = num_classes + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + + self.stage_names, self.stages = [], [] + + if out_features is not None: + # Avoid keeping unused layers in this module. They consume extra memory + # and may cause allreduce to fail + num_stages = max( + [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features] + ) + stages = stages[:num_stages] + for i, blocks in enumerate(stages): + assert len(blocks) > 0, len(blocks) + for block in blocks: + assert isinstance(block, CNNBlockBase), block + + name = "res" + str(i + 2) + stage = nn.Sequential(*blocks) + + self.add_module(name, stage) + self.stage_names.append(name) + self.stages.append(stage) + + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in blocks]) + ) + self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels + self.stage_names = tuple(self.stage_names) # Make it static for scripting + + if num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(curr_channels, num_classes) + + # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "The 1000-way fully-connected layer is initialized by + # drawing weights from a zero-mean Gaussian with standard deviation of 0.01." + nn.init.normal_(self.linear.weight, std=0.01) + name = "linear" + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {}".format(", ".join(children)) + self.freeze(freeze_at) + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for name, stage in zip(self.stage_names, self.stages): + x = stage(x) + if name in self._out_features: + outputs[name] = x + if self.num_classes is not None: + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.linear(x) + if "linear" in self._out_features: + outputs["linear"] = x + return outputs + + def freeze(self, freeze_at=0): + """ + Freeze the first several stages of the ResNet. Commonly used in + fine-tuning. + Layers that produce the same feature map spatial size are defined as one + "stage" by :paper:`FPN`. + Args: + freeze_at (int): number of stages to freeze. + `1` means freezing the stem. `2` means freezing the stem and + one residual stage, etc. + Returns: + nn.Module: this ResNet itself + """ + if freeze_at >= 1: + self.stem.freeze() + for idx, stage in enumerate(self.stages, start=2): + if freeze_at >= idx: + for block in stage.children(): + block.freeze() + return self + + @staticmethod + def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs): + """ + Create a list of blocks of the same type that forms one ResNet stage. + Args: + block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this + stage. A module of this type must not change spatial resolution of inputs unless its + stride != 1. + num_blocks (int): number of blocks in this stage + in_channels (int): input channels of the entire stage. + out_channels (int): output channels of **every block** in the stage. + kwargs: other arguments passed to the constructor of + `block_class`. If the argument name is "xx_per_block", the + argument is a list of values to be passed to each block in the + stage. Otherwise, the same argument is passed to every block + in the stage. + Returns: + list[CNNBlockBase]: a list of block module. + Examples: + :: + stage = ResNet.make_stage( + BottleneckBlock, 3, in_channels=16, out_channels=64, + bottleneck_channels=16, num_groups=1, + stride_per_block=[2, 1, 1], + dilations_per_block=[1, 1, 2] + ) + Usually, layers that produce the same feature map spatial size are defined as one + "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should + all be 1. + """ + blocks = [] + for i in range(num_blocks): + curr_kwargs = {} + for k, v in kwargs.items(): + if k.endswith("_per_block"): + assert len(v) == num_blocks, ( + f"Argument '{k}' of make_stage should have the " + f"same length as num_blocks={num_blocks}." + ) + newk = k[: -len("_per_block")] + assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!" + curr_kwargs[newk] = v[i] + else: + curr_kwargs[k] = v + + blocks.append( + block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs) + ) + in_channels = out_channels + return blocks + + @staticmethod + def make_default_stages(depth, block_class=None, **kwargs): + """ + Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152). + If it doesn't create the ResNet variant you need, please use :meth:`make_stage` + instead for fine-grained customization. + Args: + depth (int): depth of ResNet + block_class (type): the CNN block class. Has to accept + `bottleneck_channels` argument for depth > 50. + By default it is BasicBlock or BottleneckBlock, based on the + depth. + kwargs: + other arguments to pass to `make_stage`. Should not contain + stride and channels, as they are predefined for each depth. + Returns: + list[list[CNNBlockBase]]: modules in all stages; see arguments of + :class:`ResNet.__init__`. + """ + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + if block_class is None: + block_class = BasicBlock if depth < 50 else BottleneckBlock + if depth < 50: + in_channels = [64, 64, 128, 256] + out_channels = [64, 128, 256, 512] + else: + in_channels = [64, 256, 512, 1024] + out_channels = [256, 512, 1024, 2048] + ret = [] + for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels): + if depth >= 50: + kwargs["bottleneck_channels"] = o // 4 + ret.append( + ResNet.make_stage( + block_class=block_class, + num_blocks=n, + stride_per_block=[s] + [1] * (n - 1), + in_channels=i, + out_channels=o, + **kwargs, + ) + ) + return ret \ No newline at end of file diff --git a/modules/speed_util.py b/modules/speed_util.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe582e7ff805f1b4fc6fb9b8df3eba8c057531e --- /dev/null +++ b/modules/speed_util.py @@ -0,0 +1,55 @@ +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) \ No newline at end of file diff --git a/modules/stage_a.py b/modules/stage_a.py new file mode 100644 index 0000000000000000000000000000000000000000..2840ef71d30e3da74954ab4a05e724fd7fef86cf --- /dev/null +++ b/modules/stage_a.py @@ -0,0 +1,183 @@ +import torch +from torch import nn +from torchtools.nn import VectorQuantize +from einops import rearrange +import torch.nn.functional as F +import math +class ResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), + nn.Conv2d(c, c, kernel_size=3, groups=c) + ) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + + #x = x.to(torch.float64) + x = x + self.depthwise(x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + + +def extract_patches(tensor, patch_size, stride): + b, c, H, W = tensor.shape + pad_h = (patch_size - (H - patch_size) % stride) % stride + pad_w = (patch_size - (W - patch_size) % stride) % stride + tensor = F.pad(tensor, (0, pad_w, 0, pad_h), mode='reflect') + + + patches = tensor.unfold(2, patch_size, stride).unfold(3, patch_size, stride) + patches = patches.contiguous().view(b, c, -1, patch_size, patch_size) + patches = patches.permute(0, 2, 1, 3, 4) + return patches, (H, W) + +def fuse_patches(patches, patch_size, stride, H, W): + + b, num_patches, c, _, _ = patches.shape + patches = patches.permute(0, 2, 1, 3, 4) + + + + pad_h = (patch_size - (H - patch_size) % stride) % stride + pad_w = (patch_size - (W - patch_size) % stride) % stride + out_h = H + pad_h + out_w = W + pad_w + patches = patches.contiguous().view(b, c , -1, patch_size*patch_size ).permute(0, 1, 3, 2) + patches = patches.contiguous().view(b, c*patch_size*patch_size, -1) + + tensor = F.fold(patches, output_size=(out_h, out_w), kernel_size=patch_size, stride=stride) + overlap_cnt = F.fold(torch.ones_like(patches), output_size=(out_h, out_w), kernel_size=patch_size, stride=stride) + tensor = tensor / overlap_cnt + print('end fuse patch', tensor.shape, (tensor.dtype)) + return tensor[:, :, :H, :W] + + + +class StageA(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, + scale_factor=0.43): # 0.3764 + super().__init__() + self.c_latent = c_latent + self.scale_factor = scale_factor + c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(2), + nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = ResBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append(nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + )) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential( + nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + )] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, + padding=1)) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x, quantize=False): + x = self.in_block(x) + x = self.down_blocks(x) + if quantize: + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + else: + return x / self.scale_factor, None, None, None + + + + def decode(self, x, tiled_decoding=False): + x = x * self.scale_factor + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + + +class Discriminator(nn.Module): + def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = c_hidden // (2 ** max((d - i), 0)) + c_out = c_hidden // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x diff --git a/modules/stage_b.py b/modules/stage_b.py new file mode 100644 index 0000000000000000000000000000000000000000..f89b42d61327278820e164b1c093cbf8d1048ee1 --- /dev/null +++ b/modules/stage_b.py @@ -0,0 +1,239 @@ +import math +import numpy as np +import torch +from torch import nn +from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock + + +class StageB(nn.Module): + def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280], + nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], + block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280, + c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.1, 0.1], self_attn=True, + t_conds=['sca']): + super().__init__() + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.effnet_mapper = nn.Sequential( + nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + self.pixels_mapper = nn.Sequential( + nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq) + self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) # General init + nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip): + if len(clip.shape) == 2: + clip = clip.unsqueeze(1) + clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + if pixels is None: + pixels = x.new_zeros(x.size(0), 3, 8, 8) + + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x = self.embedding(x) + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode='bilinear', align_corners=True)) + x = x + nn.functional.interpolate(self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode='bilinear', + align_corners=True) + level_outputs = self._down_encode(x, r_embed, clip) + x = self._up_decode(level_outputs, r_embed, clip) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/modules/stage_c.py b/modules/stage_c.py new file mode 100644 index 0000000000000000000000000000000000000000..53b73d0197712b981ec1a154428c21af2149646a --- /dev/null +++ b/modules/stage_c.py @@ -0,0 +1,252 @@ +import torch +from torch import nn +import numpy as np +import math +from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock +#from .controlnet import ControlNetDeliverer + + +class UpDownBlock2d(nn.Module): + def __init__(self, c_in, c_out, mode, enabled=True): + super().__init__() + assert mode in ['up', 'down'] + interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', + align_corners=True) if enabled else nn.Identity() + mapping = nn.Conv2d(c_in, c_out, kernel_size=1) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x.float()) + return x + + +class StageC(nn.Module): + def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], + blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], + c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, + dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False]): + super().__init__() + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond) + self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq) + self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq) + self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1]) + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1]) + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) # General init + nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_txt_pooled.shape) == 2: + clip_txt_pool = clip_txt_pooled.unsqueeze(1) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) + clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip, cnet=None): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', + align_corners=True) + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) + clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + + # Model Blocks + x = self.embedding(x) + if cnet is not None: + cnet = ControlNetDeliverer(cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet) + x = self._up_decode(level_outputs, r_embed, clip, cnet) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/prompt_list.txt b/prompt_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..27cd31b4750d2f15fdb6f2a3f4bdd117a7377267 --- /dev/null +++ b/prompt_list.txt @@ -0,0 +1,32 @@ +A close-up of a blooming peony, with layers of soft, pink petals, a delicate fragrance, and dewdrops glistening +in the early morning light. + +A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a +clear blue sky. + +A close-up portrait of a young woman with flawless skin, vibrant red lipstick, and wavy brown hair, wearing +a vintage floral dress and standing in front of a blooming garden. + +The image features a snow-covered mountain range with a large, snow-covered mountain in the background. +The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the +winter season, with snow covering the ground and the trees. + +Crocodile in a sweater. + +A vibrant anime scene of a young girl with long, flowing pink hair, big sparkling blue eyes, and a school +uniform, standing under a cherry blossom tree with petals falling around her. The background shows a +traditional Japanese school with cherry blossoms in full bloom. + +A playful Labrador retriever puppy with a shiny, golden coat, chasing a red ball in a spacious backyard, with +green grass and a wooden fence. + +A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm +lights glowing from the windows, and a path of footprints leading to the front door. + +A highly detailed, high-quality image of the Banff National Park in Canada. The turquoise waters of Lake +Louise are surrounded by snow-capped mountains and dense pine forests. A wooden canoe is docked at the +edge of the lake. The sky is a clear, bright blue, and the air is crisp and fresh. + +A highly detailed, high-quality image of a Shih Tzu receiving a bath in a home bathroom. The dog is standing +in a tub, covered in suds, with a slightly wet and adorable look. The background includes bathroom fixtures, +towels, and a clean, tiled floor. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0a2397e70942601b5d37bd0b370cb842e1feb7bb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +--find-links https://download.pytorch.org/whl/torch_stable.html +accelerate>=0.25.0 +torch==2.1.2+cu118 +torchvision==0.16.2+cu118 +transformers>=4.30.0 +numpy>=1.23.5 +kornia>=0.7.0 +insightface>=0.7.3 +opencv-python>=4.8.1.78 +tqdm>=4.66.1 +matplotlib>=3.7.4 +webdataset>=0.2.79 +wandb>=0.16.2 +munch>=4.0.0 +onnxruntime>=1.16.3 +einops>=0.7.0 +onnx2torch>=1.5.13 +warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git +torchtools @ git+https://github.com/pabloppp/pytorch-tools diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1331f6b933f63c99a6bdf074201fdb4b8f78c2 --- /dev/null +++ b/train/__init__.py @@ -0,0 +1,5 @@ +from .train_b import WurstCore as WurstCoreB +from .train_c import WurstCore as WurstCoreC +from .train_t2i import WurstCore as WurstCore_t2i +from .train_ultrapixel_control import WurstCore as WurstCore_control_lrguide +from .train_personalized import WurstCore as WurstCore_personalized \ No newline at end of file diff --git a/train/base.py b/train/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4e8a6ef306e40da8c9d8db33ceba2f8b2982a9a9 --- /dev/null +++ b/train/base.py @@ -0,0 +1,402 @@ +import yaml +import json +import torch +import wandb +import torchvision +import numpy as np +from torch import nn +from tqdm import tqdm +from abc import abstractmethod +from fractions import Fraction +import matplotlib.pyplot as plt +from dataclasses import dataclass +from torch.distributed import barrier +from torch.utils.data import DataLoader + +from gdf import GDF +from gdf import AdaptiveLossWeight + +from core import WarpCore +from core.data import setup_webdataset_path, MultiGetter, MultiFilter, Bucketeer +from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary + +import webdataset as wds +from webdataset.handlers import warn_and_continue + +import transformers +transformers.utils.logging.set_verbosity_error() + + +class DataCore(WarpCore): + @dataclass(frozen=True) + class Config(WarpCore.Config): + image_size: int = EXPECTED_TRAIN + webdataset_path: str = EXPECTED_TRAIN + grad_accum_steps: int = EXPECTED_TRAIN + batch_size: int = EXPECTED_TRAIN + multi_aspect_ratio: list = None + + captions_getter: list = None + dataset_filters: list = None + + bucketeer_random_ratio: float = 0.05 + + @dataclass(frozen=True) + class Extras(WarpCore.Extras): + transforms: torchvision.transforms.Compose = EXPECTED + clip_preprocess: torchvision.transforms.Compose = EXPECTED + + @dataclass(frozen=True) + class Models(WarpCore.Models): + tokenizer: nn.Module = EXPECTED + text_model: nn.Module = EXPECTED + image_model: nn.Module = None + + config: Config + + def webdataset_path(self): + if isinstance(self.config.webdataset_path, str) and (self.config.webdataset_path.strip().startswith( + 'pipe:') or self.config.webdataset_path.strip().startswith('file:')): + return self.config.webdataset_path + else: + dataset_path = self.config.webdataset_path + if isinstance(self.config.webdataset_path, str) and self.config.webdataset_path.strip().endswith('.yml'): + with open(self.config.webdataset_path, 'r', encoding='utf-8') as file: + dataset_path = yaml.safe_load(file) + return setup_webdataset_path(dataset_path, cache_path=f"{self.config.experiment_id}_webdataset_cache.yml") + + def webdataset_preprocessors(self, extras: Extras): + def identity(x): + if isinstance(x, bytes): + x = x.decode('utf-8') + return x + + # CUSTOM CAPTIONS GETTER ----- + def get_caption(oc, c, p_og=0.05): # cog_contexual, cog_caption + if p_og > 0 and np.random.rand() < p_og and len(oc) > 0: + return identity(oc) + else: + return identity(c) + + captions_getter = MultiGetter(rules={ + ('old_caption', 'caption'): lambda oc, c: get_caption(json.loads(oc)['og_caption'], c, p_og=0.05) + }) + + return [ + ('jpg;png', + torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None else extras.transforms, + 'images'), + ('txt', identity, 'captions') if self.config.captions_getter is None else ( + self.config.captions_getter[0], eval(self.config.captions_getter[1]), 'captions'), + ] + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.webdataset_path() + preprocessors = self.webdataset_preprocessors(extras) + + handler = warn_and_continue + dataset = wds.WebDataset( + dataset_path, resampled=True, handler=handler + ).select( + MultiFilter(rules={ + f[0]: eval(f[1]) for f in self.config.dataset_filters + }) if self.config.dataset_filters is not None else lambda _: True + ).shuffle(690, handler=handler).decode( + "pilrgb", handler=handler + ).to_tuple( + *[p[0] for p in preprocessors], handler=handler + ).map_tuple( + *[p[1] for p in preprocessors], handler=handler + ).map(lambda x: {p[2]: x[i] for i, p in enumerate(preprocessors)}) + + def identity(x): + return x + + # SETUP DATALOADER + real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True, + collate_fn=identity if self.config.multi_aspect_ratio is not None else None + ) + if self.is_main_node: + print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") + + if self.config.multi_aspect_ratio is not None: + aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] + dataloader_iterator = Bucketeer(dataloader, density=self.config.image_size ** 2, factor=32, + ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, + interpolate_nearest=False) # , use_smartcrop=True) + else: + dataloader_iterator = iter(dataloader) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + if return_fields is None: + return_fields = ['clip_text', 'clip_text_pooled', 'clip_img'] + + captions = batch.get('captions', None) + images = batch.get('images', None) + batch_size = len(captions) + + text_embeddings = None + text_pooled_embeddings = None + if 'clip_text' in return_fields or 'clip_text_pooled' in return_fields: + if is_eval: + if is_unconditional: + captions_unpooled = ["" for _ in range(batch_size)] + else: + captions_unpooled = captions + else: + rand_idx = np.random.rand(batch_size) > 0.05 + captions_unpooled = [str(c) if keep else "" for c, keep in zip(captions, rand_idx)] + clip_tokens_unpooled = models.tokenizer(captions_unpooled, truncation=True, padding="max_length", + max_length=models.tokenizer.model_max_length, + return_tensors="pt").to(self.device) + text_encoder_output = models.text_model(**clip_tokens_unpooled, output_hidden_states=True) + if 'clip_text' in return_fields: + text_embeddings = text_encoder_output.hidden_states[-1] + if 'clip_text_pooled' in return_fields: + text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1) + + image_embeddings = None + if 'clip_img' in return_fields: + image_embeddings = torch.zeros(batch_size, 768, device=self.device) + if images is not None: + images = images.to(self.device) + if is_eval: + if not is_unconditional and eval_image_embeds: + image_embeddings = models.image_model(extras.clip_preprocess(images)).image_embeds + else: + rand_idx = np.random.rand(batch_size) > 0.9 + if any(rand_idx): + image_embeddings[rand_idx] = models.image_model(extras.clip_preprocess(images[rand_idx])).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + return { + 'clip_text': text_embeddings, + 'clip_text_pooled': text_pooled_embeddings, + 'clip_img': image_embeddings + } + + +class TrainingCore(DataCore, WarpCore): + @dataclass(frozen=True) + class Config(DataCore.Config, WarpCore.Config): + updates: int = EXPECTED_TRAIN + backup_every: int = EXPECTED_TRAIN + save_every: int = EXPECTED_TRAIN + + # EMA UPDATE + ema_start_iters: int = None + ema_iters: int = None + ema_beta: float = None + + use_fsdp: bool = None + + @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED + class Info(WarpCore.Info): + ema_loss: float = None + adaptive_loss: dict = None + + @dataclass(frozen=True) + class Models(WarpCore.Models): + generator: nn.Module = EXPECTED + generator_ema: nn.Module = None # optional + + @dataclass(frozen=True) + class Optimizers(WarpCore.Optimizers): + generator: any = EXPECTED + + @dataclass(frozen=True) + class Extras(WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + + info: Info + config: Config + + @abstractmethod + def forward_pass(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: Optimizers, + schedulers: WarpCore.Schedulers): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def models_to_save(self) -> list: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: Optimizers, + schedulers: WarpCore.Schedulers): + start_iter = self.info.iter + 1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + pbar = tqdm(range(start_iter, max_iters + 1)) if self.is_main_node else range(start_iter, + max_iters + 1) # <--- DDP + if 'generator' in self.models_to_save(): + models.generator.train() + for i in pbar: + # FORWARD PASS + loss, loss_adjusted = self.forward_pass(data, extras, models) + + # # BACKWARD PASS + grad_norm = self.backward_pass( + i % self.config.grad_accum_steps == 0 or i == max_iters, loss, loss_adjusted, + models, optimizers, schedulers + ) + self.info.iter = i + + # UPDATE EMA + if models.generator_ema is not None and i % self.config.ema_iters == 0: + update_weights_ema( + models.generator_ema, models.generator, + beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) + ) + + # UPDATE LOSS METRICS + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan( + grad_norm.item()): + wandb.alert( + title=f"NaN value encountered in training run {self.info.wandb_run_id}", + text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", + wait_duration=60 * 30 + ) + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'raw_loss': loss.mean().item(), + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, + 'total_steps': self.info.total_steps, + } + + pbar.set_postfix(logs) + if self.config.wandb_project is not None: + wandb.log(logs) + + if i == 1 or i % (self.config.save_every * self.config.grad_accum_steps) == 0 or i == max_iters: + # SAVE AND CHECKPOINT STUFF + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + tqdm.write("Skipping sampling & checkpoint because the loss is NaN") + wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.wandb_run_id}", + text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") + else: + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + self.info.adaptive_loss = { + 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), + 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), + } + self.save_checkpoints(models, optimizers) + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + self.sample(models, data, extras) + + def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): + barrier() + suffix = '' if suffix is None else suffix + self.save_info(self.info, suffix=suffix) + models_dict = models.to_dict() + optimizers_dict = optimizers.to_dict() + for key in self.models_to_save(): + model = models_dict[key] + if model is not None: + self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) + for key in optimizers_dict: + optimizer = optimizers_dict[key] + if optimizer is not None: + self.save_optimizer(optimizer, f'{key}_optim{suffix}', + fsdp_model=models_dict[key] if self.config.use_fsdp else None) + if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: + self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps // 1000}k") + torch.cuda.empty_cache() + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + if 'generator' in self.models_to_save(): + models.generator.eval() + with torch.no_grad(): + batch = next(data.iterator) + + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + latents = self.encode_latents(batch, models, extras) + noised, _, _, logSNR, noise_cond, _ = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + *_, (sampled, _, _) = extras.gdf.sample( + models.generator, conditions, + latents.shape, unconditions, device=self.device, **extras.sampling_configs + ) + + if models.generator_ema is not None: + *_, (sampled_ema, _, _) = extras.gdf.sample( + models.generator_ema, conditions, + latents.shape, unconditions, device=self.device, **extras.sampling_configs + ) + else: + sampled_ema = sampled + + if self.is_main_node: + noised_images = torch.cat( + [self.decode_latents(noised[i:i + 1], batch, models, extras) for i in range(len(noised))], dim=0) + pred_images = torch.cat( + [self.decode_latents(pred[i:i + 1], batch, models, extras) for i in range(len(pred))], dim=0) + sampled_images = torch.cat( + [self.decode_latents(sampled[i:i + 1], batch, models, extras) for i in range(len(sampled))], dim=0) + sampled_images_ema = torch.cat( + [self.decode_latents(sampled_ema[i:i + 1], batch, models, extras) for i in range(len(sampled_ema))], + dim=0) + + images = batch['images'] + if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): + images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') + + collage_img = torch.cat([ + torch.cat([i for i in images.cpu()], dim=-1), + torch.cat([i for i in noised_images.cpu()], dim=-1), + torch.cat([i for i in pred_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), + ], dim=-2) + + torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') + torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg') + + captions = batch['captions'] + if self.config.wandb_project is not None: + log_data = [ + [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ + wandb.Image(images[i])] for i in range(len(images))] + log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"]) + wandb.log({"Log": log_table}) + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1]) + plt.ylabel('Raw Loss') + plt.ylabel('LogSNR') + wandb.log({"Loss/LogSRN": plt}) + + if 'generator' in self.models_to_save(): + models.generator.train() diff --git a/train/dist_core.py b/train/dist_core.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5a75b906dec6e2ec412258ad1db31b05c94b21 --- /dev/null +++ b/train/dist_core.py @@ -0,0 +1,47 @@ +import os +import torch + + +def get_world_size(): + """Find OMPI world size without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_SIZE') is not None: + return int(os.environ.get('PMI_SIZE') or 1) + elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) + else: + return torch.cuda.device_count() + + +def get_global_rank(): + """Find OMPI world rank without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_RANK') is not None: + return int(os.environ.get('PMI_RANK') or 0) + elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) + else: + return 0 + + +def get_local_rank(): + """Find OMPI local rank without calling mpi functions + :rtype: int + """ + if os.environ.get('MPI_LOCALRANKID') is not None: + return int(os.environ.get('MPI_LOCALRANKID') or 0) + elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) + else: + return 0 + + +def get_master_ip(): + if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] + elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') + else: + return "127.0.0.1" diff --git a/train/train_b.py b/train/train_b.py new file mode 100644 index 0000000000000000000000000000000000000000..c3441a5841750a7c33b49756d2d60064a68d82d8 --- /dev/null +++ b/train/train_b.py @@ -0,0 +1,305 @@ +import torch +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection +from warmup_scheduler import GradualWarmupScheduler +import numpy as np + +import sys +import os +from dataclasses import dataclass + +from gdf import GDF, EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop + +from modules.effnet import EfficientNetEncoder +from modules.stage_a import StageA + +from modules.stage_b import StageB +from modules.stage_b import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock + +from train.base import DataCore, TrainingCore + +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + shift: float = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3BB or 700M + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + stage_a_checkpoint_path: str = EXPECTED + effnet_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + + # gdf customization + adaptive_loss_weight: str = None + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + stage_a: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 1.5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 10} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) if self.config.training else torchvision.transforms.CenterCrop(self.config.image_size) + ]) + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=None + ) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None): + images = batch.get('images', None) + + if images is not None: + images = images.to(self.device) + if is_eval and not is_unconditional: + effnet_embeddings = models.effnet(extras.effnet_preprocess(images)) + else: + if is_eval: + effnet_factor = 1 + else: + effnet_factor = np.random.uniform(0.5, 1) # f64 to f32 + effnet_height, effnet_width = int(((images.size(-2)*effnet_factor)//32)*32), int(((images.size(-1)*effnet_factor)//32)*32) + + effnet_embeddings = torch.zeros(images.size(0), 16, effnet_height//32, effnet_width//32, device=self.device) + if not is_eval: + effnet_images = torchvision.transforms.functional.resize(images, (effnet_height, effnet_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) + rand_idx = np.random.rand(len(images)) <= 0.9 + if any(rand_idx): + effnet_embeddings[rand_idx] = models.effnet(extras.effnet_preprocess(effnet_images[rand_idx])) + else: + effnet_embeddings = None + + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text_pooled'] + ) + + return {'effnet': effnet_embeddings, 'clip': conditions['clip_text_pooled']} + + def setup_models(self, extras: Extras, skip_clip: bool = False) -> Models: + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 + + # EfficientNet encoder + effnet = EfficientNetEncoder().to(self.device) + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False) + del effnet_checkpoint + + # vqGAN + stage_a = StageA().to(self.device) + stage_a_checkpoint = load_or_fail(self.config.stage_a_checkpoint_path) + stage_a.load_state_dict(stage_a_checkpoint if 'state_dict' not in stage_a_checkpoint else stage_a_checkpoint['state_dict']) + stage_a.eval().requires_grad_(False) + del stage_a_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3B': + generator = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]]) + if self.config.ema_start_iters is not None: + generator_ema = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]]) + elif self.config.model_version == '700M': + generator = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]]) + if self.config.ema_start_iters is not None: + generator_ema = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + if self.config.generator_checkpoint_path is not None: + if loading_context is dummy_context: + generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + generator = generator.to(dtype).to(self.device) + generator = self.load_model(generator, 'generator') + + if generator_ema is not None: + if loading_context is dummy_context: + generator_ema.load_state_dict(generator.state_dict()) + else: + for param_name, param in generator.state_dict().items(): + set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param) + generator_ema = self.load_model(generator_ema, 'generator_ema') + generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) + + if self.config.use_fsdp: + fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) + generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + if generator_ema is not None: + generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + if skip_clip: + tokenizer = None + text_model = None + else: + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + + return self.Models( + effnet=effnet, stage_a=stage_a, + generator=generator, generator_ema=generator_ema, + tokenizer=tokenizer, text_model=text_model + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, + optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + def _pyramid_noise(self, epsilon, size_range=None, levels=10, scale_mode='nearest'): + epsilon = epsilon.clone() + multipliers = [1] + for i in range(1, levels): + m = 0.75 ** i + h, w = epsilon.size(-2) // (2 ** i), epsilon.size(-2) // (2 ** i) + if size_range is None or (size_range[0] <= h <= size_range[1] or size_range[0] <= w <= size_range[1]): + offset = torch.randn(epsilon.size(0), epsilon.size(1), h, w, device=self.device) + epsilon = epsilon + torch.nn.functional.interpolate(offset, size=epsilon.shape[-2:], + mode=scale_mode) * m + multipliers.append(m) + if h <= 1 or w <= 1: + break + epsilon = epsilon / sum([m ** 2 for m in multipliers]) ** 0.5 + # epsilon = epsilon / epsilon.std() + return epsilon + + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + latents = self.encode_latents(batch, models, extras) + epsilon = torch.randn_like(latents) + epsilon = self._pyramid_noise(epsilon, size_range=[1, 16]) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1, + epsilon=epsilon) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, + schedulers: Schedulers): + if update: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['generator', 'generator_ema'] + + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + images = batch['images'].to(self.device) + return models.stage_a.encode(images)[0] + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.stage_a.decode(latents.float()).clamp(0, 1) + + +if __name__ == '__main__': + print("Launching Script") + warpcore = WurstCore( + config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, + device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore() diff --git a/train/train_c.py b/train/train_c.py new file mode 100644 index 0000000000000000000000000000000000000000..c4490c6eebc3e1c5126dd13c53603872f1459a3e --- /dev/null +++ b/train/train_c.py @@ -0,0 +1,266 @@ +import torch +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler + +import sys +import os +from dataclasses import dataclass + +from gdf import GDF, EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop + +from modules.effnet import EfficientNetEncoder +from modules.stage_c import StageC +from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.previewer import Previewer + +from train.base import DataCore, TrainingCore + +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + + # gdf customization + adaptive_loss_weight: str = None + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess + ) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + def setup_models(self, extras: Extras) -> Models: + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 + + # EfficientNet encoder + effnet = EfficientNetEncoder() + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False).to(self.device) + del effnet_checkpoint + + # Previewer + previewer = Previewer() + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False).to(self.device) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3.6B': + generator = StageC() + if self.config.ema_start_iters is not None: + generator_ema = StageC() + elif self.config.model_version == '1B': + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + if self.config.ema_start_iters is not None: + generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + if self.config.generator_checkpoint_path is not None: + if loading_context is dummy_context: + generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + else: + + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + generator = generator.to(dtype).to(self.device) + generator = self.load_model(generator, 'generator') + + if generator_ema is not None: + if loading_context is dummy_context: + generator_ema.load_state_dict(generator.state_dict()) + else: + for param_name, param in generator.state_dict().items(): + set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param) + generator_ema = self.load_model(generator_ema, 'generator_ema') + generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) + + if self.config.use_fsdp: + fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) + generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + if generator_ema is not None: + generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + return self.Models( + effnet=effnet, previewer=previewer, + generator=generator, generator_ema=generator_ema, + tokenizer=tokenizer, text_model=text_model, image_model=image_model + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + # Training loop -------------------------------- + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + latents = self.encode_latents(batch, models, extras) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + if update: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['generator', 'generator_ema'] + + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + images = batch['images'].to(self.device) + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + +if __name__ == '__main__': + print("Launching Script") + warpcore = WurstCore( + config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, + device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore() diff --git a/train/train_c_lora.py b/train/train_c_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..8b83eee0f250e5359901d39b8d4052254cfff4fa --- /dev/null +++ b/train/train_c_lora.py @@ -0,0 +1,330 @@ +import torch +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler + +import sys +import os +import re +from dataclasses import dataclass + +from gdf import GDF, EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop + +from modules.effnet import EfficientNetEncoder +from modules.stage_c import StageC +from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.previewer import Previewer +from modules.lora import apply_lora, apply_retoken, LoRA, ReToken + +from train.base import DataCore, TrainingCore + +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +import functools +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager + + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + lora_checkpoint_path: str = None + + # LoRA STUFF + module_filters: list = EXPECTED + rank: int = EXPECTED + train_tokens: list = EXPECTED + + # gdf customization + adaptive_loss_weight: str = None + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + lora: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + lora: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED + class Info(TrainingCore.Info): + train_tokens: list = None + + @dataclass(frozen=True) + class Optimizers(TrainingCore.Optimizers, WarpCore.Optimizers): + generator: any = None + lora: any = EXPECTED + + # -------------------------------------------- + info: Info + config: Config + + # Extras: gdf, transforms and preprocessors -------------------------------- + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess + ) + + # Data -------------------------------- + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + # Models, Optimizers & Schedulers setup -------------------------------- + def setup_models(self, extras: Extras) -> Models: + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 + + # EfficientNet encoder + effnet = EfficientNetEncoder().to(self.device) + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False) + del effnet_checkpoint + + # Previewer + previewer = Previewer().to(self.device) + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + with loading_context(): + # Diffusion models + if self.config.model_version == '3.6B': + generator = StageC() + elif self.config.model_version == '1B': + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + if self.config.generator_checkpoint_path is not None: + if loading_context is dummy_context: + generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + generator = generator.to(dtype).to(self.device) + generator = self.load_model(generator, 'generator') + + # if self.config.use_fsdp: + # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) + # generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + # CLIP encoders + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + # PREPARE LORA + update_tokens = [] + for tkn_regex, aggr_regex in self.config.train_tokens: + if (tkn_regex.startswith('[') and tkn_regex.endswith(']')) or (tkn_regex.startswith('<') and tkn_regex.endswith('>')): + # Insert new token + tokenizer.add_tokens([tkn_regex]) + # add new zeros embedding + new_embedding = torch.zeros_like(text_model.text_model.embeddings.token_embedding.weight.data)[:1] + if aggr_regex is not None: # aggregate embeddings to provide an interesting baseline + aggr_tokens = [v for k, v in tokenizer.vocab.items() if re.search(aggr_regex, k) is not None] + if len(aggr_tokens) > 0: + new_embedding = text_model.text_model.embeddings.token_embedding.weight.data[aggr_tokens].mean(dim=0, keepdim=True) + elif self.is_main_node: + print(f"WARNING: No tokens found for aggregation regex {aggr_regex}. It will be initialized as zeros.") + text_model.text_model.embeddings.token_embedding.weight.data = torch.cat([ + text_model.text_model.embeddings.token_embedding.weight.data, new_embedding + ], dim=0) + selected_tokens = [len(tokenizer.vocab) - 1] + else: + selected_tokens = [v for k, v in tokenizer.vocab.items() if re.search(tkn_regex, k) is not None] + update_tokens += selected_tokens + update_tokens = list(set(update_tokens)) # remove duplicates + + apply_retoken(text_model.text_model.embeddings.token_embedding, update_tokens) + apply_lora(generator, filters=self.config.module_filters, rank=self.config.rank) + text_model.text_model.to(self.device) + generator.to(self.device) + lora = nn.ModuleDict() + lora['embeddings'] = text_model.text_model.embeddings.token_embedding.parametrizations.weight[0] + lora['weights'] = nn.ModuleList() + for module in generator.modules(): + if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)): + lora['weights'].append(module) + + self.info.train_tokens = [(i, tokenizer.decode(i)) for i in update_tokens] + if self.is_main_node: + print("Updating tokens:", self.info.train_tokens) + print(f"LoRA training {len(lora['weights'])} layers") + + if self.config.lora_checkpoint_path is not None: + lora_checkpoint = load_or_fail(self.config.lora_checkpoint_path) + lora.load_state_dict(lora_checkpoint if 'state_dict' not in lora_checkpoint else lora_checkpoint['state_dict']) + + lora = self.load_model(lora, 'lora') + lora.to(self.device).train().requires_grad_(True) + if self.config.use_fsdp: + # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) + fsdp_auto_wrap_policy = ModuleWrapPolicy([LoRA, ReToken]) + lora = FSDP(lora, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + return self.Models( + effnet=effnet, previewer=previewer, + generator=generator, generator_ema=None, + lora=lora, + tokenizer=tokenizer, text_model=text_model, image_model=image_model + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: + optimizer = optim.AdamW(models.lora.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + optimizer = self.load_optimizer(optimizer, 'lora_optim', + fsdp_model=models.lora if self.config.use_fsdp else None) + return self.Optimizers(generator=None, lora=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.lora, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(lora=scheduler) + + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + conditions = self.get_conditions(batch, models, extras) + with torch.no_grad(): + latents = self.encode_latents(batch, models, extras) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + if update: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.lora.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if optimizers_dict[k] is not None and k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if optimizers_dict[k] is not None and k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['lora'] + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + models.lora.eval() + super().sample(models, data, extras) + models.lora.train(), models.generator.eval() + + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + images = batch['images'].to(self.device) + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + +if __name__ == '__main__': + print("Launching Script") + warpcore = WurstCore( + config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, + device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + ) + warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore() diff --git a/train/train_personalized.py b/train/train_personalized.py new file mode 100644 index 0000000000000000000000000000000000000000..978426e5e1d5804ac006245ee2f5e9c9fab1aa42 --- /dev/null +++ b/train/train_personalized.py @@ -0,0 +1,899 @@ +import torch +import json +import yaml +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler +import torch.multiprocessing as mp +import os +import numpy as np +import re +import sys +sys.path.append(os.path.abspath('./')) + +from dataclasses import dataclass +from torch.distributed import init_process_group, destroy_process_group, barrier +from gdf import GDF_dual_fixlrt as GDF +from gdf import EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop +from fractions import Fraction +from modules.effnet import EfficientNetEncoder +from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.common_ckpt import GlobalResponseNorm +from modules.previewer import Previewer +from core.data import Bucketeer +from train.base import DataCore, TrainingCore +from tqdm import tqdm +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager +from train.dist_core import * +import glob +from torch.utils.data import DataLoader, Dataset +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from PIL import Image +from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary +from core.utils import Base +import torch.nn.functional as F +import functools +import math +import copy +import random +from modules.lora import apply_lora, apply_retoken, LoRA, ReToken + +Image.MAX_IMAGE_PIXELS = None +torch.manual_seed(23) +random.seed(23) +np.random.seed(23) +#7978026 + +class Null_Model(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + pass + + + + +def identity(x): + if isinstance(x, bytes): + x = x.decode('utf-8') + return x +def check_nan_inmodel(model, meta=''): + for name, param in model.named_parameters(): + if torch.isnan(param).any(): + print(f"nan detected in {name}", meta) + return True + print('no nan', meta) + return False +class mydist_dataset(Dataset): + def __init__(self, rootpath, tmp_prompt, img_processor=None): + + self.img_pathlist = glob.glob(os.path.join(rootpath, '*.jpg')) + self.img_pathlist = self.img_pathlist * 100000 + self.img_processor = img_processor + self.length = len( self.img_pathlist) + self.caption = tmp_prompt + + + def __getitem__(self, idx): + + imgpath = self.img_pathlist[idx] + txt = self.caption + + + + + try: + img = Image.open(imgpath).convert('RGB') + w, h = img.size + if self.img_processor is not None: + img = self.img_processor(img) + + except: + print('exception', imgpath) + return self.__getitem__(random.randint(0, self.length -1 ) ) + return dict(captions=txt, images=img) + def __len__(self): + return self.length +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + ultrapixel_path: str = EXPECTED + + # gdf customization + adaptive_loss_weight: str = None + + # LoRA STUFF + module_filters: list = EXPECTED + rank: int = EXPECTED + train_tokens: list = EXPECTED + use_ddp: bool=EXPECTED + tmp_prompt: str=EXPECTED + @dataclass(frozen=True) + class Data(Base): + dataset: Dataset = EXPECTED + dataloader: DataLoader = EXPECTED + iterator: any = EXPECTED + sampler: DistributedSampler = EXPECTED + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + train_norm: nn.Module = EXPECTED + train_lora: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess + ) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + def setup_models(self, extras: Extras) -> Models: # configure model + + + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16 + + # EfficientNet encoderin + effnet = EfficientNetEncoder() + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False).to(self.device) + del effnet_checkpoint + + # Previewer + previewer = Previewer() + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False).to(self.device) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3.6B': + generator = StageC() + if self.config.ema_start_iters is not None: # default setting + generator_ema = StageC() + elif self.config.model_version == '1B': + print('in line 155 1b light model', self.config.model_version ) + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + + if self.config.ema_start_iters is not None and self.config.training: + generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + + + if loading_context is dummy_context: + generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + + generator._init_extra_parameter() + generator = generator.to(torch.bfloat16).to(self.device) + + train_norm = nn.ModuleList() + + + cnt_norm = 0 + for mm in generator.modules(): + if isinstance(mm, GlobalResponseNorm): + + train_norm.append(Null_Model()) + cnt_norm += 1 + + + + + train_norm.append(generator.agg_net) + train_norm.append(generator.agg_net_up) + sdd = torch.load(self.config.ultrapixel_path, map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + train_norm.load_state_dict(collect_sd) + + + + # CLIP encoders + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + # PREPARE LORA + train_lora = nn.ModuleList() + update_tokens = [] + for tkn_regex, aggr_regex in self.config.train_tokens: + if (tkn_regex.startswith('[') and tkn_regex.endswith(']')) or (tkn_regex.startswith('<') and tkn_regex.endswith('>')): + # Insert new token + tokenizer.add_tokens([tkn_regex]) + # add new zeros embedding + new_embedding = torch.zeros_like(text_model.text_model.embeddings.token_embedding.weight.data)[:1] + if aggr_regex is not None: # aggregate embeddings to provide an interesting baseline + aggr_tokens = [v for k, v in tokenizer.vocab.items() if re.search(aggr_regex, k) is not None] + if len(aggr_tokens) > 0: + new_embedding = text_model.text_model.embeddings.token_embedding.weight.data[aggr_tokens].mean(dim=0, keepdim=True) + elif self.is_main_node: + print(f"WARNING: No tokens found for aggregation regex {aggr_regex}. It will be initialized as zeros.") + text_model.text_model.embeddings.token_embedding.weight.data = torch.cat([ + text_model.text_model.embeddings.token_embedding.weight.data, new_embedding + ], dim=0) + selected_tokens = [len(tokenizer.vocab) - 1] + else: + selected_tokens = [v for k, v in tokenizer.vocab.items() if re.search(tkn_regex, k) is not None] + update_tokens += selected_tokens + update_tokens = list(set(update_tokens)) # remove duplicates + + apply_retoken(text_model.text_model.embeddings.token_embedding, update_tokens) + + apply_lora(generator, filters=self.config.module_filters, rank=self.config.rank) + for module in generator.modules(): + if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)): + train_lora.append(module) + + + train_lora.append(text_model.text_model.embeddings.token_embedding.parametrizations.weight[0]) + + if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_lora.safetensors')): + sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_lora.safetensors'), map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + train_lora.load_state_dict(collect_sd, strict=True) + + + train_norm.to(self.device).train().requires_grad_(True) + + if generator_ema is not None: + + generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + generator_ema._init_extra_parameter() + pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors') + if os.path.exists(pretrained_pth): + generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu')) + + generator_ema.eval().requires_grad_(False) + + check_nan_inmodel(generator, 'generator') + + + + if self.config.use_ddp and self.config.training: + + train_lora = DDP(train_lora, device_ids=[self.device], find_unused_parameters=True) + + + + return self.Models( + effnet=effnet, previewer=previewer, train_norm = train_norm, + generator=generator, generator_ema=generator_ema, + tokenizer=tokenizer, text_model=text_model, image_model=image_model, + train_lora=train_lora + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + + + params = [] + params += list(models.train_lora.module.parameters()) + optimizer = optim.AdamW(params, lr=self.config.lr) + + return self.Optimizers(generator=optimizer) + + def ema_update(self, ema_model, source_model, beta): + for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()): + param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta) + + def sync_ema(self, ema_model): + print('sync ema', torch.distributed.get_world_size()) + for param in ema_model.parameters(): + torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM) + param.data /= torch.distributed.get_world_size() + def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + + + optimizer = optim.AdamW( + models.generator.up_blocks.parameters() , + lr=self.config.lr) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.config.webdataset_path + + + dataset = mydist_dataset(dataset_path, self.config.tmp_prompt, \ + torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \ + else extras.transforms) + + # SETUP DATALOADER + real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) + + sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=4, pin_memory=True, + collate_fn=identity if self.config.multi_aspect_ratio is not None else None, + sampler = sampler + ) + if self.is_main_node: + print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") + + if self.config.multi_aspect_ratio is not None: + aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] + dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32, + ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, + interpolate_nearest=False) # , use_smartcrop=True) + else: + + dataloader_iterator = iter(dataloader) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler) + + + + + + def setup_ddp(self, experiment_id, single_gpu=False, rank=0): + + if not single_gpu: + local_rank = rank + process_id = rank + world_size = get_world_size() + + self.process_id = process_id + self.is_main_node = process_id == 0 + self.device = torch.device(local_rank) + self.world_size = world_size + + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '14443' + torch.cuda.set_device(local_rank) + init_process_group( + backend="nccl", + rank=local_rank, + world_size=world_size, + # init_method=init_method, + ) + print(f"[GPU {process_id}] READY") + else: + self.is_main_node = rank == 0 + self.process_id = rank + self.device = torch.device('cuda:0') + self.world_size = 1 + print("Running in single thread, DDP not enabled.") + # Training loop -------------------------------- + def get_target_lr_size(self, ratio, std_size=24): + w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) + return (h * 32 , w * 32) + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + + batch = data + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + shape_lr = self.get_target_lr_size(ratio) + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr) + + + + flag_lr = random.random() < 0.5 or self.info.iter <5000 + + if flag_lr: + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1) + else: + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + if not flag_lr: + noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = \ + extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, ) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + + if not flag_lr: + with torch.no_grad(): + _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions) + + + pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if not flag_lr else None , **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + + loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps + + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + return loss, loss_adjusted + + def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + + if update: + + torch.distributed.barrier() + loss_adjusted.backward() + + grad_norm = nn.utils.clip_grad_norm_(models.train_lora.module.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['generator', 'generator_ema', 'trans_inr', 'trans_inr_ema'] + + def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor: + + images = batch['images'].to(self.device) + if target_size is not None: + images = F.interpolate(images, target_size) + + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ): + + self.is_main_node = (rank == 0) + self.config: self.Config = self.setup_config(config_file_path, config_dict, training) + self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank) + self.info: self.Info = self.setup_info() + print('in line 292', self.config.experiment_id, rank, world_size <= 1) + p = [i for i in range( 2 * 768 // 32)] + p = [num / sum(p) for num in p] + self.rand_pro = p + self.res_list = [o for o in range(800, 2336, 32)] + + + + def __call__(self, single_gpu=False): + + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if self.is_main_node: + print() + print("**STARTIG JOB WITH CONFIG:**") + print(yaml.dump(self.config.to_dict(), default_flow_style=False)) + print("------------------------------------") + print() + print("**INFO:**") + print(yaml.dump(vars(self.info), default_flow_style=False)) + print("------------------------------------") + print() + print('in line 308', self.is_main_node, self.is_main_node, self.process_id, self.device ) + # SETUP STUFF + extras = self.setup_extras_pre() + assert extras is not None, "setup_extras_pre() must return a DTO" + + + + data = self.setup_data(extras) + assert data is not None, "setup_data() must return a DTO" + if self.is_main_node: + print("**DATA:**") + print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + models = self.setup_models(extras) + assert models is not None, "setup_models() must return a DTO" + if self.is_main_node: + print("**MODELS:**") + print(yaml.dump({ + k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() + }, default_flow_style=False)) + print("------------------------------------") + print() + + + + optimizers = self.setup_optimizers(extras, models) + assert optimizers is not None, "setup_optimizers() must return a DTO" + if self.is_main_node: + print("**OPTIMIZERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + schedulers = self.setup_schedulers(extras, models, optimizers) + assert schedulers is not None, "setup_schedulers() must return a DTO" + if self.is_main_node: + print("**SCHEDULERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) + assert post_extras is not None, "setup_extras_post() must return a DTO" + extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) + if self.is_main_node: + print("**EXTRAS:**") + print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + # ------- + + # TRAIN + if self.is_main_node: + print("**TRAINING STARTING...**") + self.train(data, extras, models, optimizers, schedulers) + + if single_gpu is False: + barrier() + destroy_process_group() + if self.is_main_node: + print() + print("------------------------------------") + print() + print("**TRAINING COMPLETE**") + if self.config.wandb_project is not None: + wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished") + + + def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers, + schedulers: WarpCore.Schedulers): + start_iter = self.info.iter + 1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + if 'generator' in self.models_to_save(): + models.generator.train() + + iter_cnt = 0 + epoch_cnt = 0 + models.train_norm.train() + while True: + epoch_cnt += 1 + if self.world_size > 1: + + data.sampler.set_epoch(epoch_cnt) + for ggg in range(len(data.dataloader)): + iter_cnt += 1 + # FORWARD PASS + + loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models) + + + # # BACKWARD PASS + + grad_norm = self.backward_pass( + iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted, + models, optimizers, schedulers + ) + + + + self.info.iter = iter_cnt + + + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + + if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): + print(f"gggg NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'backward_loss': loss_adjusted.mean().item(), + + 'ema_loss': self.info.ema_loss, + 'raw_ori_loss': loss.mean().item(), + + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, + 'total_steps': self.info.total_steps, + } + + + print(iter_cnt, max_iters, logs, epoch_cnt, ) + + + + + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters: + + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + print(f"NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + else: + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + self.info.adaptive_loss = { + 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), + 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), + } + + + if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0: + print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps ) + torch.save(models.train_lora.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_lora.safetensors') + + + torch.save(models.train_lora.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_lora_{iter_cnt}.safetensors') + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters: + + if self.is_main_node: + + self.sample(models, data, extras) + if False: + param_changes = {name: (param - initial_params[name]).norm().item() for name, param in models.train_norm.named_parameters()} + threshold = sorted(param_changes.values(), reverse=True)[int(len(param_changes) * 0.1)] # top 10% + important_params = [name for name, change in param_changes.items() if change > threshold] + print(important_params, threshold, len(param_changes), self.process_id) + json.dump(important_params, open(f'{self.config.output_path}/{self.config.experiment_id}/param.json', 'w'), indent=4) + + + if self.info.iter >= max_iters: + break + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + + + models.generator.eval() + models.train_norm.eval() + with torch.no_grad(): + batch = next(data.iterator) + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + + shape_lr = self.get_target_lr_size(ratio) + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr) + + if self.is_main_node: + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + + sampled_ema = sampled + sampled_ema_lr = sampled_lr + + + if self.is_main_node: + print('sampling results hr latent shape ', latents.shape, 'lr latent shape', latents_lr.shape, ) + noised_images = torch.cat( + [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0) + + sampled_images = torch.cat( + [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0) + sampled_images_ema = torch.cat( + [self.decode_latents(sampled_ema[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema))], + dim=0) + + noised_images_lr = torch.cat( + [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0) + + sampled_images_lr = torch.cat( + [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0) + sampled_images_ema_lr = torch.cat( + [self.decode_latents(sampled_ema_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema_lr))], + dim=0) + + images = batch['images'] + if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): + images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') + images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic') + + collage_img = torch.cat([ + torch.cat([i for i in images.cpu()], dim=-1), + torch.cat([i for i in noised_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), + ], dim=-2) + + collage_img_lr = torch.cat([ + torch.cat([i for i in images_lr.cpu()], dim=-1), + torch.cat([i for i in noised_images_lr.cpu()], dim=-1), + torch.cat([i for i in sampled_images_lr.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema_lr.cpu()], dim=-1), + ], dim=-2) + + torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') + torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg') + + captions = batch['captions'] + if self.config.wandb_project is not None: + log_data = [ + [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ + wandb.Image(images[i])] for i in range(len(images))] + log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"]) + wandb.log({"Log": log_table}) + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1]) + plt.ylabel('Raw Loss') + plt.ylabel('LogSNR') + wandb.log({"Loss/LogSRN": plt}) + + + models.generator.train() + models.train_norm.train() + print('finish sampling') + + + + def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False): + + + models.generator.eval() + models.trans_inr.eval() + with torch.no_grad(): + + if self.is_main_node: + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, conditions, + hr_shape, lr_shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + if models.generator_ema is not None: + + *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample( + models.generator_ema, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + else: + sampled_ema = sampled + sampled_ema_lr = sampled_lr + + + return sampled, sampled_lr +def main_worker(rank, cfg): + print("Launching Script in main worker") + warpcore = WurstCore( + config_file_path=cfg, rank=rank, world_size = get_world_size() + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore(get_world_size()==1) + +if __name__ == '__main__': + + if get_master_ip() == "127.0.0.1": + + mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, )) + else: + main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, ) diff --git a/train/train_t2i.py b/train/train_t2i.py new file mode 100644 index 0000000000000000000000000000000000000000..e28cd7542393cb1b7ee3454cc8b28f30710bae79 --- /dev/null +++ b/train/train_t2i.py @@ -0,0 +1,807 @@ +import torch +import json +import yaml +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler +import torch.multiprocessing as mp +import numpy as np +import os +import sys +sys.path.append(os.path.abspath('./')) +from dataclasses import dataclass +from torch.distributed import init_process_group, destroy_process_group, barrier +from gdf import GDF_dual_fixlrt as GDF +from gdf import EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop +from fractions import Fraction +from modules.effnet import EfficientNetEncoder + +from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.previewer import Previewer +from core.data import Bucketeer +from train.base import DataCore, TrainingCore +from tqdm import tqdm +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager +from train.dist_core import * +import glob +from torch.utils.data import DataLoader, Dataset +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from PIL import Image +from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary +from core.utils import Base +from modules.common_ckpt import LayerNorm2d, GlobalResponseNorm +import torch.nn.functional as F +import functools +import math +import copy +import random +from modules.lora import apply_lora, apply_retoken, LoRA, ReToken +Image.MAX_IMAGE_PIXELS = None +torch.manual_seed(23) +random.seed(23) +np.random.seed(23) +#7978026 + +class Null_Model(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + pass + + + + +def identity(x): + if isinstance(x, bytes): + x = x.decode('utf-8') + return x +def check_nan_inmodel(model, meta=''): + for name, param in model.named_parameters(): + if torch.isnan(param).any(): + print(f"nan detected in {name}", meta) + return True + print('no nan', meta) + return False +class mydist_dataset(Dataset): + def __init__(self, rootpath, img_processor=None): + + self.img_pathlist = glob.glob(os.path.join(rootpath, '*', '*.jpg')) + self.img_processor = img_processor + self.length = len( self.img_pathlist) + + + + def __getitem__(self, idx): + + imgpath = self.img_pathlist[idx] + json_file = imgpath.replace('.jpg', '.json') + + with open(json_file, 'r') as file: + info = json.load(file) + txt = info['caption'] + if txt is None: + txt = ' ' + try: + img = Image.open(imgpath).convert('RGB') + w, h = img.size + if self.img_processor is not None: + img = self.img_processor(img) + + except: + print('exception', imgpath) + return self.__getitem__(random.randint(0, self.length -1 ) ) + return dict(captions=txt, images=img) + def __len__(self): + return self.length + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + + generator_checkpoint_path: str = None + + # gdf customization + adaptive_loss_weight: str = None + use_ddp: bool=EXPECTED + + + @dataclass(frozen=True) + class Data(Base): + dataset: Dataset = EXPECTED + dataloader: DataLoader = EXPECTED + iterator: any = EXPECTED + sampler: DistributedSampler = EXPECTED + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + train_norm: nn.Module = EXPECTED + + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess + ) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + def setup_models(self, extras: Extras) -> Models: # configure model + + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16 + + # EfficientNet encoderin + effnet = EfficientNetEncoder() + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False).to(self.device) + del effnet_checkpoint + + # Previewer + previewer = Previewer() + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False).to(self.device) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3.6B': + generator = StageC() + if self.config.ema_start_iters is not None: # default setting + generator_ema = StageC() + elif self.config.model_version == '1B': + print('in line 155 1b light model', self.config.model_version ) + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + + if self.config.ema_start_iters is not None and self.config.training: + generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + + + if loading_context is dummy_context: + generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + + generator._init_extra_parameter() + generator = generator.to(torch.bfloat16).to(self.device) + + + train_norm = nn.ModuleList() + cnt_norm = 0 + for mm in generator.modules(): + if isinstance(mm, GlobalResponseNorm): + + train_norm.append(Null_Model()) + cnt_norm += 1 + + train_norm.append(generator.agg_net) + train_norm.append(generator.agg_net_up) + total = sum([ param.nelement() for param in train_norm.parameters()]) + print('Trainable parameter', total / 1048576) + + if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')): + sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + train_norm.load_state_dict(collect_sd, strict=True) + + + train_norm.to(self.device).train().requires_grad_(True) + train_norm_ema = copy.deepcopy(train_norm) + train_norm_ema.to(self.device).eval().requires_grad_(False) + if generator_ema is not None: + + generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + generator_ema._init_extra_parameter() + + + pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors') + if os.path.exists(pretrained_pth): + print(pretrained_pth, 'exists') + generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu')) + + + generator_ema.eval().requires_grad_(False) + + + + + check_nan_inmodel(generator, 'generator') + + + + if self.config.use_ddp and self.config.training: + + train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True) + + # CLIP encoders + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + return self.Models( + effnet=effnet, previewer=previewer, train_norm = train_norm, + generator=generator, tokenizer=tokenizer, text_model=text_model, image_model=image_model, + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + + + params = [] + params += list(models.train_norm.module.parameters()) + + optimizer = optim.AdamW(params, lr=self.config.lr) + + return self.Optimizers(generator=optimizer) + + def ema_update(self, ema_model, source_model, beta): + for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()): + param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta) + + def sync_ema(self, ema_model): + for param in ema_model.parameters(): + torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM) + param.data /= torch.distributed.get_world_size() + def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + + + optimizer = optim.AdamW( + models.generator.up_blocks.parameters() , + lr=self.config.lr) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.config.webdataset_path + dataset = mydist_dataset(dataset_path, \ + torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \ + else extras.transforms) + + # SETUP DATALOADER + real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) + + sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True, + collate_fn=identity if self.config.multi_aspect_ratio is not None else None, + sampler = sampler + ) + if self.is_main_node: + print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") + + if self.config.multi_aspect_ratio is not None: + aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] + dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32, + ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, + interpolate_nearest=False) # , use_smartcrop=True) + else: + + dataloader_iterator = iter(dataloader) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler) + + + def models_to_save(self): + pass + def setup_ddp(self, experiment_id, single_gpu=False, rank=0): + + if not single_gpu: + local_rank = rank + process_id = rank + world_size = get_world_size() + + self.process_id = process_id + self.is_main_node = process_id == 0 + self.device = torch.device(local_rank) + self.world_size = world_size + + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '41443' + torch.cuda.set_device(local_rank) + init_process_group( + backend="nccl", + rank=local_rank, + world_size=world_size, + ) + print(f"[GPU {process_id}] READY") + else: + self.is_main_node = rank == 0 + self.process_id = rank + self.device = torch.device('cuda:0') + self.world_size = 1 + print("Running in single thread, DDP not enabled.") + # Training loop -------------------------------- + def get_target_lr_size(self, ratio, std_size=24): + w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) + return (h * 32 , w * 32) + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + #batch = next(data.iterator) + batch = data + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + shape_lr = self.get_target_lr_size(ratio) + #print('in line 485', shape_lr, ratio, batch['images'].shape) + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr) + + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, ) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + # 768 1536 + require_cond = True + + with torch.no_grad(): + _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions) + + + pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + + loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps + + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + + + if update: + + torch.distributed.barrier() + loss_adjusted.backward() + + grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0) + + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + + loss_adjusted.backward() + + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + + def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor: + + images = batch['images'].to(self.device) + if target_size is not None: + images = F.interpolate(images, target_size) + + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ): + + self.is_main_node = (rank == 0) + self.config: self.Config = self.setup_config(config_file_path, config_dict, training) + self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank) + self.info: self.Info = self.setup_info() + + + + def __call__(self, single_gpu=False): + + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if self.is_main_node: + print() + print("**STARTIG JOB WITH CONFIG:**") + print(yaml.dump(self.config.to_dict(), default_flow_style=False)) + print("------------------------------------") + print() + print("**INFO:**") + print(yaml.dump(vars(self.info), default_flow_style=False)) + print("------------------------------------") + print() + + # SETUP STUFF + extras = self.setup_extras_pre() + assert extras is not None, "setup_extras_pre() must return a DTO" + + + + data = self.setup_data(extras) + assert data is not None, "setup_data() must return a DTO" + if self.is_main_node: + print("**DATA:**") + print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + models = self.setup_models(extras) + assert models is not None, "setup_models() must return a DTO" + if self.is_main_node: + print("**MODELS:**") + print(yaml.dump({ + k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() + }, default_flow_style=False)) + print("------------------------------------") + print() + + + + optimizers = self.setup_optimizers(extras, models) + assert optimizers is not None, "setup_optimizers() must return a DTO" + if self.is_main_node: + print("**OPTIMIZERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + schedulers = self.setup_schedulers(extras, models, optimizers) + assert schedulers is not None, "setup_schedulers() must return a DTO" + if self.is_main_node: + print("**SCHEDULERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) + assert post_extras is not None, "setup_extras_post() must return a DTO" + extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) + if self.is_main_node: + print("**EXTRAS:**") + print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + # ------- + + # TRAIN + if self.is_main_node: + print("**TRAINING STARTING...**") + self.train(data, extras, models, optimizers, schedulers) + + if single_gpu is False: + barrier() + destroy_process_group() + if self.is_main_node: + print() + print("------------------------------------") + print() + print("**TRAINING COMPLETE**") + + + + def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers, + schedulers: WarpCore.Schedulers): + start_iter = self.info.iter + 1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + + models.generator.train() + + iter_cnt = 0 + epoch_cnt = 0 + models.train_norm.train() + while True: + epoch_cnt += 1 + if self.world_size > 1: + + data.sampler.set_epoch(epoch_cnt) + for ggg in range(len(data.dataloader)): + iter_cnt += 1 + loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models) + grad_norm = self.backward_pass( + iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted, + models, optimizers, schedulers + ) + + self.info.iter = iter_cnt + + + # UPDATE LOSS METRICS + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss) + if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): + print(f" NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'backward_loss': loss_adjusted.mean().item(), + 'ema_loss': self.info.ema_loss, + 'raw_ori_loss': loss.mean().item(), + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, + 'total_steps': self.info.total_steps, + } + if iter_cnt % (self.config.save_every) == 0: + + print(iter_cnt, max_iters, logs, epoch_cnt, ) + + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters: + + # SAVE AND CHECKPOINT STUFF + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + print(f"NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + else: + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + self.info.adaptive_loss = { + 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), + 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), + } + + + + if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0: + print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps ) + torch.save(models.train_norm.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors') + + torch.save(models.train_norm.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors') + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters: + + if self.is_main_node: + + self.sample(models, data, extras) + + + if self.info.iter >= max_iters: + break + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + + + models.generator.eval() + models.train_norm.eval() + with torch.no_grad(): + batch = next(data.iterator) + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + + shape_lr = self.get_target_lr_size(ratio) + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr) + + + if self.is_main_node: + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + + + + if self.is_main_node: + print('sampling results hr latent shape', latents.shape, 'lr latent shape', latents_lr.shape, ) + noised_images = torch.cat( + [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0) + + sampled_images = torch.cat( + [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0) + + + noised_images_lr = torch.cat( + [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0) + + sampled_images_lr = torch.cat( + [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0) + + images = batch['images'] + if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): + images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') + images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic') + + collage_img = torch.cat([ + torch.cat([i for i in images.cpu()], dim=-1), + torch.cat([i for i in noised_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images.cpu()], dim=-1), + ], dim=-2) + + collage_img_lr = torch.cat([ + torch.cat([i for i in images_lr.cpu()], dim=-1), + torch.cat([i for i in noised_images_lr.cpu()], dim=-1), + torch.cat([i for i in sampled_images_lr.cpu()], dim=-1), + ], dim=-2) + + torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') + torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg') + + + models.generator.train() + models.train_norm.train() + print('finish sampling') + + + + def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False): + + + models.generator.eval() + + with torch.no_grad(): + + if self.is_main_node: + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, conditions, + hr_shape, lr_shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + if models.generator_ema is not None: + + *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample( + models.generator_ema, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + else: + sampled_ema = sampled + sampled_ema_lr = sampled_lr + + return sampled, sampled_lr +def main_worker(rank, cfg): + print("Launching Script in main worker") + + warpcore = WurstCore( + config_file_path=cfg, rank=rank, world_size = get_world_size() + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore(get_world_size()==1) + +if __name__ == '__main__': + print('launch multi process') + # os.environ["OMP_NUM_THREADS"] = "1" + # os.environ["MKL_NUM_THREADS"] = "1" + #dist.init_process_group(backend="nccl") + #torch.backends.cudnn.benchmark = True +#train/train_c_my.py + #mp.set_sharing_strategy('file_system') + + if get_master_ip() == "127.0.0.1": + # manually launch distributed processes + mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, )) + else: + main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, ) diff --git a/train/train_ultrapixel_control.py b/train/train_ultrapixel_control.py new file mode 100644 index 0000000000000000000000000000000000000000..97001a62b84f9bdb369d9f7948c6dbf8028d2b63 --- /dev/null +++ b/train/train_ultrapixel_control.py @@ -0,0 +1,928 @@ +import torch +import json +import yaml +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler +import torch.multiprocessing as mp +import numpy as np +import sys + +import os +from dataclasses import dataclass +from torch.distributed import init_process_group, destroy_process_group, barrier +from gdf import GDF_dual_fixlrt as GDF +from gdf import EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop +from fractions import Fraction +from modules.effnet import EfficientNetEncoder + +from modules.model_4stage_lite import StageC + +from modules.model_4stage_lite import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.common_ckpt import GlobalResponseNorm +from modules.previewer import Previewer +from core.data import Bucketeer +from train.base import DataCore, TrainingCore +from tqdm import tqdm +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail +from torch.distributed.fsdp.wrap import ModuleWrapPolicy, size_based_auto_wrap_policy +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager +from train.dist_core import * +import glob +from torch.utils.data import DataLoader, Dataset +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from PIL import Image +from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary +from core.utils import Base +from modules.common import LayerNorm2d +import torch.nn.functional as F +import functools +import math +import copy +import random +from modules.lora import apply_lora, apply_retoken, LoRA, ReToken +from modules import ControlNet, ControlNetDeliverer +from modules import controlnet_filters + +Image.MAX_IMAGE_PIXELS = None +torch.manual_seed(8432) +random.seed(8432) +np.random.seed(8432) +#7978026 + +class Null_Model(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + pass + + +def identity(x): + if isinstance(x, bytes): + x = x.decode('utf-8') + return x +def check_nan_inmodel(model, meta=''): + for name, param in model.named_parameters(): + if torch.isnan(param).any(): + print(f"nan detected in {name}", meta) + return True + print('no nan', meta) + return False + + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + #trans_inr_ckpt: str = EXPECTED + generator_checkpoint_path: str = None + controlnet_checkpoint_path: str = EXPECTED + + # controlnet settings + controlnet_blocks: list = EXPECTED + controlnet_filter: str = EXPECTED + controlnet_filter_params: dict = None + controlnet_bottleneck_mode: str = None + + + # gdf customization + adaptive_loss_weight: str = None + + #module_filters: list = EXPECTED + #rank: int = EXPECTED + @dataclass(frozen=True) + class Data(Base): + dataset: Dataset = EXPECTED + dataloader: DataLoader = EXPECTED + iterator: any = EXPECTED + sampler: DistributedSampler = EXPECTED + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + train_norm: nn.Module = EXPECTED + train_norm_ema: nn.Module = EXPECTED + controlnet: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + controlnet_filter: controlnet_filters.BaseFilter = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + controlnet_filter = getattr(controlnet_filters, self.config.controlnet_filter)( + self.device, + **(self.config.controlnet_filter_params if self.config.controlnet_filter_params is not None else {}) + ) + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess, + controlnet_filter=controlnet_filter + ) + def get_cnet(self, batch: dict, models: Models, extras: Extras, cnet_input=None, target_size=None, **kwargs): + images = batch['images'] + if target_size is not None: + images = Image.resize(images, target_size) + with torch.no_grad(): + if cnet_input is None: + cnet_input = extras.controlnet_filter(images, **kwargs) + if isinstance(cnet_input, tuple): + cnet_input, cnet_input_preview = cnet_input + else: + cnet_input_preview = cnet_input + cnet_input, cnet_input_preview = cnet_input.to(self.device), cnet_input_preview.to(self.device) + cnet = models.controlnet(cnet_input) + return cnet, cnet_input_preview + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + def setup_models(self, extras: Extras) -> Models: # configure model + + + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16 + + # EfficientNet encoderin + effnet = EfficientNetEncoder() + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False).to(self.device) + del effnet_checkpoint + + # Previewer + previewer = Previewer() + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False).to(self.device) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3.6B': + generator = StageC() + if self.config.ema_start_iters is not None: # default setting + generator_ema = StageC() + elif self.config.model_version == '1B': + + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + + if self.config.ema_start_iters is not None and self.config.training: + generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + + + if loading_context is dummy_context: + generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + + generator._init_extra_parameter() + + + + + generator = generator.to(torch.bfloat16).to(self.device) + + train_norm = nn.ModuleList() + + + cnt_norm = 0 + for mm in generator.modules(): + if isinstance(mm, GlobalResponseNorm): + + train_norm.append(Null_Model()) + cnt_norm += 1 + + + + + train_norm.append(generator.agg_net) + train_norm.append(generator.agg_net_up) + + + + + if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')): + sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + train_norm.load_state_dict(collect_sd, strict=True) + + + train_norm.to(self.device).train().requires_grad_(True) + train_norm_ema = copy.deepcopy(train_norm) + train_norm_ema.to(self.device).eval().requires_grad_(False) + if generator_ema is not None: + + generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + generator_ema._init_extra_parameter() + + pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors') + if os.path.exists(pretrained_pth): + print(pretrained_pth, 'exists') + generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu')) + + generator_ema.eval().requires_grad_(False) + + check_nan_inmodel(generator, 'generator') + + + + if self.config.use_fsdp and self.config.training: + train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True) + + + # CLIP encoders + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + controlnet = ControlNet( + c_in=extras.controlnet_filter.num_channels(), + proj_blocks=self.config.controlnet_blocks, + bottleneck_mode=self.config.controlnet_bottleneck_mode + ) + controlnet = controlnet.to(dtype).to(self.device) + controlnet = self.load_model(controlnet, 'controlnet') + controlnet.backbone.eval().requires_grad_(True) + + + return self.Models( + effnet=effnet, previewer=previewer, train_norm = train_norm, + generator=generator, generator_ema=generator_ema, + tokenizer=tokenizer, text_model=text_model, image_model=image_model, + train_norm_ema=train_norm_ema, controlnet =controlnet + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + +# + + params = [] + params += list(models.train_norm.module.parameters()) + + optimizer = optim.AdamW(params, lr=self.config.lr) + + return self.Optimizers(generator=optimizer) + + def ema_update(self, ema_model, source_model, beta): + for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()): + param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta) + + def sync_ema(self, ema_model): + print('sync ema', torch.distributed.get_world_size()) + for param in ema_model.parameters(): + torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM) + param.data /= torch.distributed.get_world_size() + def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + + + optimizer = optim.AdamW( + models.generator.up_blocks.parameters() , + lr=self.config.lr) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.config.webdataset_path + print('in line 96', dataset_path, type(dataset_path)) + + dataset = mydist_dataset(dataset_path, \ + torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \ + else extras.transforms) + + # SETUP DATALOADER + real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) + print('in line 119', self.process_id, real_batch_size) + sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=4, pin_memory=True, + collate_fn=identity if self.config.multi_aspect_ratio is not None else None, + sampler = sampler + ) + if self.is_main_node: + print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") + + if self.config.multi_aspect_ratio is not None: + aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] + dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32, + ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, + interpolate_nearest=False) # , use_smartcrop=True) + else: + + dataloader_iterator = iter(dataloader) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler) + + + + + + def setup_ddp(self, experiment_id, single_gpu=False, rank=0): + + if not single_gpu: + local_rank = rank + process_id = rank + world_size = get_world_size() + + self.process_id = process_id + self.is_main_node = process_id == 0 + self.device = torch.device(local_rank) + self.world_size = world_size + + + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '41443' + torch.cuda.set_device(local_rank) + init_process_group( + backend="nccl", + rank=local_rank, + world_size=world_size, + # init_method=init_method, + ) + print(f"[GPU {process_id}] READY") + else: + self.is_main_node = rank == 0 + self.process_id = rank + self.device = torch.device('cuda:0') + self.world_size = 1 + print("Running in single thread, DDP not enabled.") + # Training loop -------------------------------- + def get_target_lr_size(self, ratio, std_size=24): + w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) + return (h * 32 , w * 32) + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + #batch = next(data.iterator) + batch = data + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + shape_lr = self.get_target_lr_size(ratio) + + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr) + + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, ) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + require_cond = True + + with torch.no_grad(): + _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions) + + + pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + + loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps + # + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + + if update: + + torch.distributed.barrier() + loss_adjusted.backward() + + + grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0) + + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + #print('in line 457', loss_adjusted) + loss_adjusted.backward() + #torch.distributed.barrier() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['generator', 'generator_ema', 'trans_inr', 'trans_inr_ema'] + + def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor: + + images = batch['images'].to(self.device) + if target_size is not None: + images = F.interpolate(images, target_size) + #images = apply_degradations(images) + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ): + # Temporary setup, will be overriden by setup_ddp if required + # self.device = device + # self.process_id = 0 + # self.is_main_node = True + # self.world_size = 1 + # ---- + # self.world_size = world_size + # self.process_id = rank + # self.device=device + self.is_main_node = (rank == 0) + self.config: self.Config = self.setup_config(config_file_path, config_dict, training) + self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank) + self.info: self.Info = self.setup_info() + print('in line 292', self.config.experiment_id, rank, world_size <= 1) + p = [i for i in range( 2 * 768 // 32)] + p = [num / sum(p) for num in p] + self.rand_pro = p + self.res_list = [o for o in range(800, 2336, 32)] + + #[32, 40, 48] + #in line 292 stage_c_3b_finetuning False + + def __call__(self, single_gpu=False): + # this will change the device to the CUDA rank + #self.setup_wandb() + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if self.is_main_node: + print() + print("**STARTIG JOB WITH CONFIG:**") + print(yaml.dump(self.config.to_dict(), default_flow_style=False)) + print("------------------------------------") + print() + print("**INFO:**") + print(yaml.dump(vars(self.info), default_flow_style=False)) + print("------------------------------------") + print() + print('in line 308', self.is_main_node, self.is_main_node, self.process_id, self.device ) + # SETUP STUFF + extras = self.setup_extras_pre() + assert extras is not None, "setup_extras_pre() must return a DTO" + + + + data = self.setup_data(extras) + assert data is not None, "setup_data() must return a DTO" + if self.is_main_node: + print("**DATA:**") + print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + models = self.setup_models(extras) + assert models is not None, "setup_models() must return a DTO" + if self.is_main_node: + print("**MODELS:**") + print(yaml.dump({ + k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() + }, default_flow_style=False)) + print("------------------------------------") + print() + + + + optimizers = self.setup_optimizers(extras, models) + assert optimizers is not None, "setup_optimizers() must return a DTO" + if self.is_main_node: + print("**OPTIMIZERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + schedulers = self.setup_schedulers(extras, models, optimizers) + assert schedulers is not None, "setup_schedulers() must return a DTO" + if self.is_main_node: + print("**SCHEDULERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) + assert post_extras is not None, "setup_extras_post() must return a DTO" + extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) + if self.is_main_node: + print("**EXTRAS:**") + print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + # ------- + + # TRAIN + if self.is_main_node: + print("**TRAINING STARTING...**") + self.train(data, extras, models, optimizers, schedulers) + + if single_gpu is False: + barrier() + destroy_process_group() + if self.is_main_node: + print() + print("------------------------------------") + print() + print("**TRAINING COMPLETE**") + if self.config.wandb_project is not None: + wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished") + + + def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers, + schedulers: WarpCore.Schedulers): + start_iter = self.info.iter + 1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + if 'generator' in self.models_to_save(): + models.generator.train() + #initial_params = {name: param.clone() for name, param in models.train_norm.named_parameters()} + iter_cnt = 0 + epoch_cnt = 0 + models.train_norm.train() + while True: + epoch_cnt += 1 + if self.world_size > 1: + print('sampler set epoch', epoch_cnt) + data.sampler.set_epoch(epoch_cnt) + for ggg in range(len(data.dataloader)): + iter_cnt += 1 + # FORWARD PASS + #print('in line 414 before forward', iter_cnt, batch['captions'][0], self.process_id) + #loss, loss_adjusted, loss_extra = self.forward_pass(batch, extras, models) + loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models) + + #print('in line 416', loss, iter_cnt) + # # BACKWARD PASS + + grad_norm = self.backward_pass( + iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted, + models, optimizers, schedulers + ) + + + + self.info.iter = iter_cnt + + # UPDATE EMA + if iter_cnt % self.config.ema_iters == 0: + + with torch.no_grad(): + print('in line 890 ema update', self.config.ema_iters, iter_cnt) + self.ema_update(models.train_norm_ema, models.train_norm, self.config.ema_beta) + #generator.module.agg_net. + #self.ema_update(models.generator_ema.agg_net, models.generator.module.agg_net, self.config.ema_beta) + #self.ema_update(models.generator_ema.agg_net_up, models.generator.module.agg_net_up, self.config.ema_beta) + + # UPDATE LOSS METRICS + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss) + if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): + print(f"gggg NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'backward_loss': loss_adjusted.mean().item(), + #'raw_extra_loss': loss_extra.mean().item(), + 'ema_loss': self.info.ema_loss, + 'raw_ori_loss': loss.mean().item(), + #'raw_rec_loss': loss_rec.mean().item(), + #'raw_lr_loss': loss_lr.mean().item(), + #'reg_loss':loss_reg.item(), + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, + 'total_steps': self.info.total_steps, + } + if iter_cnt % (self.config.save_every) == 0: + + print(iter_cnt, max_iters, logs, epoch_cnt, ) + #pbar.set_postfix(logs) + + + #if iter_cnt % 10 == 0: + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters: + #if True: + # SAVE AND CHECKPOINT STUFF + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + print(f"NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + else: + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + self.info.adaptive_loss = { + 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), + 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), + } + #self.save_checkpoints(models, optimizers) + + #torch.save(models.trans_inr.module.state_dict(), \ + #f'{self.config.output_path}/{self.config.experiment_id}/trans_inr.safetensors') + #torch.save(models.trans_inr_ema.state_dict(), \ + #f'{self.config.output_path}/{self.config.experiment_id}/trans_inr_ema.safetensors') + + + if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0: + print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps ) + torch.save(models.train_norm.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors') + + #self.sync_ema(models.train_norm_ema) + torch.save(models.train_norm_ema.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_norm_ema.safetensors') + #if self.is_main_node and iter_cnt % (4 * self.config.save_every * self.config.grad_accum_steps) == 0: + torch.save(models.train_norm.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors') + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters: + + if self.is_main_node: + #check_nan_inmodel(models.generator, 'generator') + #check_nan_inmodel(models.generator_ema, 'generator_ema') + self.sample(models, data, extras) + if False: + param_changes = {name: (param - initial_params[name]).norm().item() for name, param in models.train_norm.named_parameters()} + threshold = sorted(param_changes.values(), reverse=True)[int(len(param_changes) * 0.1)] # top 10% + important_params = [name for name, change in param_changes.items() if change > threshold] + print(important_params, threshold, len(param_changes), self.process_id) + json.dump(important_params, open(f'{self.config.output_path}/{self.config.experiment_id}/param.json', 'w'), indent=4) + + + if self.info.iter >= max_iters: + break + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + + #if 'generator' in self.models_to_save(): + models.generator.eval() + models.train_norm.eval() + with torch.no_grad(): + batch = next(data.iterator) + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + #batch['images'] = batch['images'].to(torch.float16) + shape_lr = self.get_target_lr_size(ratio) + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + cnet, cnet_input = self.get_cnet(batch, models, extras) + conditions, unconditions = {**conditions, 'cnet': cnet}, {**unconditions, 'cnet': cnet} + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr) + + if self.is_main_node: + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + #print('in line 366 on v100 switch to tf16') + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, models.trans_inr, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + + + #else: + sampled_ema = sampled + sampled_ema_lr = sampled_lr + + + if self.is_main_node: + print('sampling results', latents.shape, latents_lr.shape, ) + noised_images = torch.cat( + [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0) + + sampled_images = torch.cat( + [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0) + sampled_images_ema = torch.cat( + [self.decode_latents(sampled_ema[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema))], + dim=0) + + noised_images_lr = torch.cat( + [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0) + + sampled_images_lr = torch.cat( + [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0) + sampled_images_ema_lr = torch.cat( + [self.decode_latents(sampled_ema_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema_lr))], + dim=0) + + images = batch['images'] + if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): + images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') + images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic') + + collage_img = torch.cat([ + torch.cat([i for i in images.cpu()], dim=-1), + torch.cat([i for i in noised_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), + ], dim=-2) + + collage_img_lr = torch.cat([ + torch.cat([i for i in images_lr.cpu()], dim=-1), + torch.cat([i for i in noised_images_lr.cpu()], dim=-1), + torch.cat([i for i in sampled_images_lr.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema_lr.cpu()], dim=-1), + ], dim=-2) + + torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') + torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg') + #torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg') + + captions = batch['captions'] + if self.config.wandb_project is not None: + log_data = [ + [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ + wandb.Image(images[i])] for i in range(len(images))] + log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"]) + wandb.log({"Log": log_table}) + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1]) + plt.ylabel('Raw Loss') + plt.ylabel('LogSNR') + wandb.log({"Loss/LogSRN": plt}) + + #if 'generator' in self.models_to_save(): + models.generator.train() + models.train_norm.train() + print('finishe sampling in line 901') + + + + def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False): + + #if 'generator' in self.models_to_save(): + models.generator.eval() + models.trans_inr.eval() + models.controlnet.eval() + with torch.no_grad(): + + if self.is_main_node: + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + cnet, cnet_input = self.get_cnet(batch, models, extras, target_size = lr_shape) + conditions, unconditions = {**conditions, 'cnet': cnet}, {**unconditions, 'cnet': cnet} + + #print('in line 885', self.is_main_node) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + #print('in line 366 on v100 switch to tf16') + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, models.trans_inr, conditions, + hr_shape, lr_shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + if models.generator_ema is not None: + + *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample( + models.generator_ema, models.trans_inr_ema, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + else: + sampled_ema = sampled + sampled_ema_lr = sampled_lr + #x0, x, epsilon, x0_lr, x_lr, pred_lr) + #sampled, _ = models.trans_inr(sampled, None, sampled) + #sampled_lr, _ = models.trans_inr(sampled, None, sampled_lr) + + return sampled, sampled_lr +def main_worker(rank, cfg): + print("Launching Script in main worker") + print('in line 467', rank) + warpcore = WurstCore( + config_file_path=cfg, rank=rank, world_size = get_world_size() + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore(get_world_size()==1) + +if __name__ == '__main__': + print('launch multi process') + # os.environ["OMP_NUM_THREADS"] = "1" + # os.environ["MKL_NUM_THREADS"] = "1" + #dist.init_process_group(backend="nccl") + #torch.backends.cudnn.benchmark = True +#train/train_c_my.py + #mp.set_sharing_strategy('file_system') + print('in line 481', sys.argv[1] if len(sys.argv) > 1 else None) + print('in line 481',get_master_ip(), get_world_size() ) + print('in line 484', get_world_size()) + if get_master_ip() == "127.0.0.1": + # manually launch distributed processes + mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, )) + else: + main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )