diff --git a/app.py b/app.py
index 05736cdd3e0b6fa96d47356442dc0be3b7de9262..243078fee26fc7708d998045dbb18e20ec0f5e81 100644
--- a/app.py
+++ b/app.py
@@ -1,161 +1,74 @@
-import spaces
-import os
-import requests
-import yaml
-import torch
import gradio as gr
-from PIL import Image
-import sys
-sys.path.append(os.path.abspath('./'))
-from inference.utils import *
-from core.utils import load_or_fail
-from train import WurstCoreB
-from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
-from train import WurstCore_t2i as WurstCoreC
-import torch.nn.functional as F
-from core.utils import load_or_fail
-import numpy as np
-import random
-import math
-from einops import rearrange
-
-def download_file(url, folder_path, filename):
- if not os.path.exists(folder_path):
- os.makedirs(folder_path)
- file_path = os.path.join(folder_path, filename)
-
- if os.path.isfile(file_path):
- print(f"File already exists: {file_path}")
- else:
- response = requests.get(url, stream=True)
- if response.status_code == 200:
- with open(file_path, 'wb') as file:
- for chunk in response.iter_content(chunk_size=1024):
- file.write(chunk)
- print(f"File successfully downloaded and saved: {file_path}")
- else:
- print(f"Error downloading the file. Status code: {response.status_code}")
-
-def download_models():
- models = {
- "STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/", "stage_a.safetensors"),
- "STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/", "previewer.safetensors"),
- "STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/", "effnet_encoder.safetensors"),
- "STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/", "stage_b_lite_bf16.safetensors"),
- "STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/", "stage_c_bf16.safetensors"),
- "ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/", "ultrapixel_t2i.safetensors"),
- "ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/", "lora_cat.safetensors"),
- }
-
- for model, (url, folder, filename) in models.items():
- download_file(url, folder, filename)
-
-download_models()
-
-# Global variables
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-dtype = torch.bfloat16
-
-# Load configs and setup models
-with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file:
- config_c = yaml.safe_load(file)
-
-with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file:
- config_b = yaml.safe_load(file)
-
-core = WurstCoreC(config_dict=config_c, device=device, training=False)
-core_b = WurstCoreB(config_dict=config_b, device=device, training=False)
-
-extras = core.setup_extras_pre()
-models = core.setup_models(extras)
-models.generator.eval().requires_grad_(False)
-
-extras_b = core_b.setup_extras_pre()
-models_b = core_b.setup_models(extras_b, skip_clip=True)
-models_b = WurstCoreB.Models(
- **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
-)
-models_b.generator.bfloat16().eval().requires_grad_(False)
-
-# Load pretrained model
-pretrained_path = "models/ultrapixel_t2i.safetensors"
-sdd = torch.load(pretrained_path, map_location='cpu')
-collect_sd = {k[7:]: v for k, v in sdd.items()}
-models.train_norm.load_state_dict(collect_sd)
-models.generator.eval()
-models.train_norm.eval()
-
-# Set up sampling configurations
-extras.sampling_configs.update({
- 'cfg': 4,
- 'shift': 1,
- 'timesteps': 20,
- 't_start': 1.0,
- 'sampler': DDPMSampler(extras.gdf)
-})
-
-extras_b.sampling_configs.update({
- 'cfg': 1.1,
- 'shift': 1,
- 'timesteps': 10,
- 't_start': 1.0
-})
-
-@spaces.GPU(duration=180)
-def generate_images(prompt, height, width, seed, num_images):
- torch.manual_seed(seed)
- random.seed(seed)
- np.random.seed(seed)
-
- batch_size = num_images
- height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
- stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
-
- batch = {'captions': [prompt] * batch_size}
- conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
- unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
-
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
-
- with torch.no_grad():
- models.generator.cuda()
- with torch.cuda.amp.autocast(dtype=dtype):
- sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
-
- models.generator.cpu()
- torch.cuda.empty_cache()
-
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
- conditions_b['effnet'] = sampled_c
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
-
- with torch.cuda.amp.autocast(dtype=dtype):
- sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True)
-
- torch.cuda.empty_cache()
- imgs = show_images(sampled)
- return imgs
-
-iface = gr.Interface(
- fn=generate_images,
- inputs=[
- gr.Textbox(label="Prompt"),
- gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024),
- gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024),
- gr.Number(label="Seed", value=42),
- gr.Slider(minimum=1, maximum=10, step=1, label="Number of Images", value=1)
- ],
- outputs=gr.Gallery(label="Generated Images", columns=5, rows=2),
- title="UltraPixel Image Generation",
- description="Generate high-resolution images using UltraPixel model.",
- theme='bethecloud/storj_theme',
- examples=[
- ["The image features a snow-covered mountain range with a large, snow-covered mountain in the background. The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the winter season, with snow covering the ground and the trees.", 1024, 1024, 42, 1]
- ],
- cache_examples=True
-)
-
-iface.launch()
\ No newline at end of file
+from transformers import AutoProcessor, AutoModelForCausalLM
+import spaces
+from PIL import Image
+
+import subprocess
+subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
+
+models = {
+ 'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-Flux-Large', trust_remote_code=True).eval(),
+ 'gokaygokay/Florence-2-Flux': AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-Flux', trust_remote_code=True).eval(),
+}
+
+processors = {
+ 'gokaygokay/Florence-2-Flux-Large': AutoProcessor.from_pretrained('gokaygokay/Florence-2-Flux-Large', trust_remote_code=True),
+ 'gokaygokay/Florence-2-Flux': AutoProcessor.from_pretrained('gokaygokay/Florence-2-Flux', trust_remote_code=True),
+}
+
+title = """
Florence-2 Captioner for Flux Prompts
+
+[Florence-2 Flux Large]
+[Florence-2 Flux Base]
+
+"""
+
+@spaces.GPU
+def run_example(image, model_name='gokaygokay/Florence-2-Flux-Large'):
+ image = Image.fromarray(image)
+ task_prompt = ""
+ prompt = task_prompt + "Describe this image in great detail."
+
+ if image.mode != "RGB":
+ image = image.convert("RGB")
+
+ model = models[model_name]
+ processor = processors[model_name]
+
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
+ generated_ids = model.generate(
+ input_ids=inputs["input_ids"],
+ pixel_values=inputs["pixel_values"],
+ max_new_tokens=1024,
+ num_beams=3,
+ repetition_penalty=1.10,
+ )
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
+ parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
+ return parsed_answer[""]
+
+with gr.Blocks(theme='bethecloud/storj_theme') as demo:
+ gr.HTML(title)
+
+ with gr.Row():
+ with gr.Column():
+ input_img = gr.Image(label="Input Picture")
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='gokaygokay/Florence-2-Flux-Large')
+ submit_btn = gr.Button(value="Submit")
+ with gr.Column():
+ output_text = gr.Textbox(label="Output Text")
+
+ gr.Examples(
+ [["image1.jpg"],
+ ["image2.jpg"],
+ ["image3.png"],
+ ["image5.jpg"]],
+ inputs=[input_img, model_selector],
+ outputs=[output_text],
+ fn=run_example,
+ label='Try captioning on below examples'
+ )
+
+ submit_btn.click(run_example, [input_img, model_selector], [output_text])
+
+demo.launch(debug=True)
\ No newline at end of file
diff --git a/configs/inference/controlnet_c_3b_canny.yaml b/configs/inference/controlnet_c_3b_canny.yaml
deleted file mode 100644
index 286d7a6c8017e922a020d6ae5633cc3e27f9b702..0000000000000000000000000000000000000000
--- a/configs/inference/controlnet_c_3b_canny.yaml
+++ /dev/null
@@ -1,14 +0,0 @@
-# GLOBAL STUFF
-model_version: 3.6B
-dtype: bfloat16
-
-# ControlNet specific
-controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
-controlnet_filter: CannyFilter
-controlnet_filter_params:
- resize: 224
-
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-previewer_checkpoint_path: models/previewer.safetensors
-generator_checkpoint_path: models/stage_c_bf16.safetensors
-controlnet_checkpoint_path: models/canny.safetensors
diff --git a/configs/inference/controlnet_c_3b_identity.yaml b/configs/inference/controlnet_c_3b_identity.yaml
deleted file mode 100644
index 8a20fa860fed5f6eea1d33113535c2633205e327..0000000000000000000000000000000000000000
--- a/configs/inference/controlnet_c_3b_identity.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-# GLOBAL STUFF
-model_version: 3.6B
-dtype: bfloat16
-
-# ControlNet specific
-controlnet_bottleneck_mode: 'simple'
-controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
-controlnet_filter: IdentityFilter
-controlnet_filter_params:
- max_faces: 4
- p_drop: 0.00
- p_full: 0.0
-
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-previewer_checkpoint_path: models/previewer.safetensors
-generator_checkpoint_path: models/stage_c_bf16.safetensors
-controlnet_checkpoint_path:
diff --git a/configs/inference/controlnet_c_3b_inpainting.yaml b/configs/inference/controlnet_c_3b_inpainting.yaml
deleted file mode 100644
index a94bd7953dfa407184d9094b481a56cdbbb73549..0000000000000000000000000000000000000000
--- a/configs/inference/controlnet_c_3b_inpainting.yaml
+++ /dev/null
@@ -1,15 +0,0 @@
-# GLOBAL STUFF
-model_version: 3.6B
-dtype: bfloat16
-
-# ControlNet specific
-controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
-controlnet_filter: InpaintFilter
-controlnet_filter_params:
- thresold: [0.04, 0.4]
- p_outpaint: 0.4
-
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-previewer_checkpoint_path: models/previewer.safetensors
-generator_checkpoint_path: models/stage_c_bf16.safetensors
-controlnet_checkpoint_path: models/inpainting.safetensors
diff --git a/configs/inference/controlnet_c_3b_sr.yaml b/configs/inference/controlnet_c_3b_sr.yaml
deleted file mode 100644
index 13c4a2cd2dcd2a3cf87fb32bd6e34269e796a747..0000000000000000000000000000000000000000
--- a/configs/inference/controlnet_c_3b_sr.yaml
+++ /dev/null
@@ -1,15 +0,0 @@
-# GLOBAL STUFF
-model_version: 3.6B
-dtype: bfloat16
-
-# ControlNet specific
-controlnet_bottleneck_mode: 'large'
-controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
-controlnet_filter: SREffnetFilter
-controlnet_filter_params:
- scale_factor: 0.5
-
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-previewer_checkpoint_path: models/previewer.safetensors
-generator_checkpoint_path: models/stage_c_bf16.safetensors
-controlnet_checkpoint_path: models/super_resolution.safetensors
diff --git a/configs/inference/lora_c_3b.yaml b/configs/inference/lora_c_3b.yaml
deleted file mode 100644
index 7468078c657c1f569c6c052a14b265d69082ab25..0000000000000000000000000000000000000000
--- a/configs/inference/lora_c_3b.yaml
+++ /dev/null
@@ -1,15 +0,0 @@
-# GLOBAL STUFF
-model_version: 3.6B
-dtype: bfloat16
-
-# LoRA specific
-module_filters: ['.attn']
-rank: 4
-train_tokens:
- # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
- - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails
-
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-previewer_checkpoint_path: models/previewer.safetensors
-generator_checkpoint_path: models/stage_c_bf16.safetensors
-lora_checkpoint_path: models/lora_fernando_10k.safetensors
diff --git a/configs/inference/stage_b_1b.yaml b/configs/inference/stage_b_1b.yaml
deleted file mode 100644
index 0811cae75622614e91de6532262acb2c062bf344..0000000000000000000000000000000000000000
--- a/configs/inference/stage_b_1b.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-# GLOBAL STUFF
-model_version: 700M
-dtype: bfloat16
-
-# For demonstration purposes in reconstruct_images.ipynb
-webdataset_path: path to your dataset
-batch_size: 1
-image_size: 2048
-grad_accum_steps: 1
-
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-stage_a_checkpoint_path: models/stage_a.safetensors
-generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
\ No newline at end of file
diff --git a/configs/inference/stage_b_3b.yaml b/configs/inference/stage_b_3b.yaml
deleted file mode 100644
index 840268980103e0c629599b966705043d6a616578..0000000000000000000000000000000000000000
--- a/configs/inference/stage_b_3b.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-# GLOBAL STUFF
-model_version: 3B
-dtype: bfloat16
-
-# For demonstration purposes in reconstruct_images.ipynb
-webdataset_path: path to your dataset
-batch_size: 4
-image_size: 1024
-grad_accum_steps: 1
-
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-stage_a_checkpoint_path: models/stage_a.safetensors
-generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
\ No newline at end of file
diff --git a/configs/inference/stage_c_1b.yaml b/configs/inference/stage_c_1b.yaml
deleted file mode 100644
index 781886e515d80e7870abb89bf8fd0ce7c7c8d4b6..0000000000000000000000000000000000000000
--- a/configs/inference/stage_c_1b.yaml
+++ /dev/null
@@ -1,7 +0,0 @@
-# GLOBAL STUFF
-model_version: 1B
-dtype: bfloat16
-
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-previewer_checkpoint_path: models/previewer.safetensors
-generator_checkpoint_path: models/stage_c_lite_bf16.safetensors
\ No newline at end of file
diff --git a/configs/inference/stage_c_3b.yaml b/configs/inference/stage_c_3b.yaml
deleted file mode 100644
index b22897e71996ad78f3832af78f5bc44ca06d206d..0000000000000000000000000000000000000000
--- a/configs/inference/stage_c_3b.yaml
+++ /dev/null
@@ -1,7 +0,0 @@
-# GLOBAL STUFF
-model_version: 3.6B
-dtype: bfloat16
-
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-previewer_checkpoint_path: models/previewer.safetensors
-generator_checkpoint_path: models/stage_c_bf16.safetensors
\ No newline at end of file
diff --git a/configs/training/cfg_control_lr.yaml b/configs/training/cfg_control_lr.yaml
deleted file mode 100644
index 2955b6a925504525b981e7004b65a33573c08aef..0000000000000000000000000000000000000000
--- a/configs/training/cfg_control_lr.yaml
+++ /dev/null
@@ -1,47 +0,0 @@
-# GLOBAL STUFF
-experiment_id: Ultrapixel_controlnet
-
-checkpoint_path: checkpoint output path
-output_path: visual results output path
-model_version: 3.6B
-dtype: float32
-# # WandB
-# wandb_project: StableCascade
-# wandb_entity: wandb_username
-#module_filters: ['.depthwise', '.mapper', '.attn', '.channelwise' ]
-#rank: 32
-# TRAINING PARAMS
-lr: 1.0e-4
-batch_size: 12
-#image_size: [1536, 2048, 2560, 3072, 4096]
-image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
-#image_size: [ 1024, 1536, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
-#image_size: [ 1024, 1280]
-multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
-grad_accum_steps: 2
-updates: 40000
-backup_every: 5000
-save_every: 256
-warmup_updates: 1
-use_fsdp: True
-
-# ControlNet specific
-controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
-controlnet_filter: CannyFilter
-controlnet_filter_params:
- resize: 224
-# offset_noise: 0.1
-
-# GDF
-adaptive_loss_weight: True
-
-ema_start_iters: 10
-ema_iters: 50
-ema_beta: 0.9
-
-webdataset_path: path to your training dataset
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-previewer_checkpoint_path: models/previewer.safetensors
-generator_checkpoint_path: models/stage_c_bf16.safetensors
-controlnet_checkpoint_path: pretrained controlnet path
-
diff --git a/configs/training/lora_personalization.yaml b/configs/training/lora_personalization.yaml
deleted file mode 100644
index 857795e6d37e9cb61bd76aa588f432978ed90ad2..0000000000000000000000000000000000000000
--- a/configs/training/lora_personalization.yaml
+++ /dev/null
@@ -1,37 +0,0 @@
-# GLOBAL STUFF
-experiment_id: roubao_cat_personalized
-
-checkpoint_path: checkpoint output path
-output_path: visual results output path
-model_version: 3.6B
-dtype: float32
-
-module_filters: [ '.attn']
-rank: 4
-train_tokens:
- # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
- - ['[roubaobao]', '^cat'] # custom token [snail], initialize as avg of snail & snails
-# TRAINING PARAMS
-lr: 1.0e-4
-batch_size: 4
-
-image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
-multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
-grad_accum_steps: 2
-updates: 40000
-backup_every: 5000
-save_every: 512
-warmup_updates: 1
-use_ddp: True
-
-# GDF
-adaptive_loss_weight: True
-
-
-tmp_prompt: a photo of a cat [roubaobao]
-webdataset_path: path to your personalized training dataset
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-previewer_checkpoint_path: models/previewer.safetensors
-generator_checkpoint_path: models/stage_c_bf16.safetensors
-ultrapixel_path: models/ultrapixel_t2i.safetensors
-
diff --git a/configs/training/t2i.yaml b/configs/training/t2i.yaml
deleted file mode 100644
index 8a0ceaca0ad8813e3c9b998661ac3e9b3c0937fd..0000000000000000000000000000000000000000
--- a/configs/training/t2i.yaml
+++ /dev/null
@@ -1,29 +0,0 @@
-# GLOBAL STUFF
-experiment_id: ultrapixel_t2i
-#strc_fixlrt_norm3_lite_1024_hrft_newdata
-checkpoint_path: checkpoint output path #output model directory
-output_path: visual results output path #experiment output directory
-model_version: 3.6B # finetune large stage c model of stablecascade
-dtype: float32
-
-
-# TRAINING PARAMS
-lr: 1.0e-4
-batch_size: 4 # gpu_number * num_per_gpu * grad_accum_steps
-image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] # possible image resolution
-multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
-grad_accum_steps: 2
-updates: 40000
-backup_every: 5000
-save_every: 256
-warmup_updates: 1
-use_ddp: True
-
-# GDF
-adaptive_loss_weight: True
-
-
-webdataset_path: path to your personalized training dataset
-effnet_checkpoint_path: models/effnet_encoder.safetensors
-previewer_checkpoint_path: models/previewer.safetensors
-generator_checkpoint_path: models/stage_c_bf16.safetensors
\ No newline at end of file
diff --git a/core/__init__.py b/core/__init__.py
deleted file mode 100644
index ed382f1907ddc86c7e9a9618c21441755a6221a9..0000000000000000000000000000000000000000
--- a/core/__init__.py
+++ /dev/null
@@ -1,372 +0,0 @@
-import os
-import yaml
-import torch
-from torch import nn
-import wandb
-import json
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
-from torch.utils.data import Dataset, DataLoader
-
-from torch.distributed import init_process_group, destroy_process_group, barrier
-from torch.distributed.fsdp import (
- FullyShardedDataParallel as FSDP,
- FullStateDictConfig,
- MixedPrecision,
- ShardingStrategy,
- StateDictType
-)
-
-from .utils import Base, EXPECTED, EXPECTED_TRAIN
-from .utils import create_folder_if_necessary, safe_save, load_or_fail
-
-# pylint: disable=unused-argument
-class WarpCore(ABC):
- @dataclass(frozen=True)
- class Config(Base):
- experiment_id: str = EXPECTED_TRAIN
- checkpoint_path: str = EXPECTED_TRAIN
- output_path: str = EXPECTED_TRAIN
- checkpoint_extension: str = "safetensors"
- dist_file_subfolder: str = ""
- allow_tf32: bool = True
-
- wandb_project: str = None
- wandb_entity: str = None
-
- @dataclass() # not frozen, means that fields are mutable
- class Info(): # not inheriting from Base, because we don't want to enforce the default fields
- wandb_run_id: str = None
- total_steps: int = 0
- iter: int = 0
-
- @dataclass(frozen=True)
- class Data(Base):
- dataset: Dataset = EXPECTED
- dataloader: DataLoader = EXPECTED
- iterator: any = EXPECTED
-
- @dataclass(frozen=True)
- class Models(Base):
- pass
-
- @dataclass(frozen=True)
- class Optimizers(Base):
- pass
-
- @dataclass(frozen=True)
- class Schedulers(Base):
- pass
-
- @dataclass(frozen=True)
- class Extras(Base):
- pass
- # ---------------------------------------
- info: Info
- config: Config
-
- # FSDP stuff
- fsdp_defaults = {
- "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
- "cpu_offload": None,
- "mixed_precision": MixedPrecision(
- param_dtype=torch.bfloat16,
- reduce_dtype=torch.bfloat16,
- buffer_dtype=torch.bfloat16,
- ),
- "limit_all_gathers": True,
- }
- fsdp_fullstate_save_policy = FullStateDictConfig(
- offload_to_cpu=True, rank0_only=True
- )
- # ------------
-
- # OVERRIDEABLE METHODS
-
- # [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup
- def setup_extras_pre(self) -> Extras:
- return self.Extras()
-
- # setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator
- @abstractmethod
- def setup_data(self, extras: Extras) -> Data:
- raise NotImplementedError("This method needs to be overriden")
-
- # return a dict with all models that are going to be used in the training
- @abstractmethod
- def setup_models(self, extras: Extras) -> Models:
- raise NotImplementedError("This method needs to be overriden")
-
- # return a dict with all optimizers that are going to be used in the training
- @abstractmethod
- def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
- raise NotImplementedError("This method needs to be overriden")
-
- # [optionally] return a dict with all schedulers that are going to be used in the training
- def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
- return self.Schedulers()
-
- # [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup
- def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras:
- return self.Extras.from_dict(extras.to_dict())
-
- # perform the training here
- @abstractmethod
- def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
- raise NotImplementedError("This method needs to be overriden")
- # ------------
-
- def setup_info(self, full_path=None) -> Info:
- if full_path is None:
- full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json")
- info_dict = load_or_fail(full_path, wandb_run_id=None) or {}
- info_dto = self.Info(**info_dict)
- if info_dto.total_steps > 0 and self.is_main_node:
- print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps)
- return info_dto
-
- def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config:
- if config_file_path is not None:
- if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"):
- with open(config_file_path, "r", encoding="utf-8") as file:
- loaded_config = yaml.safe_load(file)
- elif config_file_path.endswith(".json"):
- with open(config_file_path, "r", encoding="utf-8") as file:
- loaded_config = json.load(file)
- else:
- raise ValueError("Config file must be either a .yml|.yaml or .json file")
- return self.Config.from_dict({**loaded_config, 'training': training})
- if config_dict is not None:
- return self.Config.from_dict({**config_dict, 'training': training})
- return self.Config(training=training)
-
- def setup_ddp(self, experiment_id, single_gpu=False):
- if not single_gpu:
- local_rank = int(os.environ.get("SLURM_LOCALID"))
- process_id = int(os.environ.get("SLURM_PROCID"))
- world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count()
-
- self.process_id = process_id
- self.is_main_node = process_id == 0
- self.device = torch.device(local_rank)
- self.world_size = world_size
-
- dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}"
- # if os.path.exists(dist_file_path) and self.is_main_node:
- # os.remove(dist_file_path)
-
- torch.cuda.set_device(local_rank)
- init_process_group(
- backend="nccl",
- rank=process_id,
- world_size=world_size,
- init_method=f"file://{dist_file_path}",
- )
- print(f"[GPU {process_id}] READY")
- else:
- print("Running in single thread, DDP not enabled.")
-
- def setup_wandb(self):
- if self.is_main_node and self.config.wandb_project is not None:
- self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id()
- wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict())
-
- if self.info.total_steps > 0:
- wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}")
- else:
- wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started")
-
- # LOAD UTILITIES ----------
- def load_model(self, model, model_id=None, full_path=None, strict=True):
- print('in line 181 load model', type(model), model_id, full_path, strict)
- if model_id is not None and full_path is None:
- full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
- elif full_path is None and model_id is None:
- raise ValueError(
- "This method expects either 'model_id' or 'full_path' to be defined"
- )
-
- checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
- if checkpoint is not None:
- model.load_state_dict(checkpoint, strict=strict)
- del checkpoint
-
- return model
-
- def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
- if optim_id is not None and full_path is None:
- full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
- elif full_path is None and optim_id is None:
- raise ValueError(
- "This method expects either 'optim_id' or 'full_path' to be defined"
- )
-
- checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
- if checkpoint is not None:
- try:
- if fsdp_model is not None:
- sharded_optimizer_state_dict = (
- FSDP.scatter_full_optim_state_dict( # <---- FSDP
- checkpoint
- if (
- self.is_main_node
- or self.fsdp_defaults["sharding_strategy"]
- == ShardingStrategy.NO_SHARD
- )
- else None,
- fsdp_model,
- )
- )
- optim.load_state_dict(sharded_optimizer_state_dict)
- del checkpoint, sharded_optimizer_state_dict
- else:
- optim.load_state_dict(checkpoint)
- # pylint: disable=broad-except
- except Exception as e:
- print("!!! Failed loading optimizer, skipping... Exception:", e)
-
- return optim
-
- # SAVE UTILITIES ----------
- def save_info(self, info, suffix=""):
- full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json"
- create_folder_if_necessary(full_path)
- if self.is_main_node:
- safe_save(vars(self.info), full_path)
-
- def save_model(self, model, model_id=None, full_path=None, is_fsdp=False):
- if model_id is not None and full_path is None:
- full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
- elif full_path is None and model_id is None:
- raise ValueError(
- "This method expects either 'model_id' or 'full_path' to be defined"
- )
- create_folder_if_necessary(full_path)
- if is_fsdp:
- with FSDP.summon_full_params(model):
- pass
- with FSDP.state_dict_type(
- model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy
- ):
- checkpoint = model.state_dict()
- if self.is_main_node:
- safe_save(checkpoint, full_path)
- del checkpoint
- else:
- if self.is_main_node:
- checkpoint = model.state_dict()
- safe_save(checkpoint, full_path)
- del checkpoint
-
- def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
- if optim_id is not None and full_path is None:
- full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
- elif full_path is None and optim_id is None:
- raise ValueError(
- "This method expects either 'optim_id' or 'full_path' to be defined"
- )
- create_folder_if_necessary(full_path)
- if fsdp_model is not None:
- optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
- if self.is_main_node:
- safe_save(optim_statedict, full_path)
- del optim_statedict
- else:
- if self.is_main_node:
- checkpoint = optim.state_dict()
- safe_save(checkpoint, full_path)
- del checkpoint
- # -----
-
- def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True):
- # Temporary setup, will be overriden by setup_ddp if required
- self.device = device
- self.process_id = 0
- self.is_main_node = True
- self.world_size = 1
- # ----
-
- self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
- self.info: self.Info = self.setup_info()
-
- def __call__(self, single_gpu=False):
- self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank
- self.setup_wandb()
- if self.config.allow_tf32:
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
-
- if self.is_main_node:
- print()
- print("**STARTIG JOB WITH CONFIG:**")
- print(yaml.dump(self.config.to_dict(), default_flow_style=False))
- print("------------------------------------")
- print()
- print("**INFO:**")
- print(yaml.dump(vars(self.info), default_flow_style=False))
- print("------------------------------------")
- print()
-
- # SETUP STUFF
- extras = self.setup_extras_pre()
- assert extras is not None, "setup_extras_pre() must return a DTO"
-
- data = self.setup_data(extras)
- assert data is not None, "setup_data() must return a DTO"
- if self.is_main_node:
- print("**DATA:**")
- print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- models = self.setup_models(extras)
- assert models is not None, "setup_models() must return a DTO"
- if self.is_main_node:
- print("**MODELS:**")
- print(yaml.dump({
- k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
- }, default_flow_style=False))
- print("------------------------------------")
- print()
-
- optimizers = self.setup_optimizers(extras, models)
- assert optimizers is not None, "setup_optimizers() must return a DTO"
- if self.is_main_node:
- print("**OPTIMIZERS:**")
- print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- schedulers = self.setup_schedulers(extras, models, optimizers)
- assert schedulers is not None, "setup_schedulers() must return a DTO"
- if self.is_main_node:
- print("**SCHEDULERS:**")
- print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
- assert post_extras is not None, "setup_extras_post() must return a DTO"
- extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
- if self.is_main_node:
- print("**EXTRAS:**")
- print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
- # -------
-
- # TRAIN
- if self.is_main_node:
- print("**TRAINING STARTING...**")
- self.train(data, extras, models, optimizers, schedulers)
-
- if single_gpu is False:
- barrier()
- destroy_process_group()
- if self.is_main_node:
- print()
- print("------------------------------------")
- print()
- print("**TRAINING COMPLETE**")
- if self.config.wandb_project is not None:
- wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
diff --git a/core/data/__init__.py b/core/data/__init__.py
deleted file mode 100644
index b687719914b2e303909f7c280347e4bdee607d13..0000000000000000000000000000000000000000
--- a/core/data/__init__.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import json
-import subprocess
-import yaml
-import os
-from .bucketeer import Bucketeer
-
-class MultiFilter():
- def __init__(self, rules, default=False):
- self.rules = rules
- self.default = default
-
- def __call__(self, x):
- try:
- x_json = x['json']
- if isinstance(x_json, bytes):
- x_json = json.loads(x_json)
- validations = []
- for k, r in self.rules.items():
- if isinstance(k, tuple):
- v = r(*[x_json[kv] for kv in k])
- else:
- v = r(x_json[k])
- validations.append(v)
- return all(validations)
- except Exception:
- return False
-
-class MultiGetter():
- def __init__(self, rules):
- self.rules = rules
-
- def __call__(self, x_json):
- if isinstance(x_json, bytes):
- x_json = json.loads(x_json)
- outputs = []
- for k, r in self.rules.items():
- if isinstance(k, tuple):
- v = r(*[x_json[kv] for kv in k])
- else:
- v = r(x_json[k])
- outputs.append(v)
- if len(outputs) == 1:
- outputs = outputs[0]
- return outputs
-
-def setup_webdataset_path(paths, cache_path=None):
- if cache_path is None or not os.path.exists(cache_path):
- tar_paths = []
- if isinstance(paths, str):
- paths = [paths]
- for path in paths:
- if path.strip().endswith(".tar"):
- # Avoid looking up s3 if we already have a tar file
- tar_paths.append(path)
- continue
- bucket = "/".join(path.split("/")[:3])
- result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
- files = result.stdout.decode('utf-8').split()
- files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
- tar_paths += files
-
- with open(cache_path, 'w', encoding='utf-8') as outfile:
- yaml.dump(tar_paths, outfile, default_flow_style=False)
- else:
- with open(cache_path, 'r', encoding='utf-8') as file:
- tar_paths = yaml.safe_load(file)
-
- tar_paths_str = ",".join([f"{p}" for p in tar_paths])
- return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"
diff --git a/core/data/bucketeer.py b/core/data/bucketeer.py
deleted file mode 100644
index 131e6ba4293bd7c00399f08609aba184b712d5e8..0000000000000000000000000000000000000000
--- a/core/data/bucketeer.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import torch
-import torchvision
-import numpy as np
-from torchtools.transforms import SmartCrop
-import math
-
-class Bucketeer():
- def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
- assert crop_mode in ['center', 'random', 'smart']
- self.crop_mode = crop_mode
- self.ratios = ratios
- if reverse_list:
- for r in list(ratios):
- if 1/r not in self.ratios:
- self.ratios.append(1/r)
- self.sizes = {}
- for dd in density:
- self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
-
- self.batch_size = dataloader.batch_size
- self.iterator = iter(dataloader)
- all_sizes = []
- for k, vs in self.sizes.items():
- all_sizes += vs
- self.buckets = {s: [] for s in all_sizes}
- self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
- self.p_random_ratio = p_random_ratio
- self.interpolate_nearest = interpolate_nearest
-
- def get_available_batch(self):
- for b in self.buckets:
- if len(self.buckets[b]) >= self.batch_size:
- batch = self.buckets[b][:self.batch_size]
- self.buckets[b] = self.buckets[b][self.batch_size:]
- return batch
- return None
-
- def get_closest_size(self, x):
- w, h = x.size(-1), x.size(-2)
-
-
- best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
- find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
- min_ = find_dict[list(find_dict.keys())[0]]
- find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
- for dd, val in find_dict.items():
- if val < min_:
- min_ = val
- find_size = self.sizes[dd][best_size_idx]
-
- return find_size
-
- def get_resize_size(self, orig_size, tgt_size):
- if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
- alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
- resize_size = max(alt_min, min(tgt_size))
- else:
- alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
- resize_size = max(alt_max, max(tgt_size))
-
- return resize_size
-
- def __next__(self):
- batch = self.get_available_batch()
- while batch is None:
- elements = next(self.iterator)
- for dct in elements:
- img = dct['images']
- size = self.get_closest_size(img)
- resize_size = self.get_resize_size(img.shape[-2:], size)
-
- if self.interpolate_nearest:
- img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
- else:
- img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
- if self.crop_mode == 'center':
- img = torchvision.transforms.functional.center_crop(img, size)
- elif self.crop_mode == 'random':
- img = torchvision.transforms.RandomCrop(size)(img)
- elif self.crop_mode == 'smart':
- self.smartcrop.output_size = size
- img = self.smartcrop(img)
-
- self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
- batch = self.get_available_batch()
-
- out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
- return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
diff --git a/core/data/bucketeer_deg.py b/core/data/bucketeer_deg.py
deleted file mode 100644
index 7206ccf08932f617abb811221cc7bbe1d126f184..0000000000000000000000000000000000000000
--- a/core/data/bucketeer_deg.py
+++ /dev/null
@@ -1,91 +0,0 @@
-import torch
-import torchvision
-import numpy as np
-from torchtools.transforms import SmartCrop
-import math
-
-class Bucketeer():
- def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
- assert crop_mode in ['center', 'random', 'smart']
- self.crop_mode = crop_mode
- self.ratios = ratios
- if reverse_list:
- for r in list(ratios):
- if 1/r not in self.ratios:
- self.ratios.append(1/r)
- self.sizes = {}
- for dd in density:
- self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
- print('in line 17 buckteer', self.sizes)
- self.batch_size = dataloader.batch_size
- self.iterator = iter(dataloader)
- all_sizes = []
- for k, vs in self.sizes.items():
- all_sizes += vs
- self.buckets = {s: [] for s in all_sizes}
- self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
- self.p_random_ratio = p_random_ratio
- self.interpolate_nearest = interpolate_nearest
-
- def get_available_batch(self):
- for b in self.buckets:
- if len(self.buckets[b]) >= self.batch_size:
- batch = self.buckets[b][:self.batch_size]
- self.buckets[b] = self.buckets[b][self.batch_size:]
- return batch
- return None
-
- def get_closest_size(self, x):
- w, h = x.size(-1), x.size(-2)
- #if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
- # best_size_idx = np.random.randint(len(self.ratios))
- #print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio)
- #else:
-
- best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
- find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
- min_ = find_dict[list(find_dict.keys())[0]]
- find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
- for dd, val in find_dict.items():
- if val < min_:
- min_ = val
- find_size = self.sizes[dd][best_size_idx]
-
- return find_size
-
- def get_resize_size(self, orig_size, tgt_size):
- if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
- alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
- resize_size = max(alt_min, min(tgt_size))
- else:
- alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
- resize_size = max(alt_max, max(tgt_size))
- #print('in line 50', orig_size, tgt_size, resize_size)
- return resize_size
-
- def __next__(self):
- batch = self.get_available_batch()
- while batch is None:
- elements = next(self.iterator)
- for dct in elements:
- img = dct['images']
- size = self.get_closest_size(img)
- resize_size = self.get_resize_size(img.shape[-2:], size)
- #print('in line 74', img.size(), resize_size)
- if self.interpolate_nearest:
- img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
- else:
- img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
- if self.crop_mode == 'center':
- img = torchvision.transforms.functional.center_crop(img, size)
- elif self.crop_mode == 'random':
- img = torchvision.transforms.RandomCrop(size)(img)
- elif self.crop_mode == 'smart':
- self.smartcrop.output_size = size
- img = self.smartcrop(img)
- print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img))
- self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
- batch = self.get_available_batch()
-
- out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
- return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
diff --git a/core/data/deg_kair_utils/utils_alignfaces.py b/core/data/deg_kair_utils/utils_alignfaces.py
deleted file mode 100644
index fa74e8a2e8984f5075d0cbd06afd494c9661a015..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_alignfaces.py
+++ /dev/null
@@ -1,263 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Mon Apr 24 15:43:29 2017
-@author: zhaoy
-"""
-import cv2
-import numpy as np
-from skimage import transform as trans
-
-# reference facial points, a list of coordinates (x,y)
-REFERENCE_FACIAL_POINTS = [
- [30.29459953, 51.69630051],
- [65.53179932, 51.50139999],
- [48.02519989, 71.73660278],
- [33.54930115, 92.3655014],
- [62.72990036, 92.20410156]
-]
-
-DEFAULT_CROP_SIZE = (96, 112)
-
-
-def _umeyama(src, dst, estimate_scale=True, scale=1.0):
- """Estimate N-D similarity transformation with or without scaling.
- Parameters
- ----------
- src : (M, N) array
- Source coordinates.
- dst : (M, N) array
- Destination coordinates.
- estimate_scale : bool
- Whether to estimate scaling factor.
- Returns
- -------
- T : (N + 1, N + 1)
- The homogeneous similarity transformation matrix. The matrix contains
- NaN values only if the problem is not well-conditioned.
- References
- ----------
- .. [1] "Least-squares estimation of transformation parameters between two
- point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
- """
-
- num = src.shape[0]
- dim = src.shape[1]
-
- # Compute mean of src and dst.
- src_mean = src.mean(axis=0)
- dst_mean = dst.mean(axis=0)
-
- # Subtract mean from src and dst.
- src_demean = src - src_mean
- dst_demean = dst - dst_mean
-
- # Eq. (38).
- A = dst_demean.T @ src_demean / num
-
- # Eq. (39).
- d = np.ones((dim,), dtype=np.double)
- if np.linalg.det(A) < 0:
- d[dim - 1] = -1
-
- T = np.eye(dim + 1, dtype=np.double)
-
- U, S, V = np.linalg.svd(A)
-
- # Eq. (40) and (43).
- rank = np.linalg.matrix_rank(A)
- if rank == 0:
- return np.nan * T
- elif rank == dim - 1:
- if np.linalg.det(U) * np.linalg.det(V) > 0:
- T[:dim, :dim] = U @ V
- else:
- s = d[dim - 1]
- d[dim - 1] = -1
- T[:dim, :dim] = U @ np.diag(d) @ V
- d[dim - 1] = s
- else:
- T[:dim, :dim] = U @ np.diag(d) @ V
-
- if estimate_scale:
- # Eq. (41) and (42).
- scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
- else:
- scale = scale
-
- T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
- T[:dim, :dim] *= scale
-
- return T, scale
-
-
-class FaceWarpException(Exception):
- def __str__(self):
- return 'In File {}:{}'.format(
- __file__, super.__str__(self))
-
-
-def get_reference_facial_points(output_size=None,
- inner_padding_factor=0.0,
- outer_padding=(0, 0),
- default_square=False):
- tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
- tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
-
- # 0) make the inner region a square
- if default_square:
- size_diff = max(tmp_crop_size) - tmp_crop_size
- tmp_5pts += size_diff / 2
- tmp_crop_size += size_diff
-
- if (output_size and
- output_size[0] == tmp_crop_size[0] and
- output_size[1] == tmp_crop_size[1]):
- print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
- return tmp_5pts
-
- if (inner_padding_factor == 0 and
- outer_padding == (0, 0)):
- if output_size is None:
- print('No paddings to do: return default reference points')
- return tmp_5pts
- else:
- raise FaceWarpException(
- 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
-
- # check output size
- if not (0 <= inner_padding_factor <= 1.0):
- raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
-
- if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
- and output_size is None):
- output_size = tmp_crop_size * \
- (1 + inner_padding_factor * 2).astype(np.int32)
- output_size += np.array(outer_padding)
- print(' deduced from paddings, output_size = ', output_size)
-
- if not (outer_padding[0] < output_size[0]
- and outer_padding[1] < output_size[1]):
- raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
- 'and outer_padding[1] < output_size[1])')
-
- # 1) pad the inner region according inner_padding_factor
- # print('---> STEP1: pad the inner region according inner_padding_factor')
- if inner_padding_factor > 0:
- size_diff = tmp_crop_size * inner_padding_factor * 2
- tmp_5pts += size_diff / 2
- tmp_crop_size += np.round(size_diff).astype(np.int32)
-
- # print(' crop_size = ', tmp_crop_size)
- # print(' reference_5pts = ', tmp_5pts)
-
- # 2) resize the padded inner region
- # print('---> STEP2: resize the padded inner region')
- size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
- # print(' crop_size = ', tmp_crop_size)
- # print(' size_bf_outer_pad = ', size_bf_outer_pad)
-
- if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
- raise FaceWarpException('Must have (output_size - outer_padding)'
- '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
-
- scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
- # print(' resize scale_factor = ', scale_factor)
- tmp_5pts = tmp_5pts * scale_factor
- # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
- # tmp_5pts = tmp_5pts + size_diff / 2
- tmp_crop_size = size_bf_outer_pad
- # print(' crop_size = ', tmp_crop_size)
- # print(' reference_5pts = ', tmp_5pts)
-
- # 3) add outer_padding to make output_size
- reference_5point = tmp_5pts + np.array(outer_padding)
- tmp_crop_size = output_size
- # print('---> STEP3: add outer_padding to make output_size')
- # print(' crop_size = ', tmp_crop_size)
- # print(' reference_5pts = ', tmp_5pts)
- #
- # print('===> end get_reference_facial_points\n')
-
- return reference_5point
-
-
-def get_affine_transform_matrix(src_pts, dst_pts):
- tfm = np.float32([[1, 0, 0], [0, 1, 0]])
- n_pts = src_pts.shape[0]
- ones = np.ones((n_pts, 1), src_pts.dtype)
- src_pts_ = np.hstack([src_pts, ones])
- dst_pts_ = np.hstack([dst_pts, ones])
-
- A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
-
- if rank == 3:
- tfm = np.float32([
- [A[0, 0], A[1, 0], A[2, 0]],
- [A[0, 1], A[1, 1], A[2, 1]]
- ])
- elif rank == 2:
- tfm = np.float32([
- [A[0, 0], A[1, 0], 0],
- [A[0, 1], A[1, 1], 0]
- ])
-
- return tfm
-
-
-def warp_and_crop_face(src_img,
- facial_pts,
- reference_pts=None,
- crop_size=(96, 112),
- align_type='smilarity'): #smilarity cv2_affine affine
- if reference_pts is None:
- if crop_size[0] == 96 and crop_size[1] == 112:
- reference_pts = REFERENCE_FACIAL_POINTS
- else:
- default_square = False
- inner_padding_factor = 0
- outer_padding = (0, 0)
- output_size = crop_size
-
- reference_pts = get_reference_facial_points(output_size,
- inner_padding_factor,
- outer_padding,
- default_square)
-
- ref_pts = np.float32(reference_pts)
- ref_pts_shp = ref_pts.shape
- if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
- raise FaceWarpException(
- 'reference_pts.shape must be (K,2) or (2,K) and K>2')
-
- if ref_pts_shp[0] == 2:
- ref_pts = ref_pts.T
-
- src_pts = np.float32(facial_pts)
- src_pts_shp = src_pts.shape
- if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
- raise FaceWarpException(
- 'facial_pts.shape must be (K,2) or (2,K) and K>2')
-
- if src_pts_shp[0] == 2:
- src_pts = src_pts.T
-
- if src_pts.shape != ref_pts.shape:
- raise FaceWarpException(
- 'facial_pts and reference_pts must have the same shape')
-
- if align_type is 'cv2_affine':
- tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
- tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
- elif align_type is 'affine':
- tfm = get_affine_transform_matrix(src_pts, ref_pts)
- tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
- else:
- params, scale = _umeyama(src_pts, ref_pts)
- tfm = params[:2, :]
-
- params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale)
- tfm_inv = params[:2, :]
-
- face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
-
- return face_img, tfm_inv
diff --git a/core/data/deg_kair_utils/utils_blindsr.py b/core/data/deg_kair_utils/utils_blindsr.py
deleted file mode 100644
index 9a1a7baf99473043e216c16f464f4e168cbd94ab..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_blindsr.py
+++ /dev/null
@@ -1,631 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-import cv2
-import torch
-
-from core.data.deg_kair_utils import utils_image as util
-
-import random
-from scipy import ndimage
-import scipy
-import scipy.stats as ss
-from scipy.interpolate import interp2d
-from scipy.linalg import orth
-
-
-
-
-"""
-# --------------------------------------------
-# Super-Resolution
-# --------------------------------------------
-#
-# Kai Zhang (cskaizhang@gmail.com)
-# https://github.com/cszn
-# From 2019/03--2021/08
-# --------------------------------------------
-"""
-
-def modcrop_np(img, sf):
- '''
- Args:
- img: numpy image, WxH or WxHxC
- sf: scale factor
-
- Return:
- cropped image
- '''
- w, h = img.shape[:2]
- im = np.copy(img)
- return im[:w - w % sf, :h - h % sf, ...]
-
-
-"""
-# --------------------------------------------
-# anisotropic Gaussian kernels
-# --------------------------------------------
-"""
-def analytic_kernel(k):
- """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
- k_size = k.shape[0]
- # Calculate the big kernels size
- big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
- # Loop over the small kernel to fill the big one
- for r in range(k_size):
- for c in range(k_size):
- big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
- # Crop the edges of the big kernel to ignore very small values and increase run time of SR
- crop = k_size // 2
- cropped_big_k = big_k[crop:-crop, crop:-crop]
- # Normalize to 1
- return cropped_big_k / cropped_big_k.sum()
-
-
-def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
- """ generate an anisotropic Gaussian kernel
- Args:
- ksize : e.g., 15, kernel size
- theta : [0, pi], rotation angle range
- l1 : [0.1,50], scaling of eigenvalues
- l2 : [0.1,l1], scaling of eigenvalues
- If l1 = l2, will get an isotropic Gaussian kernel.
-
- Returns:
- k : kernel
- """
-
- v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
- V = np.array([[v[0], v[1]], [v[1], -v[0]]])
- D = np.array([[l1, 0], [0, l2]])
- Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
- k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
-
- return k
-
-
-def gm_blur_kernel(mean, cov, size=15):
- center = size / 2.0 + 0.5
- k = np.zeros([size, size])
- for y in range(size):
- for x in range(size):
- cy = y - center + 1
- cx = x - center + 1
- k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
-
- k = k / np.sum(k)
- return k
-
-
-def shift_pixel(x, sf, upper_left=True):
- """shift pixel for super-resolution with different scale factors
- Args:
- x: WxHxC or WxH
- sf: scale factor
- upper_left: shift direction
- """
- h, w = x.shape[:2]
- shift = (sf-1)*0.5
- xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
- if upper_left:
- x1 = xv + shift
- y1 = yv + shift
- else:
- x1 = xv - shift
- y1 = yv - shift
-
- x1 = np.clip(x1, 0, w-1)
- y1 = np.clip(y1, 0, h-1)
-
- if x.ndim == 2:
- x = interp2d(xv, yv, x)(x1, y1)
- if x.ndim == 3:
- for i in range(x.shape[-1]):
- x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
-
- return x
-
-
-def blur(x, k):
- '''
- x: image, NxcxHxW
- k: kernel, Nx1xhxw
- '''
- n, c = x.shape[:2]
- p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2
- x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
- k = k.repeat(1,c,1,1)
- k = k.view(-1, 1, k.shape[2], k.shape[3])
- x = x.view(1, -1, x.shape[2], x.shape[3])
- x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c)
- x = x.view(n, c, x.shape[2], x.shape[3])
-
- return x
-
-
-
-def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
- """"
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
- # Kai Zhang
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
- # max_var = 2.5 * sf
- """
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
- theta = np.random.rand() * np.pi # random theta
- noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
-
- # Set COV matrix using Lambdas and Theta
- LAMBDA = np.diag([lambda_1, lambda_2])
- Q = np.array([[np.cos(theta), -np.sin(theta)],
- [np.sin(theta), np.cos(theta)]])
- SIGMA = Q @ LAMBDA @ Q.T
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
-
- # Set expectation position (shifting kernel for aligned image)
- MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
- MU = MU[None, None, :, None]
-
- # Create meshgrid for Gaussian
- [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
- Z = np.stack([X, Y], 2)[:, :, :, None]
-
- # Calcualte Gaussian for every pixel of the kernel
- ZZ = Z-MU
- ZZ_t = ZZ.transpose(0,1,3,2)
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
-
- # shift the kernel so it will be centered
- #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
-
- # Normalize the kernel and return
- #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
- kernel = raw_kernel / np.sum(raw_kernel)
- return kernel
-
-
-def fspecial_gaussian(hsize, sigma):
- hsize = [hsize, hsize]
- siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
- std = sigma
- [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
- arg = -(x*x + y*y)/(2*std*std)
- h = np.exp(arg)
- h[h < scipy.finfo(float).eps * h.max()] = 0
- sumh = h.sum()
- if sumh != 0:
- h = h/sumh
- return h
-
-
-def fspecial_laplacian(alpha):
- alpha = max([0, min([alpha,1])])
- h1 = alpha/(alpha+1)
- h2 = (1-alpha)/(alpha+1)
- h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
- h = np.array(h)
- return h
-
-
-def fspecial(filter_type, *args, **kwargs):
- '''
- python code from:
- https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
- '''
- if filter_type == 'gaussian':
- return fspecial_gaussian(*args, **kwargs)
- if filter_type == 'laplacian':
- return fspecial_laplacian(*args, **kwargs)
-
-"""
-# --------------------------------------------
-# degradation models
-# --------------------------------------------
-"""
-
-
-def bicubic_degradation(x, sf=3):
- '''
- Args:
- x: HxWxC image, [0, 1]
- sf: down-scale factor
-
- Return:
- bicubicly downsampled LR image
- '''
- x = util.imresize_np(x, scale=1/sf)
- return x
-
-
-def srmd_degradation(x, k, sf=3):
- ''' blur + bicubic downsampling
-
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
-
- Return:
- downsampled LR image
-
- Reference:
- @inproceedings{zhang2018learning,
- title={Learning a single convolutional super-resolution network for multiple degradations},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={3262--3271},
- year={2018}
- }
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
- x = bicubic_degradation(x, sf=sf)
- return x
-
-
-def dpsr_degradation(x, k, sf=3):
-
- ''' bicubic downsampling + blur
-
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
-
- Return:
- downsampled LR image
-
- Reference:
- @inproceedings{zhang2019deep,
- title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={1671--1681},
- year={2019}
- }
- '''
- x = bicubic_degradation(x, sf=sf)
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- return x
-
-
-def classical_degradation(x, k, sf=3):
- ''' blur + downsampling
-
- Args:
- x: HxWxC image, [0, 1]/[0, 255]
- k: hxw, double
- sf: down-scale factor
-
- Return:
- downsampled LR image
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
- st = 0
- return x[st::sf, st::sf, ...]
-
-
-def add_sharpening(img, weight=0.5, radius=50, threshold=10):
- """USM sharpening. borrowed from real-ESRGAN
- Input image: I; Blurry image: B.
- 1. K = I + weight * (I - B)
- 2. Mask = 1 if abs(I - B) > threshold, else: 0
- 3. Blur mask:
- 4. Out = Mask * K + (1 - Mask) * I
- Args:
- img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
- weight (float): Sharp weight. Default: 1.
- radius (float): Kernel size of Gaussian blur. Default: 50.
- threshold (int):
- """
- if radius % 2 == 0:
- radius += 1
- blur = cv2.GaussianBlur(img, (radius, radius), 0)
- residual = img - blur
- mask = np.abs(residual) * 255 > threshold
- mask = mask.astype('float32')
- soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
-
- K = img + weight * residual
- K = np.clip(K, 0, 1)
- return soft_mask * K + (1 - soft_mask) * img
-
-
-def add_blur(img, sf=4):
- wd2 = 4.0 + sf
- wd = 2.0 + 0.2*sf
- if random.random() < 0.5:
- l1 = wd2*random.random()
- l2 = wd2*random.random()
- k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2)
- else:
- k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random())
- img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
-
- return img
-
-
-def add_resize(img, sf=4):
- rnum = np.random.rand()
- if rnum > 0.8: # up
- sf1 = random.uniform(1, 2)
- elif rnum < 0.7: # down
- sf1 = random.uniform(0.5/sf, 1)
- else:
- sf1 = 1.0
- img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- return img
-
-
-def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- rnum = np.random.rand()
- if rnum > 0.6: # add color Gaussian noise
- img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
- elif rnum < 0.4: # add grayscale Gaussian noise
- img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
- else: # add noise
- L = noise_level2/255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3,3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_speckle_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- img = np.clip(img, 0.0, 1.0)
- rnum = random.random()
- if rnum > 0.6:
- img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
- elif rnum < 0.4:
- img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
- else:
- L = noise_level2/255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3,3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_Poisson_noise(img):
- img = np.clip((img * 255.0).round(), 0, 255) / 255.
- vals = 10**(2*random.random()+2.0) # [2, 4]
- if random.random() < 0.5:
- img = np.random.poisson(img * vals).astype(np.float32) / vals
- else:
- img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114])
- img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
- noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
- img += noise_gray[:, :, np.newaxis]
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_JPEG_noise(img):
- quality_factor = random.randint(30, 95)
- img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
- result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
- img = cv2.imdecode(encimg, 1)
- img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
- return img
-
-
-def random_crop(lq, hq, sf=4, lq_patchsize=64):
- h, w = lq.shape[:2]
- rnd_h = random.randint(0, h-lq_patchsize)
- rnd_w = random.randint(0, w-lq_patchsize)
- lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
-
- rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
- hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :]
- return lq, hq
-
-
-def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
- sf: scale factor
- isp_model: camera ISP model
-
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = img.shape[:2]
- img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
- h, w = img.shape[:2]
-
- if h < lq_patchsize*sf or w < lq_patchsize*sf:
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
-
- hq = img.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3]))
- else:
- img = util.imresize_np(img, 1/2, True)
- img = np.clip(img, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- img = add_blur(img, sf=sf)
-
- elif i == 1:
- img = add_blur(img, sf=sf)
-
- elif i == 2:
- a, b = img.shape[1], img.shape[0]
- # downsample2
- if random.random() < 0.75:
- sf1 = random.uniform(1,2*sf)
- img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel
- img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
- img = img[0::sf, 0::sf, ...] # nearest downsampling
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3]))
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- img = add_JPEG_noise(img)
-
- elif i == 6:
- # add processed camera sensor noise
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- img = add_JPEG_noise(img)
-
- # random crop
- img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
-
- return img, hq
-
-
-
-
-def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None):
- """
- This is an extended degradation model by combining
- the degradation models of BSRGAN and Real-ESRGAN
- ----------
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
- sf: scale factor
- use_shuffle: the degradation shuffle
- use_sharp: sharpening the img
-
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
-
- h1, w1 = img.shape[:2]
- img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
- h, w = img.shape[:2]
-
- if h < lq_patchsize*sf or w < lq_patchsize*sf:
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
-
- if use_sharp:
- img = add_sharpening(img)
- hq = img.copy()
-
- if random.random() < shuffle_prob:
- shuffle_order = random.sample(range(13), 13)
- else:
- shuffle_order = list(range(13))
- # local shuffle for noise, JPEG is always the last one
- shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
- shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
-
- poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
-
- for i in shuffle_order:
- if i == 0:
- img = add_blur(img, sf=sf)
- elif i == 1:
- img = add_resize(img, sf=sf)
- elif i == 2:
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
- elif i == 3:
- if random.random() < poisson_prob:
- img = add_Poisson_noise(img)
- elif i == 4:
- if random.random() < speckle_prob:
- img = add_speckle_noise(img)
- elif i == 5:
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
- elif i == 6:
- img = add_JPEG_noise(img)
- elif i == 7:
- img = add_blur(img, sf=sf)
- elif i == 8:
- img = add_resize(img, sf=sf)
- elif i == 9:
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
- elif i == 10:
- if random.random() < poisson_prob:
- img = add_Poisson_noise(img)
- elif i == 11:
- if random.random() < speckle_prob:
- img = add_speckle_noise(img)
- elif i == 12:
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
- else:
- print('check the shuffle!')
-
- # resize to desired size
- img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3]))
-
- # add final JPEG compression noise
- img = add_JPEG_noise(img)
-
- # random crop
- img, hq = random_crop(img, hq, sf, lq_patchsize)
-
- return img, hq
-
-
-
-if __name__ == '__main__':
- img = util.imread_uint('utils/test.png', 3)
- img = util.uint2single(img)
- sf = 4
-
- for i in range(20):
- img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72)
- print(i)
- lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
- img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
- util.imsave(img_concat, str(i)+'.png')
-
-# for i in range(10):
-# img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64)
-# print(i)
-# lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
-# img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
-# util.imsave(img_concat, str(i)+'.png')
-
-# run utils/utils_blindsr.py
diff --git a/core/data/deg_kair_utils/utils_bnorm.py b/core/data/deg_kair_utils/utils_bnorm.py
deleted file mode 100644
index 9bd346e05b66efd074f81f1961068e2de45ac5da..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_bnorm.py
+++ /dev/null
@@ -1,91 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-"""
-# --------------------------------------------
-# Batch Normalization
-# --------------------------------------------
-
-# Kai Zhang (cskaizhang@gmail.com)
-# https://github.com/cszn
-# 01/Jan/2019
-# --------------------------------------------
-"""
-
-
-# --------------------------------------------
-# remove/delete specified layer
-# --------------------------------------------
-def deleteLayer(model, layer_type=nn.BatchNorm2d):
- ''' Kai Zhang, 11/Jan/2019.
- '''
- for k, m in list(model.named_children()):
- if isinstance(m, layer_type):
- del model._modules[k]
- deleteLayer(m, layer_type)
-
-
-# --------------------------------------------
-# merge bn, "conv+bn" --> "conv"
-# --------------------------------------------
-def merge_bn(model):
- ''' Kai Zhang, 11/Jan/2019.
- merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv')
- based on https://github.com/pytorch/pytorch/pull/901
- '''
- prev_m = None
- for k, m in list(model.named_children()):
- if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)):
-
- w = prev_m.weight.data
-
- if prev_m.bias is None:
- zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type())
- prev_m.bias = nn.Parameter(zeros)
- b = prev_m.bias.data
-
- invstd = m.running_var.clone().add_(m.eps).pow_(-0.5)
- if isinstance(prev_m, nn.ConvTranspose2d):
- w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w))
- else:
- w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
- b.add_(-m.running_mean).mul_(invstd)
- if m.affine:
- if isinstance(prev_m, nn.ConvTranspose2d):
- w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w))
- else:
- w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
- b.mul_(m.weight.data).add_(m.bias.data)
-
- del model._modules[k]
- prev_m = m
- merge_bn(m)
-
-
-# --------------------------------------------
-# add bn, "conv" --> "conv+bn"
-# --------------------------------------------
-def add_bn(model):
- ''' Kai Zhang, 11/Jan/2019.
- '''
- for k, m in list(model.named_children()):
- if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)):
- b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True)
- b.weight.data.fill_(1)
- new_m = nn.Sequential(model._modules[k], b)
- model._modules[k] = new_m
- add_bn(m)
-
-
-# --------------------------------------------
-# tidy model after removing bn
-# --------------------------------------------
-def tidy_sequential(model):
- ''' Kai Zhang, 11/Jan/2019.
- '''
- for k, m in list(model.named_children()):
- if isinstance(m, nn.Sequential):
- if m.__len__() == 1:
- model._modules[k] = m.__getitem__(0)
- tidy_sequential(m)
diff --git a/core/data/deg_kair_utils/utils_deblur.py b/core/data/deg_kair_utils/utils_deblur.py
deleted file mode 100644
index 8ab5852d0cb334627abcd9476409d632740be389..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_deblur.py
+++ /dev/null
@@ -1,655 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-import scipy
-from scipy import fftpack
-import torch
-
-from math import cos, sin
-from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round
-from numpy.random import randn, rand
-from scipy.signal import convolve2d
-import cv2
-import random
-# import utils_image as util
-
-'''
-modified by Kai Zhang (github: https://github.com/cszn)
-03/03/2019
-'''
-
-
-def get_uperleft_denominator(img, kernel):
- '''
- img: HxWxC
- kernel: hxw
- denominator: HxWx1
- upperleft: HxWxC
- '''
- V = psf2otf(kernel, img.shape[:2])
- denominator = np.expand_dims(np.abs(V)**2, axis=2)
- upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
- return upperleft, denominator
-
-
-def get_uperleft_denominator_pytorch(img, kernel):
- '''
- img: NxCxHxW
- kernel: Nx1xhxw
- denominator: Nx1xHxW
- upperleft: NxCxHxWx2
- '''
- V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2
- denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW
- upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2
- return upperleft, denominator
-
-
-def c2c(x):
- return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
-
-
-def r2c(x):
- return torch.stack([x, torch.zeros_like(x)], -1)
-
-
-def cdiv(x, y):
- a, b = x[..., 0], x[..., 1]
- c, d = y[..., 0], y[..., 1]
- cd2 = c**2 + d**2
- return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
-
-
-def cabs(x):
- return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
-
-
-def cmul(t1, t2):
- '''
- complex multiplication
- t1: NxCxHxWx2
- output: NxCxHxWx2
- '''
- real1, imag1 = t1[..., 0], t1[..., 1]
- real2, imag2 = t2[..., 0], t2[..., 1]
- return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
-
-
-def cconj(t, inplace=False):
- '''
- # complex's conjugation
- t: NxCxHxWx2
- output: NxCxHxWx2
- '''
- c = t.clone() if not inplace else t
- c[..., 1] *= -1
- return c
-
-
-def rfft(t):
- return torch.rfft(t, 2, onesided=False)
-
-
-def irfft(t):
- return torch.irfft(t, 2, onesided=False)
-
-
-def fft(t):
- return torch.fft(t, 2)
-
-
-def ifft(t):
- return torch.ifft(t, 2)
-
-
-def p2o(psf, shape):
- '''
- # psf: NxCxhxw
- # shape: [H,W]
- # otf: NxCxHxWx2
- '''
- otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
- otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
- for axis, axis_size in enumerate(psf.shape[2:]):
- otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
- otf = torch.rfft(otf, 2, onesided=False)
- n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
- otf[...,1][torch.abs(otf[...,1])= abs(y)] = abs(x)[abs(x) >= abs(y)]
- maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)]
- minxy = np.zeros(x.shape)
- minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)]
- minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)]
- m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\
- (rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\
- np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2)
- m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\
- (rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\
- np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2)
- h = None
- return h
-
-
-def fspecial_gaussian(hsize, sigma):
- hsize = [hsize, hsize]
- siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
- std = sigma
- [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
- arg = -(x*x + y*y)/(2*std*std)
- h = np.exp(arg)
- h[h < scipy.finfo(float).eps * h.max()] = 0
- sumh = h.sum()
- if sumh != 0:
- h = h/sumh
- return h
-
-
-def fspecial_laplacian(alpha):
- alpha = max([0, min([alpha,1])])
- h1 = alpha/(alpha+1)
- h2 = (1-alpha)/(alpha+1)
- h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
- h = np.array(h)
- return h
-
-
-def fspecial_log(hsize, sigma):
- raise(NotImplemented)
-
-
-def fspecial_motion(motion_len, theta):
- raise(NotImplemented)
-
-
-def fspecial_prewitt():
- return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])
-
-
-def fspecial_sobel():
- return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
-
-
-def fspecial(filter_type, *args, **kwargs):
- '''
- python code from:
- https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
- '''
- if filter_type == 'average':
- return fspecial_average(*args, **kwargs)
- if filter_type == 'disk':
- return fspecial_disk(*args, **kwargs)
- if filter_type == 'gaussian':
- return fspecial_gaussian(*args, **kwargs)
- if filter_type == 'laplacian':
- return fspecial_laplacian(*args, **kwargs)
- if filter_type == 'log':
- return fspecial_log(*args, **kwargs)
- if filter_type == 'motion':
- return fspecial_motion(*args, **kwargs)
- if filter_type == 'prewitt':
- return fspecial_prewitt(*args, **kwargs)
- if filter_type == 'sobel':
- return fspecial_sobel(*args, **kwargs)
-
-
-def fspecial_gauss(size, sigma):
- x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1]
- g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
- return g / g.sum()
-
-
-def blurkernel_synthesis(h=37, w=None):
- # https://github.com/tkkcc/prior/blob/879a0b6c117c810776d8cc6b63720bf29f7d0cc4/util/gen_kernel.py
- w = h if w is None else w
- kdims = [h, w]
- x = randomTrajectory(250)
- k = None
- while k is None:
- k = kernelFromTrajectory(x)
-
- # center pad to kdims
- pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2)
- pad_width = [(pad_width[0],), (pad_width[1],)]
-
- if pad_width[0][0]<0 or pad_width[1][0]<0:
- k = k[0:h, 0:h]
- else:
- k = pad(k, pad_width, "constant")
- x1,x2 = k.shape
- if np.random.randint(0, 4) == 1:
- k = cv2.resize(k, (random.randint(x1, 5*x1), random.randint(x2, 5*x2)), interpolation=cv2.INTER_LINEAR)
- y1, y2 = k.shape
- k = k[(y1-x1)//2: (y1-x1)//2+x1, (y2-x2)//2: (y2-x2)//2+x2]
-
- if sum(k)<0.1:
- k = fspecial_gaussian(h, 0.1+6*np.random.rand(1))
- k = k / sum(k)
- # import matplotlib.pyplot as plt
- # plt.imshow(k, interpolation="nearest", cmap="gray")
- # plt.show()
- return k
-
-
-def kernelFromTrajectory(x):
- h = 5 - log(rand()) / 0.15
- h = round(min([h, 27])).astype(int)
- h = h + 1 - h % 2
- w = h
- k = zeros((h, w))
-
- xmin = min(x[0])
- xmax = max(x[0])
- ymin = min(x[1])
- ymax = max(x[1])
- xthr = arange(xmin, xmax, (xmax - xmin) / w)
- ythr = arange(ymin, ymax, (ymax - ymin) / h)
-
- for i in range(1, xthr.size):
- for j in range(1, ythr.size):
- idx = (
- (x[0, :] >= xthr[i - 1])
- & (x[0, :] < xthr[i])
- & (x[1, :] >= ythr[j - 1])
- & (x[1, :] < ythr[j])
- )
- k[i - 1, j - 1] = sum(idx)
- if sum(k) == 0:
- return
- k = k / sum(k)
- k = convolve2d(k, fspecial_gauss(3, 1), "same")
- k = k / sum(k)
- return k
-
-
-def randomTrajectory(T):
- x = zeros((3, T))
- v = randn(3, T)
- r = zeros((3, T))
- trv = 1 / 1
- trr = 2 * pi / T
- for t in range(1, T):
- F_rot = randn(3) / (t + 1) + r[:, t - 1]
- F_trans = randn(3) / (t + 1)
- r[:, t] = r[:, t - 1] + trr * F_rot
- v[:, t] = v[:, t - 1] + trv * F_trans
- st = v[:, t]
- st = rot3D(st, r[:, t])
- x[:, t] = x[:, t - 1] + st
- return x
-
-
-def rot3D(x, r):
- Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]])
- Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]])
- Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]])
- R = Rz @ Ry @ Rx
- x = R @ x
- return x
-
-
-if __name__ == '__main__':
- a = opt_fft_size([111])
- print(a)
-
- print(fspecial('gaussian', 5, 1))
-
- print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape)
-
- k = blurkernel_synthesis(11)
- import matplotlib.pyplot as plt
- plt.imshow(k, interpolation="nearest", cmap="gray")
- plt.show()
diff --git a/core/data/deg_kair_utils/utils_dist.py b/core/data/deg_kair_utils/utils_dist.py
deleted file mode 100644
index 88811737a8fc7cb6e12d9226a9242dbf8391d86b..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_dist.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
-import functools
-import os
-import subprocess
-import torch
-import torch.distributed as dist
-import torch.multiprocessing as mp
-
-
-# ----------------------------------
-# init
-# ----------------------------------
-def init_dist(launcher, backend='nccl', **kwargs):
- if mp.get_start_method(allow_none=True) is None:
- mp.set_start_method('spawn')
- if launcher == 'pytorch':
- _init_dist_pytorch(backend, **kwargs)
- elif launcher == 'slurm':
- _init_dist_slurm(backend, **kwargs)
- else:
- raise ValueError(f'Invalid launcher type: {launcher}')
-
-
-def _init_dist_pytorch(backend, **kwargs):
- rank = int(os.environ['RANK'])
- num_gpus = torch.cuda.device_count()
- torch.cuda.set_device(rank % num_gpus)
- dist.init_process_group(backend=backend, **kwargs)
-
-
-def _init_dist_slurm(backend, port=None):
- """Initialize slurm distributed training environment.
- If argument ``port`` is not specified, then the master port will be system
- environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
- environment variable, then a default port ``29500`` will be used.
- Args:
- backend (str): Backend of torch.distributed.
- port (int, optional): Master port. Defaults to None.
- """
- proc_id = int(os.environ['SLURM_PROCID'])
- ntasks = int(os.environ['SLURM_NTASKS'])
- node_list = os.environ['SLURM_NODELIST']
- num_gpus = torch.cuda.device_count()
- torch.cuda.set_device(proc_id % num_gpus)
- addr = subprocess.getoutput(
- f'scontrol show hostname {node_list} | head -n1')
- # specify master port
- if port is not None:
- os.environ['MASTER_PORT'] = str(port)
- elif 'MASTER_PORT' in os.environ:
- pass # use MASTER_PORT in the environment variable
- else:
- # 29500 is torch.distributed default port
- os.environ['MASTER_PORT'] = '29500'
- os.environ['MASTER_ADDR'] = addr
- os.environ['WORLD_SIZE'] = str(ntasks)
- os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
- os.environ['RANK'] = str(proc_id)
- dist.init_process_group(backend=backend)
-
-
-
-# ----------------------------------
-# get rank and world_size
-# ----------------------------------
-def get_dist_info():
- if dist.is_available():
- initialized = dist.is_initialized()
- else:
- initialized = False
- if initialized:
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- else:
- rank = 0
- world_size = 1
- return rank, world_size
-
-
-def get_rank():
- if not dist.is_available():
- return 0
-
- if not dist.is_initialized():
- return 0
-
- return dist.get_rank()
-
-
-def get_world_size():
- if not dist.is_available():
- return 1
-
- if not dist.is_initialized():
- return 1
-
- return dist.get_world_size()
-
-
-def master_only(func):
-
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- rank, _ = get_dist_info()
- if rank == 0:
- return func(*args, **kwargs)
-
- return wrapper
-
-
-
-
-
-
-# ----------------------------------
-# operation across ranks
-# ----------------------------------
-def reduce_sum(tensor):
- if not dist.is_available():
- return tensor
-
- if not dist.is_initialized():
- return tensor
-
- tensor = tensor.clone()
- dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
-
- return tensor
-
-
-def gather_grad(params):
- world_size = get_world_size()
-
- if world_size == 1:
- return
-
- for param in params:
- if param.grad is not None:
- dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
- param.grad.data.div_(world_size)
-
-
-def all_gather(data):
- world_size = get_world_size()
-
- if world_size == 1:
- return [data]
-
- buffer = pickle.dumps(data)
- storage = torch.ByteStorage.from_buffer(buffer)
- tensor = torch.ByteTensor(storage).to('cuda')
-
- local_size = torch.IntTensor([tensor.numel()]).to('cuda')
- size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
- dist.all_gather(size_list, local_size)
- size_list = [int(size.item()) for size in size_list]
- max_size = max(size_list)
-
- tensor_list = []
- for _ in size_list:
- tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
-
- if local_size != max_size:
- padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
- tensor = torch.cat((tensor, padding), 0)
-
- dist.all_gather(tensor_list, tensor)
-
- data_list = []
-
- for size, tensor in zip(size_list, tensor_list):
- buffer = tensor.cpu().numpy().tobytes()[:size]
- data_list.append(pickle.loads(buffer))
-
- return data_list
-
-
-def reduce_loss_dict(loss_dict):
- world_size = get_world_size()
-
- if world_size < 2:
- return loss_dict
-
- with torch.no_grad():
- keys = []
- losses = []
-
- for k in sorted(loss_dict.keys()):
- keys.append(k)
- losses.append(loss_dict[k])
-
- losses = torch.stack(losses, 0)
- dist.reduce(losses, dst=0)
-
- if dist.get_rank() == 0:
- losses /= world_size
-
- reduced_losses = {k: v for k, v in zip(keys, losses)}
-
- return reduced_losses
-
diff --git a/core/data/deg_kair_utils/utils_googledownload.py b/core/data/deg_kair_utils/utils_googledownload.py
deleted file mode 100644
index 25533d4e0d90bac7519874a654ffd833d16ae289..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_googledownload.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import math
-import requests
-from tqdm import tqdm
-
-
-'''
-borrowed from
-https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py
-'''
-
-
-def sizeof_fmt(size, suffix='B'):
- """Get human readable file size.
- Args:
- size (int): File size.
- suffix (str): Suffix. Default: 'B'.
- Return:
- str: Formated file siz.
- """
- for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
- if abs(size) < 1024.0:
- return f'{size:3.1f} {unit}{suffix}'
- size /= 1024.0
- return f'{size:3.1f} Y{suffix}'
-
-
-def download_file_from_google_drive(file_id, save_path):
- """Download files from google drive.
- Ref:
- https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
- Args:
- file_id (str): File id.
- save_path (str): Save path.
- """
-
- session = requests.Session()
- URL = 'https://docs.google.com/uc?export=download'
- params = {'id': file_id}
-
- response = session.get(URL, params=params, stream=True)
- token = get_confirm_token(response)
- if token:
- params['confirm'] = token
- response = session.get(URL, params=params, stream=True)
-
- # get file size
- response_file_size = session.get(
- URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
- if 'Content-Range' in response_file_size.headers:
- file_size = int(
- response_file_size.headers['Content-Range'].split('/')[1])
- else:
- file_size = None
-
- save_response_content(response, save_path, file_size)
-
-
-def get_confirm_token(response):
- for key, value in response.cookies.items():
- if key.startswith('download_warning'):
- return value
- return None
-
-
-def save_response_content(response,
- destination,
- file_size=None,
- chunk_size=32768):
- if file_size is not None:
- pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
-
- readable_file_size = sizeof_fmt(file_size)
- else:
- pbar = None
-
- with open(destination, 'wb') as f:
- downloaded_size = 0
- for chunk in response.iter_content(chunk_size):
- downloaded_size += chunk_size
- if pbar is not None:
- pbar.update(1)
- pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
- f'/ {readable_file_size}')
- if chunk: # filter out keep-alive new chunks
- f.write(chunk)
- if pbar is not None:
- pbar.close()
-
-
-if __name__ == "__main__":
- file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv'
- save_path = 'BSRGAN.pth'
- download_file_from_google_drive(file_id, save_path)
diff --git a/core/data/deg_kair_utils/utils_image.py b/core/data/deg_kair_utils/utils_image.py
deleted file mode 100644
index 0e513a8bc1594c9ce2ba47ce3fe3b497269b7f16..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_image.py
+++ /dev/null
@@ -1,1016 +0,0 @@
-import os
-import math
-import random
-import numpy as np
-import torch
-import cv2
-from torchvision.utils import make_grid
-from datetime import datetime
-# import torchvision.transforms as transforms
-import matplotlib.pyplot as plt
-from mpl_toolkits.mplot3d import Axes3D
-os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
-
-
-'''
-# --------------------------------------------
-# Kai Zhang (github: https://github.com/cszn)
-# 03/Mar/2019
-# --------------------------------------------
-# https://github.com/twhui/SRGAN-pyTorch
-# https://github.com/xinntao/BasicSR
-# --------------------------------------------
-'''
-
-
-IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
-
-
-def is_image_file(filename):
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
-
-
-def get_timestamp():
- return datetime.now().strftime('%y%m%d-%H%M%S')
-
-
-def imshow(x, title=None, cbar=False, figsize=None):
- plt.figure(figsize=figsize)
- plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
- if title:
- plt.title(title)
- if cbar:
- plt.colorbar()
- plt.show()
-
-
-def surf(Z, cmap='rainbow', figsize=None):
- plt.figure(figsize=figsize)
- ax3 = plt.axes(projection='3d')
-
- w, h = Z.shape[:2]
- xx = np.arange(0,w,1)
- yy = np.arange(0,h,1)
- X, Y = np.meshgrid(xx, yy)
- ax3.plot_surface(X,Y,Z,cmap=cmap)
- #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
- plt.show()
-
-
-'''
-# --------------------------------------------
-# get image pathes
-# --------------------------------------------
-'''
-
-
-def get_image_paths(dataroot):
- paths = None # return None if dataroot is None
- if isinstance(dataroot, str):
- paths = sorted(_get_paths_from_images(dataroot))
- elif isinstance(dataroot, list):
- paths = []
- for i in dataroot:
- paths += sorted(_get_paths_from_images(i))
- return paths
-
-
-def _get_paths_from_images(path):
- assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
- images = []
- for dirpath, _, fnames in sorted(os.walk(path)):
- for fname in sorted(fnames):
- if is_image_file(fname):
- img_path = os.path.join(dirpath, fname)
- images.append(img_path)
- assert images, '{:s} has no valid image file'.format(path)
- return images
-
-
-'''
-# --------------------------------------------
-# split large images into small images
-# --------------------------------------------
-'''
-
-
-def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
- w, h = img.shape[:2]
- patches = []
- if w > p_max and h > p_max:
- w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
- h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
- w1.append(w-p_size)
- h1.append(h-p_size)
- # print(w1)
- # print(h1)
- for i in w1:
- for j in h1:
- patches.append(img[i:i+p_size, j:j+p_size,:])
- else:
- patches.append(img)
-
- return patches
-
-
-def imssave(imgs, img_path):
- """
- imgs: list, N images of size WxHxC
- """
- img_name, ext = os.path.splitext(os.path.basename(img_path))
- for i, img in enumerate(imgs):
- if img.ndim == 3:
- img = img[:, :, [2, 1, 0]]
- new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png')
- cv2.imwrite(new_path, img)
-
-
-def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800):
- """
- split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
- and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
- will be splitted.
-
- Args:
- original_dataroot:
- taget_dataroot:
- p_size: size of small images
- p_overlap: patch size in training is a good choice
- p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
- """
- paths = get_image_paths(original_dataroot)
- for img_path in paths:
- # img_name, ext = os.path.splitext(os.path.basename(img_path))
- img = imread_uint(img_path, n_channels=n_channels)
- patches = patches_from_image(img, p_size, p_overlap, p_max)
- imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
- #if original_dataroot == taget_dataroot:
- #del img_path
-
-'''
-# --------------------------------------------
-# makedir
-# --------------------------------------------
-'''
-
-
-def mkdir(path):
- if not os.path.exists(path):
- os.makedirs(path)
-
-
-def mkdirs(paths):
- if isinstance(paths, str):
- mkdir(paths)
- else:
- for path in paths:
- mkdir(path)
-
-
-def mkdir_and_rename(path):
- if os.path.exists(path):
- new_name = path + '_archived_' + get_timestamp()
- print('Path already exists. Rename it to [{:s}]'.format(new_name))
- os.rename(path, new_name)
- os.makedirs(path)
-
-
-'''
-# --------------------------------------------
-# read image from path
-# opencv is fast, but read BGR numpy image
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# get uint8 image of size HxWxn_channles (RGB)
-# --------------------------------------------
-def imread_uint(path, n_channels=3):
- # input: path
- # output: HxWx3(RGB or GGG), or HxWx1 (G)
- if n_channels == 1:
- img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
- img = np.expand_dims(img, axis=2) # HxWx1
- elif n_channels == 3:
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
- if img.ndim == 2:
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
- else:
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
- return img
-
-
-# --------------------------------------------
-# matlab's imwrite
-# --------------------------------------------
-def imsave(img, img_path):
- img = np.squeeze(img)
- if img.ndim == 3:
- img = img[:, :, [2, 1, 0]]
- cv2.imwrite(img_path, img)
-
-def imwrite(img, img_path):
- img = np.squeeze(img)
- if img.ndim == 3:
- img = img[:, :, [2, 1, 0]]
- cv2.imwrite(img_path, img)
-
-
-
-# --------------------------------------------
-# get single image of size HxWxn_channles (BGR)
-# --------------------------------------------
-def read_img(path):
- # read image by cv2
- # return: Numpy float32, HWC, BGR, [0,1]
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
- img = img.astype(np.float32) / 255.
- if img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- # some images have 4 channels
- if img.shape[2] > 3:
- img = img[:, :, :3]
- return img
-
-
-'''
-# --------------------------------------------
-# image format conversion
-# --------------------------------------------
-# numpy(single) <---> numpy(uint)
-# numpy(single) <---> tensor
-# numpy(uint) <---> tensor
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# numpy(single) [0, 1] <---> numpy(uint)
-# --------------------------------------------
-
-
-def uint2single(img):
-
- return np.float32(img/255.)
-
-
-def single2uint(img):
-
- return np.uint8((img.clip(0, 1)*255.).round())
-
-
-def uint162single(img):
-
- return np.float32(img/65535.)
-
-
-def single2uint16(img):
-
- return np.uint16((img.clip(0, 1)*65535.).round())
-
-
-# --------------------------------------------
-# numpy(uint) (HxWxC or HxW) <---> tensor
-# --------------------------------------------
-
-
-# convert uint to 4-dimensional torch tensor
-def uint2tensor4(img):
- if img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
-
-
-# convert uint to 3-dimensional torch tensor
-def uint2tensor3(img):
- if img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
-
-
-# convert 2/3/4-dimensional torch tensor to uint
-def tensor2uint(img):
- img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
- if img.ndim == 3:
- img = np.transpose(img, (1, 2, 0))
- return np.uint8((img*255.0).round())
-
-
-# --------------------------------------------
-# numpy(single) (HxWxC) <---> tensor
-# --------------------------------------------
-
-
-# convert single (HxWxC) to 3-dimensional torch tensor
-def single2tensor3(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
-
-
-# convert single (HxWxC) to 4-dimensional torch tensor
-def single2tensor4(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
-
-
-# convert torch tensor to single
-def tensor2single(img):
- img = img.data.squeeze().float().cpu().numpy()
- if img.ndim == 3:
- img = np.transpose(img, (1, 2, 0))
-
- return img
-
-# convert torch tensor to single
-def tensor2single3(img):
- img = img.data.squeeze().float().cpu().numpy()
- if img.ndim == 3:
- img = np.transpose(img, (1, 2, 0))
- elif img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- return img
-
-
-def single2tensor5(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
-
-
-def single32tensor5(img):
- return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
-
-
-def single42tensor4(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
-
-
-# from skimage.io import imread, imsave
-def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
- '''
- Converts a torch Tensor into an image Numpy array of BGR channel order
- Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
- Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
- '''
- tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
- tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
- n_dim = tensor.dim()
- if n_dim == 4:
- n_img = len(tensor)
- img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
- img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
- elif n_dim == 3:
- img_np = tensor.numpy()
- img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
- elif n_dim == 2:
- img_np = tensor.numpy()
- else:
- raise TypeError(
- 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
- if out_type == np.uint8:
- img_np = (img_np * 255.0).round()
- # Important. Unlike matlab, numpy.uint8() WILL NOT round by default.
- return img_np.astype(out_type)
-
-
-'''
-# --------------------------------------------
-# Augmentation, flipe and/or rotate
-# --------------------------------------------
-# The following two are enough.
-# (1) augmet_img: numpy image of WxHxC or WxH
-# (2) augment_img_tensor4: tensor image 1xCxWxH
-# --------------------------------------------
-'''
-
-
-def augment_img(img, mode=0):
- '''Kai Zhang (github: https://github.com/cszn)
- '''
- if mode == 0:
- return img
- elif mode == 1:
- return np.flipud(np.rot90(img))
- elif mode == 2:
- return np.flipud(img)
- elif mode == 3:
- return np.rot90(img, k=3)
- elif mode == 4:
- return np.flipud(np.rot90(img, k=2))
- elif mode == 5:
- return np.rot90(img)
- elif mode == 6:
- return np.rot90(img, k=2)
- elif mode == 7:
- return np.flipud(np.rot90(img, k=3))
-
-
-def augment_img_tensor4(img, mode=0):
- '''Kai Zhang (github: https://github.com/cszn)
- '''
- if mode == 0:
- return img
- elif mode == 1:
- return img.rot90(1, [2, 3]).flip([2])
- elif mode == 2:
- return img.flip([2])
- elif mode == 3:
- return img.rot90(3, [2, 3])
- elif mode == 4:
- return img.rot90(2, [2, 3]).flip([2])
- elif mode == 5:
- return img.rot90(1, [2, 3])
- elif mode == 6:
- return img.rot90(2, [2, 3])
- elif mode == 7:
- return img.rot90(3, [2, 3]).flip([2])
-
-
-def augment_img_tensor(img, mode=0):
- '''Kai Zhang (github: https://github.com/cszn)
- '''
- img_size = img.size()
- img_np = img.data.cpu().numpy()
- if len(img_size) == 3:
- img_np = np.transpose(img_np, (1, 2, 0))
- elif len(img_size) == 4:
- img_np = np.transpose(img_np, (2, 3, 1, 0))
- img_np = augment_img(img_np, mode=mode)
- img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
- if len(img_size) == 3:
- img_tensor = img_tensor.permute(2, 0, 1)
- elif len(img_size) == 4:
- img_tensor = img_tensor.permute(3, 2, 0, 1)
-
- return img_tensor.type_as(img)
-
-
-def augment_img_np3(img, mode=0):
- if mode == 0:
- return img
- elif mode == 1:
- return img.transpose(1, 0, 2)
- elif mode == 2:
- return img[::-1, :, :]
- elif mode == 3:
- img = img[::-1, :, :]
- img = img.transpose(1, 0, 2)
- return img
- elif mode == 4:
- return img[:, ::-1, :]
- elif mode == 5:
- img = img[:, ::-1, :]
- img = img.transpose(1, 0, 2)
- return img
- elif mode == 6:
- img = img[:, ::-1, :]
- img = img[::-1, :, :]
- return img
- elif mode == 7:
- img = img[:, ::-1, :]
- img = img[::-1, :, :]
- img = img.transpose(1, 0, 2)
- return img
-
-
-def augment_imgs(img_list, hflip=True, rot=True):
- # horizontal flip OR rotate
- hflip = hflip and random.random() < 0.5
- vflip = rot and random.random() < 0.5
- rot90 = rot and random.random() < 0.5
-
- def _augment(img):
- if hflip:
- img = img[:, ::-1, :]
- if vflip:
- img = img[::-1, :, :]
- if rot90:
- img = img.transpose(1, 0, 2)
- return img
-
- return [_augment(img) for img in img_list]
-
-
-'''
-# --------------------------------------------
-# modcrop and shave
-# --------------------------------------------
-'''
-
-
-def modcrop(img_in, scale):
- # img_in: Numpy, HWC or HW
- img = np.copy(img_in)
- if img.ndim == 2:
- H, W = img.shape
- H_r, W_r = H % scale, W % scale
- img = img[:H - H_r, :W - W_r]
- elif img.ndim == 3:
- H, W, C = img.shape
- H_r, W_r = H % scale, W % scale
- img = img[:H - H_r, :W - W_r, :]
- else:
- raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
- return img
-
-
-def shave(img_in, border=0):
- # img_in: Numpy, HWC or HW
- img = np.copy(img_in)
- h, w = img.shape[:2]
- img = img[border:h-border, border:w-border]
- return img
-
-
-'''
-# --------------------------------------------
-# image processing process on numpy image
-# channel_convert(in_c, tar_type, img_list):
-# rgb2ycbcr(img, only_y=True):
-# bgr2ycbcr(img, only_y=True):
-# ycbcr2rgb(img):
-# --------------------------------------------
-'''
-
-
-def rgb2ycbcr(img, only_y=True):
- '''same as matlab rgb2ycbcr
- only_y: only return Y channel
- Input:
- uint8, [0, 255]
- float, [0, 1]
- '''
- in_img_type = img.dtype
- img.astype(np.float32)
- if in_img_type != np.uint8:
- img *= 255.
- # convert
- if only_y:
- rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
- else:
- rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
- [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
- if in_img_type == np.uint8:
- rlt = rlt.round()
- else:
- rlt /= 255.
- return rlt.astype(in_img_type)
-
-
-def ycbcr2rgb(img):
- '''same as matlab ycbcr2rgb
- Input:
- uint8, [0, 255]
- float, [0, 1]
- '''
- in_img_type = img.dtype
- img.astype(np.float32)
- if in_img_type != np.uint8:
- img *= 255.
- # convert
- rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
- [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
- rlt = np.clip(rlt, 0, 255)
- if in_img_type == np.uint8:
- rlt = rlt.round()
- else:
- rlt /= 255.
- return rlt.astype(in_img_type)
-
-
-def bgr2ycbcr(img, only_y=True):
- '''bgr version of rgb2ycbcr
- only_y: only return Y channel
- Input:
- uint8, [0, 255]
- float, [0, 1]
- '''
- in_img_type = img.dtype
- img.astype(np.float32)
- if in_img_type != np.uint8:
- img *= 255.
- # convert
- if only_y:
- rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
- else:
- rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
- [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
- if in_img_type == np.uint8:
- rlt = rlt.round()
- else:
- rlt /= 255.
- return rlt.astype(in_img_type)
-
-
-def channel_convert(in_c, tar_type, img_list):
- # conversion among BGR, gray and y
- if in_c == 3 and tar_type == 'gray': # BGR to gray
- gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
- return [np.expand_dims(img, axis=2) for img in gray_list]
- elif in_c == 3 and tar_type == 'y': # BGR to y
- y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
- return [np.expand_dims(img, axis=2) for img in y_list]
- elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
- return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
- else:
- return img_list
-
-
-'''
-# --------------------------------------------
-# metric, PSNR, SSIM and PSNRB
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# PSNR
-# --------------------------------------------
-def calculate_psnr(img1, img2, border=0):
- # img1 and img2 have range [0, 255]
- #img1 = img1.squeeze()
- #img2 = img2.squeeze()
- if not img1.shape == img2.shape:
- raise ValueError('Input images must have the same dimensions.')
- h, w = img1.shape[:2]
- img1 = img1[border:h-border, border:w-border]
- img2 = img2[border:h-border, border:w-border]
-
- img1 = img1.astype(np.float64)
- img2 = img2.astype(np.float64)
- mse = np.mean((img1 - img2)**2)
- if mse == 0:
- return float('inf')
- return 20 * math.log10(255.0 / math.sqrt(mse))
-
-
-# --------------------------------------------
-# SSIM
-# --------------------------------------------
-def calculate_ssim(img1, img2, border=0):
- '''calculate SSIM
- the same outputs as MATLAB's
- img1, img2: [0, 255]
- '''
- #img1 = img1.squeeze()
- #img2 = img2.squeeze()
- if not img1.shape == img2.shape:
- raise ValueError('Input images must have the same dimensions.')
- h, w = img1.shape[:2]
- img1 = img1[border:h-border, border:w-border]
- img2 = img2[border:h-border, border:w-border]
-
- if img1.ndim == 2:
- return ssim(img1, img2)
- elif img1.ndim == 3:
- if img1.shape[2] == 3:
- ssims = []
- for i in range(3):
- ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
- return np.array(ssims).mean()
- elif img1.shape[2] == 1:
- return ssim(np.squeeze(img1), np.squeeze(img2))
- else:
- raise ValueError('Wrong input image dimensions.')
-
-
-def ssim(img1, img2):
- C1 = (0.01 * 255)**2
- C2 = (0.03 * 255)**2
-
- img1 = img1.astype(np.float64)
- img2 = img2.astype(np.float64)
- kernel = cv2.getGaussianKernel(11, 1.5)
- window = np.outer(kernel, kernel.transpose())
-
- mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
- mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
- mu1_sq = mu1**2
- mu2_sq = mu2**2
- mu1_mu2 = mu1 * mu2
- sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
- sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
- sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
-
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
- (sigma1_sq + sigma2_sq + C2))
- return ssim_map.mean()
-
-
-def _blocking_effect_factor(im):
- block_size = 8
-
- block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
- block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)
-
- horizontal_block_difference = (
- (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
- 3).sum(2).sum(1)
- vertical_block_difference = (
- (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
- 2).sum(1)
-
- nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
- nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)
-
- horizontal_nonblock_difference = (
- (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
- 3).sum(2).sum(1)
- vertical_nonblock_difference = (
- (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
- 3).sum(2).sum(1)
-
- n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
- n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
- boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
- n_boundary_horiz + n_boundary_vert)
-
- n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
- n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
- nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
- n_nonboundary_horiz + n_nonboundary_vert)
-
- scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
- bef = scaler * (boundary_difference - nonboundary_difference)
-
- bef[boundary_difference <= nonboundary_difference] = 0
- return bef
-
-
-def calculate_psnrb(img1, img2, border=0):
- """Calculate PSNR-B (Peak Signal-to-Noise Ratio).
- Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation
- # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
- Args:
- img1 (ndarray): Images with range [0, 255].
- img2 (ndarray): Images with range [0, 255].
- border (int): Cropped pixels in each edge of an image. These
- pixels are not involved in the PSNR calculation.
- test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
- Returns:
- float: psnr result.
- """
-
- if not img1.shape == img2.shape:
- raise ValueError('Input images must have the same dimensions.')
-
- if img1.ndim == 2:
- img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2)
-
- h, w = img1.shape[:2]
- img1 = img1[border:h-border, border:w-border]
- img2 = img2[border:h-border, border:w-border]
-
- img1 = img1.astype(np.float64)
- img2 = img2.astype(np.float64)
-
- # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
- img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
- img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.
-
- total = 0
- for c in range(img1.shape[1]):
- mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
- bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])
-
- mse = mse.view(mse.shape[0], -1).mean(1)
- total += 10 * torch.log10(1 / (mse + bef))
-
- return float(total) / img1.shape[1]
-
-'''
-# --------------------------------------------
-# matlab's bicubic imresize (numpy and torch) [0, 1]
-# --------------------------------------------
-'''
-
-
-# matlab 'imresize' function, now only support 'bicubic'
-def cubic(x):
- absx = torch.abs(x)
- absx2 = absx**2
- absx3 = absx**3
- return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
- (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
-
-
-def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
- if (scale < 1) and (antialiasing):
- # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
- kernel_width = kernel_width / scale
-
- # Output-space coordinates
- x = torch.linspace(1, out_length, out_length)
-
- # Input-space coordinates. Calculate the inverse mapping such that 0.5
- # in output space maps to 0.5 in input space, and 0.5+scale in output
- # space maps to 1.5 in input space.
- u = x / scale + 0.5 * (1 - 1 / scale)
-
- # What is the left-most pixel that can be involved in the computation?
- left = torch.floor(u - kernel_width / 2)
-
- # What is the maximum number of pixels that can be involved in the
- # computation? Note: it's OK to use an extra pixel here; if the
- # corresponding weights are all zero, it will be eliminated at the end
- # of this function.
- P = math.ceil(kernel_width) + 2
-
- # The indices of the input pixels involved in computing the k-th output
- # pixel are in row k of the indices matrix.
- indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
- 1, P).expand(out_length, P)
-
- # The weights used to compute the k-th output pixel are in row k of the
- # weights matrix.
- distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
- # apply cubic kernel
- if (scale < 1) and (antialiasing):
- weights = scale * cubic(distance_to_center * scale)
- else:
- weights = cubic(distance_to_center)
- # Normalize the weights matrix so that each row sums to 1.
- weights_sum = torch.sum(weights, 1).view(out_length, 1)
- weights = weights / weights_sum.expand(out_length, P)
-
- # If a column in weights is all zero, get rid of it. only consider the first and last column.
- weights_zero_tmp = torch.sum((weights == 0), 0)
- if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
- indices = indices.narrow(1, 1, P - 2)
- weights = weights.narrow(1, 1, P - 2)
- if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
- indices = indices.narrow(1, 0, P - 2)
- weights = weights.narrow(1, 0, P - 2)
- weights = weights.contiguous()
- indices = indices.contiguous()
- sym_len_s = -indices.min() + 1
- sym_len_e = indices.max() - in_length
- indices = indices + sym_len_s - 1
- return weights, indices, int(sym_len_s), int(sym_len_e)
-
-
-# --------------------------------------------
-# imresize for tensor image [0, 1]
-# --------------------------------------------
-def imresize(img, scale, antialiasing=True):
- # Now the scale should be the same for H and W
- # input: img: pytorch tensor, CHW or HW [0,1]
- # output: CHW or HW [0,1] w/o round
- need_squeeze = True if img.dim() == 2 else False
- if need_squeeze:
- img.unsqueeze_(0)
- in_C, in_H, in_W = img.size()
- out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
- kernel_width = 4
- kernel = 'cubic'
-
- # Return the desired dimension order for performing the resize. The
- # strategy is to perform the resize first along the dimension with the
- # smallest scale factor.
- # Now we do not support this.
-
- # get weights and indices
- weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
- in_H, out_H, scale, kernel, kernel_width, antialiasing)
- weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
- in_W, out_W, scale, kernel, kernel_width, antialiasing)
- # process H dimension
- # symmetric copying
- img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
- img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
-
- sym_patch = img[:, :sym_len_Hs, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
-
- sym_patch = img[:, -sym_len_He:, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
-
- out_1 = torch.FloatTensor(in_C, out_H, in_W)
- kernel_width = weights_H.size(1)
- for i in range(out_H):
- idx = int(indices_H[i][0])
- for j in range(out_C):
- out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
-
- # process W dimension
- # symmetric copying
- out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
- out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
-
- sym_patch = out_1[:, :, :sym_len_Ws]
- inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(2, inv_idx)
- out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
-
- sym_patch = out_1[:, :, -sym_len_We:]
- inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(2, inv_idx)
- out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
-
- out_2 = torch.FloatTensor(in_C, out_H, out_W)
- kernel_width = weights_W.size(1)
- for i in range(out_W):
- idx = int(indices_W[i][0])
- for j in range(out_C):
- out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
- if need_squeeze:
- out_2.squeeze_()
- return out_2
-
-
-# --------------------------------------------
-# imresize for numpy image [0, 1]
-# --------------------------------------------
-def imresize_np(img, scale, antialiasing=True):
- # Now the scale should be the same for H and W
- # input: img: Numpy, HWC or HW [0,1]
- # output: HWC or HW [0,1] w/o round
- img = torch.from_numpy(img)
- need_squeeze = True if img.dim() == 2 else False
- if need_squeeze:
- img.unsqueeze_(2)
-
- in_H, in_W, in_C = img.size()
- out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
- kernel_width = 4
- kernel = 'cubic'
-
- # Return the desired dimension order for performing the resize. The
- # strategy is to perform the resize first along the dimension with the
- # smallest scale factor.
- # Now we do not support this.
-
- # get weights and indices
- weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
- in_H, out_H, scale, kernel, kernel_width, antialiasing)
- weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
- in_W, out_W, scale, kernel, kernel_width, antialiasing)
- # process H dimension
- # symmetric copying
- img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
- img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
-
- sym_patch = img[:sym_len_Hs, :, :]
- inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(0, inv_idx)
- img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
-
- sym_patch = img[-sym_len_He:, :, :]
- inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(0, inv_idx)
- img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
-
- out_1 = torch.FloatTensor(out_H, in_W, in_C)
- kernel_width = weights_H.size(1)
- for i in range(out_H):
- idx = int(indices_H[i][0])
- for j in range(out_C):
- out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
-
- # process W dimension
- # symmetric copying
- out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
- out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
-
- sym_patch = out_1[:, :sym_len_Ws, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
-
- sym_patch = out_1[:, -sym_len_We:, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
-
- out_2 = torch.FloatTensor(out_H, out_W, in_C)
- kernel_width = weights_W.size(1)
- for i in range(out_W):
- idx = int(indices_W[i][0])
- for j in range(out_C):
- out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
- if need_squeeze:
- out_2.squeeze_()
-
- return out_2.numpy()
-
-
-if __name__ == '__main__':
- img = imread_uint('test.bmp', 3)
-# img = uint2single(img)
-# img_bicubic = imresize_np(img, 1/4)
-# imshow(single2uint(img_bicubic))
-#
-# img_tensor = single2tensor4(img)
-# for i in range(8):
-# imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1))
-
-# patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200)
-# imssave(patches,'a.png')
-
-
-
-
-
-
-
diff --git a/core/data/deg_kair_utils/utils_lmdb.py b/core/data/deg_kair_utils/utils_lmdb.py
deleted file mode 100644
index 75192c346bb9c0b96f8b09635ed548bd6e797d89..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_lmdb.py
+++ /dev/null
@@ -1,205 +0,0 @@
-import cv2
-import lmdb
-import sys
-from multiprocessing import Pool
-from os import path as osp
-from tqdm import tqdm
-
-
-def make_lmdb_from_imgs(data_path,
- lmdb_path,
- img_path_list,
- keys,
- batch=5000,
- compress_level=1,
- multiprocessing_read=False,
- n_thread=40,
- map_size=None):
- """Make lmdb from images.
-
- Contents of lmdb. The file structure is:
- example.lmdb
- ├── data.mdb
- ├── lock.mdb
- ├── meta_info.txt
-
- The data.mdb and lock.mdb are standard lmdb files and you can refer to
- https://lmdb.readthedocs.io/en/release/ for more details.
-
- The meta_info.txt is a specified txt file to record the meta information
- of our datasets. It will be automatically created when preparing
- datasets by our provided dataset tools.
- Each line in the txt file records 1)image name (with extension),
- 2)image shape, and 3)compression level, separated by a white space.
-
- For example, the meta information could be:
- `000_00000000.png (720,1280,3) 1`, which means:
- 1) image name (with extension): 000_00000000.png;
- 2) image shape: (720,1280,3);
- 3) compression level: 1
-
- We use the image name without extension as the lmdb key.
-
- If `multiprocessing_read` is True, it will read all the images to memory
- using multiprocessing. Thus, your server needs to have enough memory.
-
- Args:
- data_path (str): Data path for reading images.
- lmdb_path (str): Lmdb save path.
- img_path_list (str): Image path list.
- keys (str): Used for lmdb keys.
- batch (int): After processing batch images, lmdb commits.
- Default: 5000.
- compress_level (int): Compress level when encoding images. Default: 1.
- multiprocessing_read (bool): Whether use multiprocessing to read all
- the images to memory. Default: False.
- n_thread (int): For multiprocessing.
- map_size (int | None): Map size for lmdb env. If None, use the
- estimated size from images. Default: None
- """
-
- assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
- f'but got {len(img_path_list)} and {len(keys)}')
- print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
- print(f'Totoal images: {len(img_path_list)}')
- if not lmdb_path.endswith('.lmdb'):
- raise ValueError("lmdb_path must end with '.lmdb'.")
- if osp.exists(lmdb_path):
- print(f'Folder {lmdb_path} already exists. Exit.')
- sys.exit(1)
-
- if multiprocessing_read:
- # read all the images to memory (multiprocessing)
- dataset = {} # use dict to keep the order for multiprocessing
- shapes = {}
- print(f'Read images with multiprocessing, #thread: {n_thread} ...')
- pbar = tqdm(total=len(img_path_list), unit='image')
-
- def callback(arg):
- """get the image data and update pbar."""
- key, dataset[key], shapes[key] = arg
- pbar.update(1)
- pbar.set_description(f'Read {key}')
-
- pool = Pool(n_thread)
- for path, key in zip(img_path_list, keys):
- pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
- pool.close()
- pool.join()
- pbar.close()
- print(f'Finish reading {len(img_path_list)} images.')
-
- # create lmdb environment
- if map_size is None:
- # obtain data size for one image
- img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
- _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
- data_size_per_img = img_byte.nbytes
- print('Data size per image is: ', data_size_per_img)
- data_size = data_size_per_img * len(img_path_list)
- map_size = data_size * 10
-
- env = lmdb.open(lmdb_path, map_size=map_size)
-
- # write data to lmdb
- pbar = tqdm(total=len(img_path_list), unit='chunk')
- txn = env.begin(write=True)
- txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
- for idx, (path, key) in enumerate(zip(img_path_list, keys)):
- pbar.update(1)
- pbar.set_description(f'Write {key}')
- key_byte = key.encode('ascii')
- if multiprocessing_read:
- img_byte = dataset[key]
- h, w, c = shapes[key]
- else:
- _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
- h, w, c = img_shape
-
- txn.put(key_byte, img_byte)
- # write meta information
- txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
- if idx % batch == 0:
- txn.commit()
- txn = env.begin(write=True)
- pbar.close()
- txn.commit()
- env.close()
- txt_file.close()
- print('\nFinish writing lmdb.')
-
-
-def read_img_worker(path, key, compress_level):
- """Read image worker.
-
- Args:
- path (str): Image path.
- key (str): Image key.
- compress_level (int): Compress level when encoding images.
-
- Returns:
- str: Image key.
- byte: Image byte.
- tuple[int]: Image shape.
- """
-
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
- # deal with `libpng error: Read Error`
- if img is None:
- print(f'To deal with `libpng error: Read Error`, use PIL to load {path}')
- from PIL import Image
- import numpy as np
- img = Image.open(path)
- img = np.asanyarray(img)
- img = img[:, :, [2, 1, 0]]
-
- if img.ndim == 2:
- h, w = img.shape
- c = 1
- else:
- h, w, c = img.shape
- _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
- return (key, img_byte, (h, w, c))
-
-
-class LmdbMaker():
- """LMDB Maker.
-
- Args:
- lmdb_path (str): Lmdb save path.
- map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
- batch (int): After processing batch images, lmdb commits.
- Default: 5000.
- compress_level (int): Compress level when encoding images. Default: 1.
- """
-
- def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
- if not lmdb_path.endswith('.lmdb'):
- raise ValueError("lmdb_path must end with '.lmdb'.")
- if osp.exists(lmdb_path):
- print(f'Folder {lmdb_path} already exists. Exit.')
- sys.exit(1)
-
- self.lmdb_path = lmdb_path
- self.batch = batch
- self.compress_level = compress_level
- self.env = lmdb.open(lmdb_path, map_size=map_size)
- self.txn = self.env.begin(write=True)
- self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
- self.counter = 0
-
- def put(self, img_byte, key, img_shape):
- self.counter += 1
- key_byte = key.encode('ascii')
- self.txn.put(key_byte, img_byte)
- # write meta information
- h, w, c = img_shape
- self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
- if self.counter % self.batch == 0:
- self.txn.commit()
- self.txn = self.env.begin(write=True)
-
- def close(self):
- self.txn.commit()
- self.env.close()
- self.txt_file.close()
diff --git a/core/data/deg_kair_utils/utils_logger.py b/core/data/deg_kair_utils/utils_logger.py
deleted file mode 100644
index 3067190e1b09b244814e0ccc4496b18f06e22b54..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_logger.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import sys
-import datetime
-import logging
-
-
-'''
-# --------------------------------------------
-# Kai Zhang (github: https://github.com/cszn)
-# 03/Mar/2019
-# --------------------------------------------
-# https://github.com/xinntao/BasicSR
-# --------------------------------------------
-'''
-
-
-def log(*args, **kwargs):
- print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
-
-
-'''
-# --------------------------------------------
-# logger
-# --------------------------------------------
-'''
-
-
-def logger_info(logger_name, log_path='default_logger.log'):
- ''' set up logger
- modified by Kai Zhang (github: https://github.com/cszn)
- '''
- log = logging.getLogger(logger_name)
- if log.hasHandlers():
- print('LogHandlers exist!')
- else:
- print('LogHandlers setup!')
- level = logging.INFO
- formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
- fh = logging.FileHandler(log_path, mode='a')
- fh.setFormatter(formatter)
- log.setLevel(level)
- log.addHandler(fh)
- # print(len(log.handlers))
-
- sh = logging.StreamHandler()
- sh.setFormatter(formatter)
- log.addHandler(sh)
-
-
-'''
-# --------------------------------------------
-# print to file and std_out simultaneously
-# --------------------------------------------
-'''
-
-
-class logger_print(object):
- def __init__(self, log_path="default.log"):
- self.terminal = sys.stdout
- self.log = open(log_path, 'a')
-
- def write(self, message):
- self.terminal.write(message)
- self.log.write(message) # write the message
-
- def flush(self):
- pass
diff --git a/core/data/deg_kair_utils/utils_mat.py b/core/data/deg_kair_utils/utils_mat.py
deleted file mode 100644
index cd25d500c0eae77a3b815b8e956205b737ee43d4..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_mat.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import os
-import json
-import scipy.io as spio
-import pandas as pd
-
-
-def loadmat(filename):
- '''
- this function should be called instead of direct spio.loadmat
- as it cures the problem of not properly recovering python dictionaries
- from mat files. It calls the function check keys to cure all entries
- which are still mat-objects
- '''
- data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True)
- return dict_to_nonedict(_check_keys(data))
-
-def _check_keys(dict):
- '''
- checks if entries in dictionary are mat-objects. If yes
- todict is called to change them to nested dictionaries
- '''
- for key in dict:
- if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
- dict[key] = _todict(dict[key])
- return dict
-
-def _todict(matobj):
- '''
- A recursive function which constructs from matobjects nested dictionaries
- '''
- dict = {}
- for strg in matobj._fieldnames:
- elem = matobj.__dict__[strg]
- if isinstance(elem, spio.matlab.mio5_params.mat_struct):
- dict[strg] = _todict(elem)
- else:
- dict[strg] = elem
- return dict
-
-
-def dict_to_nonedict(opt):
- if isinstance(opt, dict):
- new_opt = dict()
- for key, sub_opt in opt.items():
- new_opt[key] = dict_to_nonedict(sub_opt)
- return NoneDict(**new_opt)
- elif isinstance(opt, list):
- return [dict_to_nonedict(sub_opt) for sub_opt in opt]
- else:
- return opt
-
-
-class NoneDict(dict):
- def __missing__(self, key):
- return None
-
-
-def mat2json(mat_path=None, filepath = None):
- """
- Converts .mat file to .json and writes new file
- Parameters
- ----------
- mat_path: Str
- path/filename .mat存放路径
- filepath: Str
- 如果需要保存成json, 添加这一路径. 否则不保存
- Returns
- 返回转化的字典
- -------
- None
- Examples
- --------
- >>> mat2json(blah blah)
- """
-
- matlabFile = loadmat(mat_path)
- #pop all those dumb fields that don't let you jsonize file
- matlabFile.pop('__header__')
- matlabFile.pop('__version__')
- matlabFile.pop('__globals__')
- #jsonize the file - orientation is 'index'
- matlabFile = pd.Series(matlabFile).to_json()
-
- if filepath:
- json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json'
- with open(json_path, 'w') as f:
- f.write(matlabFile)
- return matlabFile
\ No newline at end of file
diff --git a/core/data/deg_kair_utils/utils_matconvnet.py b/core/data/deg_kair_utils/utils_matconvnet.py
deleted file mode 100644
index 506dc47805ae07976022b236ca64c98e9a6f78b3..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_matconvnet.py
+++ /dev/null
@@ -1,197 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-import torch
-from collections import OrderedDict
-
-# import scipy.io as io
-import hdf5storage
-
-"""
-# --------------------------------------------
-# Convert matconvnet SimpleNN model into pytorch model
-# --------------------------------------------
-# Kai Zhang (cskaizhang@gmail.com)
-# https://github.com/cszn
-# 28/Nov/2019
-# --------------------------------------------
-"""
-
-
-def weights2tensor(x, squeeze=False, in_features=None, out_features=None):
- """Modified version of https://github.com/albanie/pytorch-mcn
- Adjust memory layout and load weights as torch tensor
- Args:
- x (ndaray): a numpy array, corresponding to a set of network weights
- stored in column major order
- squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove
- singletons from the trailing dimensions. So after converting to
- pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1)
- it will be reshaped to a matrix with shape (A,B).
- in_features (int :: None): used to reshape weights for a linear block.
- out_features (int :: None): used to reshape weights for a linear block.
- Returns:
- torch.tensor: a permuted sets of weights, matching the pytorch layout
- convention
- """
- if x.ndim == 4:
- x = x.transpose((3, 2, 0, 1))
-# for FFDNet, pixel-shuffle layer
-# if x.shape[1]==13:
-# x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:]
-# if x.shape[0]==12:
-# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
-# if x.shape[1]==5:
-# x=x[:,[0,2,1,3, 4],:,:]
-# if x.shape[0]==4:
-# x=x[[0,2,1,3],:,:,:]
-## for SRMD, pixel-shuffle layer
-# if x.shape[0]==12:
-# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
-# if x.shape[0]==27:
-# x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:]
-# if x.shape[0]==48:
-# x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:]
-
- elif x.ndim == 3: # add by Kai
- x = x[:,:,:,None]
- x = x.transpose((3, 2, 0, 1))
- elif x.ndim == 2:
- if x.shape[1] == 1:
- x = x.flatten()
- if squeeze:
- if in_features and out_features:
- x = x.reshape((out_features, in_features))
- x = np.squeeze(x)
- return torch.from_numpy(np.ascontiguousarray(x))
-
-
-def save_model(network, save_path):
- state_dict = network.state_dict()
- for key, param in state_dict.items():
- state_dict[key] = param.cpu()
- torch.save(state_dict, save_path)
-
-
-if __name__ == '__main__':
-
-
-# from utils import utils_logger
-# import logging
-# utils_logger.logger_info('a', 'a.log')
-# logger = logging.getLogger('a')
-#
- # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat')
- mcn = hdf5storage.loadmat('models/modelcolor.mat')
-
-
- #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0])
-
- mat_net = OrderedDict()
- for idx in range(25):
- mat_net[str(idx)] = OrderedDict()
- count = -1
-
- print(idx)
- for i in range(13):
-
- if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv':
-
- count += 1
- w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0]
- # print(w.shape)
- w = weights2tensor(w)
- # print(w.shape)
-
- b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1]
- b = weights2tensor(b)
- print(b.shape)
-
- mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w
- mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b
-
- torch.save(mat_net, 'model_zoo/modelcolor.pth')
-
-
-
-# from models.network_dncnn import IRCNN as net
-# network = net(in_nc=3, out_nc=3, nc=64)
-# state_dict = network.state_dict()
-#
-# #show_kv(state_dict)
-#
-# for i in range(len(mcn['net'][0][0][0])):
-# print(mcn['net'][0][0][0][i][0][0][0][0])
-#
-# count = -1
-# mat_net = OrderedDict()
-# for i in range(len(mcn['net'][0][0][0])):
-# if mcn['net'][0][0][0][i][0][0][0][0] == 'conv':
-#
-# count += 1
-# w = mcn['net'][0][0][0][i][0][1][0][0]
-# print(w.shape)
-# w = weights2tensor(w)
-# print(w.shape)
-#
-# b = mcn['net'][0][0][0][i][0][1][0][1]
-# b = weights2tensor(b)
-# print(b.shape)
-#
-# mat_net['model.{:d}.weight'.format(count*2)] = w
-# mat_net['model.{:d}.bias'.format(count*2)] = b
-#
-# torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth')
-#
-#
-#
-# crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth')
-# def show_kv(net):
-# for k, v in net.items():
-# print(k)
-#
-# show_kv(crt_net)
-
-
-# from models.network_dncnn import DnCNN as net
-# network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
-
-# from models.network_srmd import SRMD as net
-# #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R')
-# network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
-#
-# from models.network_rrdb import RRDB as net
-# network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv')
-#
-# state_dict = network.state_dict()
-# for key, param in state_dict.items():
-# print(key)
-# from models.network_imdn import IMDN as net
-# network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle')
-# state_dict = network.state_dict()
-# mat_net = OrderedDict()
-# for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()):
-# mat_net[key] = param2
-# torch.save(mat_net, 'model_zoo/imdn_x4_1.pth')
-#
-
-# net_old = torch.load('net_old.pth')
-# def show_kv(net):
-# for k, v in net.items():
-# print(k)
-#
-# show_kv(net_old)
-# from models.network_dpsr import MSRResNet_prior as net
-# model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
-# state_dict = network.state_dict()
-# net_new = OrderedDict()
-# for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()):
-# net_new[key] = param_old
-# torch.save(net_new, 'net_new.pth')
-
-
- # print(key)
- # print(param.size())
-
-
-
- # run utils/utils_matconvnet.py
diff --git a/core/data/deg_kair_utils/utils_model.py b/core/data/deg_kair_utils/utils_model.py
deleted file mode 100644
index a4d9e6ac651784c7ed36e623c3a6175883123c2b..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_model.py
+++ /dev/null
@@ -1,330 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-import torch
-from utils import utils_image as util
-import re
-import glob
-import os
-
-
-'''
-# --------------------------------------------
-# Model
-# --------------------------------------------
-# Kai Zhang (github: https://github.com/cszn)
-# 03/Mar/2019
-# --------------------------------------------
-'''
-
-
-def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
- """
- # ---------------------------------------
- # Kai Zhang (github: https://github.com/cszn)
- # 03/Mar/2019
- # ---------------------------------------
- Args:
- save_dir: model folder
- net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
- pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
-
- Return:
- init_iter: iteration number
- init_path: model path
- # ---------------------------------------
- """
-
- file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
- if file_list:
- iter_exist = []
- for file_ in file_list:
- iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
- iter_exist.append(int(iter_current[0]))
- init_iter = max(iter_exist)
- init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
- else:
- init_iter = 0
- init_path = pretrained_path
- return init_iter, init_path
-
-
-def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1):
- '''
- # ---------------------------------------
- # Kai Zhang (github: https://github.com/cszn)
- # 03/Mar/2019
- # ---------------------------------------
- Args:
- model: trained model
- L: input Low-quality image
- mode:
- (0) normal: test(model, L)
- (1) pad: test_pad(model, L, modulo=16)
- (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1)
- (3) x8: test_x8(model, L, modulo=1) ^_^
- (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1)
- refield: effective receptive filed of the network, 32 is enough
- useful when split, i.e., mode=2, 4
- min_size: min_sizeXmin_size image, e.g., 256X256 image
- useful when split, i.e., mode=2, 4
- sf: scale factor for super-resolution, otherwise 1
- modulo: 1 if split
- useful when pad, i.e., mode=1
-
- Returns:
- E: estimated image
- # ---------------------------------------
- '''
- if mode == 0:
- E = test(model, L)
- elif mode == 1:
- E = test_pad(model, L, modulo, sf)
- elif mode == 2:
- E = test_split(model, L, refield, min_size, sf, modulo)
- elif mode == 3:
- E = test_x8(model, L, modulo, sf)
- elif mode == 4:
- E = test_split_x8(model, L, refield, min_size, sf, modulo)
- return E
-
-
-'''
-# --------------------------------------------
-# normal (0)
-# --------------------------------------------
-'''
-
-
-def test(model, L):
- E = model(L)
- return E
-
-
-'''
-# --------------------------------------------
-# pad (1)
-# --------------------------------------------
-'''
-
-
-def test_pad(model, L, modulo=16, sf=1):
- h, w = L.size()[-2:]
- paddingBottom = int(np.ceil(h/modulo)*modulo-h)
- paddingRight = int(np.ceil(w/modulo)*modulo-w)
- L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
- E = model(L)
- E = E[..., :h*sf, :w*sf]
- return E
-
-
-'''
-# --------------------------------------------
-# split (function)
-# --------------------------------------------
-'''
-
-
-def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1):
- """
- Args:
- model: trained model
- L: input Low-quality image
- refield: effective receptive filed of the network, 32 is enough
- min_size: min_sizeXmin_size image, e.g., 256X256 image
- sf: scale factor for super-resolution, otherwise 1
- modulo: 1 if split
-
- Returns:
- E: estimated result
- """
- h, w = L.size()[-2:]
- if h*w <= min_size**2:
- L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
- E = model(L)
- E = E[..., :h*sf, :w*sf]
- else:
- top = slice(0, (h//2//refield+1)*refield)
- bottom = slice(h - (h//2//refield+1)*refield, h)
- left = slice(0, (w//2//refield+1)*refield)
- right = slice(w - (w//2//refield+1)*refield, w)
- Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
-
- if h * w <= 4*(min_size**2):
- Es = [model(Ls[i]) for i in range(4)]
- else:
- Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
-
- b, c = Es[0].size()[:2]
- E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
-
- E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
- E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
- E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
- E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
- return E
-
-
-'''
-# --------------------------------------------
-# split (2)
-# --------------------------------------------
-'''
-
-
-def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1):
- E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
- return E
-
-
-'''
-# --------------------------------------------
-# x8 (3)
-# --------------------------------------------
-'''
-
-
-def test_x8(model, L, modulo=1, sf=1):
- E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)]
- for i in range(len(E_list)):
- if i == 3 or i == 5:
- E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i)
- else:
- E_list[i] = util.augment_img_tensor4(E_list[i], mode=i)
- output_cat = torch.stack(E_list, dim=0)
- E = output_cat.mean(dim=0, keepdim=False)
- return E
-
-
-'''
-# --------------------------------------------
-# split and x8 (4)
-# --------------------------------------------
-'''
-
-
-def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1):
- E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)]
- for k, i in enumerate(range(len(E_list))):
- if i==3 or i==5:
- E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i)
- else:
- E_list[k] = util.augment_img_tensor4(E_list[k], mode=i)
- output_cat = torch.stack(E_list, dim=0)
- E = output_cat.mean(dim=0, keepdim=False)
- return E
-
-
-'''
-# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
-# _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^
-# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
-'''
-
-
-'''
-# --------------------------------------------
-# print
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# print model
-# --------------------------------------------
-def print_model(model):
- msg = describe_model(model)
- print(msg)
-
-
-# --------------------------------------------
-# print params
-# --------------------------------------------
-def print_params(model):
- msg = describe_params(model)
- print(msg)
-
-
-'''
-# --------------------------------------------
-# information
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# model inforation
-# --------------------------------------------
-def info_model(model):
- msg = describe_model(model)
- return msg
-
-
-# --------------------------------------------
-# params inforation
-# --------------------------------------------
-def info_params(model):
- msg = describe_params(model)
- return msg
-
-
-'''
-# --------------------------------------------
-# description
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# model name and total number of parameters
-# --------------------------------------------
-def describe_model(model):
- if isinstance(model, torch.nn.DataParallel):
- model = model.module
- msg = '\n'
- msg += 'models name: {}'.format(model.__class__.__name__) + '\n'
- msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n'
- msg += 'Net structure:\n{}'.format(str(model)) + '\n'
- return msg
-
-
-# --------------------------------------------
-# parameters description
-# --------------------------------------------
-def describe_params(model):
- if isinstance(model, torch.nn.DataParallel):
- model = model.module
- msg = '\n'
- msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
- for name, param in model.state_dict().items():
- if not 'num_batches_tracked' in name:
- v = param.data.clone().float()
- msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
- return msg
-
-
-if __name__ == '__main__':
-
- class Net(torch.nn.Module):
- def __init__(self, in_channels=3, out_channels=3):
- super(Net, self).__init__()
- self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
-
- def forward(self, x):
- x = self.conv(x)
- return x
-
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
-
- model = Net()
- model = model.eval()
- print_model(model)
- print_params(model)
- x = torch.randn((2,3,401,401))
- torch.cuda.empty_cache()
- with torch.no_grad():
- for mode in range(5):
- y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1)
- print(y.shape)
-
- # run utils/utils_model.py
diff --git a/core/data/deg_kair_utils/utils_modelsummary.py b/core/data/deg_kair_utils/utils_modelsummary.py
deleted file mode 100644
index 5e040e31d8ddffbb8b7b2e2dc4ddf0b9cdca6a23..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_modelsummary.py
+++ /dev/null
@@ -1,485 +0,0 @@
-import torch.nn as nn
-import torch
-import numpy as np
-
-'''
----- 1) FLOPs: floating point operations
----- 2) #Activations: the number of elements of all ‘Conv2d’ outputs
----- 3) #Conv2d: the number of ‘Conv2d’ layers
-# --------------------------------------------
-# Kai Zhang (github: https://github.com/cszn)
-# 21/July/2020
-# --------------------------------------------
-# Reference
-https://github.com/sovrasov/flops-counter.pytorch.git
-
-# If you use this code, please consider the following citation:
-
-@inproceedings{zhang2020aim, %
- title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results},
- author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others},
- booktitle={European Conference on Computer Vision Workshops},
- year={2020}
-}
-# --------------------------------------------
-'''
-
-def get_model_flops(model, input_res, print_per_layer_stat=True,
- input_constructor=None):
- assert type(input_res) is tuple, 'Please provide the size of the input image.'
- assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
- flops_model = add_flops_counting_methods(model)
- flops_model.eval().start_flops_count()
- if input_constructor:
- input = input_constructor(input_res)
- _ = flops_model(**input)
- else:
- device = list(flops_model.parameters())[-1].device
- batch = torch.FloatTensor(1, *input_res).to(device)
- _ = flops_model(batch)
-
- if print_per_layer_stat:
- print_model_with_flops(flops_model)
- flops_count = flops_model.compute_average_flops_cost()
- flops_model.stop_flops_count()
-
- return flops_count
-
-def get_model_activation(model, input_res, input_constructor=None):
- assert type(input_res) is tuple, 'Please provide the size of the input image.'
- assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
- activation_model = add_activation_counting_methods(model)
- activation_model.eval().start_activation_count()
- if input_constructor:
- input = input_constructor(input_res)
- _ = activation_model(**input)
- else:
- device = list(activation_model.parameters())[-1].device
- batch = torch.FloatTensor(1, *input_res).to(device)
- _ = activation_model(batch)
-
- activation_count, num_conv = activation_model.compute_average_activation_cost()
- activation_model.stop_activation_count()
-
- return activation_count, num_conv
-
-
-def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True,
- input_constructor=None):
- assert type(input_res) is tuple
- assert len(input_res) >= 3
- flops_model = add_flops_counting_methods(model)
- flops_model.eval().start_flops_count()
- if input_constructor:
- input = input_constructor(input_res)
- _ = flops_model(**input)
- else:
- batch = torch.FloatTensor(1, *input_res)
- _ = flops_model(batch)
-
- if print_per_layer_stat:
- print_model_with_flops(flops_model)
- flops_count = flops_model.compute_average_flops_cost()
- params_count = get_model_parameters_number(flops_model)
- flops_model.stop_flops_count()
-
- if as_strings:
- return flops_to_string(flops_count), params_to_string(params_count)
-
- return flops_count, params_count
-
-
-def flops_to_string(flops, units='GMac', precision=2):
- if units is None:
- if flops // 10**9 > 0:
- return str(round(flops / 10.**9, precision)) + ' GMac'
- elif flops // 10**6 > 0:
- return str(round(flops / 10.**6, precision)) + ' MMac'
- elif flops // 10**3 > 0:
- return str(round(flops / 10.**3, precision)) + ' KMac'
- else:
- return str(flops) + ' Mac'
- else:
- if units == 'GMac':
- return str(round(flops / 10.**9, precision)) + ' ' + units
- elif units == 'MMac':
- return str(round(flops / 10.**6, precision)) + ' ' + units
- elif units == 'KMac':
- return str(round(flops / 10.**3, precision)) + ' ' + units
- else:
- return str(flops) + ' Mac'
-
-
-def params_to_string(params_num):
- if params_num // 10 ** 6 > 0:
- return str(round(params_num / 10 ** 6, 2)) + ' M'
- elif params_num // 10 ** 3:
- return str(round(params_num / 10 ** 3, 2)) + ' k'
- else:
- return str(params_num)
-
-
-def print_model_with_flops(model, units='GMac', precision=3):
- total_flops = model.compute_average_flops_cost()
-
- def accumulate_flops(self):
- if is_supported_instance(self):
- return self.__flops__ / model.__batch_counter__
- else:
- sum = 0
- for m in self.children():
- sum += m.accumulate_flops()
- return sum
-
- def flops_repr(self):
- accumulated_flops_cost = self.accumulate_flops()
- return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
- '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
- self.original_extra_repr()])
-
- def add_extra_repr(m):
- m.accumulate_flops = accumulate_flops.__get__(m)
- flops_extra_repr = flops_repr.__get__(m)
- if m.extra_repr != flops_extra_repr:
- m.original_extra_repr = m.extra_repr
- m.extra_repr = flops_extra_repr
- assert m.extra_repr != m.original_extra_repr
-
- def del_extra_repr(m):
- if hasattr(m, 'original_extra_repr'):
- m.extra_repr = m.original_extra_repr
- del m.original_extra_repr
- if hasattr(m, 'accumulate_flops'):
- del m.accumulate_flops
-
- model.apply(add_extra_repr)
- print(model)
- model.apply(del_extra_repr)
-
-
-def get_model_parameters_number(model):
- params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
- return params_num
-
-
-def add_flops_counting_methods(net_main_module):
- # adding additional methods to the existing module object,
- # this is done this way so that each function has access to self object
- # embed()
- net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
- net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
- net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
- net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
-
- net_main_module.reset_flops_count()
- return net_main_module
-
-
-def compute_average_flops_cost(self):
- """
- A method that will be available after add_flops_counting_methods() is called
- on a desired net object.
-
- Returns current mean flops consumption per image.
-
- """
-
- flops_sum = 0
- for module in self.modules():
- if is_supported_instance(module):
- flops_sum += module.__flops__
-
- return flops_sum
-
-
-def start_flops_count(self):
- """
- A method that will be available after add_flops_counting_methods() is called
- on a desired net object.
-
- Activates the computation of mean flops consumption per image.
- Call it before you run the network.
-
- """
- self.apply(add_flops_counter_hook_function)
-
-
-def stop_flops_count(self):
- """
- A method that will be available after add_flops_counting_methods() is called
- on a desired net object.
-
- Stops computing the mean flops consumption per image.
- Call whenever you want to pause the computation.
-
- """
- self.apply(remove_flops_counter_hook_function)
-
-
-def reset_flops_count(self):
- """
- A method that will be available after add_flops_counting_methods() is called
- on a desired net object.
-
- Resets statistics computed so far.
-
- """
- self.apply(add_flops_counter_variable_or_reset)
-
-
-def add_flops_counter_hook_function(module):
- if is_supported_instance(module):
- if hasattr(module, '__flops_handle__'):
- return
-
- if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
- handle = module.register_forward_hook(conv_flops_counter_hook)
- elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)):
- handle = module.register_forward_hook(relu_flops_counter_hook)
- elif isinstance(module, nn.Linear):
- handle = module.register_forward_hook(linear_flops_counter_hook)
- elif isinstance(module, (nn.BatchNorm2d)):
- handle = module.register_forward_hook(bn_flops_counter_hook)
- else:
- handle = module.register_forward_hook(empty_flops_counter_hook)
- module.__flops_handle__ = handle
-
-
-def remove_flops_counter_hook_function(module):
- if is_supported_instance(module):
- if hasattr(module, '__flops_handle__'):
- module.__flops_handle__.remove()
- del module.__flops_handle__
-
-
-def add_flops_counter_variable_or_reset(module):
- if is_supported_instance(module):
- module.__flops__ = 0
-
-
-# ---- Internal functions
-def is_supported_instance(module):
- if isinstance(module,
- (
- nn.Conv2d, nn.ConvTranspose2d,
- nn.BatchNorm2d,
- nn.Linear,
- nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6,
- )):
- return True
-
- return False
-
-
-def conv_flops_counter_hook(conv_module, input, output):
- # Can have multiple inputs, getting the first one
- # input = input[0]
-
- batch_size = output.shape[0]
- output_dims = list(output.shape[2:])
-
- kernel_dims = list(conv_module.kernel_size)
- in_channels = conv_module.in_channels
- out_channels = conv_module.out_channels
- groups = conv_module.groups
-
- filters_per_channel = out_channels // groups
- conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
-
- active_elements_count = batch_size * np.prod(output_dims)
- overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count)
-
- # overall_flops = overall_conv_flops
-
- conv_module.__flops__ += int(overall_conv_flops)
- # conv_module.__output_dims__ = output_dims
-
-
-def relu_flops_counter_hook(module, input, output):
- active_elements_count = output.numel()
- module.__flops__ += int(active_elements_count)
- # print(module.__flops__, id(module))
- # print(module)
-
-
-def linear_flops_counter_hook(module, input, output):
- input = input[0]
- if len(input.shape) == 1:
- batch_size = 1
- module.__flops__ += int(batch_size * input.shape[0] * output.shape[0])
- else:
- batch_size = input.shape[0]
- module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
-
-
-def bn_flops_counter_hook(module, input, output):
- # input = input[0]
- # TODO: need to check here
- # batch_flops = np.prod(input.shape)
- # if module.affine:
- # batch_flops *= 2
- # module.__flops__ += int(batch_flops)
- batch = output.shape[0]
- output_dims = output.shape[2:]
- channels = module.num_features
- batch_flops = batch * channels * np.prod(output_dims)
- if module.affine:
- batch_flops *= 2
- module.__flops__ += int(batch_flops)
-
-
-# ---- Count the number of convolutional layers and the activation
-def add_activation_counting_methods(net_main_module):
- # adding additional methods to the existing module object,
- # this is done this way so that each function has access to self object
- # embed()
- net_main_module.start_activation_count = start_activation_count.__get__(net_main_module)
- net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module)
- net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module)
- net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module)
-
- net_main_module.reset_activation_count()
- return net_main_module
-
-
-def compute_average_activation_cost(self):
- """
- A method that will be available after add_activation_counting_methods() is called
- on a desired net object.
-
- Returns current mean activation consumption per image.
-
- """
-
- activation_sum = 0
- num_conv = 0
- for module in self.modules():
- if is_supported_instance_for_activation(module):
- activation_sum += module.__activation__
- num_conv += module.__num_conv__
- return activation_sum, num_conv
-
-
-def start_activation_count(self):
- """
- A method that will be available after add_activation_counting_methods() is called
- on a desired net object.
-
- Activates the computation of mean activation consumption per image.
- Call it before you run the network.
-
- """
- self.apply(add_activation_counter_hook_function)
-
-
-def stop_activation_count(self):
- """
- A method that will be available after add_activation_counting_methods() is called
- on a desired net object.
-
- Stops computing the mean activation consumption per image.
- Call whenever you want to pause the computation.
-
- """
- self.apply(remove_activation_counter_hook_function)
-
-
-def reset_activation_count(self):
- """
- A method that will be available after add_activation_counting_methods() is called
- on a desired net object.
-
- Resets statistics computed so far.
-
- """
- self.apply(add_activation_counter_variable_or_reset)
-
-
-def add_activation_counter_hook_function(module):
- if is_supported_instance_for_activation(module):
- if hasattr(module, '__activation_handle__'):
- return
-
- if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
- handle = module.register_forward_hook(conv_activation_counter_hook)
- module.__activation_handle__ = handle
-
-
-def remove_activation_counter_hook_function(module):
- if is_supported_instance_for_activation(module):
- if hasattr(module, '__activation_handle__'):
- module.__activation_handle__.remove()
- del module.__activation_handle__
-
-
-def add_activation_counter_variable_or_reset(module):
- if is_supported_instance_for_activation(module):
- module.__activation__ = 0
- module.__num_conv__ = 0
-
-
-def is_supported_instance_for_activation(module):
- if isinstance(module,
- (
- nn.Conv2d, nn.ConvTranspose2d,
- )):
- return True
-
- return False
-
-def conv_activation_counter_hook(module, input, output):
- """
- Calculate the activations in the convolutional operation.
- Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces.
- :param module:
- :param input:
- :param output:
- :return:
- """
- module.__activation__ += output.numel()
- module.__num_conv__ += 1
-
-
-def empty_flops_counter_hook(module, input, output):
- module.__flops__ += 0
-
-
-def upsample_flops_counter_hook(module, input, output):
- output_size = output[0]
- batch_size = output_size.shape[0]
- output_elements_count = batch_size
- for val in output_size.shape[1:]:
- output_elements_count *= val
- module.__flops__ += int(output_elements_count)
-
-
-def pool_flops_counter_hook(module, input, output):
- input = input[0]
- module.__flops__ += int(np.prod(input.shape))
-
-
-def dconv_flops_counter_hook(dconv_module, input, output):
- input = input[0]
-
- batch_size = input.shape[0]
- output_dims = list(output.shape[2:])
-
- m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape
- out_channels, _, kernel_dim2, _, = dconv_module.projection.shape
- # groups = dconv_module.groups
-
- # filters_per_channel = out_channels // groups
- conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels
- conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels
- active_elements_count = batch_size * np.prod(output_dims)
-
- overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count
- overall_flops = overall_conv_flops
-
- dconv_module.__flops__ += int(overall_flops)
- # dconv_module.__output_dims__ = output_dims
-
-
-
-
-
diff --git a/core/data/deg_kair_utils/utils_option.py b/core/data/deg_kair_utils/utils_option.py
deleted file mode 100644
index cf096210e2d8ea553b06a91ac5cdaa21127d837c..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_option.py
+++ /dev/null
@@ -1,255 +0,0 @@
-import os
-from collections import OrderedDict
-from datetime import datetime
-import json
-import re
-import glob
-
-
-'''
-# --------------------------------------------
-# Kai Zhang (github: https://github.com/cszn)
-# 03/Mar/2019
-# --------------------------------------------
-# https://github.com/xinntao/BasicSR
-# --------------------------------------------
-'''
-
-
-def get_timestamp():
- return datetime.now().strftime('_%y%m%d_%H%M%S')
-
-
-def parse(opt_path, is_train=True):
-
- # ----------------------------------------
- # remove comments starting with '//'
- # ----------------------------------------
- json_str = ''
- with open(opt_path, 'r') as f:
- for line in f:
- line = line.split('//')[0] + '\n'
- json_str += line
-
- # ----------------------------------------
- # initialize opt
- # ----------------------------------------
- opt = json.loads(json_str, object_pairs_hook=OrderedDict)
-
- opt['opt_path'] = opt_path
- opt['is_train'] = is_train
-
- # ----------------------------------------
- # set default
- # ----------------------------------------
- if 'merge_bn' not in opt:
- opt['merge_bn'] = False
- opt['merge_bn_startpoint'] = -1
-
- if 'scale' not in opt:
- opt['scale'] = 1
-
- # ----------------------------------------
- # datasets
- # ----------------------------------------
- for phase, dataset in opt['datasets'].items():
- phase = phase.split('_')[0]
- dataset['phase'] = phase
- dataset['scale'] = opt['scale'] # broadcast
- dataset['n_channels'] = opt['n_channels'] # broadcast
- if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None:
- dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H'])
- if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None:
- dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L'])
-
- # ----------------------------------------
- # path
- # ----------------------------------------
- for key, path in opt['path'].items():
- if path and key in opt['path']:
- opt['path'][key] = os.path.expanduser(path)
-
- path_task = os.path.join(opt['path']['root'], opt['task'])
- opt['path']['task'] = path_task
- opt['path']['log'] = path_task
- opt['path']['options'] = os.path.join(path_task, 'options')
-
- if is_train:
- opt['path']['models'] = os.path.join(path_task, 'models')
- opt['path']['images'] = os.path.join(path_task, 'images')
- else: # test
- opt['path']['images'] = os.path.join(path_task, 'test_images')
-
- # ----------------------------------------
- # network
- # ----------------------------------------
- opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1
-
- # ----------------------------------------
- # GPU devices
- # ----------------------------------------
- gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
- os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
- print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
-
- # ----------------------------------------
- # default setting for distributeddataparallel
- # ----------------------------------------
- if 'find_unused_parameters' not in opt:
- opt['find_unused_parameters'] = True
- if 'use_static_graph' not in opt:
- opt['use_static_graph'] = False
- if 'dist' not in opt:
- opt['dist'] = False
- opt['num_gpu'] = len(opt['gpu_ids'])
- print('number of GPUs is: ' + str(opt['num_gpu']))
-
- # ----------------------------------------
- # default setting for perceptual loss
- # ----------------------------------------
- if 'F_feature_layer' not in opt['train']:
- opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34]
- if 'F_weights' not in opt['train']:
- opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0]
- if 'F_lossfn_type' not in opt['train']:
- opt['train']['F_lossfn_type'] = 'l1'
- if 'F_use_input_norm' not in opt['train']:
- opt['train']['F_use_input_norm'] = True
- if 'F_use_range_norm' not in opt['train']:
- opt['train']['F_use_range_norm'] = False
-
- # ----------------------------------------
- # default setting for optimizer
- # ----------------------------------------
- if 'G_optimizer_type' not in opt['train']:
- opt['train']['G_optimizer_type'] = "adam"
- if 'G_optimizer_betas' not in opt['train']:
- opt['train']['G_optimizer_betas'] = [0.9,0.999]
- if 'G_scheduler_restart_weights' not in opt['train']:
- opt['train']['G_scheduler_restart_weights'] = 1
- if 'G_optimizer_wd' not in opt['train']:
- opt['train']['G_optimizer_wd'] = 0
- if 'G_optimizer_reuse' not in opt['train']:
- opt['train']['G_optimizer_reuse'] = False
- if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']:
- opt['train']['D_optimizer_reuse'] = False
-
- # ----------------------------------------
- # default setting of strict for model loading
- # ----------------------------------------
- if 'G_param_strict' not in opt['train']:
- opt['train']['G_param_strict'] = True
- if 'netD' in opt and 'D_param_strict' not in opt['path']:
- opt['train']['D_param_strict'] = True
- if 'E_param_strict' not in opt['path']:
- opt['train']['E_param_strict'] = True
-
- # ----------------------------------------
- # Exponential Moving Average
- # ----------------------------------------
- if 'E_decay' not in opt['train']:
- opt['train']['E_decay'] = 0
-
- # ----------------------------------------
- # default setting for discriminator
- # ----------------------------------------
- if 'netD' in opt:
- if 'net_type' not in opt['netD']:
- opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet
- if 'in_nc' not in opt['netD']:
- opt['netD']['in_nc'] = 3
- if 'base_nc' not in opt['netD']:
- opt['netD']['base_nc'] = 64
- if 'n_layers' not in opt['netD']:
- opt['netD']['n_layers'] = 3
- if 'norm_type' not in opt['netD']:
- opt['netD']['norm_type'] = 'spectral'
-
-
- return opt
-
-
-def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
- """
- Args:
- save_dir: model folder
- net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
- pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
-
- Return:
- init_iter: iteration number
- init_path: model path
- """
- file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
- if file_list:
- iter_exist = []
- for file_ in file_list:
- iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
- iter_exist.append(int(iter_current[0]))
- init_iter = max(iter_exist)
- init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
- else:
- init_iter = 0
- init_path = pretrained_path
- return init_iter, init_path
-
-
-'''
-# --------------------------------------------
-# convert the opt into json file
-# --------------------------------------------
-'''
-
-
-def save(opt):
- opt_path = opt['opt_path']
- opt_path_copy = opt['path']['options']
- dirname, filename_ext = os.path.split(opt_path)
- filename, ext = os.path.splitext(filename_ext)
- dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext)
- with open(dump_path, 'w') as dump_file:
- json.dump(opt, dump_file, indent=2)
-
-
-'''
-# --------------------------------------------
-# dict to string for logger
-# --------------------------------------------
-'''
-
-
-def dict2str(opt, indent_l=1):
- msg = ''
- for k, v in opt.items():
- if isinstance(v, dict):
- msg += ' ' * (indent_l * 2) + k + ':[\n'
- msg += dict2str(v, indent_l + 1)
- msg += ' ' * (indent_l * 2) + ']\n'
- else:
- msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
- return msg
-
-
-'''
-# --------------------------------------------
-# convert OrderedDict to NoneDict,
-# return None for missing key
-# --------------------------------------------
-'''
-
-
-def dict_to_nonedict(opt):
- if isinstance(opt, dict):
- new_opt = dict()
- for key, sub_opt in opt.items():
- new_opt[key] = dict_to_nonedict(sub_opt)
- return NoneDict(**new_opt)
- elif isinstance(opt, list):
- return [dict_to_nonedict(sub_opt) for sub_opt in opt]
- else:
- return opt
-
-
-class NoneDict(dict):
- def __missing__(self, key):
- return None
diff --git a/core/data/deg_kair_utils/utils_params.py b/core/data/deg_kair_utils/utils_params.py
deleted file mode 100644
index def1cb79e11472b9b8ebbaae4bd83e7216af2ccb..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_params.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import torch
-
-import torchvision
-
-from models import basicblock as B
-
-def show_kv(net):
- for k, v in net.items():
- print(k)
-
-# should run train debug mode first to get an initial model
-#crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth')
-#
-#for k, v in crt_net.items():
-# print(k)
-#for k, v in crt_net.items():
-# if k in pretrained_net:
-# crt_net[k] = pretrained_net[k]
-# print('replace ... ', k)
-
-# x2 -> x4
-#crt_net['model.5.weight'] = pretrained_net['model.2.weight']
-#crt_net['model.5.bias'] = pretrained_net['model.2.bias']
-#crt_net['model.8.weight'] = pretrained_net['model.5.weight']
-#crt_net['model.8.bias'] = pretrained_net['model.5.bias']
-#crt_net['model.10.weight'] = pretrained_net['model.7.weight']
-#crt_net['model.10.bias'] = pretrained_net['model.7.bias']
-#torch.save(crt_net, '../pretrained_tmp.pth')
-
-# x2 -> x3
-'''
-in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3
-new_filter = torch.Tensor(576, 64, 3, 3)
-new_filter[0:256, :, :, :] = in_filter
-new_filter[256:512, :, :, :] = in_filter
-new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :]
-crt_net['model.2.weight'] = new_filter
-
-in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3
-new_bias = torch.Tensor(576)
-new_bias[0:256] = in_bias
-new_bias[256:512] = in_bias
-new_bias[512:] = in_bias[0:576 - 512]
-crt_net['model.2.bias'] = new_bias
-
-torch.save(crt_net, '../pretrained_tmp.pth')
-'''
-
-# x2 -> x8
-'''
-crt_net['model.5.weight'] = pretrained_net['model.2.weight']
-crt_net['model.5.bias'] = pretrained_net['model.2.bias']
-crt_net['model.8.weight'] = pretrained_net['model.2.weight']
-crt_net['model.8.bias'] = pretrained_net['model.2.bias']
-crt_net['model.11.weight'] = pretrained_net['model.5.weight']
-crt_net['model.11.bias'] = pretrained_net['model.5.bias']
-crt_net['model.13.weight'] = pretrained_net['model.7.weight']
-crt_net['model.13.bias'] = pretrained_net['model.7.bias']
-torch.save(crt_net, '../pretrained_tmp.pth')
-'''
-
-# x3/4/8 RGB -> Y
-
-def rgb2gray_net(net, only_input=True):
-
- if only_input:
- in_filter = net['0.weight']
- in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114
- in_new_filter.unsqueeze_(1)
- net['0.weight'] = in_new_filter
-
-# out_filter = pretrained_net['model.13.weight']
-# out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \
-# out_filter[2, :, :, :] * 0.114
-# out_new_filter.unsqueeze_(0)
-# crt_net['model.13.weight'] = out_new_filter
-# out_bias = pretrained_net['model.13.bias']
-# out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114
-# out_new_bias = torch.Tensor(1).fill_(out_new_bias)
-# crt_net['model.13.bias'] = out_new_bias
-
-# torch.save(crt_net, '../pretrained_tmp.pth')
-
- return net
-
-
-
-if __name__ == '__main__':
-
- net = torchvision.models.vgg19(pretrained=True)
- for k,v in net.features.named_parameters():
- if k=='0.weight':
- in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114
- in_new_filter.unsqueeze_(1)
- v = in_new_filter
- print(v.shape)
- print(v[0,0,0,0])
- if k=='0.bias':
- in_new_bias = v
- print(v[0])
-
- print(net.features[0])
-
- net.features[0] = B.conv(1, 64, mode='C')
-
- print(net.features[0])
- net.features[0].weight.data=in_new_filter
- net.features[0].bias.data=in_new_bias
-
- for k,v in net.features.named_parameters():
- if k=='0.weight':
- print(v[0,0,0,0])
- if k=='0.bias':
- print(v[0])
-
- # transfer parameters of old model to new one
- model_old = torch.load(model_path)
- state_dict = model.state_dict()
- for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
- state_dict[key2] = param
- print([key, key2])
- # print([param.size(), param2.size()])
- torch.save(state_dict, 'model_new.pth')
-
-
- # rgb2gray_net(net)
-
-
-
-
-
-
-
-
-
diff --git a/core/data/deg_kair_utils/utils_receptivefield.py b/core/data/deg_kair_utils/utils_receptivefield.py
deleted file mode 100644
index 82ad613b9e744189e13b721a558dbc0f42c57b30..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_receptivefield.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# -*- coding: utf-8 -*-
-
-# online calculation: https://fomoro.com/research/article/receptive-field-calculator#
-
-# [filter size, stride, padding]
-#Assume the two dimensions are the same
-#Each kernel requires the following parameters:
-# - k_i: kernel size
-# - s_i: stride
-# - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow)
-#
-#Each layer i requires the following parameters to be fully represented:
-# - n_i: number of feature (data layer has n_1 = imagesize )
-# - j_i: distance (projected to image pixel distance) between center of two adjacent features
-# - r_i: receptive field of a feature in layer i
-# - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding)
-
-import math
-
-def outFromIn(conv, layerIn):
- n_in = layerIn[0]
- j_in = layerIn[1]
- r_in = layerIn[2]
- start_in = layerIn[3]
- k = conv[0]
- s = conv[1]
- p = conv[2]
-
- n_out = math.floor((n_in - k + 2*p)/s) + 1
- actualP = (n_out-1)*s - n_in + k
- pR = math.ceil(actualP/2)
- pL = math.floor(actualP/2)
-
- j_out = j_in * s
- r_out = r_in + (k - 1)*j_in
- start_out = start_in + ((k-1)/2 - pL)*j_in
- return n_out, j_out, r_out, start_out
-
-def printLayer(layer, layer_name):
- print(layer_name + ":")
- print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3]))
-
-
-
-layerInfos = []
-if __name__ == '__main__':
-
- convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]]
- layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12']
- imsize = 128
-
- print ("-------Net summary------")
- currentLayer = [imsize, 1, 1, 0.5]
- printLayer(currentLayer, "input image")
- for i in range(len(convnet)):
- currentLayer = outFromIn(convnet[i], currentLayer)
- layerInfos.append(currentLayer)
- printLayer(currentLayer, layer_names[i])
-
-
-# run utils/utils_receptivefield.py
-
\ No newline at end of file
diff --git a/core/data/deg_kair_utils/utils_regularizers.py b/core/data/deg_kair_utils/utils_regularizers.py
deleted file mode 100644
index 17e7c8524b716f36e10b41d72fee2e375af69454..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_regularizers.py
+++ /dev/null
@@ -1,104 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-'''
-# --------------------------------------------
-# Kai Zhang (github: https://github.com/cszn)
-# 03/Mar/2019
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# SVD Orthogonal Regularization
-# --------------------------------------------
-def regularizer_orth(m):
- """
- # ----------------------------------------
- # SVD Orthogonal Regularization
- # ----------------------------------------
- # Applies regularization to the training by performing the
- # orthogonalization technique described in the paper
- # This function is to be called by the torch.nn.Module.apply() method,
- # which applies svd_orthogonalization() to every layer of the model.
- # usage: net.apply(regularizer_orth)
- # ----------------------------------------
- """
- classname = m.__class__.__name__
- if classname.find('Conv') != -1:
- w = m.weight.data.clone()
- c_out, c_in, f1, f2 = w.size()
- # dtype = m.weight.data.type()
- w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
- # self.netG.apply(svd_orthogonalization)
- u, s, v = torch.svd(w)
- s[s > 1.5] = s[s > 1.5] - 1e-4
- s[s < 0.5] = s[s < 0.5] + 1e-4
- w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
- m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
- else:
- pass
-
-
-# --------------------------------------------
-# SVD Orthogonal Regularization
-# --------------------------------------------
-def regularizer_orth2(m):
- """
- # ----------------------------------------
- # Applies regularization to the training by performing the
- # orthogonalization technique described in the paper
- # This function is to be called by the torch.nn.Module.apply() method,
- # which applies svd_orthogonalization() to every layer of the model.
- # usage: net.apply(regularizer_orth2)
- # ----------------------------------------
- """
- classname = m.__class__.__name__
- if classname.find('Conv') != -1:
- w = m.weight.data.clone()
- c_out, c_in, f1, f2 = w.size()
- # dtype = m.weight.data.type()
- w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
- u, s, v = torch.svd(w)
- s_mean = s.mean()
- s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4
- s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4
- w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
- m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
- else:
- pass
-
-
-
-def regularizer_clip(m):
- """
- # ----------------------------------------
- # usage: net.apply(regularizer_clip)
- # ----------------------------------------
- """
- eps = 1e-4
- c_min = -1.5
- c_max = 1.5
-
- classname = m.__class__.__name__
- if classname.find('Conv') != -1 or classname.find('Linear') != -1:
- w = m.weight.data.clone()
- w[w > c_max] -= eps
- w[w < c_min] += eps
- m.weight.data = w
-
- if m.bias is not None:
- b = m.bias.data.clone()
- b[b > c_max] -= eps
- b[b < c_min] += eps
- m.bias.data = b
-
-# elif classname.find('BatchNorm2d') != -1:
-#
-# rv = m.running_var.data.clone()
-# rm = m.running_mean.data.clone()
-#
-# if m.affine:
-# m.weight.data
-# m.bias.data
diff --git a/core/data/deg_kair_utils/utils_sisr.py b/core/data/deg_kair_utils/utils_sisr.py
deleted file mode 100644
index e9edbd72ce53351d9e306c9774073a0e2eb0bdb3..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_sisr.py
+++ /dev/null
@@ -1,848 +0,0 @@
-# -*- coding: utf-8 -*-
-from utils import utils_image as util
-import random
-
-import scipy
-import scipy.stats as ss
-import scipy.io as io
-from scipy import ndimage
-from scipy.interpolate import interp2d
-
-import numpy as np
-import torch
-
-
-"""
-# --------------------------------------------
-# Super-Resolution
-# --------------------------------------------
-#
-# Kai Zhang (cskaizhang@gmail.com)
-# https://github.com/cszn
-# modified by Kai Zhang (github: https://github.com/cszn)
-# 03/03/2020
-# --------------------------------------------
-"""
-
-
-"""
-# --------------------------------------------
-# anisotropic Gaussian kernels
-# --------------------------------------------
-"""
-
-
-def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
- """ generate an anisotropic Gaussian kernel
- Args:
- ksize : e.g., 15, kernel size
- theta : [0, pi], rotation angle range
- l1 : [0.1,50], scaling of eigenvalues
- l2 : [0.1,l1], scaling of eigenvalues
- If l1 = l2, will get an isotropic Gaussian kernel.
- Returns:
- k : kernel
- """
-
- v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
- V = np.array([[v[0], v[1]], [v[1], -v[0]]])
- D = np.array([[l1, 0], [0, l2]])
- Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
- k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
-
- return k
-
-
-def gm_blur_kernel(mean, cov, size=15):
- center = size / 2.0 + 0.5
- k = np.zeros([size, size])
- for y in range(size):
- for x in range(size):
- cy = y - center + 1
- cx = x - center + 1
- k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
-
- k = k / np.sum(k)
- return k
-
-
-"""
-# --------------------------------------------
-# calculate PCA projection matrix
-# --------------------------------------------
-"""
-
-
-def get_pca_matrix(x, dim_pca=15):
- """
- Args:
- x: 225x10000 matrix
- dim_pca: 15
- Returns:
- pca_matrix: 15x225
- """
- C = np.dot(x, x.T)
- w, v = scipy.linalg.eigh(C)
- pca_matrix = v[:, -dim_pca:].T
-
- return pca_matrix
-
-
-def show_pca(x):
- """
- x: PCA projection matrix, e.g., 15x225
- """
- for i in range(x.shape[0]):
- xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F")
- util.surf(xc)
-
-
-def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500):
- kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32)
- for i in range(num_samples):
-
- theta = np.pi*np.random.rand(1)
- l1 = 0.1+l_max*np.random.rand(1)
- l2 = 0.1+(l1-0.1)*np.random.rand(1)
-
- k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0])
-
- # util.imshow(k)
-
- kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F')
-
- # io.savemat('k.mat', {'k': kernels})
-
- pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca)
-
- io.savemat(path, {'p': pca_matrix})
-
- return pca_matrix
-
-
-"""
-# --------------------------------------------
-# shifted anisotropic Gaussian kernels
-# --------------------------------------------
-"""
-
-
-def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
- """"
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
- # Kai Zhang
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
- # max_var = 2.5 * sf
- """
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
- theta = np.random.rand() * np.pi # random theta
- noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
-
- # Set COV matrix using Lambdas and Theta
- LAMBDA = np.diag([lambda_1, lambda_2])
- Q = np.array([[np.cos(theta), -np.sin(theta)],
- [np.sin(theta), np.cos(theta)]])
- SIGMA = Q @ LAMBDA @ Q.T
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
-
- # Set expectation position (shifting kernel for aligned image)
- MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
- MU = MU[None, None, :, None]
-
- # Create meshgrid for Gaussian
- [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
- Z = np.stack([X, Y], 2)[:, :, :, None]
-
- # Calcualte Gaussian for every pixel of the kernel
- ZZ = Z-MU
- ZZ_t = ZZ.transpose(0,1,3,2)
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
-
- # shift the kernel so it will be centered
- #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
-
- # Normalize the kernel and return
- #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
- kernel = raw_kernel / np.sum(raw_kernel)
- return kernel
-
-
-def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0):
- """"
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
- # Kai Zhang
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
- # max_var = 2.5 * sf
- """
- sf = random.choice([1, 2, 3, 4])
- scale_factor = np.array([sf, sf])
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
- theta = np.random.rand() * np.pi # random theta
- noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2
-
- # Set COV matrix using Lambdas and Theta
- LAMBDA = np.diag([lambda_1, lambda_2])
- Q = np.array([[np.cos(theta), -np.sin(theta)],
- [np.sin(theta), np.cos(theta)]])
- SIGMA = Q @ LAMBDA @ Q.T
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
-
- # Set expectation position (shifting kernel for aligned image)
- MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
- MU = MU[None, None, :, None]
-
- # Create meshgrid for Gaussian
- [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
- Z = np.stack([X, Y], 2)[:, :, :, None]
-
- # Calcualte Gaussian for every pixel of the kernel
- ZZ = Z-MU
- ZZ_t = ZZ.transpose(0,1,3,2)
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
-
- # shift the kernel so it will be centered
- #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
-
- # Normalize the kernel and return
- #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
- kernel = raw_kernel / np.sum(raw_kernel)
- return kernel
-
-
-"""
-# --------------------------------------------
-# degradation models
-# --------------------------------------------
-"""
-
-
-def bicubic_degradation(x, sf=3):
- '''
- Args:
- x: HxWxC image, [0, 1]
- sf: down-scale factor
- Return:
- bicubicly downsampled LR image
- '''
- x = util.imresize_np(x, scale=1/sf)
- return x
-
-
-def srmd_degradation(x, k, sf=3):
- ''' blur + bicubic downsampling
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2018learning,
- title={Learning a single convolutional super-resolution network for multiple degradations},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={3262--3271},
- year={2018}
- }
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
- x = bicubic_degradation(x, sf=sf)
- return x
-
-
-def dpsr_degradation(x, k, sf=3):
-
- ''' bicubic downsampling + blur
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2019deep,
- title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={1671--1681},
- year={2019}
- }
- '''
- x = bicubic_degradation(x, sf=sf)
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- return x
-
-
-def classical_degradation(x, k, sf=3):
- ''' blur + downsampling
-
- Args:
- x: HxWxC image, [0, 1]/[0, 255]
- k: hxw, double
- sf: down-scale factor
-
- Return:
- downsampled LR image
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
- st = 0
- return x[st::sf, st::sf, ...]
-
-
-def modcrop_np(img, sf):
- '''
- Args:
- img: numpy image, WxH or WxHxC
- sf: scale factor
- Return:
- cropped image
- '''
- w, h = img.shape[:2]
- im = np.copy(img)
- return im[:w - w % sf, :h - h % sf, ...]
-
-
-'''
-# =================
-# Numpy
-# =================
-'''
-
-
-def shift_pixel(x, sf, upper_left=True):
- """shift pixel for super-resolution with different scale factors
- Args:
- x: WxHxC or WxH, image or kernel
- sf: scale factor
- upper_left: shift direction
- """
- h, w = x.shape[:2]
- shift = (sf-1)*0.5
- xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
- if upper_left:
- x1 = xv + shift
- y1 = yv + shift
- else:
- x1 = xv - shift
- y1 = yv - shift
-
- x1 = np.clip(x1, 0, w-1)
- y1 = np.clip(y1, 0, h-1)
-
- if x.ndim == 2:
- x = interp2d(xv, yv, x)(x1, y1)
- if x.ndim == 3:
- for i in range(x.shape[-1]):
- x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
-
- return x
-
-
-'''
-# =================
-# pytorch
-# =================
-'''
-
-
-def splits(a, sf):
- '''
- a: tensor NxCxWxHx2
- sf: scale factor
- out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2)
- '''
- b = torch.stack(torch.chunk(a, sf, dim=2), dim=5)
- b = torch.cat(torch.chunk(b, sf, dim=3), dim=5)
- return b
-
-
-def c2c(x):
- return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
-
-
-def r2c(x):
- return torch.stack([x, torch.zeros_like(x)], -1)
-
-
-def cdiv(x, y):
- a, b = x[..., 0], x[..., 1]
- c, d = y[..., 0], y[..., 1]
- cd2 = c**2 + d**2
- return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
-
-
-def csum(x, y):
- return torch.stack([x[..., 0] + y, x[..., 1]], -1)
-
-
-def cabs(x):
- return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
-
-
-def cmul(t1, t2):
- '''
- complex multiplication
- t1: NxCxHxWx2
- output: NxCxHxWx2
- '''
- real1, imag1 = t1[..., 0], t1[..., 1]
- real2, imag2 = t2[..., 0], t2[..., 1]
- return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
-
-
-def cconj(t, inplace=False):
- '''
- # complex's conjugation
- t: NxCxHxWx2
- output: NxCxHxWx2
- '''
- c = t.clone() if not inplace else t
- c[..., 1] *= -1
- return c
-
-
-def rfft(t):
- return torch.rfft(t, 2, onesided=False)
-
-
-def irfft(t):
- return torch.irfft(t, 2, onesided=False)
-
-
-def fft(t):
- return torch.fft(t, 2)
-
-
-def ifft(t):
- return torch.ifft(t, 2)
-
-
-def p2o(psf, shape):
- '''
- Args:
- psf: NxCxhxw
- shape: [H,W]
-
- Returns:
- otf: NxCxHxWx2
- '''
- otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
- otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
- for axis, axis_size in enumerate(psf.shape[2:]):
- otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
- otf = torch.rfft(otf, 2, onesided=False)
- n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
- otf[...,1][torch.abs(otf[...,1]) x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
- '''
- x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
- x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
- x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
- x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
- return x
-
-
-def pad_circular(input, padding):
- # type: (Tensor, List[int]) -> Tensor
- """
- Arguments
- :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
- :param padding: (tuple): m-elem tuple where m is the degree of convolution
- Returns
- :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
- H + 2 * padding[1]], W + 2 * padding[2]))`
- """
- offset = 3
- for dimension in range(input.dim() - offset + 1):
- input = dim_pad_circular(input, padding[dimension], dimension + offset)
- return input
-
-
-def dim_pad_circular(input, padding, dimension):
- # type: (Tensor, int, int) -> Tensor
- input = torch.cat([input, input[[slice(None)] * (dimension - 1) +
- [slice(0, padding)]]], dim=dimension - 1)
- input = torch.cat([input[[slice(None)] * (dimension - 1) +
- [slice(-2 * padding, -padding)]], input], dim=dimension - 1)
- return input
-
-
-def imfilter(x, k):
- '''
- x: image, NxcxHxW
- k: kernel, cx1xhxw
- '''
- x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2))
- x = torch.nn.functional.conv2d(x, k, groups=x.shape[1])
- return x
-
-
-def G(x, k, sf=3, center=False):
- '''
- x: image, NxcxHxW
- k: kernel, cx1xhxw
- sf: scale factor
- center: the first one or the moddle one
-
- Matlab function:
- tmp = imfilter(x,h,'circular');
- y = downsample2(tmp,K);
- '''
- x = downsample(imfilter(x, k), sf=sf, center=center)
- return x
-
-
-def Gt(x, k, sf=3, center=False):
- '''
- x: image, NxcxHxW
- k: kernel, cx1xhxw
- sf: scale factor
- center: the first one or the moddle one
-
- Matlab function:
- tmp = upsample2(x,K);
- y = imfilter(tmp,h,'circular');
- '''
- x = imfilter(upsample(x, sf=sf, center=center), k)
- return x
-
-
-def interpolation_down(x, sf, center=False):
- mask = torch.zeros_like(x)
- if center:
- start = torch.tensor((sf-1)//2)
- mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x)
- LR = x[..., start::sf, start::sf]
- else:
- mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x)
- LR = x[..., ::sf, ::sf]
- y = x.mul(mask)
-
- return LR, y, mask
-
-
-'''
-# =================
-Numpy
-# =================
-'''
-
-
-def blockproc(im, blocksize, fun):
- xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0)
- xblocks_proc = []
- for xb in xblocks:
- yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1)
- yblocks_proc = []
- for yb in yblocks:
- yb_proc = fun(yb)
- yblocks_proc.append(yb_proc)
- xblocks_proc.append(np.concatenate(yblocks_proc, axis=1))
-
- proc = np.concatenate(xblocks_proc, axis=0)
-
- return proc
-
-
-def fun_reshape(a):
- return np.reshape(a, (-1,1,a.shape[-1]), order='F')
-
-
-def fun_mul(a, b):
- return a*b
-
-
-def BlockMM(nr, nc, Nb, m, x1):
- '''
- myfun = @(block_struct) reshape(block_struct.data,m,1);
- x1 = blockproc(x1,[nr nc],myfun);
- x1 = reshape(x1,m,Nb);
- x1 = sum(x1,2);
- x = reshape(x1,nr,nc);
- '''
- fun = fun_reshape
- x1 = blockproc(x1, blocksize=(nr, nc), fun=fun)
- x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F')
- x1 = np.sum(x1, 1)
- x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F')
- return x
-
-
-def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m):
- '''
- x1 = FB.*FR;
- FBR = BlockMM(nr,nc,Nb,m,x1);
- invW = BlockMM(nr,nc,Nb,m,F2B);
- invWBR = FBR./(invW + tau*Nb);
- fun = @(block_struct) block_struct.data.*invWBR;
- FCBinvWBR = blockproc(FBC,[nr,nc],fun);
- FX = (FR-FCBinvWBR)/tau;
- Xest = real(ifft2(FX));
- '''
- x1 = FB*FR
- FBR = BlockMM(nr, nc, Nb, m, x1)
- invW = BlockMM(nr, nc, Nb, m, F2B)
- invWBR = FBR/(invW + tau*Nb)
- FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR))
- FX = (FR-FCBinvWBR)/tau
- Xest = np.real(np.fft.ifft2(FX, axes=(0, 1)))
- return Xest
-
-
-def psf2otf(psf, shape=None):
- """
- Convert point-spread function to optical transfer function.
- Compute the Fast Fourier Transform (FFT) of the point-spread
- function (PSF) array and creates the optical transfer function (OTF)
- array that is not influenced by the PSF off-centering.
- By default, the OTF array is the same size as the PSF array.
- To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
- post-pads the PSF array (down or to the right) with zeros to match
- dimensions specified in OUTSIZE, then circularly shifts the values of
- the PSF array up (or to the left) until the central pixel reaches (1,1)
- position.
- Parameters
- ----------
- psf : `numpy.ndarray`
- PSF array
- shape : int
- Output shape of the OTF array
- Returns
- -------
- otf : `numpy.ndarray`
- OTF array
- Notes
- -----
- Adapted from MATLAB psf2otf function
- """
- if type(shape) == type(None):
- shape = psf.shape
- shape = np.array(shape)
- if np.all(psf == 0):
- # return np.zeros_like(psf)
- return np.zeros(shape)
- if len(psf.shape) == 1:
- psf = psf.reshape((1, psf.shape[0]))
- inshape = psf.shape
- psf = zero_pad(psf, shape, position='corner')
- for axis, axis_size in enumerate(inshape):
- psf = np.roll(psf, -int(axis_size / 2), axis=axis)
- # Compute the OTF
- otf = np.fft.fft2(psf, axes=(0, 1))
- # Estimate the rough number of operations involved in the FFT
- # and discard the PSF imaginary part if within roundoff error
- # roundoff error = machine epsilon = sys.float_info.epsilon
- # or np.finfo().eps
- n_ops = np.sum(psf.size * np.log2(psf.shape))
- otf = np.real_if_close(otf, tol=n_ops)
- return otf
-
-
-def zero_pad(image, shape, position='corner'):
- """
- Extends image to a certain size with zeros
- Parameters
- ----------
- image: real 2d `numpy.ndarray`
- Input image
- shape: tuple of int
- Desired output shape of the image
- position : str, optional
- The position of the input image in the output one:
- * 'corner'
- top-left corner (default)
- * 'center'
- centered
- Returns
- -------
- padded_img: real `numpy.ndarray`
- The zero-padded image
- """
- shape = np.asarray(shape, dtype=int)
- imshape = np.asarray(image.shape, dtype=int)
- if np.alltrue(imshape == shape):
- return image
- if np.any(shape <= 0):
- raise ValueError("ZERO_PAD: null or negative shape given")
- dshape = shape - imshape
- if np.any(dshape < 0):
- raise ValueError("ZERO_PAD: target size smaller than source one")
- pad_img = np.zeros(shape, dtype=image.dtype)
- idx, idy = np.indices(imshape)
- if position == 'center':
- if np.any(dshape % 2 != 0):
- raise ValueError("ZERO_PAD: source and target shapes "
- "have different parity.")
- offx, offy = dshape // 2
- else:
- offx, offy = (0, 0)
- pad_img[idx + offx, idy + offy] = image
- return pad_img
-
-
-def upsample_np(x, sf=3, center=False):
- st = (sf-1)//2 if center else 0
- z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2]))
- z[st::sf, st::sf, ...] = x
- return z
-
-
-def downsample_np(x, sf=3, center=False):
- st = (sf-1)//2 if center else 0
- return x[st::sf, st::sf, ...]
-
-
-def imfilter_np(x, k):
- '''
- x: image, NxcxHxW
- k: kernel, cx1xhxw
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- return x
-
-
-def G_np(x, k, sf=3, center=False):
- '''
- x: image, NxcxHxW
- k: kernel, cx1xhxw
-
- Matlab function:
- tmp = imfilter(x,h,'circular');
- y = downsample2(tmp,K);
- '''
- x = downsample_np(imfilter_np(x, k), sf=sf, center=center)
- return x
-
-
-def Gt_np(x, k, sf=3, center=False):
- '''
- x: image, NxcxHxW
- k: kernel, cx1xhxw
-
- Matlab function:
- tmp = upsample2(x,K);
- y = imfilter(tmp,h,'circular');
- '''
- x = imfilter_np(upsample_np(x, sf=sf, center=center), k)
- return x
-
-
-if __name__ == '__main__':
- img = util.imread_uint('test.bmp', 3)
-
- img = util.uint2single(img)
- k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6)
- util.imshow(k*10)
-
-
- for sf in [2, 3, 4]:
-
- # modcrop
- img = modcrop_np(img, sf=sf)
-
- # 1) bicubic degradation
- img_b = bicubic_degradation(img, sf=sf)
- print(img_b.shape)
-
- # 2) srmd degradation
- img_s = srmd_degradation(img, k, sf=sf)
- print(img_s.shape)
-
- # 3) dpsr degradation
- img_d = dpsr_degradation(img, k, sf=sf)
- print(img_d.shape)
-
- # 4) classical degradation
- img_d = classical_degradation(img, k, sf=sf)
- print(img_d.shape)
-
- k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01)
- #print(k)
-# util.imshow(k*10)
-
- k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0)
-# util.imshow(k*10)
-
-
- # PCA
-# pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500)
-# print(pca_matrix.shape)
-# show_pca(pca_matrix)
- # run utils/utils_sisr.py
- # run utils_sisr.py
-
-
-
-
-
-
-
diff --git a/core/data/deg_kair_utils/utils_video.py b/core/data/deg_kair_utils/utils_video.py
deleted file mode 100644
index 596dd4203098cf7b36f3d8499ccbf299623381ae..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_video.py
+++ /dev/null
@@ -1,493 +0,0 @@
-import os
-import cv2
-import numpy as np
-import torch
-import random
-from os import path as osp
-from torch.nn import functional as F
-from abc import ABCMeta, abstractmethod
-
-
-def scandir(dir_path, suffix=None, recursive=False, full_path=False):
- """Scan a directory to find the interested files.
-
- Args:
- dir_path (str): Path of the directory.
- suffix (str | tuple(str), optional): File suffix that we are
- interested in. Default: None.
- recursive (bool, optional): If set to True, recursively scan the
- directory. Default: False.
- full_path (bool, optional): If set to True, include the dir_path.
- Default: False.
-
- Returns:
- A generator for all the interested files with relative paths.
- """
-
- if (suffix is not None) and not isinstance(suffix, (str, tuple)):
- raise TypeError('"suffix" must be a string or tuple of strings')
-
- root = dir_path
-
- def _scandir(dir_path, suffix, recursive):
- for entry in os.scandir(dir_path):
- if not entry.name.startswith('.') and entry.is_file():
- if full_path:
- return_path = entry.path
- else:
- return_path = osp.relpath(entry.path, root)
-
- if suffix is None:
- yield return_path
- elif return_path.endswith(suffix):
- yield return_path
- else:
- if recursive:
- yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
- else:
- continue
-
- return _scandir(dir_path, suffix=suffix, recursive=recursive)
-
-
-def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
- """Read a sequence of images from a given folder path.
-
- Args:
- path (list[str] | str): List of image paths or image folder path.
- require_mod_crop (bool): Require mod crop for each image.
- Default: False.
- scale (int): Scale factor for mod_crop. Default: 1.
- return_imgname(bool): Whether return image names. Default False.
-
- Returns:
- Tensor: size (t, c, h, w), RGB, [0, 1].
- list[str]: Returned image name list.
- """
- if isinstance(path, list):
- img_paths = path
- else:
- img_paths = sorted(list(scandir(path, full_path=True)))
- imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
-
- if require_mod_crop:
- imgs = [mod_crop(img, scale) for img in imgs]
- imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
- imgs = torch.stack(imgs, dim=0)
-
- if return_imgname:
- imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
- return imgs, imgnames
- else:
- return imgs
-
-
-def img2tensor(imgs, bgr2rgb=True, float32=True):
- """Numpy array to tensor.
-
- Args:
- imgs (list[ndarray] | ndarray): Input images.
- bgr2rgb (bool): Whether to change bgr to rgb.
- float32 (bool): Whether to change to float32.
-
- Returns:
- list[tensor] | tensor: Tensor images. If returned results only have
- one element, just return tensor.
- """
-
- def _totensor(img, bgr2rgb, float32):
- if img.shape[2] == 3 and bgr2rgb:
- if img.dtype == 'float64':
- img = img.astype('float32')
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- img = torch.from_numpy(img.transpose(2, 0, 1))
- if float32:
- img = img.float()
- return img
-
- if isinstance(imgs, list):
- return [_totensor(img, bgr2rgb, float32) for img in imgs]
- else:
- return _totensor(imgs, bgr2rgb, float32)
-
-
-def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
- """Convert torch Tensors into image numpy arrays.
-
- After clamping to [min, max], values will be normalized to [0, 1].
-
- Args:
- tensor (Tensor or list[Tensor]): Accept shapes:
- 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
- 2) 3D Tensor of shape (3/1 x H x W);
- 3) 2D Tensor of shape (H x W).
- Tensor channel should be in RGB order.
- rgb2bgr (bool): Whether to change rgb to bgr.
- out_type (numpy type): output types. If ``np.uint8``, transform outputs
- to uint8 type with range [0, 255]; otherwise, float type with
- range [0, 1]. Default: ``np.uint8``.
- min_max (tuple[int]): min and max values for clamp.
-
- Returns:
- (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
- shape (H x W). The channel order is BGR.
- """
- if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
- raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
-
- if torch.is_tensor(tensor):
- tensor = [tensor]
- result = []
- for _tensor in tensor:
- _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
- _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
-
- n_dim = _tensor.dim()
- if n_dim == 4:
- img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
- img_np = img_np.transpose(1, 2, 0)
- if rgb2bgr:
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
- elif n_dim == 3:
- img_np = _tensor.numpy()
- img_np = img_np.transpose(1, 2, 0)
- if img_np.shape[2] == 1: # gray image
- img_np = np.squeeze(img_np, axis=2)
- else:
- if rgb2bgr:
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
- elif n_dim == 2:
- img_np = _tensor.numpy()
- else:
- raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
- if out_type == np.uint8:
- # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
- img_np = (img_np * 255.0).round()
- img_np = img_np.astype(out_type)
- result.append(img_np)
- if len(result) == 1:
- result = result[0]
- return result
-
-
-def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
- """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
-
- We use vertical flip and transpose for rotation implementation.
- All the images in the list use the same augmentation.
-
- Args:
- imgs (list[ndarray] | ndarray): Images to be augmented. If the input
- is an ndarray, it will be transformed to a list.
- hflip (bool): Horizontal flip. Default: True.
- rotation (bool): Ratotation. Default: True.
- flows (list[ndarray]: Flows to be augmented. If the input is an
- ndarray, it will be transformed to a list.
- Dimension is (h, w, 2). Default: None.
- return_status (bool): Return the status of flip and rotation.
- Default: False.
-
- Returns:
- list[ndarray] | ndarray: Augmented images and flows. If returned
- results only have one element, just return ndarray.
-
- """
- hflip = hflip and random.random() < 0.5
- vflip = rotation and random.random() < 0.5
- rot90 = rotation and random.random() < 0.5
-
- def _augment(img):
- if hflip: # horizontal
- cv2.flip(img, 1, img)
- if vflip: # vertical
- cv2.flip(img, 0, img)
- if rot90:
- img = img.transpose(1, 0, 2)
- return img
-
- def _augment_flow(flow):
- if hflip: # horizontal
- cv2.flip(flow, 1, flow)
- flow[:, :, 0] *= -1
- if vflip: # vertical
- cv2.flip(flow, 0, flow)
- flow[:, :, 1] *= -1
- if rot90:
- flow = flow.transpose(1, 0, 2)
- flow = flow[:, :, [1, 0]]
- return flow
-
- if not isinstance(imgs, list):
- imgs = [imgs]
- imgs = [_augment(img) for img in imgs]
- if len(imgs) == 1:
- imgs = imgs[0]
-
- if flows is not None:
- if not isinstance(flows, list):
- flows = [flows]
- flows = [_augment_flow(flow) for flow in flows]
- if len(flows) == 1:
- flows = flows[0]
- return imgs, flows
- else:
- if return_status:
- return imgs, (hflip, vflip, rot90)
- else:
- return imgs
-
-
-def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
- """Paired random crop. Support Numpy array and Tensor inputs.
-
- It crops lists of lq and gt images with corresponding locations.
-
- Args:
- img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
- should have the same shape. If the input is an ndarray, it will
- be transformed to a list containing itself.
- img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
- should have the same shape. If the input is an ndarray, it will
- be transformed to a list containing itself.
- gt_patch_size (int): GT patch size.
- scale (int): Scale factor.
- gt_path (str): Path to ground-truth. Default: None.
-
- Returns:
- list[ndarray] | ndarray: GT images and LQ images. If returned results
- only have one element, just return ndarray.
- """
-
- if not isinstance(img_gts, list):
- img_gts = [img_gts]
- if not isinstance(img_lqs, list):
- img_lqs = [img_lqs]
-
- # determine input type: Numpy array or Tensor
- input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
-
- if input_type == 'Tensor':
- h_lq, w_lq = img_lqs[0].size()[-2:]
- h_gt, w_gt = img_gts[0].size()[-2:]
- else:
- h_lq, w_lq = img_lqs[0].shape[0:2]
- h_gt, w_gt = img_gts[0].shape[0:2]
- lq_patch_size = gt_patch_size // scale
-
- if h_gt != h_lq * scale or w_gt != w_lq * scale:
- raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
- f'multiplication of LQ ({h_lq}, {w_lq}).')
- if h_lq < lq_patch_size or w_lq < lq_patch_size:
- raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
- f'({lq_patch_size}, {lq_patch_size}). '
- f'Please remove {gt_path}.')
-
- # randomly choose top and left coordinates for lq patch
- top = random.randint(0, h_lq - lq_patch_size)
- left = random.randint(0, w_lq - lq_patch_size)
-
- # crop lq patch
- if input_type == 'Tensor':
- img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
- else:
- img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
-
- # crop corresponding gt patch
- top_gt, left_gt = int(top * scale), int(left * scale)
- if input_type == 'Tensor':
- img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
- else:
- img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
- if len(img_gts) == 1:
- img_gts = img_gts[0]
- if len(img_lqs) == 1:
- img_lqs = img_lqs[0]
- return img_gts, img_lqs
-
-
-# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
-class BaseStorageBackend(metaclass=ABCMeta):
- """Abstract class of storage backends.
-
- All backends need to implement two apis: ``get()`` and ``get_text()``.
- ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
- as texts.
- """
-
- @abstractmethod
- def get(self, filepath):
- pass
-
- @abstractmethod
- def get_text(self, filepath):
- pass
-
-
-class MemcachedBackend(BaseStorageBackend):
- """Memcached storage backend.
-
- Attributes:
- server_list_cfg (str): Config file for memcached server list.
- client_cfg (str): Config file for memcached client.
- sys_path (str | None): Additional path to be appended to `sys.path`.
- Default: None.
- """
-
- def __init__(self, server_list_cfg, client_cfg, sys_path=None):
- if sys_path is not None:
- import sys
- sys.path.append(sys_path)
- try:
- import mc
- except ImportError:
- raise ImportError('Please install memcached to enable MemcachedBackend.')
-
- self.server_list_cfg = server_list_cfg
- self.client_cfg = client_cfg
- self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
- # mc.pyvector servers as a point which points to a memory cache
- self._mc_buffer = mc.pyvector()
-
- def get(self, filepath):
- filepath = str(filepath)
- import mc
- self._client.Get(filepath, self._mc_buffer)
- value_buf = mc.ConvertBuffer(self._mc_buffer)
- return value_buf
-
- def get_text(self, filepath):
- raise NotImplementedError
-
-
-class HardDiskBackend(BaseStorageBackend):
- """Raw hard disks storage backend."""
-
- def get(self, filepath):
- filepath = str(filepath)
- with open(filepath, 'rb') as f:
- value_buf = f.read()
- return value_buf
-
- def get_text(self, filepath):
- filepath = str(filepath)
- with open(filepath, 'r') as f:
- value_buf = f.read()
- return value_buf
-
-
-class LmdbBackend(BaseStorageBackend):
- """Lmdb storage backend.
-
- Args:
- db_paths (str | list[str]): Lmdb database paths.
- client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
- readonly (bool, optional): Lmdb environment parameter. If True,
- disallow any write operations. Default: True.
- lock (bool, optional): Lmdb environment parameter. If False, when
- concurrent access occurs, do not lock the database. Default: False.
- readahead (bool, optional): Lmdb environment parameter. If False,
- disable the OS filesystem readahead mechanism, which may improve
- random read performance when a database is larger than RAM.
- Default: False.
-
- Attributes:
- db_paths (list): Lmdb database path.
- _client (list): A list of several lmdb envs.
- """
-
- def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
- try:
- import lmdb
- except ImportError:
- raise ImportError('Please install lmdb to enable LmdbBackend.')
-
- if isinstance(client_keys, str):
- client_keys = [client_keys]
-
- if isinstance(db_paths, list):
- self.db_paths = [str(v) for v in db_paths]
- elif isinstance(db_paths, str):
- self.db_paths = [str(db_paths)]
- assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
- f'but received {len(client_keys)} and {len(self.db_paths)}.')
-
- self._client = {}
- for client, path in zip(client_keys, self.db_paths):
- self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
-
- def get(self, filepath, client_key):
- """Get values according to the filepath from one lmdb named client_key.
-
- Args:
- filepath (str | obj:`Path`): Here, filepath is the lmdb key.
- client_key (str): Used for distinguishing different lmdb envs.
- """
- filepath = str(filepath)
- assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
- client = self._client[client_key]
- with client.begin(write=False) as txn:
- value_buf = txn.get(filepath.encode('ascii'))
- return value_buf
-
- def get_text(self, filepath):
- raise NotImplementedError
-
-
-class FileClient(object):
- """A general file client to access files in different backend.
-
- The client loads a file or text in a specified backend from its path
- and return it as a binary file. it can also register other backend
- accessor with a given name and backend class.
-
- Attributes:
- backend (str): The storage backend type. Options are "disk",
- "memcached" and "lmdb".
- client (:obj:`BaseStorageBackend`): The backend object.
- """
-
- _backends = {
- 'disk': HardDiskBackend,
- 'memcached': MemcachedBackend,
- 'lmdb': LmdbBackend,
- }
-
- def __init__(self, backend='disk', **kwargs):
- if backend not in self._backends:
- raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
- f' are {list(self._backends.keys())}')
- self.backend = backend
- self.client = self._backends[backend](**kwargs)
-
- def get(self, filepath, client_key='default'):
- # client_key is used only for lmdb, where different fileclients have
- # different lmdb environments.
- if self.backend == 'lmdb':
- return self.client.get(filepath, client_key)
- else:
- return self.client.get(filepath)
-
- def get_text(self, filepath):
- return self.client.get_text(filepath)
-
-
-def imfrombytes(content, flag='color', float32=False):
- """Read an image from bytes.
-
- Args:
- content (bytes): Image bytes got from files or other streams.
- flag (str): Flags specifying the color type of a loaded image,
- candidates are `color`, `grayscale` and `unchanged`.
- float32 (bool): Whether to change to float32., If True, will also norm
- to [0, 1]. Default: False.
-
- Returns:
- ndarray: Loaded image array.
- """
- img_np = np.frombuffer(content, np.uint8)
- imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
- img = cv2.imdecode(img_np, imread_flags[flag])
- if float32:
- img = img.astype(np.float32) / 255.
- return img
-
diff --git a/core/data/deg_kair_utils/utils_videoio.py b/core/data/deg_kair_utils/utils_videoio.py
deleted file mode 100644
index 5be8c7f06802d5aaa7155a1cdcb27d2838a0882c..0000000000000000000000000000000000000000
--- a/core/data/deg_kair_utils/utils_videoio.py
+++ /dev/null
@@ -1,555 +0,0 @@
-import os
-import cv2
-import numpy as np
-import torch
-import random
-from os import path as osp
-from torchvision.utils import make_grid
-import sys
-from pathlib import Path
-import six
-from collections import OrderedDict
-import math
-import glob
-import av
-import io
-from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
- CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
- CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
-
-if sys.version_info <= (3, 3):
- FileNotFoundError = IOError
-else:
- FileNotFoundError = FileNotFoundError
-
-
-def is_str(x):
- """Whether the input is an string instance."""
- return isinstance(x, six.string_types)
-
-
-def is_filepath(x):
- return is_str(x) or isinstance(x, Path)
-
-
-def fopen(filepath, *args, **kwargs):
- if is_str(filepath):
- return open(filepath, *args, **kwargs)
- elif isinstance(filepath, Path):
- return filepath.open(*args, **kwargs)
- raise ValueError('`filepath` should be a string or a Path')
-
-
-def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
- if not osp.isfile(filename):
- raise FileNotFoundError(msg_tmpl.format(filename))
-
-
-def mkdir_or_exist(dir_name, mode=0o777):
- if dir_name == '':
- return
- dir_name = osp.expanduser(dir_name)
- os.makedirs(dir_name, mode=mode, exist_ok=True)
-
-
-def symlink(src, dst, overwrite=True, **kwargs):
- if os.path.lexists(dst) and overwrite:
- os.remove(dst)
- os.symlink(src, dst, **kwargs)
-
-
-def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
- """Scan a directory to find the interested files.
- Args:
- dir_path (str | :obj:`Path`): Path of the directory.
- suffix (str | tuple(str), optional): File suffix that we are
- interested in. Default: None.
- recursive (bool, optional): If set to True, recursively scan the
- directory. Default: False.
- case_sensitive (bool, optional) : If set to False, ignore the case of
- suffix. Default: True.
- Returns:
- A generator for all the interested files with relative paths.
- """
- if isinstance(dir_path, (str, Path)):
- dir_path = str(dir_path)
- else:
- raise TypeError('"dir_path" must be a string or Path object')
-
- if (suffix is not None) and not isinstance(suffix, (str, tuple)):
- raise TypeError('"suffix" must be a string or tuple of strings')
-
- if suffix is not None and not case_sensitive:
- suffix = suffix.lower() if isinstance(suffix, str) else tuple(
- item.lower() for item in suffix)
-
- root = dir_path
-
- def _scandir(dir_path, suffix, recursive, case_sensitive):
- for entry in os.scandir(dir_path):
- if not entry.name.startswith('.') and entry.is_file():
- rel_path = osp.relpath(entry.path, root)
- _rel_path = rel_path if case_sensitive else rel_path.lower()
- if suffix is None or _rel_path.endswith(suffix):
- yield rel_path
- elif recursive and os.path.isdir(entry.path):
- # scan recursively if entry.path is a directory
- yield from _scandir(entry.path, suffix, recursive,
- case_sensitive)
-
- return _scandir(dir_path, suffix, recursive, case_sensitive)
-
-
-class Cache:
-
- def __init__(self, capacity):
- self._cache = OrderedDict()
- self._capacity = int(capacity)
- if capacity <= 0:
- raise ValueError('capacity must be a positive integer')
-
- @property
- def capacity(self):
- return self._capacity
-
- @property
- def size(self):
- return len(self._cache)
-
- def put(self, key, val):
- if key in self._cache:
- return
- if len(self._cache) >= self.capacity:
- self._cache.popitem(last=False)
- self._cache[key] = val
-
- def get(self, key, default=None):
- val = self._cache[key] if key in self._cache else default
- return val
-
-
-class VideoReader:
- """Video class with similar usage to a list object.
-
- This video warpper class provides convenient apis to access frames.
- There exists an issue of OpenCV's VideoCapture class that jumping to a
- certain frame may be inaccurate. It is fixed in this class by checking
- the position after jumping each time.
- Cache is used when decoding videos. So if the same frame is visited for
- the second time, there is no need to decode again if it is stored in the
- cache.
-
- """
-
- def __init__(self, filename, cache_capacity=10):
- # Check whether the video path is a url
- if not filename.startswith(('https://', 'http://')):
- check_file_exist(filename, 'Video file not found: ' + filename)
- self._vcap = cv2.VideoCapture(filename)
- assert cache_capacity > 0
- self._cache = Cache(cache_capacity)
- self._position = 0
- # get basic info
- self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
- self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
- self._fps = self._vcap.get(CAP_PROP_FPS)
- self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
- self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
-
- @property
- def vcap(self):
- """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
- return self._vcap
-
- @property
- def opened(self):
- """bool: Indicate whether the video is opened."""
- return self._vcap.isOpened()
-
- @property
- def width(self):
- """int: Width of video frames."""
- return self._width
-
- @property
- def height(self):
- """int: Height of video frames."""
- return self._height
-
- @property
- def resolution(self):
- """tuple: Video resolution (width, height)."""
- return (self._width, self._height)
-
- @property
- def fps(self):
- """float: FPS of the video."""
- return self._fps
-
- @property
- def frame_cnt(self):
- """int: Total frames of the video."""
- return self._frame_cnt
-
- @property
- def fourcc(self):
- """str: "Four character code" of the video."""
- return self._fourcc
-
- @property
- def position(self):
- """int: Current cursor position, indicating frame decoded."""
- return self._position
-
- def _get_real_position(self):
- return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
-
- def _set_real_position(self, frame_id):
- self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
- pos = self._get_real_position()
- for _ in range(frame_id - pos):
- self._vcap.read()
- self._position = frame_id
-
- def read(self):
- """Read the next frame.
-
- If the next frame have been decoded before and in the cache, then
- return it directly, otherwise decode, cache and return it.
-
- Returns:
- ndarray or None: Return the frame if successful, otherwise None.
- """
- # pos = self._position
- if self._cache:
- img = self._cache.get(self._position)
- if img is not None:
- ret = True
- else:
- if self._position != self._get_real_position():
- self._set_real_position(self._position)
- ret, img = self._vcap.read()
- if ret:
- self._cache.put(self._position, img)
- else:
- ret, img = self._vcap.read()
- if ret:
- self._position += 1
- return img
-
- def get_frame(self, frame_id):
- """Get frame by index.
-
- Args:
- frame_id (int): Index of the expected frame, 0-based.
-
- Returns:
- ndarray or None: Return the frame if successful, otherwise None.
- """
- if frame_id < 0 or frame_id >= self._frame_cnt:
- raise IndexError(
- f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
- if frame_id == self._position:
- return self.read()
- if self._cache:
- img = self._cache.get(frame_id)
- if img is not None:
- self._position = frame_id + 1
- return img
- self._set_real_position(frame_id)
- ret, img = self._vcap.read()
- if ret:
- if self._cache:
- self._cache.put(self._position, img)
- self._position += 1
- return img
-
- def current_frame(self):
- """Get the current frame (frame that is just visited).
-
- Returns:
- ndarray or None: If the video is fresh, return None, otherwise
- return the frame.
- """
- if self._position == 0:
- return None
- return self._cache.get(self._position - 1)
-
- def cvt2frames(self,
- frame_dir,
- file_start=0,
- filename_tmpl='{:06d}.jpg',
- start=0,
- max_num=0,
- show_progress=False):
- """Convert a video to frame images.
-
- Args:
- frame_dir (str): Output directory to store all the frame images.
- file_start (int): Filenames will start from the specified number.
- filename_tmpl (str): Filename template with the index as the
- placeholder.
- start (int): The starting frame index.
- max_num (int): Maximum number of frames to be written.
- show_progress (bool): Whether to show a progress bar.
- """
- mkdir_or_exist(frame_dir)
- if max_num == 0:
- task_num = self.frame_cnt - start
- else:
- task_num = min(self.frame_cnt - start, max_num)
- if task_num <= 0:
- raise ValueError('start must be less than total frame number')
- if start > 0:
- self._set_real_position(start)
-
- def write_frame(file_idx):
- img = self.read()
- if img is None:
- return
- filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
- cv2.imwrite(filename, img)
-
- if show_progress:
- pass
- #track_progress(write_frame, range(file_start,file_start + task_num))
- else:
- for i in range(task_num):
- write_frame(file_start + i)
-
- def __len__(self):
- return self.frame_cnt
-
- def __getitem__(self, index):
- if isinstance(index, slice):
- return [
- self.get_frame(i)
- for i in range(*index.indices(self.frame_cnt))
- ]
- # support negative indexing
- if index < 0:
- index += self.frame_cnt
- if index < 0:
- raise IndexError('index out of range')
- return self.get_frame(index)
-
- def __iter__(self):
- self._set_real_position(0)
- return self
-
- def __next__(self):
- img = self.read()
- if img is not None:
- return img
- else:
- raise StopIteration
-
- next = __next__
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- self._vcap.release()
-
-
-def frames2video(frame_dir,
- video_file,
- fps=30,
- fourcc='XVID',
- filename_tmpl='{:06d}.jpg',
- start=0,
- end=0,
- show_progress=False):
- """Read the frame images from a directory and join them as a video.
-
- Args:
- frame_dir (str): The directory containing video frames.
- video_file (str): Output filename.
- fps (float): FPS of the output video.
- fourcc (str): Fourcc of the output video, this should be compatible
- with the output file type.
- filename_tmpl (str): Filename template with the index as the variable.
- start (int): Starting frame index.
- end (int): Ending frame index.
- show_progress (bool): Whether to show a progress bar.
- """
- if end == 0:
- ext = filename_tmpl.split('.')[-1]
- end = len([name for name in scandir(frame_dir, ext)])
- first_file = osp.join(frame_dir, filename_tmpl.format(start))
- check_file_exist(first_file, 'The start frame not found: ' + first_file)
- img = cv2.imread(first_file)
- height, width = img.shape[:2]
- resolution = (width, height)
- vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
- resolution)
-
- def write_frame(file_idx):
- filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
- img = cv2.imread(filename)
- vwriter.write(img)
-
- if show_progress:
- pass
- # track_progress(write_frame, range(start, end))
- else:
- for i in range(start, end):
- write_frame(i)
- vwriter.release()
-
-
-def video2images(video_path, output_dir):
- vidcap = cv2.VideoCapture(video_path)
- in_fps = vidcap.get(cv2.CAP_PROP_FPS)
- print('video fps:', in_fps)
- if not os.path.isdir(output_dir):
- os.makedirs(output_dir)
- loaded, frame = vidcap.read()
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
- print(f'number of total frames is: {total_frames:06}')
- for i_frame in range(total_frames):
- if i_frame % 100 == 0:
- print(f'{i_frame:06} / {total_frames:06}')
- frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png')
- cv2.imwrite(frame_name, frame)
- loaded, frame = vidcap.read()
-
-
-def images2video(image_dir, video_path, fps=24, image_ext='png'):
- '''
- #codec = cv2.VideoWriter_fourcc(*'XVID')
- #codec = cv2.VideoWriter_fourcc('A','V','C','1')
- #codec = cv2.VideoWriter_fourcc('Y','U','V','1')
- #codec = cv2.VideoWriter_fourcc('P','I','M','1')
- #codec = cv2.VideoWriter_fourcc('M','J','P','G')
- codec = cv2.VideoWriter_fourcc('M','P','4','2')
- #codec = cv2.VideoWriter_fourcc('D','I','V','3')
- #codec = cv2.VideoWriter_fourcc('D','I','V','X')
- #codec = cv2.VideoWriter_fourcc('U','2','6','3')
- #codec = cv2.VideoWriter_fourcc('I','2','6','3')
- #codec = cv2.VideoWriter_fourcc('F','L','V','1')
- #codec = cv2.VideoWriter_fourcc('H','2','6','4')
- #codec = cv2.VideoWriter_fourcc('A','Y','U','V')
- #codec = cv2.VideoWriter_fourcc('I','U','Y','V')
- 编码器常用的几种:
- cv2.VideoWriter_fourcc("I", "4", "2", "0")
- 压缩的yuv颜色编码器,4:2:0色彩度子采样 兼容性好,产生很大的视频 avi
- cv2.VideoWriter_fourcc("P", I", "M", "1")
- 采用mpeg-1编码,文件为avi
- cv2.VideoWriter_fourcc("X", "V", "T", "D")
- 采用mpeg-4编码,得到视频大小平均 拓展名avi
- cv2.VideoWriter_fourcc("T", "H", "E", "O")
- Ogg Vorbis, 拓展名为ogv
- cv2.VideoWriter_fourcc("F", "L", "V", "1")
- FLASH视频,拓展名为.flv
- '''
- image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext))))
- print(len(image_files))
- height, width, _ = cv2.imread(image_files[0]).shape
- out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') # cv2.VideoWriter_fourcc(*'MP4V')
- out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height))
-
- for image_file in image_files:
- img = cv2.imread(image_file)
- img = cv2.resize(img, (width, height), interpolation=3)
- out_video.write(img)
- out_video.release()
-
-
-def add_video_compression(imgs):
- codec_type = ['libx264', 'h264', 'mpeg4']
- codec_prob = [1 / 3., 1 / 3., 1 / 3.]
- codec = random.choices(codec_type, codec_prob)[0]
- # codec = 'mpeg4'
- bitrate = [1e4, 1e5]
- bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
-
- buf = io.BytesIO()
- with av.open(buf, 'w', 'mp4') as container:
- stream = container.add_stream(codec, rate=1)
- stream.height = imgs[0].shape[0]
- stream.width = imgs[0].shape[1]
- stream.pix_fmt = 'yuv420p'
- stream.bit_rate = bitrate
-
- for img in imgs:
- img = np.uint8((img.clip(0, 1)*255.).round())
- frame = av.VideoFrame.from_ndarray(img, format='rgb24')
- frame.pict_type = 'NONE'
- # pdb.set_trace()
- for packet in stream.encode(frame):
- container.mux(packet)
-
- # Flush stream
- for packet in stream.encode():
- container.mux(packet)
-
- outputs = []
- with av.open(buf, 'r', 'mp4') as container:
- if container.streams.video:
- for frame in container.decode(**{'video': 0}):
- outputs.append(
- frame.to_rgb().to_ndarray().astype(np.float32) / 255.)
-
- #outputs = np.stack(outputs, axis=0)
- return outputs
-
-
-if __name__ == '__main__':
-
- # -----------------------------------
- # test VideoReader(filename, cache_capacity=10)
- # -----------------------------------
-# video_reader = VideoReader('utils/test.mp4')
-# from utils import utils_image as util
-# inputs = []
-# for frame in video_reader:
-# print(frame.dtype)
-# util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
-# #util.imshow(np.flip(frame, axis=2))
-
- # -----------------------------------
- # test video2images(video_path, output_dir)
- # -----------------------------------
-# video2images('utils/test.mp4', 'frames')
-
- # -----------------------------------
- # test images2video(image_dir, video_path, fps=24, image_ext='png')
- # -----------------------------------
-# images2video('frames', 'video_02.mp4', fps=30, image_ext='png')
-
-
- # -----------------------------------
- # test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png')
- # -----------------------------------
-# frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png')
-
-
- # -----------------------------------
- # test add_video_compression(imgs)
- # -----------------------------------
-# imgs = []
-# image_ext = 'png'
-# frames = 'frames'
-# from utils import utils_image as util
-# image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext))))
-# for i, image_file in enumerate(image_files):
-# if i < 7:
-# img = util.imread_uint(image_file, 3)
-# img = util.uint2single(img)
-# imgs.append(img)
-#
-# results = add_video_compression(imgs)
-# for i, img in enumerate(results):
-# util.imshow(util.single2uint(img))
-# util.imsave(util.single2uint(img),f'{i:05}.png')
-
- # run utils/utils_video.py
-
-
-
-
-
-
-
diff --git a/core/scripts/__init__.py b/core/scripts/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/core/scripts/cli.py b/core/scripts/cli.py
deleted file mode 100644
index bfe3ecc330ecf9f0b3af1e7dc6b3758673712cc7..0000000000000000000000000000000000000000
--- a/core/scripts/cli.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import sys
-import argparse
-from .. import WarpCore
-from .. import templates
-
-
-def template_init(args):
- return ''''
-
-
- '''.strip()
-
-
-def init_template(args):
- parser = argparse.ArgumentParser(description='WarpCore template init tool')
- parser.add_argument('-t', '--template', type=str, default='WarpCore')
- args = parser.parse_args(args)
-
- if args.template == 'WarpCore':
- template_cls = WarpCore
- else:
- try:
- template_cls = __import__(args.template)
- except ModuleNotFoundError:
- template_cls = getattr(templates, args.template)
- print(template_cls)
-
-
-def main():
- if len(sys.argv) < 2:
- print('Usage: core ')
- sys.exit(1)
- if sys.argv[1] == 'init':
- init_template(sys.argv[2:])
- else:
- print('Unknown command')
- sys.exit(1)
-
-
-if __name__ == '__main__':
- main()
diff --git a/core/templates/__init__.py b/core/templates/__init__.py
deleted file mode 100644
index 570f16de78bcce68aa49ff0a5d0fad63284f6948..0000000000000000000000000000000000000000
--- a/core/templates/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .diffusion import DiffusionCore
\ No newline at end of file
diff --git a/core/templates/diffusion.py b/core/templates/diffusion.py
deleted file mode 100644
index f36dc3f5efa14669cc36cc3c0cffcc8def037289..0000000000000000000000000000000000000000
--- a/core/templates/diffusion.py
+++ /dev/null
@@ -1,236 +0,0 @@
-from .. import WarpCore
-from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
-from abc import abstractmethod
-from dataclasses import dataclass
-import torch
-from torch import nn
-from torch.utils.data import DataLoader
-from gdf import GDF
-import numpy as np
-from tqdm import tqdm
-import wandb
-
-import webdataset as wds
-from webdataset.handlers import warn_and_continue
-from torch.distributed import barrier
-from enum import Enum
-
-class TargetReparametrization(Enum):
- EPSILON = 'epsilon'
- X0 = 'x0'
-
-class DiffusionCore(WarpCore):
- @dataclass(frozen=True)
- class Config(WarpCore.Config):
- # TRAINING PARAMS
- lr: float = EXPECTED_TRAIN
- grad_accum_steps: int = EXPECTED_TRAIN
- batch_size: int = EXPECTED_TRAIN
- updates: int = EXPECTED_TRAIN
- warmup_updates: int = EXPECTED_TRAIN
- save_every: int = 500
- backup_every: int = 20000
- use_fsdp: bool = True
-
- # EMA UPDATE
- ema_start_iters: int = None
- ema_iters: int = None
- ema_beta: float = None
-
- # GDF setting
- gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
-
- @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
- class Info(WarpCore.Info):
- ema_loss: float = None
-
- @dataclass(frozen=True)
- class Models(WarpCore.Models):
- generator : nn.Module = EXPECTED
- generator_ema : nn.Module = None # optional
-
- @dataclass(frozen=True)
- class Optimizers(WarpCore.Optimizers):
- generator : any = EXPECTED
-
- @dataclass(frozen=True)
- class Schedulers(WarpCore.Schedulers):
- generator: any = None
-
- @dataclass(frozen=True)
- class Extras(WarpCore.Extras):
- gdf: GDF = EXPECTED
- sampling_configs: dict = EXPECTED
-
- # --------------------------------------------
- info: Info
- config: Config
-
- @abstractmethod
- def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- raise NotImplementedError("This method needs to be overriden")
-
- @abstractmethod
- def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- raise NotImplementedError("This method needs to be overriden")
-
- @abstractmethod
- def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
- raise NotImplementedError("This method needs to be overriden")
-
- @abstractmethod
- def webdataset_path(self, extras: Extras):
- raise NotImplementedError("This method needs to be overriden")
-
- @abstractmethod
- def webdataset_filters(self, extras: Extras):
- raise NotImplementedError("This method needs to be overriden")
-
- @abstractmethod
- def webdataset_preprocessors(self, extras: Extras):
- raise NotImplementedError("This method needs to be overriden")
-
- @abstractmethod
- def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
- raise NotImplementedError("This method needs to be overriden")
- # -------------
-
- def setup_data(self, extras: Extras) -> WarpCore.Data:
- # SETUP DATASET
- dataset_path = self.webdataset_path(extras)
- preprocessors = self.webdataset_preprocessors(extras)
- filters = self.webdataset_filters(extras)
-
- handler = warn_and_continue # None
- # handler = None
- dataset = wds.WebDataset(
- dataset_path, resampled=True, handler=handler
- ).select(filters).shuffle(690, handler=handler).decode(
- "pilrgb", handler=handler
- ).to_tuple(
- *[p[0] for p in preprocessors], handler=handler
- ).map_tuple(
- *[p[1] for p in preprocessors], handler=handler
- ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})
-
- # SETUP DATALOADER
- real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
- dataloader = DataLoader(
- dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
- )
-
- return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))
-
- def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
- batch = next(data.iterator)
-
- with torch.no_grad():
- conditions = self.get_conditions(batch, models, extras)
- latents = self.encode_latents(batch, models, extras)
- noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
-
- # FORWARD PASS
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
- pred = models.generator(noised, noise_cond, **conditions)
- if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
- pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
- target = noise
- elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
- pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
- target = latents
- loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
- loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
-
- return loss, loss_adjusted
-
- def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
- start_iter = self.info.iter+1
- max_iters = self.config.updates * self.config.grad_accum_steps
- if self.is_main_node:
- print(f"STARTING AT STEP: {start_iter}/{max_iters}")
-
- pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
- models.generator.train()
- for i in pbar:
- # FORWARD PASS
- loss, loss_adjusted = self.forward_pass(data, extras, models)
-
- # BACKWARD PASS
- if i % self.config.grad_accum_steps == 0 or i == max_iters:
- loss_adjusted.backward()
- grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
- optimizers_dict = optimizers.to_dict()
- for k in optimizers_dict:
- optimizers_dict[k].step()
- schedulers_dict = schedulers.to_dict()
- for k in schedulers_dict:
- schedulers_dict[k].step()
- models.generator.zero_grad(set_to_none=True)
- self.info.total_steps += 1
- else:
- with models.generator.no_sync():
- loss_adjusted.backward()
- self.info.iter = i
-
- # UPDATE EMA
- if models.generator_ema is not None and i % self.config.ema_iters == 0:
- update_weights_ema(
- models.generator_ema, models.generator,
- beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
- )
-
- # UPDATE LOSS METRICS
- self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
-
- if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
- wandb.alert(
- title=f"NaN value encountered in training run {self.info.wandb_run_id}",
- text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
- wait_duration=60*30
- )
-
- if self.is_main_node:
- logs = {
- 'loss': self.info.ema_loss,
- 'raw_loss': loss.mean().item(),
- 'grad_norm': grad_norm.item(),
- 'lr': optimizers.generator.param_groups[0]['lr'],
- 'total_steps': self.info.total_steps,
- }
-
- pbar.set_postfix(logs)
- if self.config.wandb_project is not None:
- wandb.log(logs)
-
- if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
- # SAVE AND CHECKPOINT STUFF
- if np.isnan(loss.mean().item()):
- if self.is_main_node and self.config.wandb_project is not None:
- tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
- wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
- else:
- self.save_checkpoints(models, optimizers)
- if self.is_main_node:
- create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
- self.sample(models, data, extras)
-
- def models_to_save(self):
- return ['generator', 'generator_ema']
-
- def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
- barrier()
- suffix = '' if suffix is None else suffix
- self.save_info(self.info, suffix=suffix)
- models_dict = models.to_dict()
- optimizers_dict = optimizers.to_dict()
- for key in self.models_to_save():
- model = models_dict[key]
- if model is not None:
- self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
- for key in optimizers_dict:
- optimizer = optimizers_dict[key]
- if optimizer is not None:
- self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
- if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
- self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
- torch.cuda.empty_cache()
diff --git a/core/utils/__init__.py b/core/utils/__init__.py
deleted file mode 100644
index 2e71b37e8d1690a00ab1e0958320775bc822b6f5..0000000000000000000000000000000000000000
--- a/core/utils/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN
-from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail
-
-# MOVE IT SOMERWHERE ELSE
-def update_weights_ema(tgt_model, src_model, beta=0.999):
- for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
- self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta)
- for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()):
- self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta)
\ No newline at end of file
diff --git a/core/utils/__pycache__/__init__.cpython-310.pyc b/core/utils/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 63c0a7e0fbf358f557d6bea755a0f550b4010a48..0000000000000000000000000000000000000000
Binary files a/core/utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/core/utils/__pycache__/__init__.cpython-39.pyc b/core/utils/__pycache__/__init__.cpython-39.pyc
deleted file mode 100644
index 6f18d6921da3c9d93087c1b6d8eacd7a5e46a8e5..0000000000000000000000000000000000000000
Binary files a/core/utils/__pycache__/__init__.cpython-39.pyc and /dev/null differ
diff --git a/core/utils/__pycache__/base_dto.cpython-310.pyc b/core/utils/__pycache__/base_dto.cpython-310.pyc
deleted file mode 100644
index de093eb65813d4abf69edfbb6923f2cabab21ad7..0000000000000000000000000000000000000000
Binary files a/core/utils/__pycache__/base_dto.cpython-310.pyc and /dev/null differ
diff --git a/core/utils/__pycache__/base_dto.cpython-39.pyc b/core/utils/__pycache__/base_dto.cpython-39.pyc
deleted file mode 100644
index b80d348c7959338709ec24c3ac24dfc4f6dab3dc..0000000000000000000000000000000000000000
Binary files a/core/utils/__pycache__/base_dto.cpython-39.pyc and /dev/null differ
diff --git a/core/utils/__pycache__/save_and_load.cpython-310.pyc b/core/utils/__pycache__/save_and_load.cpython-310.pyc
deleted file mode 100644
index a7a0f63ac8bbaf073dcd8a046ed112cec181d33a..0000000000000000000000000000000000000000
Binary files a/core/utils/__pycache__/save_and_load.cpython-310.pyc and /dev/null differ
diff --git a/core/utils/__pycache__/save_and_load.cpython-39.pyc b/core/utils/__pycache__/save_and_load.cpython-39.pyc
deleted file mode 100644
index ec04e9aba6f83ab76f0bbc243bb95fda07ad8d16..0000000000000000000000000000000000000000
Binary files a/core/utils/__pycache__/save_and_load.cpython-39.pyc and /dev/null differ
diff --git a/core/utils/base_dto.py b/core/utils/base_dto.py
deleted file mode 100644
index 7cf185f00e5c6f56d23774cec8591b8d4554971e..0000000000000000000000000000000000000000
--- a/core/utils/base_dto.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import dataclasses
-from dataclasses import dataclass, _MISSING_TYPE
-from munch import Munch
-
-EXPECTED = "___REQUIRED___"
-EXPECTED_TRAIN = "___REQUIRED_TRAIN___"
-
-# pylint: disable=invalid-field-call
-def nested_dto(x, raw=False):
- return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x))
-
-@dataclass(frozen=True)
-class Base:
- training: bool = None
- def __new__(cls, **kwargs):
- training = kwargs.get('training', True)
- setteable_fields = cls.setteable_fields(**kwargs)
- mandatory_fields = cls.mandatory_fields(**kwargs)
- invalid_kwargs = [
- {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False)
- ]
- print(mandatory_fields)
- assert (
- len(invalid_kwargs) == 0
- ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable."
- missing_kwargs = [f for f in mandatory_fields if f not in kwargs]
- assert (
- len(missing_kwargs) == 0
- ), f"Required fields missing initializing this DTO: {missing_kwargs}."
- return object.__new__(cls)
-
-
- @classmethod
- def setteable_fields(cls, **kwargs):
- return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN]
-
- @classmethod
- def mandatory_fields(cls, **kwargs):
- training = kwargs.get('training', True)
- return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)]
-
- @classmethod
- def from_dict(cls, kwargs):
- for k in kwargs:
- if isinstance(kwargs[k], (dict, list, tuple)):
- kwargs[k] = Munch.fromDict(kwargs[k])
- return cls(**kwargs)
-
- def to_dict(self):
- # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes
- selfdict = {}
- for k in dataclasses.fields(self):
- selfdict[k.name] = getattr(self, k.name)
- if isinstance(selfdict[k.name], Munch):
- selfdict[k.name] = selfdict[k.name].toDict()
- return selfdict
diff --git a/core/utils/save_and_load.py b/core/utils/save_and_load.py
deleted file mode 100644
index 0215f664f5a8e738147d0828b6a7e65b9c3a8507..0000000000000000000000000000000000000000
--- a/core/utils/save_and_load.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import os
-import torch
-import json
-from pathlib import Path
-import safetensors
-import wandb
-
-
-def create_folder_if_necessary(path):
- path = "/".join(path.split("/")[:-1])
- Path(path).mkdir(parents=True, exist_ok=True)
-
-
-def safe_save(ckpt, path):
- try:
- os.remove(f"{path}.bak")
- except OSError:
- pass
- try:
- os.rename(path, f"{path}.bak")
- except OSError:
- pass
- if path.endswith(".pt") or path.endswith(".ckpt"):
- torch.save(ckpt, path)
- elif path.endswith(".json"):
- with open(path, "w", encoding="utf-8") as f:
- json.dump(ckpt, f, indent=4)
- elif path.endswith(".safetensors"):
- safetensors.torch.save_file(ckpt, path)
- else:
- raise ValueError(f"File extension not supported: {path}")
-
-
-def load_or_fail(path, wandb_run_id=None):
- accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
- try:
- assert any(
- [path.endswith(ext) for ext in accepted_extensions]
- ), f"Automatic loading not supported for this extension: {path}"
- if not os.path.exists(path):
- checkpoint = None
- elif path.endswith(".pt") or path.endswith(".ckpt"):
- checkpoint = torch.load(path, map_location="cpu")
- elif path.endswith(".json"):
- with open(path, "r", encoding="utf-8") as f:
- checkpoint = json.load(f)
- elif path.endswith(".safetensors"):
- checkpoint = {}
- with safetensors.safe_open(path, framework="pt", device="cpu") as f:
- for key in f.keys():
- checkpoint[key] = f.get_tensor(key)
- return checkpoint
- except Exception as e:
- if wandb_run_id is not None:
- wandb.alert(
- title=f"Corrupt checkpoint for run {wandb_run_id}",
- text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
- )
- raise e
diff --git a/gdf/__init__.py b/gdf/__init__.py
deleted file mode 100644
index 753b52e2e07e2540385594627a6faf4f6091b0a0..0000000000000000000000000000000000000000
--- a/gdf/__init__.py
+++ /dev/null
@@ -1,205 +0,0 @@
-import torch
-from .scalers import *
-from .targets import *
-from .schedulers import *
-from .noise_conditions import *
-from .loss_weights import *
-from .samplers import *
-import torch.nn.functional as F
-import math
-class GDF():
- def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0):
- self.schedule = schedule
- self.input_scaler = input_scaler
- self.target = target
- self.noise_cond = noise_cond
- self.loss_weight = loss_weight
- self.offset_noise = offset_noise
-
- def setup_limits(self, stretch_max=True, stretch_min=True, shift=1):
- stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift)
- return stretched_limits
-
- def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None):
- if epsilon is None:
- epsilon = torch.randn_like(x0)
- if self.offset_noise > 0:
- if offset is None:
- offset = torch.randn([x0.size(0), x0.size(1)] + [1]*(len(x0.shape)-2)).to(x0.device)
- epsilon = epsilon + offset * self.offset_noise
- logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device)
- a, b = self.input_scaler(logSNR) # B
- if len(a.shape) == 1:
- a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) # BxCxHxW
- #print('in line 33 a b', a.shape, b.shape, x0.shape, logSNR.shape, logSNR, self.noise_cond(logSNR))
- target = self.target(x0, epsilon, logSNR, a, b)
-
- # noised, noise, logSNR, t_cond
- #noised, noise, target, logSNR, noise_cond, loss_weight
- return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift)
-
- def undiffuse(self, x, logSNR, pred):
- a, b = self.input_scaler(logSNR)
- if len(a.shape) == 1:
- a, b = a.view(-1, *[1]*(len(x.shape)-1)), b.view(-1, *[1]*(len(x.shape)-1))
- return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b)
-
- def sample(self, model, model_inputs, shape, unconditional_inputs=None, sampler=None, schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"):
- sampler_params = {} if sampler_params is None else sampler_params
- if sampler is None:
- sampler = DDPMSampler(self)
- r_range = torch.linspace(t_start, t_end, timesteps+1)
- schedule = self.schedule if schedule is None else schedule
- logSNR_range = schedule(r_range, shift=shift)[:, None].expand(
- -1, shape[0] if x_init is None else x_init.size(0)
- ).to(device)
-
- x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
-
- if cfg is not None:
- if unconditional_inputs is None:
- unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
- model_inputs = {
- k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor)
- else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list)
- else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict)
- else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())
- }
-
- for i in range(0, timesteps):
- noise_cond = self.noise_cond(logSNR_range[i])
- if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
- cfg_val = cfg
- if isinstance(cfg_val, (list, tuple)):
- assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
- cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
-
- pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2)
-
- pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
- if cfg_rho > 0:
- std_pos, std_cfg = pred.std(), pred_cfg.std()
- pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
- else:
- pred = pred_cfg
- else:
- pred = model(x, noise_cond, **model_inputs)
- x0, epsilon = self.undiffuse(x, logSNR_range[i], pred)
- x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params)
- #print('in line 86', x0.shape, x.shape, i, )
- altered_vars = yield (x0, x, pred)
-
- # Update some running variables if the user wants
- if altered_vars is not None:
- cfg = altered_vars.get('cfg', cfg)
- cfg_rho = altered_vars.get('cfg_rho', cfg_rho)
- sampler = altered_vars.get('sampler', sampler)
- model_inputs = altered_vars.get('model_inputs', model_inputs)
- x = altered_vars.get('x', x)
- x_init = altered_vars.get('x_init', x_init)
-
-class GDF_dual_fixlrt(GDF):
- def ref_noise(self, noised, x0, logSNR):
- a, b = self.input_scaler(logSNR)
- if len(a.shape) == 1:
- a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1))
- #print('in line 210', a.shape, b.shape, x0.shape, noised.shape)
- return self.target.noise_givenx0_noised(x0, noised, logSNR, a, b)
-
- def sample(self, model, model_inputs, shape, shape_lr, unconditional_inputs=None, sampler=None,
- schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None,
- cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"):
- sampler_params = {} if sampler_params is None else sampler_params
- if sampler is None:
- sampler = DDPMSampler(self)
- r_range = torch.linspace(t_start, t_end, timesteps+1)
- schedule = self.schedule if schedule is None else schedule
- logSNR_range = schedule(r_range, shift=shift)[:, None].expand(
- -1, shape[0] if x_init is None else x_init.size(0)
- ).to(device)
-
- x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
- x_lr = sampler.init_x(shape_lr).to(device) if x_init is None else x_init.clone()
- if cfg is not None:
- if unconditional_inputs is None:
- unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
- model_inputs = {
- k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor)
- else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list)
- else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict)
- else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())
- }
-
- ###############################################lr sampling
-
- guide_feas = [None] * timesteps
-
- for i in range(0, timesteps):
- noise_cond = self.noise_cond(logSNR_range[i])
- if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
- cfg_val = cfg
- if isinstance(cfg_val, (list, tuple)):
- assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
- cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
-
-
-
- if i == timesteps -1 :
- output, guide_lr_enc, guide_lr_dec = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs)
- guide_feas[i] = ([f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_enc], [f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_dec])
- else:
- output, _, _ = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs)
-
- pred, pred_unconditional = output.chunk(2)
-
-
- pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
- if cfg_rho > 0:
- std_pos, std_cfg = pred.std(), pred_cfg.std()
- pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
- else:
- pred = pred_cfg
- else:
- pred = model(x_lr, noise_cond, **model_inputs)
- x0_lr, epsilon_lr = self.undiffuse(x_lr, logSNR_range[i], pred)
- x_lr = sampler(x_lr, x0_lr, epsilon_lr, logSNR_range[i], logSNR_range[i+1], **sampler_params)
-
- ###############################################hr HR sampling
- for i in range(0, timesteps):
- noise_cond = self.noise_cond(logSNR_range[i])
- if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
- cfg_val = cfg
- if isinstance(cfg_val, (list, tuple)):
- assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
- cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
-
- out_pred, t_emb = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), \
- lr_guide=guide_feas[timesteps -1] if i <=19 else None , **model_inputs, require_t=True, guide_weight=1 - i/timesteps)
- pred, pred_unconditional = out_pred.chunk(2)
- pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
- if cfg_rho > 0:
- std_pos, std_cfg = pred.std(), pred_cfg.std()
- pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
- else:
- pred = pred_cfg
- else:
- pred = model(x, noise_cond, guide_lr=(guide_lr_enc, guide_lr_dec), **model_inputs)
- x0, epsilon = self.undiffuse(x, logSNR_range[i], pred)
-
- x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params)
- altered_vars = yield (x0, x, pred, x_lr)
-
-
-
- # Update some running variables if the user wants
- if altered_vars is not None:
- cfg = altered_vars.get('cfg', cfg)
- cfg_rho = altered_vars.get('cfg_rho', cfg_rho)
- sampler = altered_vars.get('sampler', sampler)
- model_inputs = altered_vars.get('model_inputs', model_inputs)
- x = altered_vars.get('x', x)
- x_init = altered_vars.get('x_init', x_init)
-
-
-
-
diff --git a/gdf/loss_weights.py b/gdf/loss_weights.py
deleted file mode 100644
index d14ddaadeeb3f8de6c68aea4c364d9b852f2f15c..0000000000000000000000000000000000000000
--- a/gdf/loss_weights.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import torch
-import numpy as np
-
-# --- Loss Weighting
-class BaseLossWeight():
- def weight(self, logSNR):
- raise NotImplementedError("this method needs to be overridden")
-
- def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs):
- clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
- if shift != 1:
- logSNR = logSNR.clone() + 2 * np.log(shift)
- return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range)
-
-class ComposedLossWeight(BaseLossWeight):
- def __init__(self, div, mul):
- self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul
- self.div = [div] if isinstance(div, BaseLossWeight) else div
-
- def weight(self, logSNR):
- prod, div = 1, 1
- for m in self.mul:
- prod *= m.weight(logSNR)
- for d in self.div:
- div *= d.weight(logSNR)
- return prod/div
-
-class ConstantLossWeight(BaseLossWeight):
- def __init__(self, v=1):
- self.v = v
-
- def weight(self, logSNR):
- return torch.ones_like(logSNR) * self.v
-
-class SNRLossWeight(BaseLossWeight):
- def weight(self, logSNR):
- return logSNR.exp()
-
-class P2LossWeight(BaseLossWeight):
- def __init__(self, k=1.0, gamma=1.0, s=1.0):
- self.k, self.gamma, self.s = k, gamma, s
-
- def weight(self, logSNR):
- return (self.k + (logSNR * self.s).exp()) ** -self.gamma
-
-class SNRPlusOneLossWeight(BaseLossWeight):
- def weight(self, logSNR):
- return logSNR.exp() + 1
-
-class MinSNRLossWeight(BaseLossWeight):
- def __init__(self, max_snr=5):
- self.max_snr = max_snr
-
- def weight(self, logSNR):
- return logSNR.exp().clamp(max=self.max_snr)
-
-class MinSNRPlusOneLossWeight(BaseLossWeight):
- def __init__(self, max_snr=5):
- self.max_snr = max_snr
-
- def weight(self, logSNR):
- return (logSNR.exp() + 1).clamp(max=self.max_snr)
-
-class TruncatedSNRLossWeight(BaseLossWeight):
- def __init__(self, min_snr=1):
- self.min_snr = min_snr
-
- def weight(self, logSNR):
- return logSNR.exp().clamp(min=self.min_snr)
-
-class SechLossWeight(BaseLossWeight):
- def __init__(self, div=2):
- self.div = div
-
- def weight(self, logSNR):
- return 1/(logSNR/self.div).cosh()
-
-class DebiasedLossWeight(BaseLossWeight):
- def weight(self, logSNR):
- return 1/logSNR.exp().sqrt()
-
-class SigmoidLossWeight(BaseLossWeight):
- def __init__(self, s=1):
- self.s = s
-
- def weight(self, logSNR):
- return (logSNR * self.s).sigmoid()
-
-class AdaptiveLossWeight(BaseLossWeight):
- def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]):
- self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1)
- self.bucket_losses = torch.ones(buckets)
- self.weight_range = weight_range
-
- def weight(self, logSNR):
- indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR)
- return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range)
-
- def update_buckets(self, logSNR, loss, beta=0.99):
- indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu()
- self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta)
diff --git a/gdf/noise_conditions.py b/gdf/noise_conditions.py
deleted file mode 100644
index dc2791f50a6f63eff8f9bed9b827f87517cc0be8..0000000000000000000000000000000000000000
--- a/gdf/noise_conditions.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import torch
-import numpy as np
-
-class BaseNoiseCond():
- def __init__(self, *args, shift=1, clamp_range=None, **kwargs):
- clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
- self.shift = shift
- self.clamp_range = clamp_range
- self.setup(*args, **kwargs)
-
- def setup(self, *args, **kwargs):
- pass # this method is optional, override it if required
-
- def cond(self, logSNR):
- raise NotImplementedError("this method needs to be overriden")
-
- def __call__(self, logSNR):
- if self.shift != 1:
- logSNR = logSNR.clone() + 2 * np.log(self.shift)
- return self.cond(logSNR).clamp(*self.clamp_range)
-
-class CosineTNoiseCond(BaseNoiseCond):
- def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999]
- self.s = torch.tensor([s])
- self.clamp_range = clamp_range
- self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
-
- def cond(self, logSNR):
- var = logSNR.sigmoid()
- var = var.clamp(*self.clamp_range)
- s, min_var = self.s.to(var.device), self.min_var.to(var.device)
- t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
- return t
-
-class EDMNoiseCond(BaseNoiseCond):
- def cond(self, logSNR):
- return -logSNR/8
-
-class SigmoidNoiseCond(BaseNoiseCond):
- def cond(self, logSNR):
- return (-logSNR).sigmoid()
-
-class LogSNRNoiseCond(BaseNoiseCond):
- def cond(self, logSNR):
- return logSNR
-
-class EDMSigmaNoiseCond(BaseNoiseCond):
- def setup(self, sigma_data=1):
- self.sigma_data = sigma_data
-
- def cond(self, logSNR):
- return torch.exp(-logSNR / 2) * self.sigma_data
-
-class RectifiedFlowsNoiseCond(BaseNoiseCond):
- def cond(self, logSNR):
- _a = logSNR.exp() - 1
- _a[_a == 0] = 1e-3 # Avoid division by zero
- a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a)
- return a
-
-# Any NoiseCond that cannot be described easily as a continuous function of t
-# It needs to define self.x and self.y in the setup() method
-class PiecewiseLinearNoiseCond(BaseNoiseCond):
- def setup(self):
- self.x = None
- self.y = None
-
- def piecewise_linear(self, y, xs, ys):
- indices = (len(xs)-2) - torch.searchsorted(ys.flip(dims=(-1,))[:-2], y)
- x_min, x_max = xs[indices], xs[indices+1]
- y_min, y_max = ys[indices], ys[indices+1]
- x = x_min + (x_max - x_min) * (y - y_min) / (y_max - y_min)
- return x
-
- def cond(self, logSNR):
- var = logSNR.sigmoid()
- t = self.piecewise_linear(var, self.x.to(var.device), self.y.to(var.device)) # .mul(1000).round().clamp(min=0)
- return t
-
-class StableDiffusionNoiseCond(PiecewiseLinearNoiseCond):
- def setup(self, linear_range=[0.00085, 0.012], total_steps=1000):
- self.total_steps = total_steps
- linear_range_sqrt = [r**0.5 for r in linear_range]
- self.x = torch.linspace(0, 1, total_steps+1)
-
- alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2
- self.y = alphas.cumprod(dim=-1)
-
- def cond(self, logSNR):
- return super().cond(logSNR).clamp(0, 1)
-
-class DiscreteNoiseCond(BaseNoiseCond):
- def setup(self, noise_cond, steps=1000, continuous_range=[0, 1]):
- self.noise_cond = noise_cond
- self.steps = steps
- self.continuous_range = continuous_range
-
- def cond(self, logSNR):
- cond = self.noise_cond(logSNR)
- cond = (cond-self.continuous_range[0]) / (self.continuous_range[1]-self.continuous_range[0])
- return cond.mul(self.steps).long()
-
\ No newline at end of file
diff --git a/gdf/readme.md b/gdf/readme.md
deleted file mode 100644
index 9a63691513c9da6804fba53e36acc8e0cd7f5d7f..0000000000000000000000000000000000000000
--- a/gdf/readme.md
+++ /dev/null
@@ -1,86 +0,0 @@
-# Generic Diffusion Framework (GDF)
-
-# Basic usage
-GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM
-, EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different
-frameworks
-
-Using GDF is very straighforward, first of all just define an instance of the GDF class:
-
-```python
-from gdf import GDF
-from gdf import CosineSchedule
-from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight
-
-gdf = GDF(
- schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
- input_scaler=VPScaler(), target=EpsilonTarget(),
- noise_cond=CosineTNoiseCond(),
- loss_weight=P2LossWeight(),
-)
-```
-
-You need to define the following components:
-* **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution.
-* **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule.
-* **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows)
-* **Target**: What the target is during training, usually: epsilon, x0 or v
-* **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8`
-* **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use
-
-All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just:
-```python
-class VPScaler():
- def __call__(self, logSNR):
- a_squared = logSNR.sigmoid()
- a = a_squared.sqrt()
- b = (1-a_squared).sqrt()
- return a, b
-
-```
-
-So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc...
-
-### Training
-
-When you define your training loop you can get all you need by just doing:
-```python
-shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution
-for inputs, extra_conditions in dataloader_iterator:
- noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift)
- pred = diffusion_model(noised, noise_cond, extra_conditions)
-
- loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
- loss_adjusted = (loss * loss_weight).mean()
-
- loss_adjusted.backward()
- optimizer.step()
- optimizer.zero_grad(set_to_none=True)
-```
-
-And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the
-training from the GDF class.
-
-### Sampling
-
-The other important part is sampling, when you want to use this framework to sample you can just do the following:
-
-```python
-from gdf import DDPMSampler
-
-shift = 1
-sampling_configs = {
- "timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift,
- "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999])
-}
-
-*_, (sampled, _, _) = gdf.sample(
- diffusion_model, {"cond": extra_conditions}, latents.shape,
- unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)},
- device=device, **sampling_configs
-)
-```
-
-# Available modules
-
-TODO
diff --git a/gdf/samplers.py b/gdf/samplers.py
deleted file mode 100644
index b6048c86a261d53d0440a3b2c1591a03d9978c4f..0000000000000000000000000000000000000000
--- a/gdf/samplers.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import torch
-
-class SimpleSampler():
- def __init__(self, gdf):
- self.gdf = gdf
- self.current_step = -1
-
- def __call__(self, *args, **kwargs):
- self.current_step += 1
- return self.step(*args, **kwargs)
-
- def init_x(self, shape):
- return torch.randn(*shape)
-
- def step(self, x, x0, epsilon, logSNR, logSNR_prev):
- raise NotImplementedError("You should override the 'apply' function.")
-
-class DDIMSampler(SimpleSampler):
- def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0):
- a, b = self.gdf.input_scaler(logSNR)
- if len(a.shape) == 1:
- a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1))
-
- a_prev, b_prev = self.gdf.input_scaler(logSNR_prev)
- if len(a_prev.shape) == 1:
- a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1))
-
- sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0
- # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
- x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
- return x
-
-class DDPMSampler(DDIMSampler):
- def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1):
- return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta)
-
-class LCMSampler(SimpleSampler):
- def step(self, x, x0, epsilon, logSNR, logSNR_prev):
- a_prev, b_prev = self.gdf.input_scaler(logSNR_prev)
- if len(a_prev.shape) == 1:
- a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1))
- return x0 * a_prev + torch.randn_like(epsilon) * b_prev
-
\ No newline at end of file
diff --git a/gdf/scalers.py b/gdf/scalers.py
deleted file mode 100644
index b1adb8b0269667f3d006c7d7d17cbf2b7ef56ca9..0000000000000000000000000000000000000000
--- a/gdf/scalers.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import torch
-
-class BaseScaler():
- def __init__(self):
- self.stretched_limits = None
-
- def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1):
- min_logSNR = schedule(torch.ones(1), shift=shift)
- max_logSNR = schedule(torch.zeros(1), shift=shift)
-
- min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1]
- max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0]
- self.stretched_limits = [min_a, max_a, min_b, max_b]
- return self.stretched_limits
-
- def stretch_limits(self, a, b):
- min_a, max_a, min_b, max_b = self.stretched_limits
- return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b)
-
- def scalers(self, logSNR):
- raise NotImplementedError("this method needs to be overridden")
-
- def __call__(self, logSNR):
- a, b = self.scalers(logSNR)
- if self.stretched_limits is not None:
- a, b = self.stretch_limits(a, b)
- return a, b
-
-class VPScaler(BaseScaler):
- def scalers(self, logSNR):
- a_squared = logSNR.sigmoid()
- a = a_squared.sqrt()
- b = (1-a_squared).sqrt()
- return a, b
-
-class LERPScaler(BaseScaler):
- def scalers(self, logSNR):
- _a = logSNR.exp() - 1
- _a[_a == 0] = 1e-3 # Avoid division by zero
- a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a)
- b = 1-a
- return a, b
diff --git a/gdf/schedulers.py b/gdf/schedulers.py
deleted file mode 100644
index caa6e174da1d766ea5828616bb8113865106b628..0000000000000000000000000000000000000000
--- a/gdf/schedulers.py
+++ /dev/null
@@ -1,200 +0,0 @@
-import torch
-import numpy as np
-
-class BaseSchedule():
- def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs):
- self.setup(*args, **kwargs)
- self.limits = None
- self.discrete_steps = discrete_steps
- self.shift = shift
- if force_limits:
- self.reset_limits()
-
- def reset_limits(self, shift=1, disable=False):
- try:
- self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max
- return self.limits
- except Exception:
- print("WARNING: this schedule doesn't support t and will be unbounded")
- return None
-
- def setup(self, *args, **kwargs):
- raise NotImplementedError("this method needs to be overriden")
-
- def schedule(self, *args, **kwargs):
- raise NotImplementedError("this method needs to be overriden")
-
- def __call__(self, t, *args, shift=1, **kwargs):
- if isinstance(t, torch.Tensor):
- batch_size = None
- if self.discrete_steps is not None:
- if t.dtype != torch.long:
- t = (t * (self.discrete_steps-1)).round().long()
- t = t / (self.discrete_steps-1)
- t = t.clamp(0, 1)
- else:
- batch_size = t
- t = None
- logSNR = self.schedule(t, batch_size, *args, **kwargs)
- if shift*self.shift != 1:
- logSNR += 2 * np.log(1/(shift*self.shift))
- if self.limits is not None:
- logSNR = logSNR.clamp(*self.limits)
- return logSNR
-
-class CosineSchedule(BaseSchedule):
- def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False):
- self.s = torch.tensor([s])
- self.clamp_range = clamp_range
- self.norm_instead = norm_instead
- self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
-
- def schedule(self, t, batch_size):
- if t is None:
- t = (1-torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0)
- s, min_var = self.s.to(t.device), self.min_var.to(t.device)
- var = torch.cos((s + t)/(1+s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var
- if self.norm_instead:
- var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0]
- else:
- var = var.clamp(*self.clamp_range)
- logSNR = (var/(1-var)).log()
- return logSNR
-
-class CosineSchedule2(BaseSchedule):
- def setup(self, logsnr_range=[-15, 15]):
- self.t_min = np.arctan(np.exp(-0.5 * logsnr_range[1]))
- self.t_max = np.arctan(np.exp(-0.5 * logsnr_range[0]))
-
- def schedule(self, t, batch_size):
- if t is None:
- t = 1-torch.rand(batch_size)
- return -2 * (self.t_min + t*(self.t_max-self.t_min)).tan().log()
-
-class SqrtSchedule(BaseSchedule):
- def setup(self, s=1e-4, clamp_range=[0.0001, 0.9999], norm_instead=False):
- self.s = s
- self.clamp_range = clamp_range
- self.norm_instead = norm_instead
-
- def schedule(self, t, batch_size):
- if t is None:
- t = 1-torch.rand(batch_size)
- var = 1 - (t + self.s)**0.5
- if self.norm_instead:
- var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0]
- else:
- var = var.clamp(*self.clamp_range)
- logSNR = (var/(1-var)).log()
- return logSNR
-
-class RectifiedFlowsSchedule(BaseSchedule):
- def setup(self, logsnr_range=[-15, 15]):
- self.logsnr_range = logsnr_range
-
- def schedule(self, t, batch_size):
- if t is None:
- t = 1-torch.rand(batch_size)
- logSNR = (((1-t)**2)/(t**2)).log()
- logSNR = logSNR.clamp(*self.logsnr_range)
- return logSNR
-
-class EDMSampleSchedule(BaseSchedule):
- def setup(self, sigma_range=[0.002, 80], p=7):
- self.sigma_range = sigma_range
- self.p = p
-
- def schedule(self, t, batch_size):
- if t is None:
- t = 1-torch.rand(batch_size)
- smin, smax, p = *self.sigma_range, self.p
- sigma = (smax ** (1/p) + (1-t) * (smin ** (1/p) - smax ** (1/p))) ** p
- logSNR = (1/sigma**2).log()
- return logSNR
-
-class EDMTrainSchedule(BaseSchedule):
- def setup(self, mu=-1.2, std=1.2):
- self.mu = mu
- self.std = std
-
- def schedule(self, t, batch_size):
- if t is not None:
- raise Exception("EDMTrainSchedule doesn't support passing timesteps: t")
- logSNR = -2*(torch.randn(batch_size) * self.std - self.mu)
- return logSNR
-
-class LinearSchedule(BaseSchedule):
- def setup(self, logsnr_range=[-10, 10]):
- self.logsnr_range = logsnr_range
-
- def schedule(self, t, batch_size):
- if t is None:
- t = 1-torch.rand(batch_size)
- logSNR = t * (self.logsnr_range[0]-self.logsnr_range[1]) + self.logsnr_range[1]
- return logSNR
-
-# Any schedule that cannot be described easily as a continuous function of t
-# It needs to define self.x and self.y in the setup() method
-class PiecewiseLinearSchedule(BaseSchedule):
- def setup(self):
- self.x = None
- self.y = None
-
- def piecewise_linear(self, x, xs, ys):
- indices = torch.searchsorted(xs[:-1], x) - 1
- x_min, x_max = xs[indices], xs[indices+1]
- y_min, y_max = ys[indices], ys[indices+1]
- var = y_min + (y_max - y_min) * (x - x_min) / (x_max - x_min)
- return var
-
- def schedule(self, t, batch_size):
- if t is None:
- t = 1-torch.rand(batch_size)
- var = self.piecewise_linear(t, self.x.to(t.device), self.y.to(t.device))
- logSNR = (var/(1-var)).log()
- return logSNR
-
-class StableDiffusionSchedule(PiecewiseLinearSchedule):
- def setup(self, linear_range=[0.00085, 0.012], total_steps=1000):
- linear_range_sqrt = [r**0.5 for r in linear_range]
- self.x = torch.linspace(0, 1, total_steps+1)
-
- alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2
- self.y = alphas.cumprod(dim=-1)
-
-class AdaptiveTrainSchedule(BaseSchedule):
- def setup(self, logsnr_range=[-10, 10], buckets=100, min_probs=0.0):
- th = torch.linspace(logsnr_range[0], logsnr_range[1], buckets+1)
- self.bucket_ranges = torch.tensor([(th[i], th[i+1]) for i in range(buckets)])
- self.bucket_probs = torch.ones(buckets)
- self.min_probs = min_probs
-
- def schedule(self, t, batch_size):
- if t is not None:
- raise Exception("AdaptiveTrainSchedule doesn't support passing timesteps: t")
- norm_probs = ((self.bucket_probs+self.min_probs) / (self.bucket_probs+self.min_probs).sum())
- buckets = torch.multinomial(norm_probs, batch_size, replacement=True)
- ranges = self.bucket_ranges[buckets]
- logSNR = torch.rand(batch_size) * (ranges[:, 1]-ranges[:, 0]) + ranges[:, 0]
- return logSNR
-
- def update_buckets(self, logSNR, loss, beta=0.99):
- range_mtx = self.bucket_ranges.unsqueeze(0).expand(logSNR.size(0), -1, -1).to(logSNR.device)
- range_mask = (range_mtx[:, :, 0] <= logSNR[:, None]) * (range_mtx[:, :, 1] > logSNR[:, None]).float()
- range_idx = range_mask.argmax(-1).cpu()
- self.bucket_probs[range_idx] = self.bucket_probs[range_idx] * beta + loss.detach().cpu() * (1-beta)
-
-class InterpolatedSchedule(BaseSchedule):
- def setup(self, scheduler1, scheduler2, shifts=[1.0, 1.0]):
- self.scheduler1 = scheduler1
- self.scheduler2 = scheduler2
- self.shifts = shifts
-
- def schedule(self, t, batch_size):
- if t is None:
- t = 1-torch.rand(batch_size)
- t = t.clamp(1e-7, 1-1e-7) # avoid infinities multiplied by 0 which cause nan
- low_logSNR = self.scheduler1(t, shift=self.shifts[0])
- high_logSNR = self.scheduler2(t, shift=self.shifts[1])
- return low_logSNR * t + high_logSNR * (1-t)
-
diff --git a/gdf/targets.py b/gdf/targets.py
deleted file mode 100644
index 115062b6001f93082fa836e1f3742723e5972efe..0000000000000000000000000000000000000000
--- a/gdf/targets.py
+++ /dev/null
@@ -1,46 +0,0 @@
-class EpsilonTarget():
- def __call__(self, x0, epsilon, logSNR, a, b):
- return epsilon
-
- def x0(self, noised, pred, logSNR, a, b):
- return (noised - pred * b) / a
-
- def epsilon(self, noised, pred, logSNR, a, b):
- return pred
- def noise_givenx0_noised(self, x0, noised , logSNR, a, b):
- return (noised - a * x0) / b
- def xt(self, x0, noise, logSNR, a, b):
-
- return x0 * a + noise*b
-class X0Target():
- def __call__(self, x0, epsilon, logSNR, a, b):
- return x0
-
- def x0(self, noised, pred, logSNR, a, b):
- return pred
-
- def epsilon(self, noised, pred, logSNR, a, b):
- return (noised - pred * a) / b
-
-class VTarget():
- def __call__(self, x0, epsilon, logSNR, a, b):
- return a * epsilon - b * x0
-
- def x0(self, noised, pred, logSNR, a, b):
- squared_sum = a**2 + b**2
- return a/squared_sum * noised - b/squared_sum * pred
-
- def epsilon(self, noised, pred, logSNR, a, b):
- squared_sum = a**2 + b**2
- return b/squared_sum * noised + a/squared_sum * pred
-
-class RectifiedFlowsTarget():
- def __call__(self, x0, epsilon, logSNR, a, b):
- return epsilon - x0
-
- def x0(self, noised, pred, logSNR, a, b):
- return noised - pred * b
-
- def epsilon(self, noised, pred, logSNR, a, b):
- return noised + pred * a
-
\ No newline at end of file
diff --git a/inference/__init__.py b/inference/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/inference/test_controlnet.py b/inference/test_controlnet.py
deleted file mode 100644
index 250578262d2a118ece8b5a706aba1cd8115c62f5..0000000000000000000000000000000000000000
--- a/inference/test_controlnet.py
+++ /dev/null
@@ -1,166 +0,0 @@
-import os
-import yaml
-import torch
-import torchvision
-from tqdm import tqdm
-import sys
-sys.path.append(os.path.abspath('./'))
-
-from inference.utils import *
-from core.utils import load_or_fail
-from train import WurstCore_control_lrguide, WurstCoreB
-from PIL import Image
-from core.utils import load_or_fail
-import math
-import argparse
-import time
-import random
-import numpy as np
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument( '--height', type=int, default=3840, help='image height')
- parser.add_argument('--width', type=int, default=2160, help='image width')
- parser.add_argument('--control_weight', type=float, default=0.70, help='[ 0.3, 0.8]')
- parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ')
- parser.add_argument('--seed', type=int, default=123, help='random seed')
- parser.add_argument('--config_c', type=str,
- default='configs/training/cfg_control_lr.yaml' ,help='config file for stage c, latent generation')
- parser.add_argument('--config_b', type=str,
- default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding')
- parser.add_argument( '--prompt', type=str,
- default='A peaceful lake surrounded by mountain, white cloud in the sky, high quality,', help='text prompt')
- parser.add_argument( '--num_image', type=int, default=4, help='how many images generated')
- parser.add_argument( '--output_dir', type=str, default='figures/controlnet_results/', help='output directory for generated image')
- parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory')
- parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel')
- parser.add_argument( '--canny_source_url', type=str, default="figures/California_000490.jpg", help='image used to extract canny edge map')
-
- args = parser.parse_args()
- return args
-
-
-if __name__ == "__main__":
-
- args = parse_args()
- width = args.width
- height = args.height
- torch.manual_seed(args.seed)
- random.seed(args.seed)
- np.random.seed(args.seed)
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float
-
-
- # SETUP STAGE C
- with open(args.config_c, "r", encoding="utf-8") as file:
- loaded_config = yaml.safe_load(file)
- core = WurstCore_control_lrguide(config_dict=loaded_config, device=device, training=False)
-
- # SETUP STAGE B
- with open(args.config_b, "r", encoding="utf-8") as file:
- config_file_b = yaml.safe_load(file)
-
- core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
-
- extras = core.setup_extras_pre()
- models = core.setup_models(extras)
- models.generator.eval().requires_grad_(False)
- print("CONTROLNET READY")
-
- extras_b = core_b.setup_extras_pre()
- models_b = core_b.setup_models(extras_b, skip_clip=True)
- models_b = WurstCoreB.Models(
- **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
- )
- models_b.generator.eval().requires_grad_(False)
- print("STAGE B READY")
-
- batch_size = 1
- save_dir = args.output_dir
- url = args.canny_source_url
- images = resize_image(Image.open(url).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1)
- batch = {'images': images}
-
-
-
-
-
-
- cnet_multiplier = args.control_weight # 0.8 0.6 0.3 control strength
- caption_list = [args.prompt] * args.num_image
- height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
- stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
-
-
-
-
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
-
-
- sdd = torch.load(args.pretrained_path, map_location='cpu')
- collect_sd = {}
- for k, v in sdd.items():
- collect_sd[k[7:]] = v
- models.train_norm.load_state_dict(collect_sd, strict=True)
-
-
-
-
- models.controlnet.load_state_dict(load_or_fail(core.config.controlnet_checkpoint_path), strict=True)
- # Stage C Parameters
- extras.sampling_configs['cfg'] = 1
- extras.sampling_configs['shift'] = 2
- extras.sampling_configs['timesteps'] = 20
- extras.sampling_configs['t_start'] = 1.0
-
- # Stage B Parameters
- extras_b.sampling_configs['cfg'] = 1.1
- extras_b.sampling_configs['shift'] = 1
- extras_b.sampling_configs['timesteps'] = 10
- extras_b.sampling_configs['t_start'] = 1.0
-
- # PREPARE CONDITIONS
-
-
-
-
- for out_cnt, caption in enumerate(caption_list):
- with torch.no_grad():
-
- batch['captions'] = [caption + ' high quality'] * batch_size
- conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
- unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
-
- cnet, cnet_input = core.get_cnet(batch, models, extras)
- cnet_uncond = cnet
- conditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet]
- unconditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet_uncond]
- edge_images = show_images(cnet_input)
- models.generator.cuda()
- for idx, img in enumerate(edge_images):
- img.save(os.path.join(save_dir, f"edge_{url.split('/')[-1]}"))
-
-
- print('STAGE C GENERATION***************************')
- with torch.cuda.amp.autocast(dtype=dtype):
- sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions, unconditions)
- models.generator.cpu()
- torch.cuda.empty_cache()
-
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
-
- conditions_b['effnet'] = sampled_c
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
- print('STAGE B + A DECODING***************************')
- with torch.cuda.amp.autocast(dtype=dtype):
- sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled)
-
- torch.cuda.empty_cache()
- imgs = show_images(sampled)
-
- for idx, img in enumerate(imgs):
- img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(out_cnt).zfill(5) + '.jpg'))
- print('finished! Results at ', save_dir )
diff --git a/inference/test_personalized.py b/inference/test_personalized.py
deleted file mode 100644
index 840d52d0ef3b026e73c34f715b7b18ec3537e62a..0000000000000000000000000000000000000000
--- a/inference/test_personalized.py
+++ /dev/null
@@ -1,180 +0,0 @@
-
-import os
-import yaml
-import torch
-from tqdm import tqdm
-import sys
-sys.path.append(os.path.abspath('./'))
-from inference.utils import *
-from train import WurstCoreB
-from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
-from train import WurstCore_personalized as WurstCoreC
-import torch.nn.functional as F
-import numpy as np
-import random
-import math
-import argparse
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument( '--height', type=int, default=3072, help='image height')
- parser.add_argument('--width', type=int, default=4096, help='image width')
- parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ')
- parser.add_argument('--seed', type=int, default=23, help='random seed')
- parser.add_argument('--config_c', type=str,
- default="configs/training/lora_personalization.yaml" ,help='config file for stage c, latent generation')
- parser.add_argument('--config_b', type=str,
- default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding')
- parser.add_argument( '--prompt', type=str,
- default='A photo of cat [roubaobao] with sunglasses, Time Square in the background, high quality, detail rich, 8k', help='text prompt')
- parser.add_argument( '--num_image', type=int, default=4, help='how many images generated')
- parser.add_argument( '--output_dir', type=str, default='figures/personalized/', help='output directory for generated image')
- parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory')
- parser.add_argument( '--pretrained_path_lora', type=str, default='models/lora_cat.safetensors',help='pretrained path of personalized lora parameter')
- parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel')
- args = parser.parse_args()
- return args
-
-if __name__ == "__main__":
- args = parse_args()
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- torch.manual_seed(args.seed)
- random.seed(args.seed)
- np.random.seed(args.seed)
- dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float
-
-
- # SETUP STAGE C
- with open(args.config_c, "r", encoding="utf-8") as file:
- loaded_config = yaml.safe_load(file)
- core = WurstCoreC(config_dict=loaded_config, device=device, training=False)
-
- # SETUP STAGE B
- with open(args.config_b, "r", encoding="utf-8") as file:
- config_file_b = yaml.safe_load(file)
- core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
-
- extras = core.setup_extras_pre()
- models = core.setup_models(extras)
- models.generator.eval().requires_grad_(False)
- print("STAGE C READY")
-
- extras_b = core_b.setup_extras_pre()
- models_b = core_b.setup_models(extras_b, skip_clip=True)
- models_b = WurstCoreB.Models(
- **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
- )
- models_b.generator.bfloat16().eval().requires_grad_(False)
- print("STAGE B READY")
-
-
- batch_size = 1
- captions = [args.prompt] * args.num_image
- height, width = args.height, args.width
- save_dir = args.output_dir
-
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
-
-
- pretrained_pth = args.pretrained_path
- sdd = torch.load(pretrained_pth, map_location='cpu')
- collect_sd = {}
- for k, v in sdd.items():
- collect_sd[k[7:]] = v
-
- models.train_norm.load_state_dict(collect_sd)
-
-
- pretrained_pth_lora = args.pretrained_path_lora
- sdd = torch.load(pretrained_pth_lora, map_location='cpu')
- collect_sd = {}
- for k, v in sdd.items():
- collect_sd[k[7:]] = v
-
- models.train_lora.load_state_dict(collect_sd)
-
-
- models.generator.eval()
- models.train_norm.eval()
-
-
- height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
- stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
-
- # Stage C Parameters
-
- extras.sampling_configs['cfg'] = 4
- extras.sampling_configs['shift'] = 1
- extras.sampling_configs['timesteps'] = 20
- extras.sampling_configs['t_start'] = 1.0
- extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf)
-
-
-
- # Stage B Parameters
-
- extras_b.sampling_configs['cfg'] = 1.1
- extras_b.sampling_configs['shift'] = 1
- extras_b.sampling_configs['timesteps'] = 10
- extras_b.sampling_configs['t_start'] = 1.0
-
-
- for cnt, caption in enumerate(captions):
-
- batch = {'captions': [caption] * batch_size}
- conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
- unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
-
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
-
-
-
-
- for cnt, caption in enumerate(captions):
-
-
- batch = {'captions': [caption] * batch_size}
- conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
- unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
-
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
-
-
- with torch.no_grad():
-
-
- models.generator.cuda()
- print('STAGE C GENERATION***************************')
- with torch.cuda.amp.autocast(dtype=dtype):
- sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
-
-
-
- models.generator.cpu()
- torch.cuda.empty_cache()
-
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
- conditions_b['effnet'] = sampled_c
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
- print('STAGE B + A DECODING***************************')
-
- with torch.cuda.amp.autocast(dtype=dtype):
- sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled)
-
- torch.cuda.empty_cache()
- imgs = show_images(sampled)
- for idx, img in enumerate(imgs):
- print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx)
- img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'))
-
-
- print('finished! Results at ', save_dir )
-
-
-
diff --git a/inference/test_t2i.py b/inference/test_t2i.py
deleted file mode 100644
index 3478f95e4c706d88a8c73688ed4e990adc9ea8d4..0000000000000000000000000000000000000000
--- a/inference/test_t2i.py
+++ /dev/null
@@ -1,170 +0,0 @@
-
-import os
-import yaml
-import torch
-from tqdm import tqdm
-import sys
-sys.path.append(os.path.abspath('./'))
-from inference.utils import *
-from core.utils import load_or_fail
-from train import WurstCoreB
-from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
-from train import WurstCore_t2i as WurstCoreC
-import torch.nn.functional as F
-from core.utils import load_or_fail
-import numpy as np
-import random
-import math
-import argparse
-from einops import rearrange
-import math
-#inrfft_3b_strc_WurstCore
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument( '--height', type=int, default=2560, help='image height')
- parser.add_argument('--width', type=int, default=5120, help='image width')
- parser.add_argument('--seed', type=int, default=123, help='random seed')
- parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ')
- parser.add_argument('--config_c', type=str,
- default='configs/training/t2i.yaml' ,help='config file for stage c, latent generation')
- parser.add_argument('--config_b', type=str,
- default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding')
- parser.add_argument( '--prompt', type=str,
- default='A photo-realistic image of a west highland white terrier in the garden, high quality, detail rich, 8K', help='text prompt')
- parser.add_argument( '--num_image', type=int, default=10, help='how many images generated')
- parser.add_argument( '--output_dir', type=str, default='figures/output_results/', help='output directory for generated image')
- parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory')
- parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel')
- args = parser.parse_args()
- return args
-
-
-
-if __name__ == "__main__":
-
- args = parse_args()
- print(args)
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- print(device)
- torch.manual_seed(args.seed)
- random.seed(args.seed)
- np.random.seed(args.seed)
- dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float
- #gdf = gdf_refine(
- # schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
- # input_scaler=VPScaler(), target=EpsilonTarget(),
- # noise_cond=CosineTNoiseCond(),
- # loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
- # )
- # SETUP STAGE C
- config_file = args.config_c
- with open(config_file, "r", encoding="utf-8") as file:
- loaded_config = yaml.safe_load(file)
-
- core = WurstCoreC(config_dict=loaded_config, device=device, training=False)
-
- # SETUP STAGE B
- config_file_b = args.config_b
- with open(config_file_b, "r", encoding="utf-8") as file:
- config_file_b = yaml.safe_load(file)
-
- core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
-
- extras = core.setup_extras_pre()
- models = core.setup_models(extras)
- models.generator.eval().requires_grad_(False)
- print("STAGE C READY")
-
- extras_b = core_b.setup_extras_pre()
- models_b = core_b.setup_models(extras_b, skip_clip=True)
- models_b = WurstCoreB.Models(
- **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
- )
- models_b.generator.bfloat16().eval().requires_grad_(False)
- print("STAGE B READY")
-
- captions = [args.prompt] * args.num_image
-
-
- height, width = args.height, args.width
- save_dir = args.output_dir
-
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
-
- pretrained_path = args.pretrained_path
- sdd = torch.load(pretrained_path, map_location='cpu')
- collect_sd = {}
- for k, v in sdd.items():
- collect_sd[k[7:]] = v
-
- models.train_norm.load_state_dict(collect_sd)
-
-
- models.generator.eval()
- models.train_norm.eval()
-
- batch_size=1
- height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
- stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
-
- # Stage C Parameters
- extras.sampling_configs['cfg'] = 4
- extras.sampling_configs['shift'] = 1
- extras.sampling_configs['timesteps'] = 20
- extras.sampling_configs['t_start'] = 1.0
- extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf)
-
-
-
- # Stage B Parameters
- extras_b.sampling_configs['cfg'] = 1.1
- extras_b.sampling_configs['shift'] = 1
- extras_b.sampling_configs['timesteps'] = 10
- extras_b.sampling_configs['t_start'] = 1.0
-
-
-
-
- for cnt, caption in enumerate(captions):
-
-
- batch = {'captions': [caption] * batch_size}
- conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
- unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
-
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
-
-
- with torch.no_grad():
-
-
- models.generator.cuda()
- print('STAGE C GENERATION***************************')
- with torch.cuda.amp.autocast(dtype=dtype):
- sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
-
-
-
- models.generator.cpu()
- torch.cuda.empty_cache()
-
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
- conditions_b['effnet'] = sampled_c
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
- print('STAGE B + A DECODING***************************')
-
- with torch.cuda.amp.autocast(dtype=dtype):
- sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled)
-
- torch.cuda.empty_cache()
- imgs = show_images(sampled)
- for idx, img in enumerate(imgs):
- print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx)
- img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'))
-
-
- print('finished! Results at ', save_dir )
diff --git a/inference/utils.py b/inference/utils.py
deleted file mode 100644
index ab5af277069ec7803d53ff8f5fa29bed41fde29b..0000000000000000000000000000000000000000
--- a/inference/utils.py
+++ /dev/null
@@ -1,131 +0,0 @@
-import PIL
-import torch
-import requests
-import torchvision
-from math import ceil
-from io import BytesIO
-import matplotlib.pyplot as plt
-import torchvision.transforms.functional as F
-import math
-from tqdm import tqdm
-def download_image(url):
- return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB")
-
-
-def resize_image(image, size=768):
- tensor_image = F.to_tensor(image)
- resized_image = F.resize(tensor_image, size, antialias=True)
- return resized_image
-
-
-def downscale_images(images, factor=3/4):
- scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32)
- scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST)
- return scaled_image
-
-
-
-def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
- resolution_multiple = 42.67
- latent_height = ceil(height / compression_factor_b)
- latent_width = ceil(width / compression_factor_b)
- stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
-
- latent_height = ceil(height / compression_factor_a)
- latent_width = ceil(width / compression_factor_a)
- stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
-
- return stage_c_latent_shape, stage_b_latent_shape
-
-
-def get_views(H, W, window_size=64, stride=16):
- '''
- - H, W: height and width of the latent
- '''
- num_blocks_height = (H - window_size) // stride + 1
- num_blocks_width = (W - window_size) // stride + 1
- total_num_blocks = int(num_blocks_height * num_blocks_width)
- views = []
- for i in range(total_num_blocks):
- h_start = int((i // num_blocks_width) * stride)
- h_end = h_start + window_size
- w_start = int((i % num_blocks_width) * stride)
- w_end = w_start + window_size
- views.append((h_start, h_end, w_start, w_end))
- return views
-
-
-
-def show_images(images, rows=None, cols=None, **kwargs):
- if images.size(1) == 1:
- images = images.repeat(1, 3, 1, 1)
- elif images.size(1) > 3:
- images = images[:, :3]
-
- if rows is None:
- rows = 1
- if cols is None:
- cols = images.size(0) // rows
-
- _, _, h, w = images.shape
-
- imgs = []
- for i, img in enumerate(images):
- imgs.append( torchvision.transforms.functional.to_pil_image(img.clamp(0, 1)))
-
- return imgs
-
-
-
-def decode_b(conditions_b, unconditions_b, models_b, bshape, extras_b, device, \
- stage_a_tiled=False, num_instance=4, patch_size=256, stride=24):
-
-
- sampling_b = extras_b.gdf.sample(
- models_b.generator.half(), conditions_b, bshape,
- unconditions_b, device=device,
- **extras_b.sampling_configs,
- )
- models_b.generator.cuda()
- for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
- sampled_b = sampled_b
- models_b.generator.cpu()
- torch.cuda.empty_cache()
- if stage_a_tiled:
- with torch.cuda.amp.autocast(dtype=torch.float16):
- padding = (stride*2, stride*2, stride*2, stride*2)
- sampled_b = torch.nn.functional.pad(sampled_b, padding, mode='reflect')
- count = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device)
- sampled = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device)
- views = get_views(sampled_b.shape[-2], sampled_b.shape[-1], window_size=patch_size, stride=stride)
-
- for view_idx, (h_start, h_end, w_start, w_end) in enumerate(tqdm(views, total=len(views))):
-
- sampled[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += models_b.stage_a.decode(sampled_b[:, :, h_start:h_end, w_start:w_end]).float()
- count[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += 1
- sampled /= count
- sampled = sampled[:, :, stride*4*2:-stride*4*2, stride*4*2:-stride*4*2]
- else:
-
- sampled = models_b.stage_a.decode(sampled_b, tiled_decoding=stage_a_tiled)
-
- return sampled.float()
-
-
-def generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions=None, unconditions=None):
- if conditions is None:
- conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
- if unconditions is None:
- unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
- sampling_c = extras.gdf.sample(
- models.generator, conditions, stage_c_latent_shape, stage_c_latent_shape_lr,
- unconditions, device=device, **extras.sampling_configs,
- )
- for idx, (sampled_c, sampled_c_curr, _, _) in enumerate(tqdm(sampling_c, total=extras.sampling_configs['timesteps'])):
- sampled_c = sampled_c
- return sampled_c
-
-def get_target_lr_size(ratio, std_size=24):
- w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio))
- return (h * 32 , w *32 )
-
diff --git a/models/models_checklist.txt b/models/models_checklist.txt
deleted file mode 100644
index 2fdec27a72db473c51893abc64826514b1d9d065..0000000000000000000000000000000000000000
--- a/models/models_checklist.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors
-https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors
-https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors
-https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors
-https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors
-https://huggingface.co/roubaofeipi/UltraPixel/blob/main/ultrapixel_t2i.safetensors
-https://huggingface.co/roubaofeipi/UltraPixel/blob/main/lora_cat.safetensors (only required for personalization)
\ No newline at end of file
diff --git a/modules/__init__.py b/modules/__init__.py
deleted file mode 100644
index a6fcf5aa2a39061c3f4f82dde6ff063411223cb3..0000000000000000000000000000000000000000
--- a/modules/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from .effnet import EfficientNetEncoder
-from .stage_c import StageC
-from .stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
-from .previewer import Previewer
-from .controlnet import ControlNet, ControlNetDeliverer
-from . import controlnet as controlnet_filters
diff --git a/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc b/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc
deleted file mode 100644
index 8c74bb92cb0db0876acda8aa3d102141526fd428..0000000000000000000000000000000000000000
Binary files a/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc and /dev/null differ
diff --git a/modules/cnet_modules/face_id/arcface.py b/modules/cnet_modules/face_id/arcface.py
deleted file mode 100644
index 64e918bb90437f6f193a7ec384bea1fcd73c7abb..0000000000000000000000000000000000000000
--- a/modules/cnet_modules/face_id/arcface.py
+++ /dev/null
@@ -1,276 +0,0 @@
-import numpy as np
-import onnx, onnx2torch, cv2
-import torch
-from insightface.utils import face_align
-
-
-class ArcFaceRecognizer:
- def __init__(self, model_file=None, device='cpu', dtype=torch.float32):
- assert model_file is not None
- self.model_file = model_file
-
- self.device = device
- self.dtype = dtype
- self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype)
- for param in self.model.parameters():
- param.requires_grad = False
- self.model.eval()
-
- self.input_mean = 127.5
- self.input_std = 127.5
- self.input_size = (112, 112)
- self.input_shape = ['None', 3, 112, 112]
-
- def get(self, img, face):
- aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0])
- face.embedding = self.get_feat(aimg).flatten()
- return face.embedding
-
- def compute_sim(self, feat1, feat2):
- from numpy.linalg import norm
- feat1 = feat1.ravel()
- feat2 = feat2.ravel()
- sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
- return sim
-
- def get_feat(self, imgs):
- if not isinstance(imgs, list):
- imgs = [imgs]
- input_size = self.input_size
-
- blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
- (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
-
- blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype)
- net_out = self.model(blob_torch)
- return net_out[0].float().cpu()
-
-
-def distance2bbox(points, distance, max_shape=None):
- """Decode distance prediction to bounding box.
-
- Args:
- points (Tensor): Shape (n, 2), [x, y].
- distance (Tensor): Distance from the given point to 4
- boundaries (left, top, right, bottom).
- max_shape (tuple): Shape of the image.
-
- Returns:
- Tensor: Decoded bboxes.
- """
- x1 = points[:, 0] - distance[:, 0]
- y1 = points[:, 1] - distance[:, 1]
- x2 = points[:, 0] + distance[:, 2]
- y2 = points[:, 1] + distance[:, 3]
- if max_shape is not None:
- x1 = x1.clamp(min=0, max=max_shape[1])
- y1 = y1.clamp(min=0, max=max_shape[0])
- x2 = x2.clamp(min=0, max=max_shape[1])
- y2 = y2.clamp(min=0, max=max_shape[0])
- return np.stack([x1, y1, x2, y2], axis=-1)
-
-
-def distance2kps(points, distance, max_shape=None):
- """Decode distance prediction to bounding box.
-
- Args:
- points (Tensor): Shape (n, 2), [x, y].
- distance (Tensor): Distance from the given point to 4
- boundaries (left, top, right, bottom).
- max_shape (tuple): Shape of the image.
-
- Returns:
- Tensor: Decoded bboxes.
- """
- preds = []
- for i in range(0, distance.shape[1], 2):
- px = points[:, i % 2] + distance[:, i]
- py = points[:, i % 2 + 1] + distance[:, i + 1]
- if max_shape is not None:
- px = px.clamp(min=0, max=max_shape[1])
- py = py.clamp(min=0, max=max_shape[0])
- preds.append(px)
- preds.append(py)
- return np.stack(preds, axis=-1)
-
-
-class FaceDetector:
- def __init__(self, model_file=None, dtype=torch.float32, device='cuda'):
- self.model_file = model_file
- self.taskname = 'detection'
- self.center_cache = {}
- self.nms_thresh = 0.4
- self.det_thresh = 0.5
-
- self.device = device
- self.dtype = dtype
- self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype)
- for param in self.model.parameters():
- param.requires_grad = False
- self.model.eval()
-
- input_shape = (320, 320)
- self.input_size = input_shape
- self.input_shape = input_shape
-
- self.input_mean = 127.5
- self.input_std = 128.0
- self._anchor_ratio = 1.0
- self._num_anchors = 1
- self.fmc = 3
- self._feat_stride_fpn = [8, 16, 32]
- self._num_anchors = 2
- self.use_kps = True
-
- self.det_thresh = 0.5
- self.nms_thresh = 0.4
-
- def forward(self, img, threshold):
- scores_list = []
- bboxes_list = []
- kpss_list = []
- input_size = tuple(img.shape[0:2][::-1])
- blob = cv2.dnn.blobFromImage(img, 1.0 / self.input_std, input_size,
- (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
- blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype)
- net_outs_torch = self.model(blob_torch)
- # print(list(map(lambda x: x.shape, net_outs_torch)))
- net_outs = list(map(lambda x: x.float().cpu().numpy(), net_outs_torch))
-
- input_height = blob.shape[2]
- input_width = blob.shape[3]
- fmc = self.fmc
- for idx, stride in enumerate(self._feat_stride_fpn):
- scores = net_outs[idx]
- bbox_preds = net_outs[idx + fmc]
- bbox_preds = bbox_preds * stride
- if self.use_kps:
- kps_preds = net_outs[idx + fmc * 2] * stride
- height = input_height // stride
- width = input_width // stride
- K = height * width
- key = (height, width, stride)
- if key in self.center_cache:
- anchor_centers = self.center_cache[key]
- else:
- # solution-1, c style:
- # anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
- # for i in range(height):
- # anchor_centers[i, :, 1] = i
- # for i in range(width):
- # anchor_centers[:, i, 0] = i
-
- # solution-2:
- # ax = np.arange(width, dtype=np.float32)
- # ay = np.arange(height, dtype=np.float32)
- # xv, yv = np.meshgrid(np.arange(width), np.arange(height))
- # anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
-
- # solution-3:
- anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
- # print(anchor_centers.shape)
-
- anchor_centers = (anchor_centers * stride).reshape((-1, 2))
- if self._num_anchors > 1:
- anchor_centers = np.stack([anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2))
- if len(self.center_cache) < 100:
- self.center_cache[key] = anchor_centers
-
- pos_inds = np.where(scores >= threshold)[0]
- bboxes = distance2bbox(anchor_centers, bbox_preds)
- pos_scores = scores[pos_inds]
- pos_bboxes = bboxes[pos_inds]
- scores_list.append(pos_scores)
- bboxes_list.append(pos_bboxes)
- if self.use_kps:
- kpss = distance2kps(anchor_centers, kps_preds)
- # kpss = kps_preds
- kpss = kpss.reshape((kpss.shape[0], -1, 2))
- pos_kpss = kpss[pos_inds]
- kpss_list.append(pos_kpss)
- return scores_list, bboxes_list, kpss_list
-
- def detect(self, img, input_size=None, max_num=0, metric='default'):
- assert input_size is not None or self.input_size is not None
- input_size = self.input_size if input_size is None else input_size
-
- im_ratio = float(img.shape[0]) / img.shape[1]
- model_ratio = float(input_size[1]) / input_size[0]
- if im_ratio > model_ratio:
- new_height = input_size[1]
- new_width = int(new_height / im_ratio)
- else:
- new_width = input_size[0]
- new_height = int(new_width * im_ratio)
- det_scale = float(new_height) / img.shape[0]
- resized_img = cv2.resize(img, (new_width, new_height))
- det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8)
- det_img[:new_height, :new_width, :] = resized_img
-
- scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh)
-
- scores = np.vstack(scores_list)
- scores_ravel = scores.ravel()
- order = scores_ravel.argsort()[::-1]
- bboxes = np.vstack(bboxes_list) / det_scale
- if self.use_kps:
- kpss = np.vstack(kpss_list) / det_scale
- pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
- pre_det = pre_det[order, :]
- keep = self.nms(pre_det)
- det = pre_det[keep, :]
- if self.use_kps:
- kpss = kpss[order, :, :]
- kpss = kpss[keep, :, :]
- else:
- kpss = None
- if max_num > 0 and det.shape[0] > max_num:
- area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
- det[:, 1])
- img_center = img.shape[0] // 2, img.shape[1] // 2
- offsets = np.vstack([
- (det[:, 0] + det[:, 2]) / 2 - img_center[1],
- (det[:, 1] + det[:, 3]) / 2 - img_center[0]
- ])
- offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
- if metric == 'max':
- values = area
- else:
- values = area - offset_dist_squared * 2.0 # some extra weight on the centering
- bindex = np.argsort(
- values)[::-1] # some extra weight on the centering
- bindex = bindex[0:max_num]
- det = det[bindex, :]
- if kpss is not None:
- kpss = kpss[bindex, :]
- return det, kpss
-
- def nms(self, dets):
- thresh = self.nms_thresh
- x1 = dets[:, 0]
- y1 = dets[:, 1]
- x2 = dets[:, 2]
- y2 = dets[:, 3]
- scores = dets[:, 4]
-
- areas = (x2 - x1 + 1) * (y2 - y1 + 1)
- order = scores.argsort()[::-1]
-
- keep = []
- while order.size > 0:
- i = order[0]
- keep.append(i)
- xx1 = np.maximum(x1[i], x1[order[1:]])
- yy1 = np.maximum(y1[i], y1[order[1:]])
- xx2 = np.minimum(x2[i], x2[order[1:]])
- yy2 = np.minimum(y2[i], y2[order[1:]])
-
- w = np.maximum(0.0, xx2 - xx1 + 1)
- h = np.maximum(0.0, yy2 - yy1 + 1)
- inter = w * h
- ovr = inter / (areas[i] + areas[order[1:]] - inter)
-
- inds = np.where(ovr <= thresh)[0]
- order = order[inds + 1]
-
- return keep
diff --git a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc
deleted file mode 100644
index 8200104d6d66a1084685c76373c38d752ed9c3d4..0000000000000000000000000000000000000000
Binary files a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc and /dev/null differ
diff --git a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc
deleted file mode 100644
index ca432e5c5eed7ba17fc6cafb06a3ebe16002f67e..0000000000000000000000000000000000000000
Binary files a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc and /dev/null differ
diff --git a/modules/cnet_modules/inpainting/saliency_model.pt b/modules/cnet_modules/inpainting/saliency_model.pt
deleted file mode 100644
index e1b02cc60b2999a8f9ff90557182e3dafab63db7..0000000000000000000000000000000000000000
--- a/modules/cnet_modules/inpainting/saliency_model.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:225a602e1f2a5d159424be011a63b27d83b56343a4379a90710eca9a26bab920
-size 451123
diff --git a/modules/cnet_modules/inpainting/saliency_model.py b/modules/cnet_modules/inpainting/saliency_model.py
deleted file mode 100644
index 82355a02baead47f50fe643e57b81f8caca78f79..0000000000000000000000000000000000000000
--- a/modules/cnet_modules/inpainting/saliency_model.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import torch
-import torchvision
-from torch import nn
-from PIL import Image
-import numpy as np
-import os
-
-
-# MICRO RESNET
-class ResBlock(nn.Module):
- def __init__(self, channels):
- super(ResBlock, self).__init__()
-
- self.resblock = nn.Sequential(
- nn.ReflectionPad2d(1),
- nn.Conv2d(channels, channels, kernel_size=3),
- nn.InstanceNorm2d(channels, affine=True),
- nn.ReLU(),
- nn.ReflectionPad2d(1),
- nn.Conv2d(channels, channels, kernel_size=3),
- nn.InstanceNorm2d(channels, affine=True),
- )
-
- def forward(self, x):
- out = self.resblock(x)
- return out + x
-
-
-class Upsample2d(nn.Module):
- def __init__(self, scale_factor):
- super(Upsample2d, self).__init__()
-
- self.interp = nn.functional.interpolate
- self.scale_factor = scale_factor
-
- def forward(self, x):
- x = self.interp(x, scale_factor=self.scale_factor, mode='nearest')
- return x
-
-
-class MicroResNet(nn.Module):
- def __init__(self):
- super(MicroResNet, self).__init__()
-
- self.downsampler = nn.Sequential(
- nn.ReflectionPad2d(4),
- nn.Conv2d(3, 8, kernel_size=9, stride=4),
- nn.InstanceNorm2d(8, affine=True),
- nn.ReLU(),
- nn.ReflectionPad2d(1),
- nn.Conv2d(8, 16, kernel_size=3, stride=2),
- nn.InstanceNorm2d(16, affine=True),
- nn.ReLU(),
- nn.ReflectionPad2d(1),
- nn.Conv2d(16, 32, kernel_size=3, stride=2),
- nn.InstanceNorm2d(32, affine=True),
- nn.ReLU(),
- )
-
- self.residual = nn.Sequential(
- ResBlock(32),
- nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32),
- ResBlock(64),
- )
-
- self.segmentator = nn.Sequential(
- nn.ReflectionPad2d(1),
- nn.Conv2d(64, 16, kernel_size=3),
- nn.InstanceNorm2d(16, affine=True),
- nn.ReLU(),
- Upsample2d(scale_factor=2),
- nn.ReflectionPad2d(4),
- nn.Conv2d(16, 1, kernel_size=9),
- nn.Sigmoid()
- )
-
- def forward(self, x):
- out = self.downsampler(x)
- out = self.residual(out)
- out = self.segmentator(out)
- return out
diff --git a/modules/cnet_modules/pidinet/__init__.py b/modules/cnet_modules/pidinet/__init__.py
deleted file mode 100644
index a2b4625bf915cc6c4053b7d7861a22ff371bc641..0000000000000000000000000000000000000000
--- a/modules/cnet_modules/pidinet/__init__.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# Pidinet
-# https://github.com/hellozhuo/pidinet
-
-import os
-import torch
-import numpy as np
-from einops import rearrange
-from .model import pidinet
-from .util import annotator_ckpts_path, safe_step
-
-
-class PidiNetDetector:
- def __init__(self, device):
- remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth"
- modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth")
- if not os.path.exists(modelpath):
- from basicsr.utils.download_util import load_file_from_url
- load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
- self.netNetwork = pidinet()
- self.netNetwork.load_state_dict(
- {k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()})
- self.netNetwork.to(device).eval().requires_grad_(False)
-
- def __call__(self, input_image): # , safe=False):
- return self.netNetwork(input_image)[-1]
- # assert input_image.ndim == 3
- # input_image = input_image[:, :, ::-1].copy()
- # with torch.no_grad():
- # image_pidi = torch.from_numpy(input_image).float().cuda()
- # image_pidi = image_pidi / 255.0
- # image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w')
- # edge = self.netNetwork(image_pidi)[-1]
-
- # if safe:
- # edge = safe_step(edge)
- # edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
- # return edge[0][0]
diff --git a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 07fca0abb9c90b7b40746b4044c4000ae69e00c7..0000000000000000000000000000000000000000
Binary files a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc
deleted file mode 100644
index 5a060aa2baa87a3670aa0bf8276e2f34bafe9451..0000000000000000000000000000000000000000
Binary files a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc and /dev/null differ
diff --git a/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc
deleted file mode 100644
index 2243c853d18e2a404ced3eb4ac6a95a7a9ee6874..0000000000000000000000000000000000000000
Binary files a/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc and /dev/null differ
diff --git a/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc
deleted file mode 100644
index 7f70342fc64759bc7459abf0f7986ee3b7fd2126..0000000000000000000000000000000000000000
Binary files a/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc and /dev/null differ
diff --git a/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc
deleted file mode 100644
index b2e7ab031924860f1262f4d44bf2eaf57ca78edd..0000000000000000000000000000000000000000
Binary files a/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc and /dev/null differ
diff --git a/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc
deleted file mode 100644
index 4da8564d03f99caa7a45d9ccb1358cb282cd2711..0000000000000000000000000000000000000000
Binary files a/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc and /dev/null differ
diff --git a/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth b/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth
deleted file mode 100644
index 1ceba1de87e7bb3c81961b80acbb3a106ca249c0..0000000000000000000000000000000000000000
--- a/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:80860ac267258b5f27486e0ef152a211d0b08120f62aeb185a050acc30da486c
-size 2871148
diff --git a/modules/cnet_modules/pidinet/model.py b/modules/cnet_modules/pidinet/model.py
deleted file mode 100644
index 26644c6f6174c3b5407bd10c914045758cbadefe..0000000000000000000000000000000000000000
--- a/modules/cnet_modules/pidinet/model.py
+++ /dev/null
@@ -1,654 +0,0 @@
-"""
-Author: Zhuo Su, Wenzhe Liu
-Date: Feb 18, 2021
-"""
-
-import math
-
-import cv2
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-nets = {
- 'baseline': {
- 'layer0': 'cv',
- 'layer1': 'cv',
- 'layer2': 'cv',
- 'layer3': 'cv',
- 'layer4': 'cv',
- 'layer5': 'cv',
- 'layer6': 'cv',
- 'layer7': 'cv',
- 'layer8': 'cv',
- 'layer9': 'cv',
- 'layer10': 'cv',
- 'layer11': 'cv',
- 'layer12': 'cv',
- 'layer13': 'cv',
- 'layer14': 'cv',
- 'layer15': 'cv',
- },
- 'c-v15': {
- 'layer0': 'cd',
- 'layer1': 'cv',
- 'layer2': 'cv',
- 'layer3': 'cv',
- 'layer4': 'cv',
- 'layer5': 'cv',
- 'layer6': 'cv',
- 'layer7': 'cv',
- 'layer8': 'cv',
- 'layer9': 'cv',
- 'layer10': 'cv',
- 'layer11': 'cv',
- 'layer12': 'cv',
- 'layer13': 'cv',
- 'layer14': 'cv',
- 'layer15': 'cv',
- },
- 'a-v15': {
- 'layer0': 'ad',
- 'layer1': 'cv',
- 'layer2': 'cv',
- 'layer3': 'cv',
- 'layer4': 'cv',
- 'layer5': 'cv',
- 'layer6': 'cv',
- 'layer7': 'cv',
- 'layer8': 'cv',
- 'layer9': 'cv',
- 'layer10': 'cv',
- 'layer11': 'cv',
- 'layer12': 'cv',
- 'layer13': 'cv',
- 'layer14': 'cv',
- 'layer15': 'cv',
- },
- 'r-v15': {
- 'layer0': 'rd',
- 'layer1': 'cv',
- 'layer2': 'cv',
- 'layer3': 'cv',
- 'layer4': 'cv',
- 'layer5': 'cv',
- 'layer6': 'cv',
- 'layer7': 'cv',
- 'layer8': 'cv',
- 'layer9': 'cv',
- 'layer10': 'cv',
- 'layer11': 'cv',
- 'layer12': 'cv',
- 'layer13': 'cv',
- 'layer14': 'cv',
- 'layer15': 'cv',
- },
- 'cvvv4': {
- 'layer0': 'cd',
- 'layer1': 'cv',
- 'layer2': 'cv',
- 'layer3': 'cv',
- 'layer4': 'cd',
- 'layer5': 'cv',
- 'layer6': 'cv',
- 'layer7': 'cv',
- 'layer8': 'cd',
- 'layer9': 'cv',
- 'layer10': 'cv',
- 'layer11': 'cv',
- 'layer12': 'cd',
- 'layer13': 'cv',
- 'layer14': 'cv',
- 'layer15': 'cv',
- },
- 'avvv4': {
- 'layer0': 'ad',
- 'layer1': 'cv',
- 'layer2': 'cv',
- 'layer3': 'cv',
- 'layer4': 'ad',
- 'layer5': 'cv',
- 'layer6': 'cv',
- 'layer7': 'cv',
- 'layer8': 'ad',
- 'layer9': 'cv',
- 'layer10': 'cv',
- 'layer11': 'cv',
- 'layer12': 'ad',
- 'layer13': 'cv',
- 'layer14': 'cv',
- 'layer15': 'cv',
- },
- 'rvvv4': {
- 'layer0': 'rd',
- 'layer1': 'cv',
- 'layer2': 'cv',
- 'layer3': 'cv',
- 'layer4': 'rd',
- 'layer5': 'cv',
- 'layer6': 'cv',
- 'layer7': 'cv',
- 'layer8': 'rd',
- 'layer9': 'cv',
- 'layer10': 'cv',
- 'layer11': 'cv',
- 'layer12': 'rd',
- 'layer13': 'cv',
- 'layer14': 'cv',
- 'layer15': 'cv',
- },
- 'cccv4': {
- 'layer0': 'cd',
- 'layer1': 'cd',
- 'layer2': 'cd',
- 'layer3': 'cv',
- 'layer4': 'cd',
- 'layer5': 'cd',
- 'layer6': 'cd',
- 'layer7': 'cv',
- 'layer8': 'cd',
- 'layer9': 'cd',
- 'layer10': 'cd',
- 'layer11': 'cv',
- 'layer12': 'cd',
- 'layer13': 'cd',
- 'layer14': 'cd',
- 'layer15': 'cv',
- },
- 'aaav4': {
- 'layer0': 'ad',
- 'layer1': 'ad',
- 'layer2': 'ad',
- 'layer3': 'cv',
- 'layer4': 'ad',
- 'layer5': 'ad',
- 'layer6': 'ad',
- 'layer7': 'cv',
- 'layer8': 'ad',
- 'layer9': 'ad',
- 'layer10': 'ad',
- 'layer11': 'cv',
- 'layer12': 'ad',
- 'layer13': 'ad',
- 'layer14': 'ad',
- 'layer15': 'cv',
- },
- 'rrrv4': {
- 'layer0': 'rd',
- 'layer1': 'rd',
- 'layer2': 'rd',
- 'layer3': 'cv',
- 'layer4': 'rd',
- 'layer5': 'rd',
- 'layer6': 'rd',
- 'layer7': 'cv',
- 'layer8': 'rd',
- 'layer9': 'rd',
- 'layer10': 'rd',
- 'layer11': 'cv',
- 'layer12': 'rd',
- 'layer13': 'rd',
- 'layer14': 'rd',
- 'layer15': 'cv',
- },
- 'c16': {
- 'layer0': 'cd',
- 'layer1': 'cd',
- 'layer2': 'cd',
- 'layer3': 'cd',
- 'layer4': 'cd',
- 'layer5': 'cd',
- 'layer6': 'cd',
- 'layer7': 'cd',
- 'layer8': 'cd',
- 'layer9': 'cd',
- 'layer10': 'cd',
- 'layer11': 'cd',
- 'layer12': 'cd',
- 'layer13': 'cd',
- 'layer14': 'cd',
- 'layer15': 'cd',
- },
- 'a16': {
- 'layer0': 'ad',
- 'layer1': 'ad',
- 'layer2': 'ad',
- 'layer3': 'ad',
- 'layer4': 'ad',
- 'layer5': 'ad',
- 'layer6': 'ad',
- 'layer7': 'ad',
- 'layer8': 'ad',
- 'layer9': 'ad',
- 'layer10': 'ad',
- 'layer11': 'ad',
- 'layer12': 'ad',
- 'layer13': 'ad',
- 'layer14': 'ad',
- 'layer15': 'ad',
- },
- 'r16': {
- 'layer0': 'rd',
- 'layer1': 'rd',
- 'layer2': 'rd',
- 'layer3': 'rd',
- 'layer4': 'rd',
- 'layer5': 'rd',
- 'layer6': 'rd',
- 'layer7': 'rd',
- 'layer8': 'rd',
- 'layer9': 'rd',
- 'layer10': 'rd',
- 'layer11': 'rd',
- 'layer12': 'rd',
- 'layer13': 'rd',
- 'layer14': 'rd',
- 'layer15': 'rd',
- },
- 'carv4': {
- 'layer0': 'cd',
- 'layer1': 'ad',
- 'layer2': 'rd',
- 'layer3': 'cv',
- 'layer4': 'cd',
- 'layer5': 'ad',
- 'layer6': 'rd',
- 'layer7': 'cv',
- 'layer8': 'cd',
- 'layer9': 'ad',
- 'layer10': 'rd',
- 'layer11': 'cv',
- 'layer12': 'cd',
- 'layer13': 'ad',
- 'layer14': 'rd',
- 'layer15': 'cv',
- },
-}
-
-
-def createConvFunc(op_type):
- assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
- if op_type == 'cv':
- return F.conv2d
-
- if op_type == 'cd':
- def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
- assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
- assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
- assert padding == dilation, 'padding for cd_conv set wrong'
-
- weights_c = weights.sum(dim=[2, 3], keepdim=True)
- yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
- y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
- return y - yc
-
- return func
- elif op_type == 'ad':
- def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
- assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
- assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
- assert padding == dilation, 'padding for ad_conv set wrong'
-
- shape = weights.shape
- weights = weights.view(shape[0], shape[1], -1)
- weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
- y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
- return y
-
- return func
- elif op_type == 'rd':
- def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
- assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
- assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
- padding = 2 * dilation
-
- shape = weights.shape
- if weights.is_cuda:
- buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
- else:
- buffer = torch.zeros(shape[0], shape[1], 5 * 5)
- weights = weights.view(shape[0], shape[1], -1)
- buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
- buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
- buffer[:, :, 12] = 0
- buffer = buffer.view(shape[0], shape[1], 5, 5)
- y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
- return y
-
- return func
- else:
- print('impossible to be here unless you force that')
- return None
-
-
-class Conv2d(nn.Module):
- def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
- bias=False):
- super(Conv2d, self).__init__()
- if in_channels % groups != 0:
- raise ValueError('in_channels must be divisible by groups')
- if out_channels % groups != 0:
- raise ValueError('out_channels must be divisible by groups')
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.stride = stride
- self.padding = padding
- self.dilation = dilation
- self.groups = groups
- self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
- if bias:
- self.bias = nn.Parameter(torch.Tensor(out_channels))
- else:
- self.register_parameter('bias', None)
- self.reset_parameters()
- self.pdc = pdc
-
- def reset_parameters(self):
- nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
- if self.bias is not None:
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
- bound = 1 / math.sqrt(fan_in)
- nn.init.uniform_(self.bias, -bound, bound)
-
- def forward(self, input):
-
- return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
-
-
-class CSAM(nn.Module):
- """
- Compact Spatial Attention Module
- """
-
- def __init__(self, channels):
- super(CSAM, self).__init__()
-
- mid_channels = 4
- self.relu1 = nn.ReLU()
- self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
- self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
- self.sigmoid = nn.Sigmoid()
- nn.init.constant_(self.conv1.bias, 0)
-
- def forward(self, x):
- y = self.relu1(x)
- y = self.conv1(y)
- y = self.conv2(y)
- y = self.sigmoid(y)
-
- return x * y
-
-
-class CDCM(nn.Module):
- """
- Compact Dilation Convolution based Module
- """
-
- def __init__(self, in_channels, out_channels):
- super(CDCM, self).__init__()
-
- self.relu1 = nn.ReLU()
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
- self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
- self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
- self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
- self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
- nn.init.constant_(self.conv1.bias, 0)
-
- def forward(self, x):
- x = self.relu1(x)
- x = self.conv1(x)
- x1 = self.conv2_1(x)
- x2 = self.conv2_2(x)
- x3 = self.conv2_3(x)
- x4 = self.conv2_4(x)
- return x1 + x2 + x3 + x4
-
-
-class MapReduce(nn.Module):
- """
- Reduce feature maps into a single edge map
- """
-
- def __init__(self, channels):
- super(MapReduce, self).__init__()
- self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
- nn.init.constant_(self.conv.bias, 0)
-
- def forward(self, x):
- return self.conv(x)
-
-
-class PDCBlock(nn.Module):
- def __init__(self, pdc, inplane, ouplane, stride=1):
- super(PDCBlock, self).__init__()
- self.stride = stride
-
- self.stride = stride
- if self.stride > 1:
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
- self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
- self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
- self.relu2 = nn.ReLU()
- self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
-
- def forward(self, x):
- if self.stride > 1:
- x = self.pool(x)
- y = self.conv1(x)
- y = self.relu2(y)
- y = self.conv2(y)
- if self.stride > 1:
- x = self.shortcut(x)
- y = y + x
- return y
-
-
-class PDCBlock_converted(nn.Module):
- """
- CPDC, APDC can be converted to vanilla 3x3 convolution
- RPDC can be converted to vanilla 5x5 convolution
- """
-
- def __init__(self, pdc, inplane, ouplane, stride=1):
- super(PDCBlock_converted, self).__init__()
- self.stride = stride
-
- if self.stride > 1:
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
- self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
- if pdc == 'rd':
- self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False)
- else:
- self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
- self.relu2 = nn.ReLU()
- self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
-
- def forward(self, x):
- if self.stride > 1:
- x = self.pool(x)
- y = self.conv1(x)
- y = self.relu2(y)
- y = self.conv2(y)
- if self.stride > 1:
- x = self.shortcut(x)
- y = y + x
- return y
-
-
-class PiDiNet(nn.Module):
- def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
- super(PiDiNet, self).__init__()
- self.sa = sa
- if dil is not None:
- assert isinstance(dil, int), 'dil should be an int'
- self.dil = dil
-
- self.fuseplanes = []
-
- self.inplane = inplane
- if convert:
- if pdcs[0] == 'rd':
- init_kernel_size = 5
- init_padding = 2
- else:
- init_kernel_size = 3
- init_padding = 1
- self.init_block = nn.Conv2d(3, self.inplane,
- kernel_size=init_kernel_size, padding=init_padding, bias=False)
- block_class = PDCBlock_converted
- else:
- self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
- block_class = PDCBlock
-
- self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
- self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
- self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
- self.fuseplanes.append(self.inplane) # C
-
- inplane = self.inplane
- self.inplane = self.inplane * 2
- self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
- self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
- self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
- self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
- self.fuseplanes.append(self.inplane) # 2C
-
- inplane = self.inplane
- self.inplane = self.inplane * 2
- self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
- self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
- self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
- self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
- self.fuseplanes.append(self.inplane) # 4C
-
- self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2)
- self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
- self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
- self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
- self.fuseplanes.append(self.inplane) # 4C
-
- self.conv_reduces = nn.ModuleList()
- if self.sa and self.dil is not None:
- self.attentions = nn.ModuleList()
- self.dilations = nn.ModuleList()
- for i in range(4):
- self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
- self.attentions.append(CSAM(self.dil))
- self.conv_reduces.append(MapReduce(self.dil))
- elif self.sa:
- self.attentions = nn.ModuleList()
- for i in range(4):
- self.attentions.append(CSAM(self.fuseplanes[i]))
- self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
- elif self.dil is not None:
- self.dilations = nn.ModuleList()
- for i in range(4):
- self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
- self.conv_reduces.append(MapReduce(self.dil))
- else:
- for i in range(4):
- self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
-
- self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias
- nn.init.constant_(self.classifier.weight, 0.25)
- nn.init.constant_(self.classifier.bias, 0)
-
- # print('initialization done')
-
- def get_weights(self):
- conv_weights = []
- bn_weights = []
- relu_weights = []
- for pname, p in self.named_parameters():
- if 'bn' in pname:
- bn_weights.append(p)
- elif 'relu' in pname:
- relu_weights.append(p)
- else:
- conv_weights.append(p)
-
- return conv_weights, bn_weights, relu_weights
-
- def forward(self, x):
- H, W = x.size()[2:]
-
- x = self.init_block(x)
-
- x1 = self.block1_1(x)
- x1 = self.block1_2(x1)
- x1 = self.block1_3(x1)
-
- x2 = self.block2_1(x1)
- x2 = self.block2_2(x2)
- x2 = self.block2_3(x2)
- x2 = self.block2_4(x2)
-
- x3 = self.block3_1(x2)
- x3 = self.block3_2(x3)
- x3 = self.block3_3(x3)
- x3 = self.block3_4(x3)
-
- x4 = self.block4_1(x3)
- x4 = self.block4_2(x4)
- x4 = self.block4_3(x4)
- x4 = self.block4_4(x4)
-
- x_fuses = []
- if self.sa and self.dil is not None:
- for i, xi in enumerate([x1, x2, x3, x4]):
- x_fuses.append(self.attentions[i](self.dilations[i](xi)))
- elif self.sa:
- for i, xi in enumerate([x1, x2, x3, x4]):
- x_fuses.append(self.attentions[i](xi))
- elif self.dil is not None:
- for i, xi in enumerate([x1, x2, x3, x4]):
- x_fuses.append(self.dilations[i](xi))
- else:
- x_fuses = [x1, x2, x3, x4]
-
- e1 = self.conv_reduces[0](x_fuses[0])
- e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)
-
- e2 = self.conv_reduces[1](x_fuses[1])
- e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)
-
- e3 = self.conv_reduces[2](x_fuses[2])
- e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)
-
- e4 = self.conv_reduces[3](x_fuses[3])
- e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)
-
- outputs = [e1, e2, e3, e4]
-
- output = self.classifier(torch.cat(outputs, dim=1))
- # if not self.training:
- # return torch.sigmoid(output)
-
- outputs.append(output)
- outputs = [torch.sigmoid(r) for r in outputs]
- return outputs
-
-
-def config_model(model):
- model_options = list(nets.keys())
- assert model in model_options, \
- 'unrecognized model, please choose from %s' % str(model_options)
-
- # print(str(nets[model]))
-
- pdcs = []
- for i in range(16):
- layer_name = 'layer%d' % i
- op = nets[model][layer_name]
- pdcs.append(createConvFunc(op))
-
- return pdcs
-
-
-def pidinet():
- pdcs = config_model('carv4')
- dil = 24 # if args.dil else None
- return PiDiNet(60, pdcs, dil=dil, sa=True)
diff --git a/modules/cnet_modules/pidinet/util.py b/modules/cnet_modules/pidinet/util.py
deleted file mode 100644
index aec00770c7706f95abf3a0b9b02dbe3232930596..0000000000000000000000000000000000000000
--- a/modules/cnet_modules/pidinet/util.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import random
-
-import numpy as np
-import cv2
-import os
-
-annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
-
-
-def HWC3(x):
- assert x.dtype == np.uint8
- if x.ndim == 2:
- x = x[:, :, None]
- assert x.ndim == 3
- H, W, C = x.shape
- assert C == 1 or C == 3 or C == 4
- if C == 3:
- return x
- if C == 1:
- return np.concatenate([x, x, x], axis=2)
- if C == 4:
- color = x[:, :, 0:3].astype(np.float32)
- alpha = x[:, :, 3:4].astype(np.float32) / 255.0
- y = color * alpha + 255.0 * (1.0 - alpha)
- y = y.clip(0, 255).astype(np.uint8)
- return y
-
-
-def resize_image(input_image, resolution):
- H, W, C = input_image.shape
- H = float(H)
- W = float(W)
- k = float(resolution) / min(H, W)
- H *= k
- W *= k
- H = int(np.round(H / 64.0)) * 64
- W = int(np.round(W / 64.0)) * 64
- img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
- return img
-
-
-def nms(x, t, s):
- x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
-
- f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
- f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
- f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
- f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
-
- y = np.zeros_like(x)
-
- for f in [f1, f2, f3, f4]:
- np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
-
- z = np.zeros_like(y, dtype=np.uint8)
- z[y > t] = 255
- return z
-
-
-def make_noise_disk(H, W, C, F):
- noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
- noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
- noise = noise[F: F + H, F: F + W]
- noise -= np.min(noise)
- noise /= np.max(noise)
- if C == 1:
- noise = noise[:, :, None]
- return noise
-
-
-def min_max_norm(x):
- x -= np.min(x)
- x /= np.maximum(np.max(x), 1e-5)
- return x
-
-
-def safe_step(x, step=2):
- y = x.astype(np.float32) * float(step + 1)
- y = y.astype(np.int32).astype(np.float32) / float(step)
- return y
-
-
-def img2mask(img, H, W, low=10, high=90):
- assert img.ndim == 3 or img.ndim == 2
- assert img.dtype == np.uint8
-
- if img.ndim == 3:
- y = img[:, :, random.randrange(0, img.shape[2])]
- else:
- y = img
-
- y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
-
- if random.uniform(0, 1) < 0.5:
- y = 255 - y
-
- return y < np.percentile(y, random.randrange(low, high))
diff --git a/modules/common.py b/modules/common.py
deleted file mode 100644
index 5e4ad71649f60f2dd38947c9ebc23bc51db2b544..0000000000000000000000000000000000000000
--- a/modules/common.py
+++ /dev/null
@@ -1,131 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import math
-from einops import rearrange
-import torch.fft as fft
-class Linear(torch.nn.Linear):
- def reset_parameters(self):
- return None
-
-class Conv2d(torch.nn.Conv2d):
- def reset_parameters(self):
- return None
-
-
-
-class Attention2D(nn.Module):
- def __init__(self, c, nhead, dropout=0.0):
- super().__init__()
- self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
-
- def forward(self, x, kv, self_attn=False):
- orig_shape = x.shape
- x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
- if self_attn:
- #print('in line 23 algong self att ', kv.shape, x.shape)
- kv = torch.cat([x, kv], dim=1)
- #if x.shape[1] >= 72 * 72:
- # x = x * math.sqrt(math.log(64*64, 24*24))
-
- x = self.attn(x, kv, kv, need_weights=False)[0]
- x = x.permute(0, 2, 1).view(*orig_shape)
- return x
-
-
-class LayerNorm2d(nn.LayerNorm):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- def forward(self, x):
- return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
-
-class GlobalResponseNorm(nn.Module):
- "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
- def __init__(self, dim):
- super().__init__()
- self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
- self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
-
- def forward(self, x):
- Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
- Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
- return self.gamma * (x * Nx) + self.beta + x
-
-
-class ResBlock(nn.Module):
- def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2):
- super().__init__()
- self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
- # self.depthwise = SAMBlock(c, num_heads, expansion)
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- self.channelwise = nn.Sequential(
- Linear(c + c_skip, c * 4),
- nn.GELU(),
- GlobalResponseNorm(c * 4),
- nn.Dropout(dropout),
- Linear(c * 4, c)
- )
-
- def forward(self, x, x_skip=None):
- x_res = x
- x = self.norm(self.depthwise(x))
- if x_skip is not None:
- x = torch.cat([x, x_skip], dim=1)
- x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
- return x + x_res
-
-
-class AttnBlock(nn.Module):
- def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
- super().__init__()
- self.self_attn = self_attn
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- self.attention = Attention2D(c, nhead, dropout)
- self.kv_mapper = nn.Sequential(
- nn.SiLU(),
- Linear(c_cond, c)
- )
-
- def forward(self, x, kv):
- kv = self.kv_mapper(kv)
- res = self.attention(self.norm(x), kv, self_attn=self.self_attn)
-
- #print(torch.unique(res), torch.unique(x), self.self_attn)
- #scale = math.sqrt(math.log(x.shape[-2] * x.shape[-1], 24*24))
- x = x + res
-
- return x
-
-class FeedForwardBlock(nn.Module):
- def __init__(self, c, dropout=0.0):
- super().__init__()
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- self.channelwise = nn.Sequential(
- Linear(c, c * 4),
- nn.GELU(),
- GlobalResponseNorm(c * 4),
- nn.Dropout(dropout),
- Linear(c * 4, c)
- )
-
- def forward(self, x):
- x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
- return x
-
-
-class TimestepBlock(nn.Module):
- def __init__(self, c, c_timestep, conds=['sca']):
- super().__init__()
- self.mapper = Linear(c_timestep, c * 2)
- self.conds = conds
- for cname in conds:
- setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2))
-
- def forward(self, x, t):
- t = t.chunk(len(self.conds) + 1, dim=1)
- a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
- for i, c in enumerate(self.conds):
- ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
- a, b = a + ac, b + bc
- return x * (1 + a) + b
diff --git a/modules/common_ckpt.py b/modules/common_ckpt.py
deleted file mode 100644
index f64cf11790bdd2a83ca0744629336d81464b3ed0..0000000000000000000000000000000000000000
--- a/modules/common_ckpt.py
+++ /dev/null
@@ -1,360 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import math
-from einops import rearrange
-from modules.speed_util import checkpoint
-class Linear(torch.nn.Linear):
- def reset_parameters(self):
- return None
-
-class Conv2d(torch.nn.Conv2d):
- def reset_parameters(self):
- return None
-
-class AttnBlock_lrfuse_backup(nn.Module):
- def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, use_checkpoint=True):
- super().__init__()
- self.self_attn = self_attn
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- self.attention = Attention2D(c, nhead, dropout)
- self.kv_mapper = nn.Sequential(
- nn.SiLU(),
- Linear(c_cond, c)
- )
- self.fuse_mapper = nn.Sequential(
- nn.SiLU(),
- Linear(c_cond, c)
- )
- self.use_checkpoint = use_checkpoint
-
- def forward(self, hr, lr):
- return checkpoint(self._forward, (hr, lr), self.paramters(), self.use_checkpoint)
- def _forward(self, hr, lr):
- res = hr
- hr = self.kv_mapper(rearrange(hr, 'b c h w -> b (h w ) c'))
- lr_fuse = self.attention(self.norm(lr), hr, self_attn=False) + lr
-
- lr_fuse = self.fuse_mapper(rearrange(lr_fuse, 'b c h w -> b (h w ) c'))
- hr = self.attention(self.norm(res), lr_fuse, self_attn=False) + res
- return hr
-
-
-class AttnBlock_lrfuse(nn.Module):
- def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, kernel_size=3, use_checkpoint=True):
- super().__init__()
- self.self_attn = self_attn
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- self.attention = Attention2D(c, nhead, dropout)
- self.kv_mapper = nn.Sequential(
- nn.SiLU(),
- Linear(c_cond, c)
- )
-
-
- self.depthwise = Conv2d(c, c , kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
-
- self.channelwise = nn.Sequential(
- Linear(c + c, c ),
- nn.GELU(),
- GlobalResponseNorm(c ),
- nn.Dropout(dropout),
- Linear(c , c)
- )
- self.use_checkpoint = use_checkpoint
-
-
- def forward(self, hr, lr):
- return checkpoint(self._forward, (hr, lr), self.parameters(), self.use_checkpoint)
-
- def _forward(self, hr, lr):
- res = hr
- hr = self.kv_mapper(rearrange(hr, 'b c h w -> b (h w ) c'))
- lr_fuse = self.attention(self.norm(lr), hr, self_attn=False) + lr
-
- lr_fuse = torch.nn.functional.interpolate(lr_fuse.float(), res.shape[2:])
- #print('in line 65', lr_fuse.shape, res.shape)
- media = torch.cat((self.depthwise(lr_fuse), res), dim=1)
- out = self.channelwise(media.permute(0,2,3,1)).permute(0,3,1,2) + res
-
- return out
-
-
-
-
-class Attention2D(nn.Module):
- def __init__(self, c, nhead, dropout=0.0):
- super().__init__()
- self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
-
- def forward(self, x, kv, self_attn=False):
- orig_shape = x.shape
- x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
- if self_attn:
- #print('in line 23 algong self att ', kv.shape, x.shape)
-
- kv = torch.cat([x, kv], dim=1)
- #if x.shape[1] > 48 * 48 and not self.training:
- # x = x * math.sqrt(math.log(x.shape[1] , 24*24))
-
- x = self.attn(x, kv, kv, need_weights=False)[0]
- x = x.permute(0, 2, 1).view(*orig_shape)
- return x
-class Attention2D_splitpatch(nn.Module):
- def __init__(self, c, nhead, dropout=0.0):
- super().__init__()
- self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
-
- def forward(self, x, kv, self_attn=False):
- orig_shape = x.shape
-
- #x = rearrange(x, 'b c h w -> b c (nh wh) (nw ww)', wh=24, ww=24, nh=orig_shape[-2] // 24, nh=orig_shape[-1] // 24,)
- x = rearrange(x, 'b c (nh wh) (nw ww) -> (b nh nw) (wh ww) c', wh=24, ww=24, nh=orig_shape[-2] // 24, nw=orig_shape[-1] // 24,)
- #print('in line 168', x.shape)
- #x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
- if self_attn:
- #print('in line 23 algong self att ', kv.shape, x.shape)
- num = (orig_shape[-2] // 24) * (orig_shape[-1] // 24)
- kv = torch.cat([x, kv.repeat(num, 1, 1)], dim=1)
- #if x.shape[1] > 48 * 48 and not self.training:
- # x = x * math.sqrt(math.log(x.shape[1] / math.sqrt(16), 24*24))
-
- x = self.attn(x, kv, kv, need_weights=False)[0]
- x = rearrange(x, ' (b nh nw) (wh ww) c -> b c (nh wh) (nw ww)', b=orig_shape[0], wh=24, ww=24, nh=orig_shape[-2] // 24, nw=orig_shape[-1] // 24)
- #x = x.permute(0, 2, 1).view(*orig_shape)
-
- return x
-class Attention2D_extra(nn.Module):
- def __init__(self, c, nhead, dropout=0.0):
- super().__init__()
- self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
-
- def forward(self, x, kv, extra_emb=None, self_attn=False):
- orig_shape = x.shape
- x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
- num_x = x.shape[1]
-
-
- if extra_emb is not None:
- ori_extra_shape = extra_emb.shape
- extra_emb = extra_emb.view(extra_emb.size(0), extra_emb.size(1), -1).permute(0, 2, 1)
- x = torch.cat((x, extra_emb), dim=1)
- if self_attn:
- #print('in line 23 algong self att ', kv.shape, x.shape)
- kv = torch.cat([x, kv], dim=1)
- x = self.attn(x, kv, kv, need_weights=False)[0]
- img = x[:, :num_x, :].permute(0, 2, 1).view(*orig_shape)
- if extra_emb is not None:
- fix = x[:, num_x:, :].permute(0, 2, 1).view(*ori_extra_shape)
- return img, fix
- else:
- return img
-class AttnBlock_extraq(nn.Module):
- def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
- super().__init__()
- self.self_attn = self_attn
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- #self.norm2 = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- self.attention = Attention2D_extra(c, nhead, dropout)
- self.kv_mapper = nn.Sequential(
- nn.SiLU(),
- Linear(c_cond, c)
- )
- # norm2 initialization in generator in init extra parameter
- def forward(self, x, kv, extra_emb=None):
- #print('in line 84', x.shape, kv.shape, self.self_attn, extra_emb if extra_emb is None else extra_emb.shape)
- #in line 84 torch.Size([1, 1536, 32, 32]) torch.Size([1, 85, 1536]) True None
- #if extra_emb is not None:
-
- kv = self.kv_mapper(kv)
- if extra_emb is not None:
- res_x, res_extra = self.attention(self.norm(x), kv, extra_emb=self.norm2(extra_emb), self_attn=self.self_attn)
- x = x + res_x
- extra_emb = extra_emb + res_extra
- return x, extra_emb
- else:
- x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
- return x
-class AttnBlock_latent2ex(nn.Module):
- def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
- super().__init__()
- self.self_attn = self_attn
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- self.attention = Attention2D(c, nhead, dropout)
- self.kv_mapper = nn.Sequential(
- nn.SiLU(),
- Linear(c_cond, c)
- )
-
- def forward(self, x, kv):
- #print('in line 84', x.shape, kv.shape, self.self_attn)
- kv = F.interpolate(kv.float(), x.shape[2:])
- kv = kv.view(kv.size(0), kv.size(1), -1).permute(0, 2, 1)
- kv = self.kv_mapper(kv)
- x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
- return x
-
-class LayerNorm2d(nn.LayerNorm):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- def forward(self, x):
- return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
-class AttnBlock_crossbranch(nn.Module):
- def __init__(self, attnmodule, c, c_cond, nhead, self_attn=True, dropout=0.0):
- super().__init__()
- self.attn = AttnBlock(c, c_cond, nhead, self_attn, dropout)
- #print('in line 108', attnmodule.device)
- self.attn.load_state_dict(attnmodule.state_dict())
- self.norm1 = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
-
- self.channelwise1 = nn.Sequential(
- Linear(c *2, c ),
- nn.GELU(),
- GlobalResponseNorm(c ),
- nn.Dropout(dropout),
- Linear(c, c)
- )
- self.channelwise2 = nn.Sequential(
- Linear(c *2, c ),
- nn.GELU(),
- GlobalResponseNorm(c ),
- nn.Dropout(dropout),
- Linear(c, c)
- )
- self.c = c
- def forward(self, x, kv, main_x):
- #print('in line 84', x.shape, kv.shape, main_x.shape, self.c)
-
- x = self.channelwise1(torch.cat((x, F.interpolate(main_x.float(), x.shape[2:])), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x
- x = self.attn(x, kv)
- main_x = self.channelwise2(torch.cat((main_x, F.interpolate(x.float(), main_x.shape[2:])), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + main_x
- return main_x, x
-
-class GlobalResponseNorm(nn.Module):
- "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
- def __init__(self, dim):
- super().__init__()
- self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
- self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
-
- def forward(self, x):
- Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
- Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
- return self.gamma * (x * Nx) + self.beta + x
-
-
-class ResBlock(nn.Module):
- def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, use_checkpoint =True): # , num_heads=4, expansion=2):
- super().__init__()
- self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
- # self.depthwise = SAMBlock(c, num_heads, expansion)
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- self.channelwise = nn.Sequential(
- Linear(c + c_skip, c * 4),
- nn.GELU(),
- GlobalResponseNorm(c * 4),
- nn.Dropout(dropout),
- Linear(c * 4, c)
- )
- self.use_checkpoint = use_checkpoint
- def forward(self, x, x_skip=None):
-
- if x_skip is not None:
- return checkpoint(self._forward_skip, (x, x_skip), self.parameters(), self.use_checkpoint)
- else:
- #print('in line 298', x.shape)
- return checkpoint(self._forward_woskip, (x, ), self.parameters(), self.use_checkpoint)
-
-
-
- def _forward_skip(self, x, x_skip):
- x_res = x
- x = self.norm(self.depthwise(x))
- if x_skip is not None:
- x = torch.cat([x, x_skip], dim=1)
- x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
- return x + x_res
- def _forward_woskip(self, x):
- x_res = x
- x = self.norm(self.depthwise(x))
-
- x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
- return x + x_res
-
-class AttnBlock(nn.Module):
- def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, use_checkpoint=True):
- super().__init__()
- self.self_attn = self_attn
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- self.attention = Attention2D(c, nhead, dropout)
- self.kv_mapper = nn.Sequential(
- nn.SiLU(),
- Linear(c_cond, c)
- )
- self.use_checkpoint = use_checkpoint
- def forward(self, x, kv):
- return checkpoint(self._forward, (x, kv), self.parameters(), self.use_checkpoint)
- def _forward(self, x, kv):
- kv = self.kv_mapper(kv)
- res = self.attention(self.norm(x), kv, self_attn=self.self_attn)
-
- #print(torch.unique(res), torch.unique(x), self.self_attn)
- #scale = math.sqrt(math.log(x.shape[-2] * x.shape[-1], 24*24))
- x = x + res
-
- return x
-class AttnBlock_mytest(nn.Module):
- def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
- super().__init__()
- self.self_attn = self_attn
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- self.attention = Attention2D(c, nhead, dropout)
- self.kv_mapper = nn.Sequential(
- nn.SiLU(),
- nn.Linear(c_cond, c)
- )
-
- def forward(self, x, kv):
- kv = self.kv_mapper(kv)
- x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
- return x
-
-class FeedForwardBlock(nn.Module):
- def __init__(self, c, dropout=0.0):
- super().__init__()
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
- self.channelwise = nn.Sequential(
- Linear(c, c * 4),
- nn.GELU(),
- GlobalResponseNorm(c * 4),
- nn.Dropout(dropout),
- Linear(c * 4, c)
- )
-
- def forward(self, x):
- x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
- return x
-
-
-class TimestepBlock(nn.Module):
- def __init__(self, c, c_timestep, conds=['sca'], use_checkpoint=True):
- super().__init__()
- self.mapper = Linear(c_timestep, c * 2)
- self.conds = conds
- for cname in conds:
- setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2))
-
- self.use_checkpoint = use_checkpoint
- def forward(self, x, t):
- return checkpoint(self._forward, (x, t), self.parameters(), self.use_checkpoint)
-
- def _forward(self, x, t):
- #print('in line 284', x.shape, t.shape, self.conds)
- #in line 284 torch.Size([4, 2048, 19, 29]) torch.Size([4, 192]) ['sca', 'crp']
- t = t.chunk(len(self.conds) + 1, dim=1)
- a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
- for i, c in enumerate(self.conds):
- ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
- a, b = a + ac, b + bc
- return x * (1 + a) + b
diff --git a/modules/controlnet.py b/modules/controlnet.py
deleted file mode 100644
index c187aecb725e00e19924ae308e3aac401acfdf06..0000000000000000000000000000000000000000
--- a/modules/controlnet.py
+++ /dev/null
@@ -1,349 +0,0 @@
-import torchvision
-import torch
-from torch import nn
-import numpy as np
-import kornia
-import cv2
-from core.utils import load_or_fail
-#from insightface.app.common import Face
-from .effnet import EfficientNetEncoder
-from .cnet_modules.pidinet import PidiNetDetector
-from .cnet_modules.inpainting.saliency_model import MicroResNet
-#from .cnet_modules.face_id.arcface import FaceDetector, ArcFaceRecognizer
-from .common import LayerNorm2d
-
-
-class CNetResBlock(nn.Module):
- def __init__(self, c):
- super().__init__()
- self.blocks = nn.Sequential(
- LayerNorm2d(c),
- nn.GELU(),
- nn.Conv2d(c, c, kernel_size=3, padding=1),
- LayerNorm2d(c),
- nn.GELU(),
- nn.Conv2d(c, c, kernel_size=3, padding=1),
- )
-
- def forward(self, x):
- return x + self.blocks(x)
-
-
-class ControlNet(nn.Module):
- def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None):
- super().__init__()
- if bottleneck_mode is None:
- bottleneck_mode = 'effnet'
- self.proj_blocks = proj_blocks
- if bottleneck_mode == 'effnet':
- embd_channels = 1280
- #self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval()
- self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
- if c_in != 3:
- in_weights = self.backbone[0][0].weight.data
- self.backbone[0][0] = nn.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False)
- if c_in > 3:
- nn.init.constant_(self.backbone[0][0].weight, 0)
- self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
- else:
- self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
- elif bottleneck_mode == 'simple':
- embd_channels = c_in
- self.backbone = nn.Sequential(
- nn.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1),
- )
- elif bottleneck_mode == 'large':
- self.backbone = nn.Sequential(
- nn.Conv2d(c_in, 4096 * 4, kernel_size=1),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(4096 * 4, 1024, kernel_size=1),
- *[CNetResBlock(1024) for _ in range(8)],
- nn.Conv2d(1024, 1280, kernel_size=1),
- )
- embd_channels = 1280
- else:
- raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
- self.projections = nn.ModuleList()
- for _ in range(len(proj_blocks)):
- self.projections.append(nn.Sequential(
- nn.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False),
- ))
- nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
-
- def forward(self, x):
- x = self.backbone(x)
- proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
- for i, idx in enumerate(self.proj_blocks):
- proj_outputs[idx] = self.projections[i](x)
- return proj_outputs
-
-
-class ControlNetDeliverer():
- def __init__(self, controlnet_projections):
- self.controlnet_projections = controlnet_projections
- self.restart()
-
- def restart(self):
- self.idx = 0
- return self
-
- def __call__(self):
- if self.idx < len(self.controlnet_projections):
- output = self.controlnet_projections[self.idx]
- else:
- output = None
- self.idx += 1
- return output
-
-
-# CONTROLNET FILTERS ----------------------------------------------------
-
-class BaseFilter():
- def __init__(self, device):
- self.device = device
-
- def num_channels(self):
- return 3
-
- def __call__(self, x):
- return x
-
-
-class CannyFilter(BaseFilter):
- def __init__(self, device, resize=224):
- super().__init__(device)
- self.resize = resize
-
- def num_channels(self):
- return 1
-
- def __call__(self, x):
- orig_size = x.shape[-2:]
- if self.resize is not None:
- x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear')
- edges = [cv2.Canny(x[i].mul(255).permute(1, 2, 0).cpu().numpy().astype(np.uint8), 100, 200) for i in range(len(x))]
- edges = torch.stack([torch.tensor(e).div(255).unsqueeze(0) for e in edges], dim=0)
- if self.resize is not None:
- edges = nn.functional.interpolate(edges, size=orig_size, mode='bilinear')
- return edges
-
-
-class QRFilter(BaseFilter):
- def __init__(self, device, resize=224, blobify=True, dilation_kernels=[3, 5, 7], blur_kernels=[15]):
- super().__init__(device)
- self.resize = resize
- self.blobify = blobify
- self.dilation_kernels = dilation_kernels
- self.blur_kernels = blur_kernels
-
- def num_channels(self):
- return 1
-
- def __call__(self, x):
- x = x.to(self.device)
- orig_size = x.shape[-2:]
- if self.resize is not None:
- x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear')
-
- x = kornia.color.rgb_to_hsv(x)[:, -1:]
- # blobify
- if self.blobify:
- d_kernel = np.random.choice(self.dilation_kernels)
- d_blur = np.random.choice(self.blur_kernels)
- if d_blur > 0:
- x = torchvision.transforms.GaussianBlur(d_blur)(x)
- if d_kernel > 0:
- blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5,
- d_kernel).pow(2)[:,
- None]) < 0.3).float().to(self.device)
- x = kornia.morphology.dilation(x, blob_mask)
- x = kornia.morphology.erosion(x, blob_mask)
- # mask
- vmax, vmin = x.amax(dim=[2, 3], keepdim=True)[0], x.amin(dim=[2, 3], keepdim=True)[0]
- th = (vmax - vmin) * 0.33
- high_brightness, low_brightness = (x > (vmax - th)).float(), (x < (vmin + th)).float()
- mask = (torch.ones_like(x) - low_brightness + high_brightness) * 0.5
-
- if self.resize is not None:
- mask = nn.functional.interpolate(mask, size=orig_size, mode='bilinear')
- return mask.cpu()
-
-
-class PidiFilter(BaseFilter):
- def __init__(self, device, resize=224, dilation_kernels=[0, 3, 5, 7, 9], binarize=True):
- super().__init__(device)
- self.resize = resize
- self.model = PidiNetDetector(device)
- self.dilation_kernels = dilation_kernels
- self.binarize = binarize
-
- def num_channels(self):
- return 1
-
- def __call__(self, x):
- x = x.to(self.device)
- orig_size = x.shape[-2:]
- if self.resize is not None:
- x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear')
-
- x = self.model(x)
- d_kernel = np.random.choice(self.dilation_kernels)
- if d_kernel > 0:
- blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, d_kernel).pow(2)[
- :, None]) < 0.3).float().to(self.device)
- x = kornia.morphology.dilation(x, blob_mask)
- if self.binarize:
- th = np.random.uniform(0.05, 0.7)
- x = (x > th).float()
-
- if self.resize is not None:
- x = nn.functional.interpolate(x, size=orig_size, mode='bilinear')
- return x.cpu()
-
-
-class SRFilter(BaseFilter):
- def __init__(self, device, scale_factor=1 / 4):
- super().__init__(device)
- self.scale_factor = scale_factor
-
- def num_channels(self):
- return 3
-
- def __call__(self, x):
- x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest")
- return torch.nn.functional.interpolate(x, scale_factor=1 / self.scale_factor, mode="nearest")
-
-
-class SREffnetFilter(BaseFilter):
- def __init__(self, device, scale_factor=1/2):
- super().__init__(device)
- self.scale_factor = scale_factor
-
- self.effnet_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Normalize(
- mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
- )
- ])
-
- self.effnet = EfficientNetEncoder().to(self.device)
- effnet_checkpoint = load_or_fail("models/effnet_encoder.safetensors")
- self.effnet.load_state_dict(effnet_checkpoint)
- self.effnet.eval().requires_grad_(False)
-
- def num_channels(self):
- return 16
-
- def __call__(self, x):
- x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest")
- with torch.no_grad():
- effnet_embedding = self.effnet(self.effnet_preprocess(x.to(self.device))).cpu()
- effnet_embedding = torch.nn.functional.interpolate(effnet_embedding, scale_factor=1/self.scale_factor, mode="nearest")
- upscaled_image = torch.nn.functional.interpolate(x, scale_factor=1/self.scale_factor, mode="nearest")
- return effnet_embedding, upscaled_image
-
-
-class InpaintFilter(BaseFilter):
- def __init__(self, device, thresold=[0.04, 0.4], p_outpaint=0.4):
- super().__init__(device)
- self.saliency_model = MicroResNet().eval().requires_grad_(False).to(device)
- self.saliency_model.load_state_dict(load_or_fail("modules/cnet_modules/inpainting/saliency_model.pt"))
- self.thresold = thresold
- self.p_outpaint = p_outpaint
-
- def num_channels(self):
- return 4
-
- def __call__(self, x, mask=None, threshold=None, outpaint=None):
- x = x.to(self.device)
- resized_x = torchvision.transforms.functional.resize(x, 240, antialias=True)
- if threshold is None:
- threshold = np.random.uniform(self.thresold[0], self.thresold[1])
- if mask is None:
- saliency_map = self.saliency_model(resized_x) > threshold
- if outpaint is None:
- if np.random.rand() < self.p_outpaint:
- saliency_map = ~saliency_map
- else:
- if outpaint:
- saliency_map = ~saliency_map
- interpolated_saliency_map = torch.nn.functional.interpolate(saliency_map.float(), size=x.shape[2:], mode="nearest")
- saliency_map = torchvision.transforms.functional.gaussian_blur(interpolated_saliency_map, 141) > 0.5
- inpainted_images = torch.where(saliency_map, torch.ones_like(x), x)
- mask = torch.nn.functional.interpolate(saliency_map.float(), size=inpainted_images.shape[2:], mode="nearest")
- else:
- mask = mask.to(self.device)
- inpainted_images = torch.where(mask, torch.ones_like(x), x)
- c_inpaint = torch.cat([inpainted_images, mask], dim=1)
- return c_inpaint.cpu()
-
-
-# IDENTITY
-'''
-class IdentityFilter(BaseFilter):
- def __init__(self, device, max_faces=4, p_drop=0.05, p_full=0.3):
- detector_path = 'modules/cnet_modules/face_id/models/buffalo_l/det_10g.onnx'
- recognizer_path = 'modules/cnet_modules/face_id/models/buffalo_l/w600k_r50.onnx'
-
- super().__init__(device)
- self.max_faces = max_faces
- self.p_drop = p_drop
- self.p_full = p_full
-
- self.detector = FaceDetector(detector_path, device=device)
- self.recognizer = ArcFaceRecognizer(recognizer_path, device=device)
-
- self.id_colors = torch.tensor([
- [1.0, 0.0, 0.0], # RED
- [0.0, 1.0, 0.0], # GREEN
- [0.0, 0.0, 1.0], # BLUE
- [1.0, 0.0, 1.0], # PURPLE
- [0.0, 1.0, 1.0], # CYAN
- [1.0, 1.0, 0.0], # YELLOW
- [0.5, 0.0, 0.0], # DARK RED
- [0.0, 0.5, 0.0], # DARK GREEN
- [0.0, 0.0, 0.5], # DARK BLUE
- [0.5, 0.0, 0.5], # DARK PURPLE
- [0.0, 0.5, 0.5], # DARK CYAN
- [0.5, 0.5, 0.0], # DARK YELLOW
- ])
-
- def num_channels(self):
- return 512
-
- def get_faces(self, image):
- npimg = image.permute(1, 2, 0).mul(255).to(device="cpu", dtype=torch.uint8).cpu().numpy()
- bgr = cv2.cvtColor(npimg, cv2.COLOR_RGB2BGR)
- bboxes, kpss = self.detector.detect(bgr, max_num=self.max_faces)
- N = len(bboxes)
- ids = torch.zeros((N, 512), dtype=torch.float32)
- for i in range(N):
- face = Face(bbox=bboxes[i, :4], kps=kpss[i], det_score=bboxes[i, 4])
- ids[i, :] = self.recognizer.get(bgr, face)
- tbboxes = torch.tensor(bboxes[:, :4], dtype=torch.int)
-
- ids = ids / torch.linalg.norm(ids, dim=1, keepdim=True)
- return tbboxes, ids # returns bounding boxes (N x 4) and ID vectors (N x 512)
-
- def __call__(self, x):
- visual_aid = x.clone().cpu()
- face_mtx = torch.zeros(x.size(0), 512, x.size(-2) // 32, x.size(-1) // 32)
-
- for i in range(x.size(0)):
- bounding_boxes, ids = self.get_faces(x[i])
- for j in range(bounding_boxes.size(0)):
- if np.random.rand() > self.p_drop:
- sx, sy, ex, ey = (bounding_boxes[j] / 32).clamp(min=0).round().int().tolist()
- ex, ey = max(ex, sx + 1), max(ey, sy + 1)
- if bounding_boxes.size(0) == 1 and np.random.rand() < self.p_full:
- sx, sy, ex, ey = 0, 0, x.size(-1) // 32, x.size(-2) // 32
- face_mtx[i, :, sy:ey, sx:ex] = ids[j:j + 1, :, None, None]
- visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] += self.id_colors[j % 13, :,
- None, None]
- visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] *= 0.5
-
- return face_mtx.to(x.device), visual_aid.to(x.device)
-'''
diff --git a/modules/effnet.py b/modules/effnet.py
deleted file mode 100644
index 0eb2690c2547c8c7553aec8a9f9e838241f8f61c..0000000000000000000000000000000000000000
--- a/modules/effnet.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import torchvision
-from torch import nn
-
-
-# EfficientNet
-class EfficientNetEncoder(nn.Module):
- def __init__(self, c_latent=16):
- super().__init__()
- self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
- self.mapper = nn.Sequential(
- nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
- nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
- )
-
- def forward(self, x):
- return self.mapper(self.backbone(x))
-
diff --git a/modules/inr_fea_res_lite.py b/modules/inr_fea_res_lite.py
deleted file mode 100644
index 41ddfb09937f26e2c7d0193b4a65607efabde5e5..0000000000000000000000000000000000000000
--- a/modules/inr_fea_res_lite.py
+++ /dev/null
@@ -1,435 +0,0 @@
-import math
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import einops
-import numpy as np
-import models
-from modules.common_ckpt import Linear, Conv2d, AttnBlock, ResBlock, LayerNorm2d
-#from modules.common_ckpt import AttnBlock,
-from einops import rearrange
-import torch.fft as fft
-from modules.speed_util import checkpoint
-def batched_linear_mm(x, wb):
- # x: (B, N, D1); wb: (B, D1 + 1, D2) or (D1 + 1, D2)
- one = torch.ones(*x.shape[:-1], 1, device=x.device)
- return torch.matmul(torch.cat([x, one], dim=-1), wb)
-def make_coord_grid(shape, range, device=None):
- """
- Args:
- shape: tuple
- range: [minv, maxv] or [[minv_1, maxv_1], ..., [minv_d, maxv_d]] for each dim
- Returns:
- grid: shape (*shape, )
- """
- l_lst = []
- for i, s in enumerate(shape):
- l = (0.5 + torch.arange(s, device=device)) / s
- if isinstance(range[0], list) or isinstance(range[0], tuple):
- minv, maxv = range[i]
- else:
- minv, maxv = range
- l = minv + (maxv - minv) * l
- l_lst.append(l)
- grid = torch.meshgrid(*l_lst, indexing='ij')
- grid = torch.stack(grid, dim=-1)
- return grid
-def init_wb(shape):
- weight = torch.empty(shape[1], shape[0] - 1)
- nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
-
- bias = torch.empty(shape[1], 1)
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
- nn.init.uniform_(bias, -bound, bound)
-
- return torch.cat([weight, bias], dim=1).t().detach()
-
-def init_wb_rewrite(shape):
- weight = torch.empty(shape[1], shape[0] - 1)
-
- torch.nn.init.xavier_uniform_(weight)
-
- bias = torch.empty(shape[1], 1)
- torch.nn.init.xavier_uniform_(bias)
-
-
- return torch.cat([weight, bias], dim=1).t().detach()
-class HypoMlp(nn.Module):
-
- def __init__(self, depth, in_dim, out_dim, hidden_dim, use_pe, pe_dim, out_bias=0, pe_sigma=1024):
- super().__init__()
- self.use_pe = use_pe
- self.pe_dim = pe_dim
- self.pe_sigma = pe_sigma
- self.depth = depth
- self.param_shapes = dict()
- if use_pe:
- last_dim = in_dim * pe_dim
- else:
- last_dim = in_dim
- for i in range(depth): # for each layer the weight
- cur_dim = hidden_dim if i < depth - 1 else out_dim
- self.param_shapes[f'wb{i}'] = (last_dim + 1, cur_dim)
- last_dim = cur_dim
- self.relu = nn.ReLU()
- self.params = None
- self.out_bias = out_bias
-
- def set_params(self, params):
- self.params = params
-
- def convert_posenc(self, x):
- w = torch.exp(torch.linspace(0, np.log(self.pe_sigma), self.pe_dim // 2, device=x.device))
- x = torch.matmul(x.unsqueeze(-1), w.unsqueeze(0)).view(*x.shape[:-1], -1)
- x = torch.cat([torch.cos(np.pi * x), torch.sin(np.pi * x)], dim=-1)
- return x
-
- def forward(self, x):
- B, query_shape = x.shape[0], x.shape[1: -1]
- x = x.view(B, -1, x.shape[-1])
- if self.use_pe:
- x = self.convert_posenc(x)
- #print('in line 79 after pos embedding', x.shape)
- for i in range(self.depth):
- x = batched_linear_mm(x, self.params[f'wb{i}'])
- if i < self.depth - 1:
- x = self.relu(x)
- else:
- x = x + self.out_bias
- x = x.view(B, *query_shape, -1)
- return x
-
-
-
-class Attention(nn.Module):
-
- def __init__(self, dim, n_head, head_dim, dropout=0.):
- super().__init__()
- self.n_head = n_head
- inner_dim = n_head * head_dim
- self.to_q = nn.Sequential(
- nn.SiLU(),
- Linear(dim, inner_dim ))
- self.to_kv = nn.Sequential(
- nn.SiLU(),
- Linear(dim, inner_dim * 2))
- self.scale = head_dim ** -0.5
- # self.to_out = nn.Sequential(
- # Linear(inner_dim, dim),
- # nn.Dropout(dropout),
- # )
-
- def forward(self, fr, to=None):
- if to is None:
- to = fr
- q = self.to_q(fr)
- k, v = self.to_kv(to).chunk(2, dim=-1)
- q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), [q, k, v])
-
- dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
- attn = F.softmax(dots, dim=-1) # b h n n
- out = torch.matmul(attn, v)
- out = einops.rearrange(out, 'b h n d -> b n (h d)')
- return out
-
-
-class FeedForward(nn.Module):
-
- def __init__(self, dim, ff_dim, dropout=0.):
- super().__init__()
-
- self.net = nn.Sequential(
- Linear(dim, ff_dim),
- nn.GELU(),
- #GlobalResponseNorm(ff_dim),
- nn.Dropout(dropout),
- Linear(ff_dim, dim)
- )
-
- def forward(self, x):
- return self.net(x)
-
-
-class PreNorm(nn.Module):
-
- def __init__(self, dim, fn):
- super().__init__()
- self.norm = nn.LayerNorm(dim)
- self.fn = fn
-
- def forward(self, x):
- return self.fn(self.norm(x))
-
-
-#TransInr(ind=2048, ch=256, n_head=16, head_dim=16, n_groups=64, f_dim=256, time_dim=self.c_r, t_conds = [])
-class TransformerEncoder(nn.Module):
-
- def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.):
- super().__init__()
- self.layers = nn.ModuleList()
- for _ in range(depth):
- self.layers.append(nn.ModuleList([
- PreNorm(dim, Attention(dim, n_head, head_dim, dropout=dropout)),
- PreNorm(dim, FeedForward(dim, ff_dim, dropout=dropout)),
- ]))
-
- def forward(self, x):
- for norm_attn, norm_ff in self.layers:
- x = x + norm_attn(x)
- x = x + norm_ff(x)
- return x
-class ImgrecTokenizer(nn.Module):
-
- def __init__(self, input_size=32*32, patch_size=1, dim=768, padding=0, img_channels=16):
- super().__init__()
-
- if isinstance(patch_size, int):
- patch_size = (patch_size, patch_size)
- if isinstance(padding, int):
- padding = (padding, padding)
- self.patch_size = patch_size
- self.padding = padding
- self.prefc = nn.Linear(patch_size[0] * patch_size[1] * img_channels, dim)
-
- self.posemb = nn.Parameter(torch.randn(input_size, dim))
-
- def forward(self, x):
- #print(x.shape)
- p = self.patch_size
- x = F.unfold(x, p, stride=p, padding=self.padding) # (B, C * p * p, L)
- #print('in line 185 after unfoding', x.shape)
- x = x.permute(0, 2, 1).contiguous()
- ttt = self.prefc(x)
-
- x = self.prefc(x) + self.posemb[:x.shape[1]].unsqueeze(0)
- return x
-
-class SpatialAttention(nn.Module):
- def __init__(self, kernel_size=7):
- super(SpatialAttention, self).__init__()
-
- self.conv1 = Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
- self.sigmoid = nn.Sigmoid()
-
- def forward(self, x):
- avg_out = torch.mean(x, dim=1, keepdim=True)
- max_out, _ = torch.max(x, dim=1, keepdim=True)
- x = torch.cat([avg_out, max_out], dim=1)
- x = self.conv1(x)
- return self.sigmoid(x)
-
-class TimestepBlock_res(nn.Module):
- def __init__(self, c, c_timestep, conds=['sca']):
- super().__init__()
-
- self.mapper = Linear(c_timestep, c * 2)
- self.conds = conds
- for cname in conds:
- setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2))
-
-
-
-
- def forward(self, x, t):
- #print(x.shape, t.shape, self.conds, 'in line 269')
- t = t.chunk(len(self.conds) + 1, dim=1)
- a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
-
- for i, c in enumerate(self.conds):
- ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
- a, b = a + ac, b + bc
- return x * (1 + a) + b
-
-def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
-
-
-
-class ScaleNormalize_res(nn.Module):
- def __init__(self, c, scale_c, conds=['sca']):
- super().__init__()
- self.c_r = scale_c
- self.mapping = TimestepBlock_res(c, scale_c, conds=conds)
- self.t_conds = conds
- self.alpha = nn.Conv2d(c, c, kernel_size=1)
- self.gamma = nn.Conv2d(c, c, kernel_size=1)
- self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
-
-
- def gen_r_embedding(self, r, max_positions=10000):
- r = r * max_positions
- half_dim = self.c_r // 2
- emb = math.log(max_positions) / (half_dim - 1)
- emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
- emb = r[:, None] * emb[None, :]
- emb = torch.cat([emb.sin(), emb.cos()], dim=1)
- if self.c_r % 2 == 1: # zero pad
- emb = nn.functional.pad(emb, (0, 1), mode='constant')
- return emb
- def forward(self, x, std_size=24*24):
- scale_val = math.sqrt(math.log(x.shape[-2] * x.shape[-1], std_size))
- scale_val = torch.ones(x.shape[0]).to(x.device)*scale_val
- scale_val_f = self.gen_r_embedding(scale_val)
- for c in self.t_conds:
- t_cond = torch.zeros_like(scale_val)
- scale_val_f = torch.cat([scale_val_f, self.gen_r_embedding(t_cond)], dim=1)
-
- f = self.mapping(x, scale_val_f)
-
- return f + x
-
-
-class TransInr_withnorm(nn.Module):
-
- def __init__(self, ind=2048, ch=16, n_head=12, head_dim=64, n_groups=64, f_dim=768, time_dim=2048, t_conds=[]):
- super().__init__()
- self.input_layer= nn.Conv2d(ind, ch, 1)
- self.tokenizer = ImgrecTokenizer(dim=ch, img_channels=ch)
- #self.hyponet = HypoMlp(depth=12, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128)
- #self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=12, n_head=n_head, head_dim=f_dim // n_head, ff_dim=3*f_dim, )
-
- self.hyponet = HypoMlp(depth=2, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128)
- self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=1, n_head=n_head, head_dim=f_dim // n_head, ff_dim=f_dim)
- #self.transformer_encoder = TransInr( ch=ch, n_head=16, head_dim=16, n_groups=64, f_dim=ch, time_dim=time_dim, t_conds = [])
- self.base_params = nn.ParameterDict()
- n_wtokens = 0
- self.wtoken_postfc = nn.ModuleDict()
- self.wtoken_rng = dict()
- for name, shape in self.hyponet.param_shapes.items():
- self.base_params[name] = nn.Parameter(init_wb(shape))
- g = min(n_groups, shape[1])
- assert shape[1] % g == 0
- self.wtoken_postfc[name] = nn.Sequential(
- nn.LayerNorm(f_dim),
- nn.Linear(f_dim, shape[0] - 1),
- )
- self.wtoken_rng[name] = (n_wtokens, n_wtokens + g)
- n_wtokens += g
- self.wtokens = nn.Parameter(torch.randn(n_wtokens, f_dim))
- self.output_layer= nn.Conv2d(ch, ind, 1)
-
-
- self.mapp_t = TimestepBlock_res( ind, time_dim, conds = t_conds)
-
-
- self.hr_norm = ScaleNormalize_res(ind, 64, conds=[])
-
- self.normalize_final = nn.Sequential(
- LayerNorm2d(ind, elementwise_affine=False, eps=1e-6),
- )
-
- self.toout = nn.Sequential(
- Linear( ind*2, ind // 4),
- nn.GELU(),
- Linear( ind // 4, ind)
- )
- self.apply(self._init_weights)
-
- mask = torch.zeros((1, 1, 32, 32))
- h, w = 32, 32
- center_h, center_w = h // 2, w // 2
- low_freq_h, low_freq_w = h // 4, w // 4
- mask[:, :, center_h-low_freq_h:center_h+low_freq_h, center_w-low_freq_w:center_w+low_freq_w] = 1
- self.mask = mask
-
-
- def _init_weights(self, m):
- if isinstance(m, (nn.Conv2d, nn.Linear)):
- torch.nn.init.xavier_uniform_(m.weight)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- #nn.init.constant_(self.last.weight, 0)
- def adain(self, feature_a, feature_b):
- norm_mean = torch.mean(feature_a, dim=(2, 3), keepdim=True)
- norm_std = torch.std(feature_a, dim=(2, 3), keepdim=True)
- #feature_a = F.interpolate(feature_a, feature_b.shape[2:])
- feature_b = (feature_b - feature_b.mean(dim=(2, 3), keepdim=True)) / (1e-8 + feature_b.std(dim=(2, 3), keepdim=True)) * norm_std + norm_mean
- return feature_b
- def forward(self, target_shape, target, dtokens, t_emb):
- #print(target.shape, dtokens.shape, 'in line 290')
- hlr, wlr = dtokens.shape[2:]
- original = dtokens
-
- dtokens = self.input_layer(dtokens)
- dtokens = self.tokenizer(dtokens)
- B = dtokens.shape[0]
- wtokens = einops.repeat(self.wtokens, 'n d -> b n d', b=B)
- #print(wtokens.shape, dtokens.shape)
- trans_out = self.transformer_encoder(torch.cat([dtokens, wtokens], dim=1))
- trans_out = trans_out[:, -len(self.wtokens):, :]
-
- params = dict()
- for name, shape in self.hyponet.param_shapes.items():
- wb = einops.repeat(self.base_params[name], 'n m -> b n m', b=B)
- w, b = wb[:, :-1, :], wb[:, -1:, :]
-
- l, r = self.wtoken_rng[name]
- x = self.wtoken_postfc[name](trans_out[:, l: r, :])
- x = x.transpose(-1, -2) # (B, shape[0] - 1, g)
- w = F.normalize(w * x.repeat(1, 1, w.shape[2] // x.shape[2]), dim=1)
-
- wb = torch.cat([w, b], dim=1)
- params[name] = wb
- coord = make_coord_grid(target_shape[2:], (-1, 1), device=dtokens.device)
- coord = einops.repeat(coord, 'h w d -> b h w d', b=dtokens.shape[0])
- self.hyponet.set_params(params)
- ori_up = F.interpolate(original.float(), target_shape[2:])
- hr_rec = self.output_layer(rearrange(self.hyponet(coord), 'b h w c -> b c h w')) + ori_up
- #print(hr_rec.shape, target.shape, torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1).shape, 'in line 537')
-
- output = self.toout(torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
- #print(output.shape, 'in line 540')
- #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)* 0.3
- output = self.mapp_t(output, t_emb)
- output = self.normalize_final(output)
- output = self.hr_norm(output)
- #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
- #output = self.mapp_t(output, t_emb)
- #output = self.weight(output) * output
-
- return output
-
-
-
-
-
-
-class LayerNorm2d(nn.LayerNorm):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- def forward(self, x):
- return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
-
-class GlobalResponseNorm(nn.Module):
- "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
- def __init__(self, dim):
- super().__init__()
- self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
- self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
-
- def forward(self, x):
- Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
- Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
- return self.gamma * (x * Nx) + self.beta + x
-
-
-
-if __name__ == '__main__':
- #ef __init__(self, ch, n_head, head_dim, n_groups):
- trans_inr = TransInr(16, 24, 32, 64).cuda()
- input = torch.randn((1, 16, 24, 24)).cuda()
- source = torch.randn((1, 16, 16, 16)).cuda()
- t = torch.randn((1, 128)).cuda()
- output, hr = trans_inr(input, t, source)
-
- total_up = sum([ param.nelement() for param in trans_inr.parameters()])
- print(output.shape, hr.shape, total_up /1e6 )
-
diff --git a/modules/lora.py b/modules/lora.py
deleted file mode 100644
index bc0a2bd797f3669a465f6c2c4255b52fe1bda7a7..0000000000000000000000000000000000000000
--- a/modules/lora.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import torch
-from torch import nn
-
-
-class LoRA(nn.Module):
- def __init__(self, layer, name='weight', rank=16, alpha=1):
- super().__init__()
- weight = getattr(layer, name)
- self.lora_down = nn.Parameter(torch.zeros((rank, weight.size(1))))
- self.lora_up = nn.Parameter(torch.zeros((weight.size(0), rank)))
- nn.init.normal_(self.lora_up, mean=0, std=1)
-
- self.scale = alpha / rank
- self.enabled = True
-
- def forward(self, original_weights):
- if self.enabled:
- lora_shape = list(original_weights.shape[:2]) + [1] * (len(original_weights.shape) - 2)
- lora_weights = torch.matmul(self.lora_up.clone(), self.lora_down.clone()).view(*lora_shape) * self.scale
- return original_weights + lora_weights
- else:
- return original_weights
-
-
-def apply_lora(model, filters=None, rank=16):
- def check_parameter(module, name):
- return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance(
- getattr(module, name), nn.Parameter)
-
- for name, module in model.named_modules():
- if filters is None or any([f in name for f in filters]):
- if check_parameter(module, "weight"):
- device, dtype = module.weight.device, module.weight.dtype
- torch.nn.utils.parametrize.register_parametrization(module, 'weight', LoRA(module, "weight", rank=rank).to(dtype).to(device))
- elif check_parameter(module, "in_proj_weight"):
- device, dtype = module.in_proj_weight.device, module.in_proj_weight.dtype
- torch.nn.utils.parametrize.register_parametrization(module, 'in_proj_weight', LoRA(module, "in_proj_weight", rank=rank).to(dtype).to(device))
-
-
-class ReToken(nn.Module):
- def __init__(self, indices=None):
- super().__init__()
- assert indices is not None
- self.embeddings = nn.Parameter(torch.zeros(len(indices), 1280))
- self.register_buffer('indices', torch.tensor(indices))
- self.enabled = True
-
- def forward(self, embeddings):
- if self.enabled:
- embeddings = embeddings.clone()
- for i, idx in enumerate(self.indices):
- embeddings[idx] += self.embeddings[i]
- return embeddings
-
-
-def apply_retoken(module, indices=None):
- def check_parameter(module, name):
- return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance(
- getattr(module, name), nn.Parameter)
-
- if check_parameter(module, "weight"):
- device, dtype = module.weight.device, module.weight.dtype
- torch.nn.utils.parametrize.register_parametrization(module, 'weight', ReToken(indices=indices).to(dtype).to(device))
-
-
-def remove_lora(model, leave_parametrized=True):
- for module in model.modules():
- if torch.nn.utils.parametrize.is_parametrized(module, "weight"):
- nn.utils.parametrize.remove_parametrizations(module, "weight", leave_parametrized=leave_parametrized)
- elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"):
- nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight", leave_parametrized=leave_parametrized)
diff --git a/modules/model_4stage_lite.py b/modules/model_4stage_lite.py
deleted file mode 100644
index e77cc5d73ccda882774f447f5a8bb86fe71fe755..0000000000000000000000000000000000000000
--- a/modules/model_4stage_lite.py
+++ /dev/null
@@ -1,458 +0,0 @@
-import torch
-from torch import nn
-import numpy as np
-import math
-from modules.common_ckpt import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock
-from .controlnet import ControlNetDeliverer
-import torch.nn.functional as F
-from modules.inr_fea_res_lite import TransInr_withnorm as TransInr
-from modules.inr_fea_res_lite import ScaleNormalize_res
-from einops import rearrange
-import torch.fft as fft
-import random
-class UpDownBlock2d(nn.Module):
- def __init__(self, c_in, c_out, mode, enabled=True):
- super().__init__()
- assert mode in ['up', 'down']
- interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
- align_corners=True) if enabled else nn.Identity()
- mapping = nn.Conv2d(c_in, c_out, kernel_size=1)
- self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x.float())
- return x
-def ada_in(a, b):
- mean_a = torch.mean(a, dim=(2, 3), keepdim=True)
- std_a = torch.std(a, dim=(2, 3), keepdim=True)
-
- mean_b = torch.mean(b, dim=(2, 3), keepdim=True)
- std_b = torch.std(b, dim=(2, 3), keepdim=True)
-
- return (b - mean_b) / (1e-8 + std_b) * std_a + mean_a
-def feature_dist_loss(x1, x2):
- mu1 = torch.mean(x1, dim=(2, 3))
- mu2 = torch.mean(x2, dim=(2, 3))
-
- std1 = torch.std(x1, dim=(2, 3))
- std2 = torch.std(x2, dim=(2, 3))
- std_loss = torch.mean(torch.abs(torch.log(std1+ 1e-8) - torch.log(std2+ 1e-8)))
- mean_loss = torch.mean(torch.abs(mu1 - mu2))
- #print('in line 36', std_loss, mean_loss)
- return std_loss + mean_loss*0.1
-class StageC(nn.Module):
- def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
- blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
- c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
- dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False],
- lr_h=24, lr_w=24):
- super().__init__()
-
- self.lr_h, self.lr_w = lr_h, lr_w
- self.block_repeat = block_repeat
- self.c_in = c_in
- self.c_cond = c_cond
- self.patch_size = patch_size
- self.c_hidden = c_hidden
- self.nhead = nhead
- self.blocks = blocks
- self.level_config = level_config
- self.kernel_size = kernel_size
- self.c_r = c_r
- self.t_conds = t_conds
- self.c_clip_seq = c_clip_seq
- if not isinstance(dropout, list):
- dropout = [dropout] * len(c_hidden)
- if not isinstance(self_attn, list):
- self_attn = [self_attn] * len(c_hidden)
- self.self_attn = self_attn
- self.dropout = dropout
- self.switch_level = switch_level
- # CONDITIONING
- self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond)
- self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq)
- self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq)
- self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
-
- self.embedding = nn.Sequential(
- nn.PixelUnshuffle(patch_size),
- nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
- LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
- )
-
- def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
- if block_type == 'C':
- return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
- elif block_type == 'A':
- return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
- elif block_type == 'F':
- return FeedForwardBlock(c_hidden, dropout=dropout)
- elif block_type == 'T':
- return TimestepBlock(c_hidden, c_r, conds=t_conds)
- else:
- raise Exception(f'Block type {block_type} not supported')
-
- # BLOCKS
- # -- down blocks
- self.down_blocks = nn.ModuleList()
- self.down_downscalers = nn.ModuleList()
- self.down_repeat_mappers = nn.ModuleList()
- for i in range(len(c_hidden)):
- if i > 0:
- self.down_downscalers.append(nn.Sequential(
- LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
- UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1])
- ))
- else:
- self.down_downscalers.append(nn.Identity())
- down_block = nn.ModuleList()
- for _ in range(blocks[0][i]):
- for block_type in level_config[i]:
- block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
- down_block.append(block)
- self.down_blocks.append(down_block)
- if block_repeat is not None:
- block_repeat_mappers = nn.ModuleList()
- for _ in range(block_repeat[0][i] - 1):
- block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
- self.down_repeat_mappers.append(block_repeat_mappers)
-
-
-
- #extra down blocks
-
-
- # -- up blocks
- self.up_blocks = nn.ModuleList()
- self.up_upscalers = nn.ModuleList()
- self.up_repeat_mappers = nn.ModuleList()
- for i in reversed(range(len(c_hidden))):
- if i > 0:
- self.up_upscalers.append(nn.Sequential(
- LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
- UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1])
- ))
- else:
- self.up_upscalers.append(nn.Identity())
- up_block = nn.ModuleList()
- for j in range(blocks[1][::-1][i]):
- for k, block_type in enumerate(level_config[i]):
- c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
- block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
- self_attn=self_attn[i])
- up_block.append(block)
- self.up_blocks.append(up_block)
- if block_repeat is not None:
- block_repeat_mappers = nn.ModuleList()
- for _ in range(block_repeat[1][::-1][i] - 1):
- block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
- self.up_repeat_mappers.append(block_repeat_mappers)
-
- # OUTPUT
- self.clf = nn.Sequential(
- LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
- nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
- nn.PixelShuffle(patch_size),
- )
-
- # --- WEIGHT INIT ---
- self.apply(self._init_weights) # General init
- nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
- nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
- nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
- torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
- nn.init.constant_(self.clf[1].weight, 0) # outputs
-
- # blocks
- for level_block in self.down_blocks + self.up_blocks:
- for block in level_block:
- if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
- block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
- elif isinstance(block, TimestepBlock):
- for layer in block.modules():
- if isinstance(layer, nn.Linear):
- nn.init.constant_(layer.weight, 0)
-
- def _init_weights(self, m):
- if isinstance(m, (nn.Conv2d, nn.Linear)):
- torch.nn.init.xavier_uniform_(m.weight)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
-
-
- def _init_extra_parameter(self):
-
-
-
- self.agg_net = nn.ModuleList()
- for _ in range(2):
-
- self.agg_net.append(TransInr(ind=2048, ch=1024, n_head=32, head_dim=32, n_groups=64, f_dim=1024, time_dim=self.c_r, t_conds = [])) #
-
- self.agg_net_up = nn.ModuleList()
- for _ in range(2):
-
- self.agg_net_up.append(TransInr(ind=2048, ch=1024, n_head=32, head_dim=32, n_groups=64, f_dim=1024, time_dim=self.c_r, t_conds = [])) #
-
-
-
-
-
- self.norm_down_blocks = nn.ModuleList()
- for i in range(len(self.c_hidden)):
-
- up_blocks = nn.ModuleList()
- for j in range(self.blocks[0][i]):
- if j % 4 == 0:
- up_blocks.append(
- ScaleNormalize_res(self.c_hidden[0], self.c_r, conds=[]))
- self.norm_down_blocks.append(up_blocks)
-
-
- self.norm_up_blocks = nn.ModuleList()
- for i in reversed(range(len(self.c_hidden))):
-
- up_block = nn.ModuleList()
- for j in range(self.blocks[1][::-1][i]):
- if j % 4 == 0:
- up_block.append(ScaleNormalize_res(self.c_hidden[0], self.c_r, conds=[]))
- self.norm_up_blocks.append(up_block)
-
-
-
-
- self.agg_net.apply(self._init_weights)
- self.agg_net_up.apply(self._init_weights)
- self.norm_up_blocks.apply(self._init_weights)
- self.norm_down_blocks.apply(self._init_weights)
- for block in self.agg_net + self.agg_net_up:
- #for block in level_block:
- if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
- block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
- elif isinstance(block, TimestepBlock):
- for layer in block.modules():
- if isinstance(layer, nn.Linear):
- nn.init.constant_(layer.weight, 0)
-
-
-
-
-
- def gen_r_embedding(self, r, max_positions=10000):
- r = r * max_positions
- half_dim = self.c_r // 2
- emb = math.log(max_positions) / (half_dim - 1)
- emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
- emb = r[:, None] * emb[None, :]
- emb = torch.cat([emb.sin(), emb.cos()], dim=1)
- if self.c_r % 2 == 1: # zero pad
- emb = nn.functional.pad(emb, (0, 1), mode='constant')
- return emb
-
- def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
- clip_txt = self.clip_txt_mapper(clip_txt)
- if len(clip_txt_pooled.shape) == 2:
- clip_txt_pool = clip_txt_pooled.unsqueeze(1)
- if len(clip_img.shape) == 2:
- clip_img = clip_img.unsqueeze(1)
- clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
- clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
- clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
- clip = self.clip_norm(clip)
- return clip
-
- def _down_encode(self, x, r_embed, clip, cnet=None, require_q=False, lr_guide=None, r_emb_lite=None, guide_weight=1):
- level_outputs = []
- if require_q:
- qs = []
- block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
- for stage_cnt, (down_block, downscaler, repmap) in enumerate(block_group):
- x = downscaler(x)
- for i in range(len(repmap) + 1):
- for inner_cnt, block in enumerate(down_block):
-
-
- if isinstance(block, ResBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- ResBlock)):
- if cnet is not None and lr_guide is None:
- #if cnet is not None :
- next_cnet = cnet()
- if next_cnet is not None:
-
- x = x + nn.functional.interpolate(next_cnet.float(), size=x.shape[-2:], mode='bilinear',
- align_corners=True)
- x = block(x)
- elif isinstance(block, AttnBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- AttnBlock)):
-
- x = block(x, clip)
- if require_q and (inner_cnt == 2 ):
- qs.append(x.clone())
- if lr_guide is not None and (inner_cnt == 2 ) :
-
- guide = self.agg_net[stage_cnt](x.shape, x, lr_guide[stage_cnt], r_emb_lite)
- x = x + guide
-
- elif isinstance(block, TimestepBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- TimestepBlock)):
- x = block(x, r_embed)
- else:
- x = block(x)
- if i < len(repmap):
- x = repmap[i](x)
- level_outputs.insert(0, x) # 0 indicate last output
- if require_q:
- return level_outputs, qs
- return level_outputs
-
-
- def _up_decode(self, level_outputs, r_embed, clip, cnet=None, require_ff=False, agg_f=None, r_emb_lite=None, guide_weight=1):
- if require_ff:
- agg_feas = []
- x = level_outputs[0]
- block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
- for i, (up_block, upscaler, repmap) in enumerate(block_group):
- for j in range(len(repmap) + 1):
- for k, block in enumerate(up_block):
-
- if isinstance(block, ResBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- ResBlock)):
- skip = level_outputs[i] if k == 0 and i > 0 else None
-
-
- if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
- x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear',
- align_corners=True)
-
- if cnet is not None and agg_f is None:
- next_cnet = cnet()
- if next_cnet is not None:
-
- x = x + nn.functional.interpolate(next_cnet.float(), size=x.shape[-2:], mode='bilinear',
- align_corners=True)
-
-
- x = block(x, skip)
- elif isinstance(block, AttnBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- AttnBlock)):
-
-
- x = block(x, clip)
- if require_ff and (k == 2 ):
- agg_feas.append(x.clone())
- if agg_f is not None and (k == 2 ) :
-
- guide = self.agg_net_up[i](x.shape, x, agg_f[i], r_emb_lite) # training 1 test 4k 0.8 2k 0.7
- if not self.training:
- hw = x.shape[-2] * x.shape[-1]
- if hw >= 96*96:
- guide = 0.7*guide
-
- else:
-
- if hw >= 72*72:
- guide = 0.5* guide
- else:
-
- guide = 0.3* guide
-
- x = x + guide
-
-
- elif isinstance(block, TimestepBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- TimestepBlock)):
- x = block(x, r_embed)
- #if require_ff:
- # agg_feas.append(x.clone())
- else:
- x = block(x)
- if j < len(repmap):
- x = repmap[j](x)
- x = upscaler(x)
-
-
- if require_ff:
- return x, agg_feas
-
- return x
-
-
-
-
- def forward(self, x, r, clip_text, clip_text_pooled, clip_img, lr_guide=None, reuire_f=False, cnet=None, require_t=False, guide_weight=0.5, **kwargs):
-
- r_embed = self.gen_r_embedding(r)
-
- for c in self.t_conds:
- t_cond = kwargs.get(c, torch.zeros_like(r))
- r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
- clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
-
- # Model Blocks
-
- x = self.embedding(x)
-
-
-
- if cnet is not None:
- cnet = ControlNetDeliverer(cnet)
-
- if not reuire_f:
- level_outputs = self._down_encode(x, r_embed, clip, cnet, lr_guide= lr_guide[0] if lr_guide is not None else None, \
- require_q=reuire_f, r_emb_lite=self.gen_r_embedding(r), guide_weight=guide_weight)
- x = self._up_decode(level_outputs, r_embed, clip, cnet, agg_f=lr_guide[1] if lr_guide is not None else None, \
- require_ff=reuire_f, r_emb_lite=self.gen_r_embedding(r), guide_weight=guide_weight)
- else:
- level_outputs, lr_enc = self._down_encode(x, r_embed, clip, cnet, lr_guide= lr_guide[0] if lr_guide is not None else None, require_q=True)
- x, lr_dec = self._up_decode(level_outputs, r_embed, clip, cnet, agg_f=lr_guide[1] if lr_guide is not None else None, require_ff=True)
-
- if reuire_f and require_t:
- return self.clf(x), r_embed, lr_enc, lr_dec
- if reuire_f:
- return self.clf(x), lr_enc, lr_dec
- if require_t:
- return self.clf(x), r_embed
- return self.clf(x)
-
-
- def update_weights_ema(self, src_model, beta=0.999):
- for self_params, src_params in zip(self.parameters(), src_model.parameters()):
- self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
- for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
- self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
-
-
-
-if __name__ == '__main__':
- generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
- total_ori = sum([ param.nelement() for param in generator.parameters()])
- generator._init_extra_parameter()
- generator = generator.cuda()
- total = sum([ param.nelement() for param in generator.parameters()])
- total_down = sum([ param.nelement() for param in generator.down_blocks.parameters()])
-
- total_up = sum([ param.nelement() for param in generator.up_blocks.parameters()])
- total_pro = sum([ param.nelement() for param in generator.project.parameters()])
-
-
- print(total_ori / 1e6, total / 1e6, total_up / 1e6, total_down / 1e6, total_pro / 1e6)
-
- # for name, module in generator.down_blocks.named_modules():
- # print(name, module)
- output, out_lr = generator(
- x=torch.randn(1, 16, 24, 24).cuda(),
- x_lr=torch.randn(1, 16, 16, 16).cuda(),
- r=torch.tensor([0.7056]).cuda(),
- clip_text=torch.randn(1, 77, 1280).cuda(),
- clip_text_pooled = torch.randn(1, 1, 1280).cuda(),
- clip_img = torch.randn(1, 1, 768).cuda()
- )
- print(output.shape, out_lr.shape)
- # cnt
diff --git a/modules/previewer.py b/modules/previewer.py
deleted file mode 100644
index 51ab24292d8ac0da8d24b17d8fc0ac9e1419a3d7..0000000000000000000000000000000000000000
--- a/modules/previewer.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from torch import nn
-
-
-# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
-class Previewer(nn.Module):
- def __init__(self, c_in=16, c_hidden=512, c_out=3):
- super().__init__()
- self.blocks = nn.Sequential(
- nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
- nn.GELU(),
- nn.BatchNorm2d(c_hidden),
-
- nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
- nn.GELU(),
- nn.BatchNorm2d(c_hidden),
-
- nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
- nn.GELU(),
- nn.BatchNorm2d(c_hidden // 2),
-
- nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
- nn.GELU(),
- nn.BatchNorm2d(c_hidden // 2),
-
- nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
- nn.GELU(),
- nn.BatchNorm2d(c_hidden // 4),
-
- nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
- nn.GELU(),
- nn.BatchNorm2d(c_hidden // 4),
-
- nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
- nn.GELU(),
- nn.BatchNorm2d(c_hidden // 4),
-
- nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
- nn.GELU(),
- nn.BatchNorm2d(c_hidden // 4),
-
- nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
- )
-
- def forward(self, x):
- return self.blocks(x)
diff --git a/modules/resnet.py b/modules/resnet.py
deleted file mode 100644
index 460a808942be147d76b8b1f3baf29fec1e2a7b8d..0000000000000000000000000000000000000000
--- a/modules/resnet.py
+++ /dev/null
@@ -1,415 +0,0 @@
-import torch
-from torch import nn
-import torch.nn.functional as F
-#import fvcore.nn.weight_init as weight_init
-
-"""
-Functions for building the BottleneckBlock from Detectron2.
-# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/resnet.py
-"""
-
-def get_norm(norm, out_channels, num_norm_groups=32):
- """
- Args:
- norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
- or a callable that takes a channel number and returns
- the normalization layer as a nn.Module.
- Returns:
- nn.Module or None: the normalization layer
- """
- if norm is None:
- return None
- if isinstance(norm, str):
- if len(norm) == 0:
- return None
- norm = {
- "GN": lambda channels: nn.GroupNorm(num_norm_groups, channels),
- }[norm]
- return norm(out_channels)
-
-class Conv2d(nn.Conv2d):
- """
- A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
- """
-
- def __init__(self, *args, **kwargs):
- """
- Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
- Args:
- norm (nn.Module, optional): a normalization layer
- activation (callable(Tensor) -> Tensor): a callable activation function
- It assumes that norm layer is used before activation.
- """
- norm = kwargs.pop("norm", None)
- activation = kwargs.pop("activation", None)
- super().__init__(*args, **kwargs)
-
- self.norm = norm
- self.activation = activation
-
- def forward(self, x):
- x = F.conv2d(
- x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
- )
- if self.norm is not None:
- x = self.norm(x)
- if self.activation is not None:
- x = self.activation(x)
- return x
-
-class CNNBlockBase(nn.Module):
- """
- A CNN block is assumed to have input channels, output channels and a stride.
- The input and output of `forward()` method must be NCHW tensors.
- The method can perform arbitrary computation but must match the given
- channels and stride specification.
- Attribute:
- in_channels (int):
- out_channels (int):
- stride (int):
- """
-
- def __init__(self, in_channels, out_channels, stride):
- """
- The `__init__` method of any subclass should also contain these arguments.
- Args:
- in_channels (int):
- out_channels (int):
- stride (int):
- """
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.stride = stride
-
-class BottleneckBlock(CNNBlockBase):
- """
- The standard bottleneck residual block used by ResNet-50, 101 and 152
- defined in :paper:`ResNet`. It contains 3 conv layers with kernels
- 1x1, 3x3, 1x1, and a projection shortcut if needed.
- """
-
- def __init__(
- self,
- in_channels,
- out_channels,
- *,
- bottleneck_channels,
- stride=1,
- num_groups=1,
- norm="GN",
- stride_in_1x1=False,
- dilation=1,
- num_norm_groups=32
- ):
- """
- Args:
- bottleneck_channels (int): number of output channels for the 3x3
- "bottleneck" conv layers.
- num_groups (int): number of groups for the 3x3 conv layer.
- norm (str or callable): normalization for all conv layers.
- See :func:`layers.get_norm` for supported format.
- stride_in_1x1 (bool): when stride>1, whether to put stride in the
- first 1x1 convolution or the bottleneck 3x3 convolution.
- dilation (int): the dilation rate of the 3x3 conv layer.
- """
- super().__init__(in_channels, out_channels, stride)
-
- if in_channels != out_channels:
- self.shortcut = Conv2d(
- in_channels,
- out_channels,
- kernel_size=1,
- stride=stride,
- bias=False,
- norm=get_norm(norm, out_channels, num_norm_groups),
- )
- else:
- self.shortcut = None
-
- # The original MSRA ResNet models have stride in the first 1x1 conv
- # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
- # stride in the 3x3 conv
- stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
-
- self.conv1 = Conv2d(
- in_channels,
- bottleneck_channels,
- kernel_size=1,
- stride=stride_1x1,
- bias=False,
- norm=get_norm(norm, bottleneck_channels, num_norm_groups),
- )
-
- self.conv2 = Conv2d(
- bottleneck_channels,
- bottleneck_channels,
- kernel_size=3,
- stride=stride_3x3,
- padding=1 * dilation,
- bias=False,
- groups=num_groups,
- dilation=dilation,
- norm=get_norm(norm, bottleneck_channels, num_norm_groups),
- )
-
- self.conv3 = Conv2d(
- bottleneck_channels,
- out_channels,
- kernel_size=1,
- bias=False,
- norm=get_norm(norm, out_channels, num_norm_groups),
- )
-
- #for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
- # if layer is not None: # shortcut can be None
- # weight_init.c2_msra_fill(layer)
-
- # Zero-initialize the last normalization in each residual branch,
- # so that at the beginning, the residual branch starts with zeros,
- # and each residual block behaves like an identity.
- # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
- # "For BN layers, the learnable scaling coefficient �� is initialized
- # to be 1, except for each residual block's last BN
- # where �� is initialized to be 0."
-
- # nn.init.constant_(self.conv3.norm.weight, 0)
- # TODO this somehow hurts performance when training GN models from scratch.
- # Add it as an option when we need to use this code to train a backbone.
-
- def forward(self, x):
- out = self.conv1(x)
- out = F.relu_(out)
-
- out = self.conv2(out)
- out = F.relu_(out)
-
- out = self.conv3(out)
-
- if self.shortcut is not None:
- shortcut = self.shortcut(x)
- else:
- shortcut = x
-
- out += shortcut
- out = F.relu_(out)
- return out
-
-class ResNet(nn.Module):
- """
- Implement :paper:`ResNet`.
- """
-
- def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
- """
- Args:
- stem (nn.Module): a stem module
- stages (list[list[CNNBlockBase]]): several (typically 4) stages,
- each contains multiple :class:`CNNBlockBase`.
- num_classes (None or int): if None, will not perform classification.
- Otherwise, will create a linear layer.
- out_features (list[str]): name of the layers whose outputs should
- be returned in forward. Can be anything in "stem", "linear", or "res2" ...
- If None, will return the output of the last layer.
- freeze_at (int): The number of stages at the beginning to freeze.
- see :meth:`freeze` for detailed explanation.
- """
- super().__init__()
- self.stem = stem
- self.num_classes = num_classes
-
- current_stride = self.stem.stride
- self._out_feature_strides = {"stem": current_stride}
- self._out_feature_channels = {"stem": self.stem.out_channels}
-
- self.stage_names, self.stages = [], []
-
- if out_features is not None:
- # Avoid keeping unused layers in this module. They consume extra memory
- # and may cause allreduce to fail
- num_stages = max(
- [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
- )
- stages = stages[:num_stages]
- for i, blocks in enumerate(stages):
- assert len(blocks) > 0, len(blocks)
- for block in blocks:
- assert isinstance(block, CNNBlockBase), block
-
- name = "res" + str(i + 2)
- stage = nn.Sequential(*blocks)
-
- self.add_module(name, stage)
- self.stage_names.append(name)
- self.stages.append(stage)
-
- self._out_feature_strides[name] = current_stride = int(
- current_stride * np.prod([k.stride for k in blocks])
- )
- self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
- self.stage_names = tuple(self.stage_names) # Make it static for scripting
-
- if num_classes is not None:
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- self.linear = nn.Linear(curr_channels, num_classes)
-
- # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
- # "The 1000-way fully-connected layer is initialized by
- # drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
- nn.init.normal_(self.linear.weight, std=0.01)
- name = "linear"
-
- if out_features is None:
- out_features = [name]
- self._out_features = out_features
- assert len(self._out_features)
- children = [x[0] for x in self.named_children()]
- for out_feature in self._out_features:
- assert out_feature in children, "Available children: {}".format(", ".join(children))
- self.freeze(freeze_at)
-
- def forward(self, x):
- """
- Args:
- x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
- Returns:
- dict[str->Tensor]: names and the corresponding features
- """
- assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
- outputs = {}
- x = self.stem(x)
- if "stem" in self._out_features:
- outputs["stem"] = x
- for name, stage in zip(self.stage_names, self.stages):
- x = stage(x)
- if name in self._out_features:
- outputs[name] = x
- if self.num_classes is not None:
- x = self.avgpool(x)
- x = torch.flatten(x, 1)
- x = self.linear(x)
- if "linear" in self._out_features:
- outputs["linear"] = x
- return outputs
-
- def freeze(self, freeze_at=0):
- """
- Freeze the first several stages of the ResNet. Commonly used in
- fine-tuning.
- Layers that produce the same feature map spatial size are defined as one
- "stage" by :paper:`FPN`.
- Args:
- freeze_at (int): number of stages to freeze.
- `1` means freezing the stem. `2` means freezing the stem and
- one residual stage, etc.
- Returns:
- nn.Module: this ResNet itself
- """
- if freeze_at >= 1:
- self.stem.freeze()
- for idx, stage in enumerate(self.stages, start=2):
- if freeze_at >= idx:
- for block in stage.children():
- block.freeze()
- return self
-
- @staticmethod
- def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
- """
- Create a list of blocks of the same type that forms one ResNet stage.
- Args:
- block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
- stage. A module of this type must not change spatial resolution of inputs unless its
- stride != 1.
- num_blocks (int): number of blocks in this stage
- in_channels (int): input channels of the entire stage.
- out_channels (int): output channels of **every block** in the stage.
- kwargs: other arguments passed to the constructor of
- `block_class`. If the argument name is "xx_per_block", the
- argument is a list of values to be passed to each block in the
- stage. Otherwise, the same argument is passed to every block
- in the stage.
- Returns:
- list[CNNBlockBase]: a list of block module.
- Examples:
- ::
- stage = ResNet.make_stage(
- BottleneckBlock, 3, in_channels=16, out_channels=64,
- bottleneck_channels=16, num_groups=1,
- stride_per_block=[2, 1, 1],
- dilations_per_block=[1, 1, 2]
- )
- Usually, layers that produce the same feature map spatial size are defined as one
- "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
- all be 1.
- """
- blocks = []
- for i in range(num_blocks):
- curr_kwargs = {}
- for k, v in kwargs.items():
- if k.endswith("_per_block"):
- assert len(v) == num_blocks, (
- f"Argument '{k}' of make_stage should have the "
- f"same length as num_blocks={num_blocks}."
- )
- newk = k[: -len("_per_block")]
- assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
- curr_kwargs[newk] = v[i]
- else:
- curr_kwargs[k] = v
-
- blocks.append(
- block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
- )
- in_channels = out_channels
- return blocks
-
- @staticmethod
- def make_default_stages(depth, block_class=None, **kwargs):
- """
- Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
- If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
- instead for fine-grained customization.
- Args:
- depth (int): depth of ResNet
- block_class (type): the CNN block class. Has to accept
- `bottleneck_channels` argument for depth > 50.
- By default it is BasicBlock or BottleneckBlock, based on the
- depth.
- kwargs:
- other arguments to pass to `make_stage`. Should not contain
- stride and channels, as they are predefined for each depth.
- Returns:
- list[list[CNNBlockBase]]: modules in all stages; see arguments of
- :class:`ResNet.__init__`.
- """
- num_blocks_per_stage = {
- 18: [2, 2, 2, 2],
- 34: [3, 4, 6, 3],
- 50: [3, 4, 6, 3],
- 101: [3, 4, 23, 3],
- 152: [3, 8, 36, 3],
- }[depth]
- if block_class is None:
- block_class = BasicBlock if depth < 50 else BottleneckBlock
- if depth < 50:
- in_channels = [64, 64, 128, 256]
- out_channels = [64, 128, 256, 512]
- else:
- in_channels = [64, 256, 512, 1024]
- out_channels = [256, 512, 1024, 2048]
- ret = []
- for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
- if depth >= 50:
- kwargs["bottleneck_channels"] = o // 4
- ret.append(
- ResNet.make_stage(
- block_class=block_class,
- num_blocks=n,
- stride_per_block=[s] + [1] * (n - 1),
- in_channels=i,
- out_channels=o,
- **kwargs,
- )
- )
- return ret
\ No newline at end of file
diff --git a/modules/speed_util.py b/modules/speed_util.py
deleted file mode 100644
index 3b9507c74833bec270b00bd252a3c76fcc09fab3..0000000000000000000000000000000000000000
--- a/modules/speed_util.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import os
-import math
-import torch
-import torch.nn as nn
-import numpy as np
-from einops import repeat
-class CheckpointFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, run_function, length, *args):
- ctx.run_function = run_function
- ctx.input_tensors = list(args[:length])
- ctx.input_params = list(args[length:])
- ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
- "dtype": torch.get_autocast_gpu_dtype(),
- "cache_enabled": torch.is_autocast_cache_enabled()}
- with torch.no_grad():
- output_tensors = ctx.run_function(*ctx.input_tensors)
- return output_tensors
-
- @staticmethod
- def backward(ctx, *output_grads):
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
- with torch.enable_grad(), \
- torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
- # Fixes a bug where the first op in run_function modifies the
- # Tensor storage in place, which is not allowed for detach()'d
- # Tensors.
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
- output_tensors = ctx.run_function(*shallow_copies)
- input_grads = torch.autograd.grad(
- output_tensors,
- ctx.input_tensors + ctx.input_params,
- output_grads,
- allow_unused=True,
- )
- del ctx.input_tensors
- del ctx.input_params
- del output_tensors
- return (None, None) + input_grads
-
-def checkpoint(func, inputs, params, flag):
- """
- Evaluate a function without caching intermediate activations, allowing for
- reduced memory at the expense of extra compute in the backward pass.
- :param func: the function to evaluate.
- :param inputs: the argument sequence to pass to `func`.
- :param params: a sequence of parameters `func` depends on but does not
- explicitly take as arguments.
- :param flag: if False, disable gradient checkpointing.
- """
- if flag:
- args = tuple(inputs) + tuple(params)
- return CheckpointFunction.apply(func, len(inputs), *args)
- else:
- return func(*inputs)
\ No newline at end of file
diff --git a/modules/stage_a.py b/modules/stage_a.py
deleted file mode 100644
index 2840ef71d30e3da74954ab4a05e724fd7fef86cf..0000000000000000000000000000000000000000
--- a/modules/stage_a.py
+++ /dev/null
@@ -1,183 +0,0 @@
-import torch
-from torch import nn
-from torchtools.nn import VectorQuantize
-from einops import rearrange
-import torch.nn.functional as F
-import math
-class ResBlock(nn.Module):
- def __init__(self, c, c_hidden):
- super().__init__()
- # depthwise/attention
- self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
- self.depthwise = nn.Sequential(
- nn.ReplicationPad2d(1),
- nn.Conv2d(c, c, kernel_size=3, groups=c)
- )
-
- # channelwise
- self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
- self.channelwise = nn.Sequential(
- nn.Linear(c, c_hidden),
- nn.GELU(),
- nn.Linear(c_hidden, c),
- )
-
- self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
-
- # Init weights
- def _basic_init(module):
- if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
- torch.nn.init.xavier_uniform_(module.weight)
- if module.bias is not None:
- nn.init.constant_(module.bias, 0)
-
- self.apply(_basic_init)
-
- def _norm(self, x, norm):
- return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
-
- def forward(self, x):
-
- mods = self.gammas
-
- x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
-
- #x = x.to(torch.float64)
- x = x + self.depthwise(x_temp) * mods[2]
-
- x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
- x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
-
- return x
-
-
-def extract_patches(tensor, patch_size, stride):
- b, c, H, W = tensor.shape
- pad_h = (patch_size - (H - patch_size) % stride) % stride
- pad_w = (patch_size - (W - patch_size) % stride) % stride
- tensor = F.pad(tensor, (0, pad_w, 0, pad_h), mode='reflect')
-
-
- patches = tensor.unfold(2, patch_size, stride).unfold(3, patch_size, stride)
- patches = patches.contiguous().view(b, c, -1, patch_size, patch_size)
- patches = patches.permute(0, 2, 1, 3, 4)
- return patches, (H, W)
-
-def fuse_patches(patches, patch_size, stride, H, W):
-
- b, num_patches, c, _, _ = patches.shape
- patches = patches.permute(0, 2, 1, 3, 4)
-
-
-
- pad_h = (patch_size - (H - patch_size) % stride) % stride
- pad_w = (patch_size - (W - patch_size) % stride) % stride
- out_h = H + pad_h
- out_w = W + pad_w
- patches = patches.contiguous().view(b, c , -1, patch_size*patch_size ).permute(0, 1, 3, 2)
- patches = patches.contiguous().view(b, c*patch_size*patch_size, -1)
-
- tensor = F.fold(patches, output_size=(out_h, out_w), kernel_size=patch_size, stride=stride)
- overlap_cnt = F.fold(torch.ones_like(patches), output_size=(out_h, out_w), kernel_size=patch_size, stride=stride)
- tensor = tensor / overlap_cnt
- print('end fuse patch', tensor.shape, (tensor.dtype))
- return tensor[:, :, :H, :W]
-
-
-
-class StageA(nn.Module):
- def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
- scale_factor=0.43): # 0.3764
- super().__init__()
- self.c_latent = c_latent
- self.scale_factor = scale_factor
- c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
-
- # Encoder blocks
- self.in_block = nn.Sequential(
- nn.PixelUnshuffle(2),
- nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
- )
- down_blocks = []
- for i in range(levels):
- if i > 0:
- down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
- block = ResBlock(c_levels[i], c_levels[i] * 4)
- down_blocks.append(block)
- down_blocks.append(nn.Sequential(
- nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
- nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
- ))
- self.down_blocks = nn.Sequential(*down_blocks)
- self.down_blocks[0]
-
- self.codebook_size = codebook_size
- self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
-
- # Decoder blocks
- up_blocks = [nn.Sequential(
- nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
- )]
- for i in range(levels):
- for j in range(bottleneck_blocks if i == 0 else 1):
- block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
- up_blocks.append(block)
- if i < levels - 1:
- up_blocks.append(
- nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
- padding=1))
- self.up_blocks = nn.Sequential(*up_blocks)
- self.out_block = nn.Sequential(
- nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
- nn.PixelShuffle(2),
- )
-
- def encode(self, x, quantize=False):
- x = self.in_block(x)
- x = self.down_blocks(x)
- if quantize:
- qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
- return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
- else:
- return x / self.scale_factor, None, None, None
-
-
-
- def decode(self, x, tiled_decoding=False):
- x = x * self.scale_factor
- x = self.up_blocks(x)
- x = self.out_block(x)
- return x
-
- def forward(self, x, quantize=False):
- qe, x, _, vq_loss = self.encode(x, quantize)
- x = self.decode(qe)
- return x, vq_loss
-
-
-class Discriminator(nn.Module):
- def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
- super().__init__()
- d = max(depth - 3, 3)
- layers = [
- nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
- nn.LeakyReLU(0.2),
- ]
- for i in range(depth - 1):
- c_in = c_hidden // (2 ** max((d - i), 0))
- c_out = c_hidden // (2 ** max((d - 1 - i), 0))
- layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
- layers.append(nn.InstanceNorm2d(c_out))
- layers.append(nn.LeakyReLU(0.2))
- self.encoder = nn.Sequential(*layers)
- self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
- self.logits = nn.Sigmoid()
-
- def forward(self, x, cond=None):
- x = self.encoder(x)
- if cond is not None:
- cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
- x = torch.cat([x, cond], dim=1)
- x = self.shuffle(x)
- x = self.logits(x)
- return x
diff --git a/modules/stage_b.py b/modules/stage_b.py
deleted file mode 100644
index f89b42d61327278820e164b1c093cbf8d1048ee1..0000000000000000000000000000000000000000
--- a/modules/stage_b.py
+++ /dev/null
@@ -1,239 +0,0 @@
-import math
-import numpy as np
-import torch
-from torch import nn
-from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock
-
-
-class StageB(nn.Module):
- def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
- nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
- block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
- c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.1, 0.1], self_attn=True,
- t_conds=['sca']):
- super().__init__()
- self.c_r = c_r
- self.t_conds = t_conds
- self.c_clip_seq = c_clip_seq
- if not isinstance(dropout, list):
- dropout = [dropout] * len(c_hidden)
- if not isinstance(self_attn, list):
- self_attn = [self_attn] * len(c_hidden)
-
- # CONDITIONING
- self.effnet_mapper = nn.Sequential(
- nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1),
- nn.GELU(),
- nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1),
- LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
- )
- self.pixels_mapper = nn.Sequential(
- nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1),
- nn.GELU(),
- nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1),
- LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
- )
- self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq)
- self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
-
- self.embedding = nn.Sequential(
- nn.PixelUnshuffle(patch_size),
- nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
- LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
- )
-
- def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
- if block_type == 'C':
- return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
- elif block_type == 'A':
- return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
- elif block_type == 'F':
- return FeedForwardBlock(c_hidden, dropout=dropout)
- elif block_type == 'T':
- return TimestepBlock(c_hidden, c_r, conds=t_conds)
- else:
- raise Exception(f'Block type {block_type} not supported')
-
- # BLOCKS
- # -- down blocks
- self.down_blocks = nn.ModuleList()
- self.down_downscalers = nn.ModuleList()
- self.down_repeat_mappers = nn.ModuleList()
- for i in range(len(c_hidden)):
- if i > 0:
- self.down_downscalers.append(nn.Sequential(
- LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
- nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2),
- ))
- else:
- self.down_downscalers.append(nn.Identity())
- down_block = nn.ModuleList()
- for _ in range(blocks[0][i]):
- for block_type in level_config[i]:
- block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
- down_block.append(block)
- self.down_blocks.append(down_block)
- if block_repeat is not None:
- block_repeat_mappers = nn.ModuleList()
- for _ in range(block_repeat[0][i] - 1):
- block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
- self.down_repeat_mappers.append(block_repeat_mappers)
-
- # -- up blocks
- self.up_blocks = nn.ModuleList()
- self.up_upscalers = nn.ModuleList()
- self.up_repeat_mappers = nn.ModuleList()
- for i in reversed(range(len(c_hidden))):
- if i > 0:
- self.up_upscalers.append(nn.Sequential(
- LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
- nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2),
- ))
- else:
- self.up_upscalers.append(nn.Identity())
- up_block = nn.ModuleList()
- for j in range(blocks[1][::-1][i]):
- for k, block_type in enumerate(level_config[i]):
- c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
- block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
- self_attn=self_attn[i])
- up_block.append(block)
- self.up_blocks.append(up_block)
- if block_repeat is not None:
- block_repeat_mappers = nn.ModuleList()
- for _ in range(block_repeat[1][::-1][i] - 1):
- block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
- self.up_repeat_mappers.append(block_repeat_mappers)
-
- # OUTPUT
- self.clf = nn.Sequential(
- LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
- nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
- nn.PixelShuffle(patch_size),
- )
-
- # --- WEIGHT INIT ---
- self.apply(self._init_weights) # General init
- nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
- nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
- nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
- nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
- nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
- torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
- nn.init.constant_(self.clf[1].weight, 0) # outputs
-
- # blocks
- for level_block in self.down_blocks + self.up_blocks:
- for block in level_block:
- if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
- block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
- elif isinstance(block, TimestepBlock):
- for layer in block.modules():
- if isinstance(layer, nn.Linear):
- nn.init.constant_(layer.weight, 0)
-
- def _init_weights(self, m):
- if isinstance(m, (nn.Conv2d, nn.Linear)):
- torch.nn.init.xavier_uniform_(m.weight)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
-
- def gen_r_embedding(self, r, max_positions=10000):
- r = r * max_positions
- half_dim = self.c_r // 2
- emb = math.log(max_positions) / (half_dim - 1)
- emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
- emb = r[:, None] * emb[None, :]
- emb = torch.cat([emb.sin(), emb.cos()], dim=1)
- if self.c_r % 2 == 1: # zero pad
- emb = nn.functional.pad(emb, (0, 1), mode='constant')
- return emb
-
- def gen_c_embeddings(self, clip):
- if len(clip.shape) == 2:
- clip = clip.unsqueeze(1)
- clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
- clip = self.clip_norm(clip)
- return clip
-
- def _down_encode(self, x, r_embed, clip):
- level_outputs = []
- block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
- for down_block, downscaler, repmap in block_group:
- x = downscaler(x)
- for i in range(len(repmap) + 1):
- for block in down_block:
- if isinstance(block, ResBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- ResBlock)):
- x = block(x)
- elif isinstance(block, AttnBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- AttnBlock)):
- x = block(x, clip)
- elif isinstance(block, TimestepBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- TimestepBlock)):
- x = block(x, r_embed)
- else:
- x = block(x)
- if i < len(repmap):
- x = repmap[i](x)
- level_outputs.insert(0, x)
- return level_outputs
-
- def _up_decode(self, level_outputs, r_embed, clip):
- x = level_outputs[0]
- block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
- for i, (up_block, upscaler, repmap) in enumerate(block_group):
- for j in range(len(repmap) + 1):
- for k, block in enumerate(up_block):
- if isinstance(block, ResBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- ResBlock)):
- skip = level_outputs[i] if k == 0 and i > 0 else None
- if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
- x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear',
- align_corners=True)
- x = block(x, skip)
- elif isinstance(block, AttnBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- AttnBlock)):
- x = block(x, clip)
- elif isinstance(block, TimestepBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- TimestepBlock)):
- x = block(x, r_embed)
- else:
- x = block(x)
- if j < len(repmap):
- x = repmap[j](x)
- x = upscaler(x)
- return x
-
- def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
- if pixels is None:
- pixels = x.new_zeros(x.size(0), 3, 8, 8)
-
- # Process the conditioning embeddings
- r_embed = self.gen_r_embedding(r)
- for c in self.t_conds:
- t_cond = kwargs.get(c, torch.zeros_like(r))
- r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
- clip = self.gen_c_embeddings(clip)
-
- # Model Blocks
- x = self.embedding(x)
- x = x + self.effnet_mapper(
- nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode='bilinear', align_corners=True))
- x = x + nn.functional.interpolate(self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode='bilinear',
- align_corners=True)
- level_outputs = self._down_encode(x, r_embed, clip)
- x = self._up_decode(level_outputs, r_embed, clip)
- return self.clf(x)
-
- def update_weights_ema(self, src_model, beta=0.999):
- for self_params, src_params in zip(self.parameters(), src_model.parameters()):
- self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
- for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
- self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
diff --git a/modules/stage_c.py b/modules/stage_c.py
deleted file mode 100644
index 53b73d0197712b981ec1a154428c21af2149646a..0000000000000000000000000000000000000000
--- a/modules/stage_c.py
+++ /dev/null
@@ -1,252 +0,0 @@
-import torch
-from torch import nn
-import numpy as np
-import math
-from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock
-#from .controlnet import ControlNetDeliverer
-
-
-class UpDownBlock2d(nn.Module):
- def __init__(self, c_in, c_out, mode, enabled=True):
- super().__init__()
- assert mode in ['up', 'down']
- interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
- align_corners=True) if enabled else nn.Identity()
- mapping = nn.Conv2d(c_in, c_out, kernel_size=1)
- self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x.float())
- return x
-
-
-class StageC(nn.Module):
- def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
- blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
- c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
- dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False]):
- super().__init__()
- self.c_r = c_r
- self.t_conds = t_conds
- self.c_clip_seq = c_clip_seq
- if not isinstance(dropout, list):
- dropout = [dropout] * len(c_hidden)
- if not isinstance(self_attn, list):
- self_attn = [self_attn] * len(c_hidden)
-
- # CONDITIONING
- self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond)
- self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq)
- self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq)
- self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
-
- self.embedding = nn.Sequential(
- nn.PixelUnshuffle(patch_size),
- nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
- LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
- )
-
- def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
- if block_type == 'C':
- return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
- elif block_type == 'A':
- return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
- elif block_type == 'F':
- return FeedForwardBlock(c_hidden, dropout=dropout)
- elif block_type == 'T':
- return TimestepBlock(c_hidden, c_r, conds=t_conds)
- else:
- raise Exception(f'Block type {block_type} not supported')
-
- # BLOCKS
- # -- down blocks
- self.down_blocks = nn.ModuleList()
- self.down_downscalers = nn.ModuleList()
- self.down_repeat_mappers = nn.ModuleList()
- for i in range(len(c_hidden)):
- if i > 0:
- self.down_downscalers.append(nn.Sequential(
- LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
- UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1])
- ))
- else:
- self.down_downscalers.append(nn.Identity())
- down_block = nn.ModuleList()
- for _ in range(blocks[0][i]):
- for block_type in level_config[i]:
- block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
- down_block.append(block)
- self.down_blocks.append(down_block)
- if block_repeat is not None:
- block_repeat_mappers = nn.ModuleList()
- for _ in range(block_repeat[0][i] - 1):
- block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
- self.down_repeat_mappers.append(block_repeat_mappers)
-
- # -- up blocks
- self.up_blocks = nn.ModuleList()
- self.up_upscalers = nn.ModuleList()
- self.up_repeat_mappers = nn.ModuleList()
- for i in reversed(range(len(c_hidden))):
- if i > 0:
- self.up_upscalers.append(nn.Sequential(
- LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
- UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1])
- ))
- else:
- self.up_upscalers.append(nn.Identity())
- up_block = nn.ModuleList()
- for j in range(blocks[1][::-1][i]):
- for k, block_type in enumerate(level_config[i]):
- c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
- block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
- self_attn=self_attn[i])
- up_block.append(block)
- self.up_blocks.append(up_block)
- if block_repeat is not None:
- block_repeat_mappers = nn.ModuleList()
- for _ in range(block_repeat[1][::-1][i] - 1):
- block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
- self.up_repeat_mappers.append(block_repeat_mappers)
-
- # OUTPUT
- self.clf = nn.Sequential(
- LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
- nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
- nn.PixelShuffle(patch_size),
- )
-
- # --- WEIGHT INIT ---
- self.apply(self._init_weights) # General init
- nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
- nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
- nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
- torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
- nn.init.constant_(self.clf[1].weight, 0) # outputs
-
- # blocks
- for level_block in self.down_blocks + self.up_blocks:
- for block in level_block:
- if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
- block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
- elif isinstance(block, TimestepBlock):
- for layer in block.modules():
- if isinstance(layer, nn.Linear):
- nn.init.constant_(layer.weight, 0)
-
- def _init_weights(self, m):
- if isinstance(m, (nn.Conv2d, nn.Linear)):
- torch.nn.init.xavier_uniform_(m.weight)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
-
- def gen_r_embedding(self, r, max_positions=10000):
- r = r * max_positions
- half_dim = self.c_r // 2
- emb = math.log(max_positions) / (half_dim - 1)
- emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
- emb = r[:, None] * emb[None, :]
- emb = torch.cat([emb.sin(), emb.cos()], dim=1)
- if self.c_r % 2 == 1: # zero pad
- emb = nn.functional.pad(emb, (0, 1), mode='constant')
- return emb
-
- def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
- clip_txt = self.clip_txt_mapper(clip_txt)
- if len(clip_txt_pooled.shape) == 2:
- clip_txt_pool = clip_txt_pooled.unsqueeze(1)
- if len(clip_img.shape) == 2:
- clip_img = clip_img.unsqueeze(1)
- clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
- clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
- clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
- clip = self.clip_norm(clip)
- return clip
-
- def _down_encode(self, x, r_embed, clip, cnet=None):
- level_outputs = []
- block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
- for down_block, downscaler, repmap in block_group:
- x = downscaler(x)
- for i in range(len(repmap) + 1):
- for block in down_block:
- if isinstance(block, ResBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- ResBlock)):
- if cnet is not None:
- next_cnet = cnet()
- if next_cnet is not None:
- x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
- align_corners=True)
- x = block(x)
- elif isinstance(block, AttnBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- AttnBlock)):
- x = block(x, clip)
- elif isinstance(block, TimestepBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- TimestepBlock)):
- x = block(x, r_embed)
- else:
- x = block(x)
- if i < len(repmap):
- x = repmap[i](x)
- level_outputs.insert(0, x)
- return level_outputs
-
- def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
- x = level_outputs[0]
- block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
- for i, (up_block, upscaler, repmap) in enumerate(block_group):
- for j in range(len(repmap) + 1):
- for k, block in enumerate(up_block):
- if isinstance(block, ResBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- ResBlock)):
- skip = level_outputs[i] if k == 0 and i > 0 else None
- if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
- x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear',
- align_corners=True)
- if cnet is not None:
- next_cnet = cnet()
- if next_cnet is not None:
- x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
- align_corners=True)
- x = block(x, skip)
- elif isinstance(block, AttnBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- AttnBlock)):
- x = block(x, clip)
- elif isinstance(block, TimestepBlock) or (
- hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
- TimestepBlock)):
- x = block(x, r_embed)
- else:
- x = block(x)
- if j < len(repmap):
- x = repmap[j](x)
- x = upscaler(x)
- return x
-
- def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs):
- # Process the conditioning embeddings
- r_embed = self.gen_r_embedding(r)
- for c in self.t_conds:
- t_cond = kwargs.get(c, torch.zeros_like(r))
- r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
- clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
-
- # Model Blocks
- x = self.embedding(x)
- if cnet is not None:
- cnet = ControlNetDeliverer(cnet)
- level_outputs = self._down_encode(x, r_embed, clip, cnet)
- x = self._up_decode(level_outputs, r_embed, clip, cnet)
- return self.clf(x)
-
- def update_weights_ema(self, src_model, beta=0.999):
- for self_params, src_params in zip(self.parameters(), src_model.parameters()):
- self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
- for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
- self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
diff --git a/prompt_list.txt b/prompt_list.txt
deleted file mode 100644
index 27cd31b4750d2f15fdb6f2a3f4bdd117a7377267..0000000000000000000000000000000000000000
--- a/prompt_list.txt
+++ /dev/null
@@ -1,32 +0,0 @@
-A close-up of a blooming peony, with layers of soft, pink petals, a delicate fragrance, and dewdrops glistening
-in the early morning light.
-
-A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a
-clear blue sky.
-
-A close-up portrait of a young woman with flawless skin, vibrant red lipstick, and wavy brown hair, wearing
-a vintage floral dress and standing in front of a blooming garden.
-
-The image features a snow-covered mountain range with a large, snow-covered mountain in the background.
-The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the
-winter season, with snow covering the ground and the trees.
-
-Crocodile in a sweater.
-
-A vibrant anime scene of a young girl with long, flowing pink hair, big sparkling blue eyes, and a school
-uniform, standing under a cherry blossom tree with petals falling around her. The background shows a
-traditional Japanese school with cherry blossoms in full bloom.
-
-A playful Labrador retriever puppy with a shiny, golden coat, chasing a red ball in a spacious backyard, with
-green grass and a wooden fence.
-
-A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm
-lights glowing from the windows, and a path of footprints leading to the front door.
-
-A highly detailed, high-quality image of the Banff National Park in Canada. The turquoise waters of Lake
-Louise are surrounded by snow-capped mountains and dense pine forests. A wooden canoe is docked at the
-edge of the lake. The sky is a clear, bright blue, and the air is crisp and fresh.
-
-A highly detailed, high-quality image of a Shih Tzu receiving a bath in a home bathroom. The dog is standing
-in a tub, covered in suds, with a slightly wet and adorable look. The background includes bathroom fixtures,
-towels, and a clean, tiled floor.
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 431ebbcb0dc492a0b05801e3f9d0f96efdd27245..1270d9d1c13425922f21302c1724fcd0e133a8b0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,19 +1,3 @@
---find-links https://download.pytorch.org/whl/torch_stable.html
-accelerate>=0.25.0
-torch==2.1.2
-torchvision==0.16.2
-transformers>=4.30.0
-numpy==1.26.4
-kornia>=0.7.0
-insightface>=0.7.3
-opencv-python>=4.8.1.78
-tqdm>=4.66.1
-matplotlib>=3.7.4
-webdataset>=0.2.79
-wandb>=0.16.2
-munch>=4.0.0
-onnxruntime>=1.16.3
-einops>=0.7.0
-onnx2torch>=1.5.13
-warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
-torchtools @ git+https://github.com/pabloppp/pytorch-tools
+timm
+transformers
+spaces
\ No newline at end of file
diff --git a/train/__init__.py b/train/__init__.py
deleted file mode 100644
index ea1331f6b933f63c99a6bdf074201fdb4b8f78c2..0000000000000000000000000000000000000000
--- a/train/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .train_b import WurstCore as WurstCoreB
-from .train_c import WurstCore as WurstCoreC
-from .train_t2i import WurstCore as WurstCore_t2i
-from .train_ultrapixel_control import WurstCore as WurstCore_control_lrguide
-from .train_personalized import WurstCore as WurstCore_personalized
\ No newline at end of file
diff --git a/train/base.py b/train/base.py
deleted file mode 100644
index 4e8a6ef306e40da8c9d8db33ceba2f8b2982a9a9..0000000000000000000000000000000000000000
--- a/train/base.py
+++ /dev/null
@@ -1,402 +0,0 @@
-import yaml
-import json
-import torch
-import wandb
-import torchvision
-import numpy as np
-from torch import nn
-from tqdm import tqdm
-from abc import abstractmethod
-from fractions import Fraction
-import matplotlib.pyplot as plt
-from dataclasses import dataclass
-from torch.distributed import barrier
-from torch.utils.data import DataLoader
-
-from gdf import GDF
-from gdf import AdaptiveLossWeight
-
-from core import WarpCore
-from core.data import setup_webdataset_path, MultiGetter, MultiFilter, Bucketeer
-from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
-
-import webdataset as wds
-from webdataset.handlers import warn_and_continue
-
-import transformers
-transformers.utils.logging.set_verbosity_error()
-
-
-class DataCore(WarpCore):
- @dataclass(frozen=True)
- class Config(WarpCore.Config):
- image_size: int = EXPECTED_TRAIN
- webdataset_path: str = EXPECTED_TRAIN
- grad_accum_steps: int = EXPECTED_TRAIN
- batch_size: int = EXPECTED_TRAIN
- multi_aspect_ratio: list = None
-
- captions_getter: list = None
- dataset_filters: list = None
-
- bucketeer_random_ratio: float = 0.05
-
- @dataclass(frozen=True)
- class Extras(WarpCore.Extras):
- transforms: torchvision.transforms.Compose = EXPECTED
- clip_preprocess: torchvision.transforms.Compose = EXPECTED
-
- @dataclass(frozen=True)
- class Models(WarpCore.Models):
- tokenizer: nn.Module = EXPECTED
- text_model: nn.Module = EXPECTED
- image_model: nn.Module = None
-
- config: Config
-
- def webdataset_path(self):
- if isinstance(self.config.webdataset_path, str) and (self.config.webdataset_path.strip().startswith(
- 'pipe:') or self.config.webdataset_path.strip().startswith('file:')):
- return self.config.webdataset_path
- else:
- dataset_path = self.config.webdataset_path
- if isinstance(self.config.webdataset_path, str) and self.config.webdataset_path.strip().endswith('.yml'):
- with open(self.config.webdataset_path, 'r', encoding='utf-8') as file:
- dataset_path = yaml.safe_load(file)
- return setup_webdataset_path(dataset_path, cache_path=f"{self.config.experiment_id}_webdataset_cache.yml")
-
- def webdataset_preprocessors(self, extras: Extras):
- def identity(x):
- if isinstance(x, bytes):
- x = x.decode('utf-8')
- return x
-
- # CUSTOM CAPTIONS GETTER -----
- def get_caption(oc, c, p_og=0.05): # cog_contexual, cog_caption
- if p_og > 0 and np.random.rand() < p_og and len(oc) > 0:
- return identity(oc)
- else:
- return identity(c)
-
- captions_getter = MultiGetter(rules={
- ('old_caption', 'caption'): lambda oc, c: get_caption(json.loads(oc)['og_caption'], c, p_og=0.05)
- })
-
- return [
- ('jpg;png',
- torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None else extras.transforms,
- 'images'),
- ('txt', identity, 'captions') if self.config.captions_getter is None else (
- self.config.captions_getter[0], eval(self.config.captions_getter[1]), 'captions'),
- ]
-
- def setup_data(self, extras: Extras) -> WarpCore.Data:
- # SETUP DATASET
- dataset_path = self.webdataset_path()
- preprocessors = self.webdataset_preprocessors(extras)
-
- handler = warn_and_continue
- dataset = wds.WebDataset(
- dataset_path, resampled=True, handler=handler
- ).select(
- MultiFilter(rules={
- f[0]: eval(f[1]) for f in self.config.dataset_filters
- }) if self.config.dataset_filters is not None else lambda _: True
- ).shuffle(690, handler=handler).decode(
- "pilrgb", handler=handler
- ).to_tuple(
- *[p[0] for p in preprocessors], handler=handler
- ).map_tuple(
- *[p[1] for p in preprocessors], handler=handler
- ).map(lambda x: {p[2]: x[i] for i, p in enumerate(preprocessors)})
-
- def identity(x):
- return x
-
- # SETUP DATALOADER
- real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps)
- dataloader = DataLoader(
- dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True,
- collate_fn=identity if self.config.multi_aspect_ratio is not None else None
- )
- if self.is_main_node:
- print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)")
-
- if self.config.multi_aspect_ratio is not None:
- aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio]
- dataloader_iterator = Bucketeer(dataloader, density=self.config.image_size ** 2, factor=32,
- ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio,
- interpolate_nearest=False) # , use_smartcrop=True)
- else:
- dataloader_iterator = iter(dataloader)
-
- return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator)
-
- def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
- eval_image_embeds=False, return_fields=None):
- if return_fields is None:
- return_fields = ['clip_text', 'clip_text_pooled', 'clip_img']
-
- captions = batch.get('captions', None)
- images = batch.get('images', None)
- batch_size = len(captions)
-
- text_embeddings = None
- text_pooled_embeddings = None
- if 'clip_text' in return_fields or 'clip_text_pooled' in return_fields:
- if is_eval:
- if is_unconditional:
- captions_unpooled = ["" for _ in range(batch_size)]
- else:
- captions_unpooled = captions
- else:
- rand_idx = np.random.rand(batch_size) > 0.05
- captions_unpooled = [str(c) if keep else "" for c, keep in zip(captions, rand_idx)]
- clip_tokens_unpooled = models.tokenizer(captions_unpooled, truncation=True, padding="max_length",
- max_length=models.tokenizer.model_max_length,
- return_tensors="pt").to(self.device)
- text_encoder_output = models.text_model(**clip_tokens_unpooled, output_hidden_states=True)
- if 'clip_text' in return_fields:
- text_embeddings = text_encoder_output.hidden_states[-1]
- if 'clip_text_pooled' in return_fields:
- text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1)
-
- image_embeddings = None
- if 'clip_img' in return_fields:
- image_embeddings = torch.zeros(batch_size, 768, device=self.device)
- if images is not None:
- images = images.to(self.device)
- if is_eval:
- if not is_unconditional and eval_image_embeds:
- image_embeddings = models.image_model(extras.clip_preprocess(images)).image_embeds
- else:
- rand_idx = np.random.rand(batch_size) > 0.9
- if any(rand_idx):
- image_embeddings[rand_idx] = models.image_model(extras.clip_preprocess(images[rand_idx])).image_embeds
- image_embeddings = image_embeddings.unsqueeze(1)
- return {
- 'clip_text': text_embeddings,
- 'clip_text_pooled': text_pooled_embeddings,
- 'clip_img': image_embeddings
- }
-
-
-class TrainingCore(DataCore, WarpCore):
- @dataclass(frozen=True)
- class Config(DataCore.Config, WarpCore.Config):
- updates: int = EXPECTED_TRAIN
- backup_every: int = EXPECTED_TRAIN
- save_every: int = EXPECTED_TRAIN
-
- # EMA UPDATE
- ema_start_iters: int = None
- ema_iters: int = None
- ema_beta: float = None
-
- use_fsdp: bool = None
-
- @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
- class Info(WarpCore.Info):
- ema_loss: float = None
- adaptive_loss: dict = None
-
- @dataclass(frozen=True)
- class Models(WarpCore.Models):
- generator: nn.Module = EXPECTED
- generator_ema: nn.Module = None # optional
-
- @dataclass(frozen=True)
- class Optimizers(WarpCore.Optimizers):
- generator: any = EXPECTED
-
- @dataclass(frozen=True)
- class Extras(WarpCore.Extras):
- gdf: GDF = EXPECTED
- sampling_configs: dict = EXPECTED
-
- info: Info
- config: Config
-
- @abstractmethod
- def forward_pass(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models):
- raise NotImplementedError("This method needs to be overriden")
-
- @abstractmethod
- def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: Optimizers,
- schedulers: WarpCore.Schedulers):
- raise NotImplementedError("This method needs to be overriden")
-
- @abstractmethod
- def models_to_save(self) -> list:
- raise NotImplementedError("This method needs to be overriden")
-
- @abstractmethod
- def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- raise NotImplementedError("This method needs to be overriden")
-
- @abstractmethod
- def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- raise NotImplementedError("This method needs to be overriden")
-
- def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: Optimizers,
- schedulers: WarpCore.Schedulers):
- start_iter = self.info.iter + 1
- max_iters = self.config.updates * self.config.grad_accum_steps
- if self.is_main_node:
- print(f"STARTING AT STEP: {start_iter}/{max_iters}")
-
- pbar = tqdm(range(start_iter, max_iters + 1)) if self.is_main_node else range(start_iter,
- max_iters + 1) # <--- DDP
- if 'generator' in self.models_to_save():
- models.generator.train()
- for i in pbar:
- # FORWARD PASS
- loss, loss_adjusted = self.forward_pass(data, extras, models)
-
- # # BACKWARD PASS
- grad_norm = self.backward_pass(
- i % self.config.grad_accum_steps == 0 or i == max_iters, loss, loss_adjusted,
- models, optimizers, schedulers
- )
- self.info.iter = i
-
- # UPDATE EMA
- if models.generator_ema is not None and i % self.config.ema_iters == 0:
- update_weights_ema(
- models.generator_ema, models.generator,
- beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
- )
-
- # UPDATE LOSS METRICS
- self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
-
- if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(
- grad_norm.item()):
- wandb.alert(
- title=f"NaN value encountered in training run {self.info.wandb_run_id}",
- text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
- wait_duration=60 * 30
- )
-
- if self.is_main_node:
- logs = {
- 'loss': self.info.ema_loss,
- 'raw_loss': loss.mean().item(),
- 'grad_norm': grad_norm.item(),
- 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0,
- 'total_steps': self.info.total_steps,
- }
-
- pbar.set_postfix(logs)
- if self.config.wandb_project is not None:
- wandb.log(logs)
-
- if i == 1 or i % (self.config.save_every * self.config.grad_accum_steps) == 0 or i == max_iters:
- # SAVE AND CHECKPOINT STUFF
- if np.isnan(loss.mean().item()):
- if self.is_main_node and self.config.wandb_project is not None:
- tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
- wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.wandb_run_id}",
- text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
- else:
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- self.info.adaptive_loss = {
- 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
- 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
- }
- self.save_checkpoints(models, optimizers)
- if self.is_main_node:
- create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
- self.sample(models, data, extras)
-
- def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
- barrier()
- suffix = '' if suffix is None else suffix
- self.save_info(self.info, suffix=suffix)
- models_dict = models.to_dict()
- optimizers_dict = optimizers.to_dict()
- for key in self.models_to_save():
- model = models_dict[key]
- if model is not None:
- self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
- for key in optimizers_dict:
- optimizer = optimizers_dict[key]
- if optimizer is not None:
- self.save_optimizer(optimizer, f'{key}_optim{suffix}',
- fsdp_model=models_dict[key] if self.config.use_fsdp else None)
- if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
- self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps // 1000}k")
- torch.cuda.empty_cache()
-
- def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
- if 'generator' in self.models_to_save():
- models.generator.eval()
- with torch.no_grad():
- batch = next(data.iterator)
-
- conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
- unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
-
- latents = self.encode_latents(batch, models, extras)
- noised, _, _, logSNR, noise_cond, _ = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
- pred = models.generator(noised, noise_cond, **conditions)
- pred = extras.gdf.undiffuse(noised, logSNR, pred)[0]
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
- *_, (sampled, _, _) = extras.gdf.sample(
- models.generator, conditions,
- latents.shape, unconditions, device=self.device, **extras.sampling_configs
- )
-
- if models.generator_ema is not None:
- *_, (sampled_ema, _, _) = extras.gdf.sample(
- models.generator_ema, conditions,
- latents.shape, unconditions, device=self.device, **extras.sampling_configs
- )
- else:
- sampled_ema = sampled
-
- if self.is_main_node:
- noised_images = torch.cat(
- [self.decode_latents(noised[i:i + 1], batch, models, extras) for i in range(len(noised))], dim=0)
- pred_images = torch.cat(
- [self.decode_latents(pred[i:i + 1], batch, models, extras) for i in range(len(pred))], dim=0)
- sampled_images = torch.cat(
- [self.decode_latents(sampled[i:i + 1], batch, models, extras) for i in range(len(sampled))], dim=0)
- sampled_images_ema = torch.cat(
- [self.decode_latents(sampled_ema[i:i + 1], batch, models, extras) for i in range(len(sampled_ema))],
- dim=0)
-
- images = batch['images']
- if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2):
- images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic')
-
- collage_img = torch.cat([
- torch.cat([i for i in images.cpu()], dim=-1),
- torch.cat([i for i in noised_images.cpu()], dim=-1),
- torch.cat([i for i in pred_images.cpu()], dim=-1),
- torch.cat([i for i in sampled_images.cpu()], dim=-1),
- torch.cat([i for i in sampled_images_ema.cpu()], dim=-1),
- ], dim=-2)
-
- torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg')
- torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg')
-
- captions = batch['captions']
- if self.config.wandb_project is not None:
- log_data = [
- [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [
- wandb.Image(images[i])] for i in range(len(images))]
- log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"])
- wandb.log({"Log": log_table})
-
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1])
- plt.ylabel('Raw Loss')
- plt.ylabel('LogSNR')
- wandb.log({"Loss/LogSRN": plt})
-
- if 'generator' in self.models_to_save():
- models.generator.train()
diff --git a/train/dist_core.py b/train/dist_core.py
deleted file mode 100644
index 4e4e9e670a3b853fac345618d3557d648d813902..0000000000000000000000000000000000000000
--- a/train/dist_core.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import os
-import torch
-
-
-def get_world_size():
- """Find OMPI world size without calling mpi functions
- :rtype: int
- """
- if os.environ.get('PMI_SIZE') is not None:
- return int(os.environ.get('PMI_SIZE') or 1)
- elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
- return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
- else:
- return torch.cuda.device_count()
-
-
-def get_global_rank():
- """Find OMPI world rank without calling mpi functions
- :rtype: int
- """
- if os.environ.get('PMI_RANK') is not None:
- return int(os.environ.get('PMI_RANK') or 0)
- elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
- return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
- else:
- return 0
-
-
-def get_local_rank():
- """Find OMPI local rank without calling mpi functions
- :rtype: int
- """
- if os.environ.get('MPI_LOCALRANKID') is not None:
- return int(os.environ.get('MPI_LOCALRANKID') or 0)
- elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
- return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
- else:
- return 0
-
-
-def get_master_ip():
- if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
- return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
- elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
- return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
- else:
- return "127.0.0.1"
diff --git a/train/train_b.py b/train/train_b.py
deleted file mode 100644
index c3441a5841750a7c33b49756d2d60064a68d82d8..0000000000000000000000000000000000000000
--- a/train/train_b.py
+++ /dev/null
@@ -1,305 +0,0 @@
-import torch
-import torchvision
-from torch import nn, optim
-from transformers import AutoTokenizer, CLIPTextModelWithProjection
-from warmup_scheduler import GradualWarmupScheduler
-import numpy as np
-
-import sys
-import os
-from dataclasses import dataclass
-
-from gdf import GDF, EpsilonTarget, CosineSchedule
-from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
-from torchtools.transforms import SmartCrop
-
-from modules.effnet import EfficientNetEncoder
-from modules.stage_a import StageA
-
-from modules.stage_b import StageB
-from modules.stage_b import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
-
-from train.base import DataCore, TrainingCore
-
-from core import WarpCore
-from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
-
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from torch.distributed.fsdp.wrap import ModuleWrapPolicy
-from accelerate import init_empty_weights
-from accelerate.utils import set_module_tensor_to_device
-from contextlib import contextmanager
-
-class WurstCore(TrainingCore, DataCore, WarpCore):
- @dataclass(frozen=True)
- class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
- # TRAINING PARAMS
- lr: float = EXPECTED_TRAIN
- warmup_updates: int = EXPECTED_TRAIN
- shift: float = EXPECTED_TRAIN
- dtype: str = None
-
- # MODEL VERSION
- model_version: str = EXPECTED # 3BB or 700M
- clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
-
- # CHECKPOINT PATHS
- stage_a_checkpoint_path: str = EXPECTED
- effnet_checkpoint_path: str = EXPECTED
- generator_checkpoint_path: str = None
-
- # gdf customization
- adaptive_loss_weight: str = None
-
- @dataclass(frozen=True)
- class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
- effnet: nn.Module = EXPECTED
- stage_a: nn.Module = EXPECTED
-
- @dataclass(frozen=True)
- class Schedulers(WarpCore.Schedulers):
- generator: any = None
-
- @dataclass(frozen=True)
- class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
- gdf: GDF = EXPECTED
- sampling_configs: dict = EXPECTED
- effnet_preprocess: torchvision.transforms.Compose = EXPECTED
-
- info: TrainingCore.Info
- config: Config
-
- def setup_extras_pre(self) -> Extras:
- gdf = GDF(
- schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
- input_scaler=VPScaler(), target=EpsilonTarget(),
- noise_cond=CosineTNoiseCond(),
- loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
- )
- sampling_configs = {"cfg": 1.5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 10}
-
- if self.info.adaptive_loss is not None:
- gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
- gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
-
- effnet_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Normalize(
- mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
- )
- ])
-
- transforms = torchvision.transforms.Compose([
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Resize(self.config.image_size,
- interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
- antialias=True),
- SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) if self.config.training else torchvision.transforms.CenterCrop(self.config.image_size)
- ])
-
- return self.Extras(
- gdf=gdf,
- sampling_configs=sampling_configs,
- transforms=transforms,
- effnet_preprocess=effnet_preprocess,
- clip_preprocess=None
- )
-
- def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None):
- images = batch.get('images', None)
-
- if images is not None:
- images = images.to(self.device)
- if is_eval and not is_unconditional:
- effnet_embeddings = models.effnet(extras.effnet_preprocess(images))
- else:
- if is_eval:
- effnet_factor = 1
- else:
- effnet_factor = np.random.uniform(0.5, 1) # f64 to f32
- effnet_height, effnet_width = int(((images.size(-2)*effnet_factor)//32)*32), int(((images.size(-1)*effnet_factor)//32)*32)
-
- effnet_embeddings = torch.zeros(images.size(0), 16, effnet_height//32, effnet_width//32, device=self.device)
- if not is_eval:
- effnet_images = torchvision.transforms.functional.resize(images, (effnet_height, effnet_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST)
- rand_idx = np.random.rand(len(images)) <= 0.9
- if any(rand_idx):
- effnet_embeddings[rand_idx] = models.effnet(extras.effnet_preprocess(effnet_images[rand_idx]))
- else:
- effnet_embeddings = None
-
- conditions = super().get_conditions(
- batch, models, extras, is_eval, is_unconditional,
- eval_image_embeds, return_fields=return_fields or ['clip_text_pooled']
- )
-
- return {'effnet': effnet_embeddings, 'clip': conditions['clip_text_pooled']}
-
- def setup_models(self, extras: Extras, skip_clip: bool = False) -> Models:
- dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32
-
- # EfficientNet encoder
- effnet = EfficientNetEncoder().to(self.device)
- effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
-
- effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
- effnet.eval().requires_grad_(False)
- del effnet_checkpoint
-
- # vqGAN
- stage_a = StageA().to(self.device)
- stage_a_checkpoint = load_or_fail(self.config.stage_a_checkpoint_path)
- stage_a.load_state_dict(stage_a_checkpoint if 'state_dict' not in stage_a_checkpoint else stage_a_checkpoint['state_dict'])
- stage_a.eval().requires_grad_(False)
- del stage_a_checkpoint
-
- @contextmanager
- def dummy_context():
- yield None
-
- loading_context = dummy_context if self.config.training else init_empty_weights
-
- # Diffusion models
- with loading_context():
- generator_ema = None
- if self.config.model_version == '3B':
- generator = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]])
- if self.config.ema_start_iters is not None:
- generator_ema = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]])
- elif self.config.model_version == '700M':
- generator = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]])
- if self.config.ema_start_iters is not None:
- generator_ema = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]])
- else:
- raise ValueError(f"Unknown model version {self.config.model_version}")
-
- if self.config.generator_checkpoint_path is not None:
- if loading_context is dummy_context:
- generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
- else:
- for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
- set_module_tensor_to_device(generator, param_name, "cpu", value=param)
- generator = generator.to(dtype).to(self.device)
- generator = self.load_model(generator, 'generator')
-
- if generator_ema is not None:
- if loading_context is dummy_context:
- generator_ema.load_state_dict(generator.state_dict())
- else:
- for param_name, param in generator.state_dict().items():
- set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param)
- generator_ema = self.load_model(generator_ema, 'generator_ema')
- generator_ema.to(dtype).to(self.device).eval().requires_grad_(False)
-
- if self.config.use_fsdp:
- fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock])
- generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
- if generator_ema is not None:
- generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
-
- if skip_clip:
- tokenizer = None
- text_model = None
- else:
- tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
- text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
-
- return self.Models(
- effnet=effnet, stage_a=stage_a,
- generator=generator, generator_ema=generator_ema,
- tokenizer=tokenizer, text_model=text_model
- )
-
- def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
- optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
- optimizer = self.load_optimizer(optimizer, 'generator_optim',
- fsdp_model=models.generator if self.config.use_fsdp else None)
- return self.Optimizers(generator=optimizer)
-
- def setup_schedulers(self, extras: Extras, models: Models,
- optimizers: TrainingCore.Optimizers) -> Schedulers:
- scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
- scheduler.last_epoch = self.info.total_steps
- return self.Schedulers(generator=scheduler)
-
- def _pyramid_noise(self, epsilon, size_range=None, levels=10, scale_mode='nearest'):
- epsilon = epsilon.clone()
- multipliers = [1]
- for i in range(1, levels):
- m = 0.75 ** i
- h, w = epsilon.size(-2) // (2 ** i), epsilon.size(-2) // (2 ** i)
- if size_range is None or (size_range[0] <= h <= size_range[1] or size_range[0] <= w <= size_range[1]):
- offset = torch.randn(epsilon.size(0), epsilon.size(1), h, w, device=self.device)
- epsilon = epsilon + torch.nn.functional.interpolate(offset, size=epsilon.shape[-2:],
- mode=scale_mode) * m
- multipliers.append(m)
- if h <= 1 or w <= 1:
- break
- epsilon = epsilon / sum([m ** 2 for m in multipliers]) ** 0.5
- # epsilon = epsilon / epsilon.std()
- return epsilon
-
- def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
- batch = next(data.iterator)
-
- with torch.no_grad():
- conditions = self.get_conditions(batch, models, extras)
- latents = self.encode_latents(batch, models, extras)
- epsilon = torch.randn_like(latents)
- epsilon = self._pyramid_noise(epsilon, size_range=[1, 16])
- noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1,
- epsilon=epsilon)
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
- pred = models.generator(noised, noise_cond, **conditions)
- loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
- loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
-
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- extras.gdf.loss_weight.update_buckets(logSNR, loss)
-
- return loss, loss_adjusted
-
- def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers,
- schedulers: Schedulers):
- if update:
- loss_adjusted.backward()
- grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
- optimizers_dict = optimizers.to_dict()
- for k in optimizers_dict:
- if k != 'training':
- optimizers_dict[k].step()
- schedulers_dict = schedulers.to_dict()
- for k in schedulers_dict:
- if k != 'training':
- schedulers_dict[k].step()
- for k in optimizers_dict:
- if k != 'training':
- optimizers_dict[k].zero_grad(set_to_none=True)
- self.info.total_steps += 1
- else:
- loss_adjusted.backward()
- grad_norm = torch.tensor(0.0).to(self.device)
-
- return grad_norm
-
- def models_to_save(self):
- return ['generator', 'generator_ema']
-
- def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- images = batch['images'].to(self.device)
- return models.stage_a.encode(images)[0]
-
- def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- return models.stage_a.decode(latents.float()).clamp(0, 1)
-
-
-if __name__ == '__main__':
- print("Launching Script")
- warpcore = WurstCore(
- config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
- device=torch.device(int(os.environ.get("SLURM_LOCALID")))
- )
- # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
-
- # RUN TRAINING
- warpcore()
diff --git a/train/train_c.py b/train/train_c.py
deleted file mode 100644
index c4490c6eebc3e1c5126dd13c53603872f1459a3e..0000000000000000000000000000000000000000
--- a/train/train_c.py
+++ /dev/null
@@ -1,266 +0,0 @@
-import torch
-import torchvision
-from torch import nn, optim
-from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
-from warmup_scheduler import GradualWarmupScheduler
-
-import sys
-import os
-from dataclasses import dataclass
-
-from gdf import GDF, EpsilonTarget, CosineSchedule
-from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
-from torchtools.transforms import SmartCrop
-
-from modules.effnet import EfficientNetEncoder
-from modules.stage_c import StageC
-from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
-from modules.previewer import Previewer
-
-from train.base import DataCore, TrainingCore
-
-from core import WarpCore
-from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
-
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from torch.distributed.fsdp.wrap import ModuleWrapPolicy
-from accelerate import init_empty_weights
-from accelerate.utils import set_module_tensor_to_device
-from contextlib import contextmanager
-
-class WurstCore(TrainingCore, DataCore, WarpCore):
- @dataclass(frozen=True)
- class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
- # TRAINING PARAMS
- lr: float = EXPECTED_TRAIN
- warmup_updates: int = EXPECTED_TRAIN
- dtype: str = None
-
- # MODEL VERSION
- model_version: str = EXPECTED # 3.6B or 1B
- clip_image_model_name: str = 'openai/clip-vit-large-patch14'
- clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
-
- # CHECKPOINT PATHS
- effnet_checkpoint_path: str = EXPECTED
- previewer_checkpoint_path: str = EXPECTED
- generator_checkpoint_path: str = None
-
- # gdf customization
- adaptive_loss_weight: str = None
-
- @dataclass(frozen=True)
- class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
- effnet: nn.Module = EXPECTED
- previewer: nn.Module = EXPECTED
-
- @dataclass(frozen=True)
- class Schedulers(WarpCore.Schedulers):
- generator: any = None
-
- @dataclass(frozen=True)
- class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
- gdf: GDF = EXPECTED
- sampling_configs: dict = EXPECTED
- effnet_preprocess: torchvision.transforms.Compose = EXPECTED
-
- info: TrainingCore.Info
- config: Config
-
- def setup_extras_pre(self) -> Extras:
- gdf = GDF(
- schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
- input_scaler=VPScaler(), target=EpsilonTarget(),
- noise_cond=CosineTNoiseCond(),
- loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
- )
- sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
-
- if self.info.adaptive_loss is not None:
- gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
- gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
-
- effnet_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Normalize(
- mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
- )
- ])
-
- clip_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
- torchvision.transforms.CenterCrop(224),
- torchvision.transforms.Normalize(
- mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
- )
- ])
-
- if self.config.training:
- transforms = torchvision.transforms.Compose([
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
- SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
- ])
- else:
- transforms = None
-
- return self.Extras(
- gdf=gdf,
- sampling_configs=sampling_configs,
- transforms=transforms,
- effnet_preprocess=effnet_preprocess,
- clip_preprocess=clip_preprocess
- )
-
- def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
- eval_image_embeds=False, return_fields=None):
- conditions = super().get_conditions(
- batch, models, extras, is_eval, is_unconditional,
- eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
- )
- return conditions
-
- def setup_models(self, extras: Extras) -> Models:
- dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32
-
- # EfficientNet encoder
- effnet = EfficientNetEncoder()
- effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
- effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
- effnet.eval().requires_grad_(False).to(self.device)
- del effnet_checkpoint
-
- # Previewer
- previewer = Previewer()
- previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
- previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
- previewer.eval().requires_grad_(False).to(self.device)
- del previewer_checkpoint
-
- @contextmanager
- def dummy_context():
- yield None
-
- loading_context = dummy_context if self.config.training else init_empty_weights
-
- # Diffusion models
- with loading_context():
- generator_ema = None
- if self.config.model_version == '3.6B':
- generator = StageC()
- if self.config.ema_start_iters is not None:
- generator_ema = StageC()
- elif self.config.model_version == '1B':
- generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
- if self.config.ema_start_iters is not None:
- generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
- else:
- raise ValueError(f"Unknown model version {self.config.model_version}")
-
- if self.config.generator_checkpoint_path is not None:
- if loading_context is dummy_context:
- generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
- else:
-
- for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
- set_module_tensor_to_device(generator, param_name, "cpu", value=param)
- generator = generator.to(dtype).to(self.device)
- generator = self.load_model(generator, 'generator')
-
- if generator_ema is not None:
- if loading_context is dummy_context:
- generator_ema.load_state_dict(generator.state_dict())
- else:
- for param_name, param in generator.state_dict().items():
- set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param)
- generator_ema = self.load_model(generator_ema, 'generator_ema')
- generator_ema.to(dtype).to(self.device).eval().requires_grad_(False)
-
- if self.config.use_fsdp:
- fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock])
- generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
- if generator_ema is not None:
- generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
-
- tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
- text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
- image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
-
- return self.Models(
- effnet=effnet, previewer=previewer,
- generator=generator, generator_ema=generator_ema,
- tokenizer=tokenizer, text_model=text_model, image_model=image_model
- )
-
- def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
- optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
- optimizer = self.load_optimizer(optimizer, 'generator_optim',
- fsdp_model=models.generator if self.config.use_fsdp else None)
- return self.Optimizers(generator=optimizer)
-
- def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers:
- scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
- scheduler.last_epoch = self.info.total_steps
- return self.Schedulers(generator=scheduler)
-
- # Training loop --------------------------------
- def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
- batch = next(data.iterator)
-
- with torch.no_grad():
- conditions = self.get_conditions(batch, models, extras)
- latents = self.encode_latents(batch, models, extras)
- noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
- pred = models.generator(noised, noise_cond, **conditions)
- loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
- loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
-
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- extras.gdf.loss_weight.update_buckets(logSNR, loss)
-
- return loss, loss_adjusted
-
- def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
- if update:
- loss_adjusted.backward()
- grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
- optimizers_dict = optimizers.to_dict()
- for k in optimizers_dict:
- if k != 'training':
- optimizers_dict[k].step()
- schedulers_dict = schedulers.to_dict()
- for k in schedulers_dict:
- if k != 'training':
- schedulers_dict[k].step()
- for k in optimizers_dict:
- if k != 'training':
- optimizers_dict[k].zero_grad(set_to_none=True)
- self.info.total_steps += 1
- else:
- loss_adjusted.backward()
- grad_norm = torch.tensor(0.0).to(self.device)
-
- return grad_norm
-
- def models_to_save(self):
- return ['generator', 'generator_ema']
-
- def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- images = batch['images'].to(self.device)
- return models.effnet(extras.effnet_preprocess(images))
-
- def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- return models.previewer(latents)
-
-
-if __name__ == '__main__':
- print("Launching Script")
- warpcore = WurstCore(
- config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
- device=torch.device(int(os.environ.get("SLURM_LOCALID")))
- )
- # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
-
- # RUN TRAINING
- warpcore()
diff --git a/train/train_c_lora.py b/train/train_c_lora.py
deleted file mode 100644
index 8b83eee0f250e5359901d39b8d4052254cfff4fa..0000000000000000000000000000000000000000
--- a/train/train_c_lora.py
+++ /dev/null
@@ -1,330 +0,0 @@
-import torch
-import torchvision
-from torch import nn, optim
-from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
-from warmup_scheduler import GradualWarmupScheduler
-
-import sys
-import os
-import re
-from dataclasses import dataclass
-
-from gdf import GDF, EpsilonTarget, CosineSchedule
-from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
-from torchtools.transforms import SmartCrop
-
-from modules.effnet import EfficientNetEncoder
-from modules.stage_c import StageC
-from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
-from modules.previewer import Previewer
-from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
-
-from train.base import DataCore, TrainingCore
-
-from core import WarpCore
-from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
-
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
-from torch.distributed.fsdp.wrap import ModuleWrapPolicy
-from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
-import functools
-from accelerate import init_empty_weights
-from accelerate.utils import set_module_tensor_to_device
-from contextlib import contextmanager
-
-
-class WurstCore(TrainingCore, DataCore, WarpCore):
- @dataclass(frozen=True)
- class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
- # TRAINING PARAMS
- lr: float = EXPECTED_TRAIN
- warmup_updates: int = EXPECTED_TRAIN
- dtype: str = None
-
- # MODEL VERSION
- model_version: str = EXPECTED # 3.6B or 1B
- clip_image_model_name: str = 'openai/clip-vit-large-patch14'
- clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
-
- # CHECKPOINT PATHS
- effnet_checkpoint_path: str = EXPECTED
- previewer_checkpoint_path: str = EXPECTED
- generator_checkpoint_path: str = None
- lora_checkpoint_path: str = None
-
- # LoRA STUFF
- module_filters: list = EXPECTED
- rank: int = EXPECTED
- train_tokens: list = EXPECTED
-
- # gdf customization
- adaptive_loss_weight: str = None
-
- @dataclass(frozen=True)
- class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
- effnet: nn.Module = EXPECTED
- previewer: nn.Module = EXPECTED
- lora: nn.Module = EXPECTED
-
- @dataclass(frozen=True)
- class Schedulers(WarpCore.Schedulers):
- lora: any = None
-
- @dataclass(frozen=True)
- class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
- gdf: GDF = EXPECTED
- sampling_configs: dict = EXPECTED
- effnet_preprocess: torchvision.transforms.Compose = EXPECTED
-
- @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
- class Info(TrainingCore.Info):
- train_tokens: list = None
-
- @dataclass(frozen=True)
- class Optimizers(TrainingCore.Optimizers, WarpCore.Optimizers):
- generator: any = None
- lora: any = EXPECTED
-
- # --------------------------------------------
- info: Info
- config: Config
-
- # Extras: gdf, transforms and preprocessors --------------------------------
- def setup_extras_pre(self) -> Extras:
- gdf = GDF(
- schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
- input_scaler=VPScaler(), target=EpsilonTarget(),
- noise_cond=CosineTNoiseCond(),
- loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
- )
- sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
-
- if self.info.adaptive_loss is not None:
- gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
- gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
-
- effnet_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Normalize(
- mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
- )
- ])
-
- clip_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
- torchvision.transforms.CenterCrop(224),
- torchvision.transforms.Normalize(
- mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
- )
- ])
-
- if self.config.training:
- transforms = torchvision.transforms.Compose([
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
- SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
- ])
- else:
- transforms = None
-
- return self.Extras(
- gdf=gdf,
- sampling_configs=sampling_configs,
- transforms=transforms,
- effnet_preprocess=effnet_preprocess,
- clip_preprocess=clip_preprocess
- )
-
- # Data --------------------------------
- def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
- eval_image_embeds=False, return_fields=None):
- conditions = super().get_conditions(
- batch, models, extras, is_eval, is_unconditional,
- eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
- )
- return conditions
-
- # Models, Optimizers & Schedulers setup --------------------------------
- def setup_models(self, extras: Extras) -> Models:
- dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32
-
- # EfficientNet encoder
- effnet = EfficientNetEncoder().to(self.device)
- effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
- effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
- effnet.eval().requires_grad_(False)
- del effnet_checkpoint
-
- # Previewer
- previewer = Previewer().to(self.device)
- previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
- previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
- previewer.eval().requires_grad_(False)
- del previewer_checkpoint
-
- @contextmanager
- def dummy_context():
- yield None
-
- loading_context = dummy_context if self.config.training else init_empty_weights
-
- with loading_context():
- # Diffusion models
- if self.config.model_version == '3.6B':
- generator = StageC()
- elif self.config.model_version == '1B':
- generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
- else:
- raise ValueError(f"Unknown model version {self.config.model_version}")
-
- if self.config.generator_checkpoint_path is not None:
- if loading_context is dummy_context:
- generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
- else:
- for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
- set_module_tensor_to_device(generator, param_name, "cpu", value=param)
- generator = generator.to(dtype).to(self.device)
- generator = self.load_model(generator, 'generator')
-
- # if self.config.use_fsdp:
- # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000)
- # generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
-
- # CLIP encoders
- tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
- text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
- image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
-
- # PREPARE LORA
- update_tokens = []
- for tkn_regex, aggr_regex in self.config.train_tokens:
- if (tkn_regex.startswith('[') and tkn_regex.endswith(']')) or (tkn_regex.startswith('<') and tkn_regex.endswith('>')):
- # Insert new token
- tokenizer.add_tokens([tkn_regex])
- # add new zeros embedding
- new_embedding = torch.zeros_like(text_model.text_model.embeddings.token_embedding.weight.data)[:1]
- if aggr_regex is not None: # aggregate embeddings to provide an interesting baseline
- aggr_tokens = [v for k, v in tokenizer.vocab.items() if re.search(aggr_regex, k) is not None]
- if len(aggr_tokens) > 0:
- new_embedding = text_model.text_model.embeddings.token_embedding.weight.data[aggr_tokens].mean(dim=0, keepdim=True)
- elif self.is_main_node:
- print(f"WARNING: No tokens found for aggregation regex {aggr_regex}. It will be initialized as zeros.")
- text_model.text_model.embeddings.token_embedding.weight.data = torch.cat([
- text_model.text_model.embeddings.token_embedding.weight.data, new_embedding
- ], dim=0)
- selected_tokens = [len(tokenizer.vocab) - 1]
- else:
- selected_tokens = [v for k, v in tokenizer.vocab.items() if re.search(tkn_regex, k) is not None]
- update_tokens += selected_tokens
- update_tokens = list(set(update_tokens)) # remove duplicates
-
- apply_retoken(text_model.text_model.embeddings.token_embedding, update_tokens)
- apply_lora(generator, filters=self.config.module_filters, rank=self.config.rank)
- text_model.text_model.to(self.device)
- generator.to(self.device)
- lora = nn.ModuleDict()
- lora['embeddings'] = text_model.text_model.embeddings.token_embedding.parametrizations.weight[0]
- lora['weights'] = nn.ModuleList()
- for module in generator.modules():
- if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)):
- lora['weights'].append(module)
-
- self.info.train_tokens = [(i, tokenizer.decode(i)) for i in update_tokens]
- if self.is_main_node:
- print("Updating tokens:", self.info.train_tokens)
- print(f"LoRA training {len(lora['weights'])} layers")
-
- if self.config.lora_checkpoint_path is not None:
- lora_checkpoint = load_or_fail(self.config.lora_checkpoint_path)
- lora.load_state_dict(lora_checkpoint if 'state_dict' not in lora_checkpoint else lora_checkpoint['state_dict'])
-
- lora = self.load_model(lora, 'lora')
- lora.to(self.device).train().requires_grad_(True)
- if self.config.use_fsdp:
- # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000)
- fsdp_auto_wrap_policy = ModuleWrapPolicy([LoRA, ReToken])
- lora = FSDP(lora, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
-
- return self.Models(
- effnet=effnet, previewer=previewer,
- generator=generator, generator_ema=None,
- lora=lora,
- tokenizer=tokenizer, text_model=text_model, image_model=image_model
- )
-
- def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
- optimizer = optim.AdamW(models.lora.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
- optimizer = self.load_optimizer(optimizer, 'lora_optim',
- fsdp_model=models.lora if self.config.use_fsdp else None)
- return self.Optimizers(generator=None, lora=optimizer)
-
- def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
- scheduler = GradualWarmupScheduler(optimizers.lora, multiplier=1, total_epoch=self.config.warmup_updates)
- scheduler.last_epoch = self.info.total_steps
- return self.Schedulers(lora=scheduler)
-
- def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
- batch = next(data.iterator)
-
- conditions = self.get_conditions(batch, models, extras)
- with torch.no_grad():
- latents = self.encode_latents(batch, models, extras)
- noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
- pred = models.generator(noised, noise_cond, **conditions)
- loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
- loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
-
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- extras.gdf.loss_weight.update_buckets(logSNR, loss)
-
- return loss, loss_adjusted
-
- def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
- if update:
- loss_adjusted.backward()
- grad_norm = nn.utils.clip_grad_norm_(models.lora.parameters(), 1.0)
- optimizers_dict = optimizers.to_dict()
- for k in optimizers_dict:
- if optimizers_dict[k] is not None and k != 'training':
- optimizers_dict[k].step()
- schedulers_dict = schedulers.to_dict()
- for k in schedulers_dict:
- if k != 'training':
- schedulers_dict[k].step()
- for k in optimizers_dict:
- if optimizers_dict[k] is not None and k != 'training':
- optimizers_dict[k].zero_grad(set_to_none=True)
- self.info.total_steps += 1
- else:
- loss_adjusted.backward()
- grad_norm = torch.tensor(0.0).to(self.device)
-
- return grad_norm
-
- def models_to_save(self):
- return ['lora']
-
- def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
- models.lora.eval()
- super().sample(models, data, extras)
- models.lora.train(), models.generator.eval()
-
- def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- images = batch['images'].to(self.device)
- return models.effnet(extras.effnet_preprocess(images))
-
- def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- return models.previewer(latents)
-
-
-if __name__ == '__main__':
- print("Launching Script")
- warpcore = WurstCore(
- config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
- device=torch.device(int(os.environ.get("SLURM_LOCALID")))
- )
- warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
-
- # RUN TRAINING
- warpcore()
diff --git a/train/train_personalized.py b/train/train_personalized.py
deleted file mode 100644
index 5161b7c621a0eb9daf9d0f0566322bbeed646284..0000000000000000000000000000000000000000
--- a/train/train_personalized.py
+++ /dev/null
@@ -1,899 +0,0 @@
-import torch
-import json
-import yaml
-import torchvision
-from torch import nn, optim
-from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
-from warmup_scheduler import GradualWarmupScheduler
-import torch.multiprocessing as mp
-import os
-import numpy as np
-import re
-import sys
-sys.path.append(os.path.abspath('./'))
-
-from dataclasses import dataclass
-from torch.distributed import init_process_group, destroy_process_group, barrier
-from gdf import GDF_dual_fixlrt as GDF
-from gdf import EpsilonTarget, CosineSchedule
-from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
-from torchtools.transforms import SmartCrop
-from fractions import Fraction
-from modules.effnet import EfficientNetEncoder
-from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
-from modules.common_ckpt import GlobalResponseNorm
-from modules.previewer import Previewer
-from core.data import Bucketeer
-from train.base import DataCore, TrainingCore
-from tqdm import tqdm
-from core import WarpCore
-from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
-
-from accelerate import init_empty_weights
-from accelerate.utils import set_module_tensor_to_device
-from contextlib import contextmanager
-from train.dist_core import *
-import glob
-from torch.utils.data import DataLoader, Dataset
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.utils.data.distributed import DistributedSampler
-from PIL import Image
-from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
-from core.utils import Base
-import torch.nn.functional as F
-import functools
-import math
-import copy
-import random
-from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
-
-Image.MAX_IMAGE_PIXELS = None
-torch.manual_seed(23)
-random.seed(23)
-np.random.seed(23)
-#7978026
-
-class Null_Model(torch.nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, x):
- pass
-
-
-
-
-def identity(x):
- if isinstance(x, bytes):
- x = x.decode('utf-8')
- return x
-def check_nan_inmodel(model, meta=''):
- for name, param in model.named_parameters():
- if torch.isnan(param).any():
- print(f"nan detected in {name}", meta)
- return True
- print('no nan', meta)
- return False
-class mydist_dataset(Dataset):
- def __init__(self, rootpath, tmp_prompt, img_processor=None):
-
- self.img_pathlist = glob.glob(os.path.join(rootpath, '*.jpg'))
- self.img_pathlist = self.img_pathlist * 100000
- self.img_processor = img_processor
- self.length = len( self.img_pathlist)
- self.caption = tmp_prompt
-
-
- def __getitem__(self, idx):
-
- imgpath = self.img_pathlist[idx]
- txt = self.caption
-
-
-
-
- try:
- img = Image.open(imgpath).convert('RGB')
- w, h = img.size
- if self.img_processor is not None:
- img = self.img_processor(img)
-
- except:
- print('exception', imgpath)
- return self.__getitem__(random.randint(0, self.length -1 ) )
- return dict(captions=txt, images=img)
- def __len__(self):
- return self.length
-class WurstCore(TrainingCore, DataCore, WarpCore):
- @dataclass(frozen=True)
- class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
- # TRAINING PARAMS
- lr: float = EXPECTED_TRAIN
- warmup_updates: int = EXPECTED_TRAIN
- dtype: str = None
-
- # MODEL VERSION
- model_version: str = EXPECTED # 3.6B or 1B
- clip_image_model_name: str = 'openai/clip-vit-large-patch14'
- clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
-
- # CHECKPOINT PATHS
- effnet_checkpoint_path: str = EXPECTED
- previewer_checkpoint_path: str = EXPECTED
- generator_checkpoint_path: str = None
- ultrapixel_path: str = EXPECTED
-
- # gdf customization
- adaptive_loss_weight: str = None
-
- # LoRA STUFF
- module_filters: list = EXPECTED
- rank: int = EXPECTED
- train_tokens: list = EXPECTED
- use_ddp: bool=EXPECTED
- tmp_prompt: str=EXPECTED
- @dataclass(frozen=True)
- class Data(Base):
- dataset: Dataset = EXPECTED
- dataloader: DataLoader = EXPECTED
- iterator: any = EXPECTED
- sampler: DistributedSampler = EXPECTED
-
- @dataclass(frozen=True)
- class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
- effnet: nn.Module = EXPECTED
- previewer: nn.Module = EXPECTED
- train_norm: nn.Module = EXPECTED
- train_lora: nn.Module = EXPECTED
-
- @dataclass(frozen=True)
- class Schedulers(WarpCore.Schedulers):
- generator: any = None
-
- @dataclass(frozen=True)
- class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
- gdf: GDF = EXPECTED
- sampling_configs: dict = EXPECTED
- effnet_preprocess: torchvision.transforms.Compose = EXPECTED
-
- info: TrainingCore.Info
- config: Config
-
- def setup_extras_pre(self) -> Extras:
- gdf = GDF(
- schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
- input_scaler=VPScaler(), target=EpsilonTarget(),
- noise_cond=CosineTNoiseCond(),
- loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
- )
- sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
-
- if self.info.adaptive_loss is not None:
- gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
- gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
-
- effnet_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Normalize(
- mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
- )
- ])
-
- clip_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
- torchvision.transforms.CenterCrop(224),
- torchvision.transforms.Normalize(
- mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
- )
- ])
-
- if self.config.training:
- transforms = torchvision.transforms.Compose([
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
- SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
- ])
- else:
- transforms = None
-
- return self.Extras(
- gdf=gdf,
- sampling_configs=sampling_configs,
- transforms=transforms,
- effnet_preprocess=effnet_preprocess,
- clip_preprocess=clip_preprocess
- )
-
- def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
- eval_image_embeds=False, return_fields=None):
- conditions = super().get_conditions(
- batch, models, extras, is_eval, is_unconditional,
- eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
- )
- return conditions
-
- def setup_models(self, extras: Extras) -> Models: # configure model
-
-
- dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16
-
- # EfficientNet encoderin
- effnet = EfficientNetEncoder()
- effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
- effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
- effnet.eval().requires_grad_(False).to(self.device)
- del effnet_checkpoint
-
- # Previewer
- previewer = Previewer()
- previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
- previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
- previewer.eval().requires_grad_(False).to(self.device)
- del previewer_checkpoint
-
- @contextmanager
- def dummy_context():
- yield None
-
- loading_context = dummy_context if self.config.training else init_empty_weights
-
- # Diffusion models
- with loading_context():
- generator_ema = None
- if self.config.model_version == '3.6B':
- generator = StageC()
- if self.config.ema_start_iters is not None: # default setting
- generator_ema = StageC()
- elif self.config.model_version == '1B':
- print('in line 155 1b light model', self.config.model_version )
- generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
-
- if self.config.ema_start_iters is not None and self.config.training:
- generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
- else:
- raise ValueError(f"Unknown model version {self.config.model_version}")
-
-
-
- if loading_context is dummy_context:
- generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path))
- else:
- for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
- set_module_tensor_to_device(generator, param_name, "cpu", value=param)
-
- generator._init_extra_parameter()
- generator = generator.to(torch.bfloat16).to(self.device)
-
- train_norm = nn.ModuleList()
-
-
- cnt_norm = 0
- for mm in generator.modules():
- if isinstance(mm, GlobalResponseNorm):
-
- train_norm.append(Null_Model())
- cnt_norm += 1
-
-
-
-
- train_norm.append(generator.agg_net)
- train_norm.append(generator.agg_net_up)
- sdd = torch.load(self.config.ultrapixel_path, map_location='cpu')
- collect_sd = {}
- for k, v in sdd.items():
- collect_sd[k[7:]] = v
- train_norm.load_state_dict(collect_sd)
-
-
-
- # CLIP encoders
- tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
- text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
- image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
-
- # PREPARE LORA
- train_lora = nn.ModuleList()
- update_tokens = []
- for tkn_regex, aggr_regex in self.config.train_tokens:
- if (tkn_regex.startswith('[') and tkn_regex.endswith(']')) or (tkn_regex.startswith('<') and tkn_regex.endswith('>')):
- # Insert new token
- tokenizer.add_tokens([tkn_regex])
- # add new zeros embedding
- new_embedding = torch.zeros_like(text_model.text_model.embeddings.token_embedding.weight.data)[:1]
- if aggr_regex is not None: # aggregate embeddings to provide an interesting baseline
- aggr_tokens = [v for k, v in tokenizer.vocab.items() if re.search(aggr_regex, k) is not None]
- if len(aggr_tokens) > 0:
- new_embedding = text_model.text_model.embeddings.token_embedding.weight.data[aggr_tokens].mean(dim=0, keepdim=True)
- elif self.is_main_node:
- print(f"WARNING: No tokens found for aggregation regex {aggr_regex}. It will be initialized as zeros.")
- text_model.text_model.embeddings.token_embedding.weight.data = torch.cat([
- text_model.text_model.embeddings.token_embedding.weight.data, new_embedding
- ], dim=0)
- selected_tokens = [len(tokenizer.vocab) - 1]
- else:
- selected_tokens = [v for k, v in tokenizer.vocab.items() if re.search(tkn_regex, k) is not None]
- update_tokens += selected_tokens
- update_tokens = list(set(update_tokens)) # remove duplicates
-
- apply_retoken(text_model.text_model.embeddings.token_embedding, update_tokens)
-
- apply_lora(generator, filters=self.config.module_filters, rank=self.config.rank)
- for module in generator.modules():
- if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)):
- train_lora.append(module)
-
-
- train_lora.append(text_model.text_model.embeddings.token_embedding.parametrizations.weight[0])
-
- if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_lora.safetensors')):
- sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_lora.safetensors'), map_location='cpu')
- collect_sd = {}
- for k, v in sdd.items():
- collect_sd[k[7:]] = v
- train_lora.load_state_dict(collect_sd, strict=True)
-
-
- train_norm.to(self.device).train().requires_grad_(True)
-
- if generator_ema is not None:
-
- generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
- generator_ema._init_extra_parameter()
- pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors')
- if os.path.exists(pretrained_pth):
- generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu'))
-
- generator_ema.eval().requires_grad_(False)
-
- check_nan_inmodel(generator, 'generator')
-
-
-
- if self.config.use_ddp and self.config.training:
-
- train_lora = DDP(train_lora, device_ids=[self.device], find_unused_parameters=True)
-
-
-
- return self.Models(
- effnet=effnet, previewer=previewer, train_norm = train_norm,
- generator=generator, generator_ema=generator_ema,
- tokenizer=tokenizer, text_model=text_model, image_model=image_model,
- train_lora=train_lora
- )
-
- def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
-
-
- params = []
- params += list(models.train_lora.module.parameters())
- optimizer = optim.AdamW(params, lr=self.config.lr)
-
- return self.Optimizers(generator=optimizer)
-
- def ema_update(self, ema_model, source_model, beta):
- for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()):
- param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta)
-
- def sync_ema(self, ema_model):
- print('sync ema', torch.distributed.get_world_size())
- for param in ema_model.parameters():
- torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM)
- param.data /= torch.distributed.get_world_size()
- def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
-
-
- optimizer = optim.AdamW(
- models.generator.up_blocks.parameters() ,
- lr=self.config.lr)
- optimizer = self.load_optimizer(optimizer, 'generator_optim',
- fsdp_model=models.generator if self.config.use_fsdp else None)
- return self.Optimizers(generator=optimizer)
-
- def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers:
- scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
- scheduler.last_epoch = self.info.total_steps
- return self.Schedulers(generator=scheduler)
-
- def setup_data(self, extras: Extras) -> WarpCore.Data:
- # SETUP DATASET
- dataset_path = self.config.webdataset_path
-
-
- dataset = mydist_dataset(dataset_path, self.config.tmp_prompt, \
- torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \
- else extras.transforms)
-
- # SETUP DATALOADER
- real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps)
-
- sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True)
- dataloader = DataLoader(
- dataset, batch_size=real_batch_size, num_workers=4, pin_memory=True,
- collate_fn=identity if self.config.multi_aspect_ratio is not None else None,
- sampler = sampler
- )
- if self.is_main_node:
- print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)")
-
- if self.config.multi_aspect_ratio is not None:
- aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio]
- dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32,
- ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio,
- interpolate_nearest=False) # , use_smartcrop=True)
- else:
-
- dataloader_iterator = iter(dataloader)
-
- return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler)
-
-
-
-
-
- def setup_ddp(self, experiment_id, single_gpu=False, rank=0):
-
- if not single_gpu:
- local_rank = rank
- process_id = rank
- world_size = get_world_size()
-
- self.process_id = process_id
- self.is_main_node = process_id == 0
- self.device = torch.device(local_rank)
- self.world_size = world_size
-
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '14443'
- torch.cuda.set_device(local_rank)
- init_process_group(
- backend="nccl",
- rank=local_rank,
- world_size=world_size,
- # init_method=init_method,
- )
- print(f"[GPU {process_id}] READY")
- else:
- self.is_main_node = rank == 0
- self.process_id = rank
- self.device = torch.device('cuda:0')
- self.world_size = 1
- print("Running in single thread, DDP not enabled.")
- # Training loop --------------------------------
- def get_target_lr_size(self, ratio, std_size=24):
- w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio))
- return (h * 32 , w * 32)
- def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
-
- batch = data
- ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
- shape_lr = self.get_target_lr_size(ratio)
- with torch.no_grad():
- conditions = self.get_conditions(batch, models, extras)
-
- latents = self.encode_latents(batch, models, extras)
- latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr)
-
-
-
- flag_lr = random.random() < 0.5 or self.info.iter <5000
-
- if flag_lr:
- noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1)
- else:
- noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
- if not flag_lr:
- noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = \
- extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, )
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
-
-
- if not flag_lr:
- with torch.no_grad():
- _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions)
-
-
- pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if not flag_lr else None , **conditions)
- loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
-
- loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps
-
-
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- extras.gdf.loss_weight.update_buckets(logSNR, loss)
- return loss, loss_adjusted
-
- def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
-
- if update:
-
- torch.distributed.barrier()
- loss_adjusted.backward()
-
- grad_norm = nn.utils.clip_grad_norm_(models.train_lora.module.parameters(), 1.0)
- optimizers_dict = optimizers.to_dict()
- for k in optimizers_dict:
- if k != 'training':
- optimizers_dict[k].step()
- schedulers_dict = schedulers.to_dict()
- for k in schedulers_dict:
- if k != 'training':
- schedulers_dict[k].step()
- for k in optimizers_dict:
- if k != 'training':
- optimizers_dict[k].zero_grad(set_to_none=True)
- self.info.total_steps += 1
- else:
-
- loss_adjusted.backward()
- grad_norm = torch.tensor(0.0).to(self.device)
-
- return grad_norm
-
- def models_to_save(self):
- return ['generator', 'generator_ema', 'trans_inr', 'trans_inr_ema']
-
- def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor:
-
- images = batch['images'].to(self.device)
- if target_size is not None:
- images = F.interpolate(images, target_size)
-
- return models.effnet(extras.effnet_preprocess(images))
-
- def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- return models.previewer(latents)
-
- def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ):
-
- self.is_main_node = (rank == 0)
- self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
- self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank)
- self.info: self.Info = self.setup_info()
- print('in line 292', self.config.experiment_id, rank, world_size <= 1)
- p = [i for i in range( 2 * 768 // 32)]
- p = [num / sum(p) for num in p]
- self.rand_pro = p
- self.res_list = [o for o in range(800, 2336, 32)]
-
-
-
- def __call__(self, single_gpu=False):
-
- if self.config.allow_tf32:
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
-
- if self.is_main_node:
- print()
- print("**STARTIG JOB WITH CONFIG:**")
- print(yaml.dump(self.config.to_dict(), default_flow_style=False))
- print("------------------------------------")
- print()
- print("**INFO:**")
- print(yaml.dump(vars(self.info), default_flow_style=False))
- print("------------------------------------")
- print()
- print('in line 308', self.is_main_node, self.is_main_node, self.process_id, self.device )
- # SETUP STUFF
- extras = self.setup_extras_pre()
- assert extras is not None, "setup_extras_pre() must return a DTO"
-
-
-
- data = self.setup_data(extras)
- assert data is not None, "setup_data() must return a DTO"
- if self.is_main_node:
- print("**DATA:**")
- print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- models = self.setup_models(extras)
- assert models is not None, "setup_models() must return a DTO"
- if self.is_main_node:
- print("**MODELS:**")
- print(yaml.dump({
- k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
- }, default_flow_style=False))
- print("------------------------------------")
- print()
-
-
-
- optimizers = self.setup_optimizers(extras, models)
- assert optimizers is not None, "setup_optimizers() must return a DTO"
- if self.is_main_node:
- print("**OPTIMIZERS:**")
- print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- schedulers = self.setup_schedulers(extras, models, optimizers)
- assert schedulers is not None, "setup_schedulers() must return a DTO"
- if self.is_main_node:
- print("**SCHEDULERS:**")
- print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
- assert post_extras is not None, "setup_extras_post() must return a DTO"
- extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
- if self.is_main_node:
- print("**EXTRAS:**")
- print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
- # -------
-
- # TRAIN
- if self.is_main_node:
- print("**TRAINING STARTING...**")
- self.train(data, extras, models, optimizers, schedulers)
-
- if single_gpu is False:
- barrier()
- destroy_process_group()
- if self.is_main_node:
- print()
- print("------------------------------------")
- print()
- print("**TRAINING COMPLETE**")
- if self.config.wandb_project is not None:
- wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
-
-
- def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers,
- schedulers: WarpCore.Schedulers):
- start_iter = self.info.iter + 1
- max_iters = self.config.updates * self.config.grad_accum_steps
- if self.is_main_node:
- print(f"STARTING AT STEP: {start_iter}/{max_iters}")
-
-
- if self.is_main_node:
- create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
- if 'generator' in self.models_to_save():
- models.generator.train()
-
- iter_cnt = 0
- epoch_cnt = 0
- models.train_norm.train()
- while True:
- epoch_cnt += 1
- if self.world_size > 1:
-
- data.sampler.set_epoch(epoch_cnt)
- for ggg in range(len(data.dataloader)):
- iter_cnt += 1
- # FORWARD PASS
-
- loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models)
-
-
- # # BACKWARD PASS
-
- grad_norm = self.backward_pass(
- iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted,
- models, optimizers, schedulers
- )
-
-
-
- self.info.iter = iter_cnt
-
-
- self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
-
-
- if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
- print(f"gggg NaN value encountered in training run {self.info.wandb_run_id}", \
- f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
-
- if self.is_main_node:
- logs = {
- 'loss': self.info.ema_loss,
- 'backward_loss': loss_adjusted.mean().item(),
-
- 'ema_loss': self.info.ema_loss,
- 'raw_ori_loss': loss.mean().item(),
-
- 'grad_norm': grad_norm.item(),
- 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0,
- 'total_steps': self.info.total_steps,
- }
-
-
- print(iter_cnt, max_iters, logs, epoch_cnt, )
-
-
-
-
-
-
- if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters:
-
- if np.isnan(loss.mean().item()):
- if self.is_main_node and self.config.wandb_project is not None:
- print(f"NaN value encountered in training run {self.info.wandb_run_id}", \
- f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
-
- else:
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- self.info.adaptive_loss = {
- 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
- 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
- }
-
-
- if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0:
- print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps )
- torch.save(models.train_lora.state_dict(), \
- f'{self.config.output_path}/{self.config.experiment_id}/train_lora.safetensors')
-
-
- torch.save(models.train_lora.state_dict(), \
- f'{self.config.output_path}/{self.config.experiment_id}/train_lora_{iter_cnt}.safetensors')
-
-
- if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters:
-
- if self.is_main_node:
-
- self.sample(models, data, extras)
- if False:
- param_changes = {name: (param - initial_params[name]).norm().item() for name, param in models.train_norm.named_parameters()}
- threshold = sorted(param_changes.values(), reverse=True)[int(len(param_changes) * 0.1)] # top 10%
- important_params = [name for name, change in param_changes.items() if change > threshold]
- print(important_params, threshold, len(param_changes), self.process_id)
- json.dump(important_params, open(f'{self.config.output_path}/{self.config.experiment_id}/param.json', 'w'), indent=4)
-
-
- if self.info.iter >= max_iters:
- break
-
- def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
-
-
- models.generator.eval()
- models.train_norm.eval()
- with torch.no_grad():
- batch = next(data.iterator)
- ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
-
- shape_lr = self.get_target_lr_size(ratio)
- conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
- unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
-
- latents = self.encode_latents(batch, models, extras)
- latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr)
-
- if self.is_main_node:
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
-
- *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
- models.generator, conditions,
- latents.shape, latents_lr.shape,
- unconditions, device=self.device, **extras.sampling_configs
- )
-
-
- sampled_ema = sampled
- sampled_ema_lr = sampled_lr
-
-
- if self.is_main_node:
- print('sampling results hr latent shape ', latents.shape, 'lr latent shape', latents_lr.shape, )
- noised_images = torch.cat(
- [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0)
-
- sampled_images = torch.cat(
- [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0)
- sampled_images_ema = torch.cat(
- [self.decode_latents(sampled_ema[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema))],
- dim=0)
-
- noised_images_lr = torch.cat(
- [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0)
-
- sampled_images_lr = torch.cat(
- [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0)
- sampled_images_ema_lr = torch.cat(
- [self.decode_latents(sampled_ema_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema_lr))],
- dim=0)
-
- images = batch['images']
- if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2):
- images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic')
- images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic')
-
- collage_img = torch.cat([
- torch.cat([i for i in images.cpu()], dim=-1),
- torch.cat([i for i in noised_images.cpu()], dim=-1),
- torch.cat([i for i in sampled_images.cpu()], dim=-1),
- torch.cat([i for i in sampled_images_ema.cpu()], dim=-1),
- ], dim=-2)
-
- collage_img_lr = torch.cat([
- torch.cat([i for i in images_lr.cpu()], dim=-1),
- torch.cat([i for i in noised_images_lr.cpu()], dim=-1),
- torch.cat([i for i in sampled_images_lr.cpu()], dim=-1),
- torch.cat([i for i in sampled_images_ema_lr.cpu()], dim=-1),
- ], dim=-2)
-
- torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg')
- torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg')
-
- captions = batch['captions']
- if self.config.wandb_project is not None:
- log_data = [
- [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [
- wandb.Image(images[i])] for i in range(len(images))]
- log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"])
- wandb.log({"Log": log_table})
-
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1])
- plt.ylabel('Raw Loss')
- plt.ylabel('LogSNR')
- wandb.log({"Loss/LogSRN": plt})
-
-
- models.generator.train()
- models.train_norm.train()
- print('finish sampling')
-
-
-
- def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False):
-
-
- models.generator.eval()
- models.trans_inr.eval()
- with torch.no_grad():
-
- if self.is_main_node:
- conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds)
- unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
-
- *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
- models.generator, conditions,
- hr_shape, lr_shape,
- unconditions, device=self.device, **extras.sampling_configs
- )
-
- if models.generator_ema is not None:
-
- *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample(
- models.generator_ema, conditions,
- latents.shape, latents_lr.shape,
- unconditions, device=self.device, **extras.sampling_configs
- )
-
- else:
- sampled_ema = sampled
- sampled_ema_lr = sampled_lr
-
-
- return sampled, sampled_lr
-def main_worker(rank, cfg):
- print("Launching Script in main worker")
- warpcore = WurstCore(
- config_file_path=cfg, rank=rank, world_size = get_world_size()
- )
- # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
-
- # RUN TRAINING
- warpcore(get_world_size()==1)
-
-if __name__ == '__main__':
-
- if get_master_ip() == "127.0.0.1":
-
- mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, ))
- else:
- main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )
diff --git a/train/train_t2i.py b/train/train_t2i.py
deleted file mode 100644
index 456ca4b0dd1fe8e1fc18e3e5c940797439071d1f..0000000000000000000000000000000000000000
--- a/train/train_t2i.py
+++ /dev/null
@@ -1,807 +0,0 @@
-import torch
-import json
-import yaml
-import torchvision
-from torch import nn, optim
-from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
-from warmup_scheduler import GradualWarmupScheduler
-import torch.multiprocessing as mp
-import numpy as np
-import os
-import sys
-sys.path.append(os.path.abspath('./'))
-from dataclasses import dataclass
-from torch.distributed import init_process_group, destroy_process_group, barrier
-from gdf import GDF_dual_fixlrt as GDF
-from gdf import EpsilonTarget, CosineSchedule
-from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
-from torchtools.transforms import SmartCrop
-from fractions import Fraction
-from modules.effnet import EfficientNetEncoder
-
-from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
-from modules.previewer import Previewer
-from core.data import Bucketeer
-from train.base import DataCore, TrainingCore
-from tqdm import tqdm
-from core import WarpCore
-from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
-
-from accelerate import init_empty_weights
-from accelerate.utils import set_module_tensor_to_device
-from contextlib import contextmanager
-from train.dist_core import *
-import glob
-from torch.utils.data import DataLoader, Dataset
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.utils.data.distributed import DistributedSampler
-from PIL import Image
-from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
-from core.utils import Base
-from modules.common_ckpt import LayerNorm2d, GlobalResponseNorm
-import torch.nn.functional as F
-import functools
-import math
-import copy
-import random
-from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
-Image.MAX_IMAGE_PIXELS = None
-torch.manual_seed(23)
-random.seed(23)
-np.random.seed(23)
-#7978026
-
-class Null_Model(torch.nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, x):
- pass
-
-
-
-
-def identity(x):
- if isinstance(x, bytes):
- x = x.decode('utf-8')
- return x
-def check_nan_inmodel(model, meta=''):
- for name, param in model.named_parameters():
- if torch.isnan(param).any():
- print(f"nan detected in {name}", meta)
- return True
- print('no nan', meta)
- return False
-class mydist_dataset(Dataset):
- def __init__(self, rootpath, img_processor=None):
-
- self.img_pathlist = glob.glob(os.path.join(rootpath, '*', '*.jpg'))
- self.img_processor = img_processor
- self.length = len( self.img_pathlist)
-
-
-
- def __getitem__(self, idx):
-
- imgpath = self.img_pathlist[idx]
- json_file = imgpath.replace('.jpg', '.json')
-
- with open(json_file, 'r') as file:
- info = json.load(file)
- txt = info['caption']
- if txt is None:
- txt = ' '
- try:
- img = Image.open(imgpath).convert('RGB')
- w, h = img.size
- if self.img_processor is not None:
- img = self.img_processor(img)
-
- except:
- print('exception', imgpath)
- return self.__getitem__(random.randint(0, self.length -1 ) )
- return dict(captions=txt, images=img)
- def __len__(self):
- return self.length
-
-class WurstCore(TrainingCore, DataCore, WarpCore):
- @dataclass(frozen=True)
- class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
- # TRAINING PARAMS
- lr: float = EXPECTED_TRAIN
- warmup_updates: int = EXPECTED_TRAIN
- dtype: str = None
-
- # MODEL VERSION
- model_version: str = EXPECTED # 3.6B or 1B
- clip_image_model_name: str = 'openai/clip-vit-large-patch14'
- clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
-
- # CHECKPOINT PATHS
- effnet_checkpoint_path: str = EXPECTED
- previewer_checkpoint_path: str = EXPECTED
-
- generator_checkpoint_path: str = None
-
- # gdf customization
- adaptive_loss_weight: str = None
- use_ddp: bool=EXPECTED
-
-
- @dataclass(frozen=True)
- class Data(Base):
- dataset: Dataset = EXPECTED
- dataloader: DataLoader = EXPECTED
- iterator: any = EXPECTED
- sampler: DistributedSampler = EXPECTED
-
- @dataclass(frozen=True)
- class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
- effnet: nn.Module = EXPECTED
- previewer: nn.Module = EXPECTED
- train_norm: nn.Module = EXPECTED
-
-
- @dataclass(frozen=True)
- class Schedulers(WarpCore.Schedulers):
- generator: any = None
-
- @dataclass(frozen=True)
- class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
- gdf: GDF = EXPECTED
- sampling_configs: dict = EXPECTED
- effnet_preprocess: torchvision.transforms.Compose = EXPECTED
-
- info: TrainingCore.Info
- config: Config
-
- def setup_extras_pre(self) -> Extras:
- gdf = GDF(
- schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
- input_scaler=VPScaler(), target=EpsilonTarget(),
- noise_cond=CosineTNoiseCond(),
- loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
- )
- sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
-
- if self.info.adaptive_loss is not None:
- gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
- gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
-
- effnet_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Normalize(
- mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
- )
- ])
-
- clip_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
- torchvision.transforms.CenterCrop(224),
- torchvision.transforms.Normalize(
- mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
- )
- ])
-
- if self.config.training:
- transforms = torchvision.transforms.Compose([
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
- SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
- ])
- else:
- transforms = None
-
- return self.Extras(
- gdf=gdf,
- sampling_configs=sampling_configs,
- transforms=transforms,
- effnet_preprocess=effnet_preprocess,
- clip_preprocess=clip_preprocess
- )
-
- def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
- eval_image_embeds=False, return_fields=None):
- conditions = super().get_conditions(
- batch, models, extras, is_eval, is_unconditional,
- eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
- )
- return conditions
-
- def setup_models(self, extras: Extras) -> Models: # configure model
-
- dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16
-
- # EfficientNet encoderin
- effnet = EfficientNetEncoder()
- effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
- effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
- effnet.eval().requires_grad_(False).to(self.device)
- del effnet_checkpoint
-
- # Previewer
- previewer = Previewer()
- previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
- previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
- previewer.eval().requires_grad_(False).to(self.device)
- del previewer_checkpoint
-
- @contextmanager
- def dummy_context():
- yield None
-
- loading_context = dummy_context if self.config.training else init_empty_weights
-
- # Diffusion models
- with loading_context():
- generator_ema = None
- if self.config.model_version == '3.6B':
- generator = StageC()
- if self.config.ema_start_iters is not None: # default setting
- generator_ema = StageC()
- elif self.config.model_version == '1B':
- print('in line 155 1b light model', self.config.model_version )
- generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
-
- if self.config.ema_start_iters is not None and self.config.training:
- generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
- else:
- raise ValueError(f"Unknown model version {self.config.model_version}")
-
-
-
- if loading_context is dummy_context:
- generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path))
- else:
- for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
- set_module_tensor_to_device(generator, param_name, "cpu", value=param)
-
- generator._init_extra_parameter()
- generator = generator.to(torch.bfloat16).to(self.device)
-
-
- train_norm = nn.ModuleList()
- cnt_norm = 0
- for mm in generator.modules():
- if isinstance(mm, GlobalResponseNorm):
-
- train_norm.append(Null_Model())
- cnt_norm += 1
-
- train_norm.append(generator.agg_net)
- train_norm.append(generator.agg_net_up)
- total = sum([ param.nelement() for param in train_norm.parameters()])
- print('Trainable parameter', total / 1048576)
-
- if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')):
- sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu')
- collect_sd = {}
- for k, v in sdd.items():
- collect_sd[k[7:]] = v
- train_norm.load_state_dict(collect_sd, strict=True)
-
-
- train_norm.to(self.device).train().requires_grad_(True)
- train_norm_ema = copy.deepcopy(train_norm)
- train_norm_ema.to(self.device).eval().requires_grad_(False)
- if generator_ema is not None:
-
- generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
- generator_ema._init_extra_parameter()
-
-
- pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors')
- if os.path.exists(pretrained_pth):
- print(pretrained_pth, 'exists')
- generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu'))
-
-
- generator_ema.eval().requires_grad_(False)
-
-
-
-
- check_nan_inmodel(generator, 'generator')
-
-
-
- if self.config.use_ddp and self.config.training:
-
- train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True)
-
- # CLIP encoders
- tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
- text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
- image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
-
- return self.Models(
- effnet=effnet, previewer=previewer, train_norm = train_norm,
- generator=generator, tokenizer=tokenizer, text_model=text_model, image_model=image_model,
- )
-
- def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
-
-
- params = []
- params += list(models.train_norm.module.parameters())
-
- optimizer = optim.AdamW(params, lr=self.config.lr)
-
- return self.Optimizers(generator=optimizer)
-
- def ema_update(self, ema_model, source_model, beta):
- for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()):
- param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta)
-
- def sync_ema(self, ema_model):
- for param in ema_model.parameters():
- torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM)
- param.data /= torch.distributed.get_world_size()
- def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
-
-
- optimizer = optim.AdamW(
- models.generator.up_blocks.parameters() ,
- lr=self.config.lr)
- optimizer = self.load_optimizer(optimizer, 'generator_optim',
- fsdp_model=models.generator if self.config.use_fsdp else None)
- return self.Optimizers(generator=optimizer)
-
- def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers:
- scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
- scheduler.last_epoch = self.info.total_steps
- return self.Schedulers(generator=scheduler)
-
- def setup_data(self, extras: Extras) -> WarpCore.Data:
- # SETUP DATASET
- dataset_path = self.config.webdataset_path
- dataset = mydist_dataset(dataset_path, \
- torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \
- else extras.transforms)
-
- # SETUP DATALOADER
- real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps)
-
- sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True)
- dataloader = DataLoader(
- dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True,
- collate_fn=identity if self.config.multi_aspect_ratio is not None else None,
- sampler = sampler
- )
- if self.is_main_node:
- print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)")
-
- if self.config.multi_aspect_ratio is not None:
- aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio]
- dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32,
- ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio,
- interpolate_nearest=False) # , use_smartcrop=True)
- else:
-
- dataloader_iterator = iter(dataloader)
-
- return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler)
-
-
- def models_to_save(self):
- pass
- def setup_ddp(self, experiment_id, single_gpu=False, rank=0):
-
- if not single_gpu:
- local_rank = rank
- process_id = rank
- world_size = get_world_size()
-
- self.process_id = process_id
- self.is_main_node = process_id == 0
- self.device = torch.device(local_rank)
- self.world_size = world_size
-
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '41443'
- torch.cuda.set_device(local_rank)
- init_process_group(
- backend="nccl",
- rank=local_rank,
- world_size=world_size,
- )
- print(f"[GPU {process_id}] READY")
- else:
- self.is_main_node = rank == 0
- self.process_id = rank
- self.device = torch.device('cuda:0')
- self.world_size = 1
- print("Running in single thread, DDP not enabled.")
- # Training loop --------------------------------
- def get_target_lr_size(self, ratio, std_size=24):
- w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio))
- return (h * 32 , w * 32)
- def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
- #batch = next(data.iterator)
- batch = data
- ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
- shape_lr = self.get_target_lr_size(ratio)
- #print('in line 485', shape_lr, ratio, batch['images'].shape)
- with torch.no_grad():
- conditions = self.get_conditions(batch, models, extras)
-
- latents = self.encode_latents(batch, models, extras)
- latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr)
-
- noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
- noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, )
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
- # 768 1536
- require_cond = True
-
- with torch.no_grad():
- _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions)
-
-
- pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions)
- loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
-
- loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps
-
-
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- extras.gdf.loss_weight.update_buckets(logSNR, loss)
-
- return loss, loss_adjusted
-
- def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
-
-
- if update:
-
- torch.distributed.barrier()
- loss_adjusted.backward()
-
- grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0)
-
- optimizers_dict = optimizers.to_dict()
- for k in optimizers_dict:
- if k != 'training':
- optimizers_dict[k].step()
- schedulers_dict = schedulers.to_dict()
- for k in schedulers_dict:
- if k != 'training':
- schedulers_dict[k].step()
- for k in optimizers_dict:
- if k != 'training':
- optimizers_dict[k].zero_grad(set_to_none=True)
- self.info.total_steps += 1
- else:
-
- loss_adjusted.backward()
-
- grad_norm = torch.tensor(0.0).to(self.device)
-
- return grad_norm
-
-
- def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor:
-
- images = batch['images'].to(self.device)
- if target_size is not None:
- images = F.interpolate(images, target_size)
-
- return models.effnet(extras.effnet_preprocess(images))
-
- def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- return models.previewer(latents)
-
- def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ):
-
- self.is_main_node = (rank == 0)
- self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
- self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank)
- self.info: self.Info = self.setup_info()
-
-
-
- def __call__(self, single_gpu=False):
-
- if self.config.allow_tf32:
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
-
- if self.is_main_node:
- print()
- print("**STARTIG JOB WITH CONFIG:**")
- print(yaml.dump(self.config.to_dict(), default_flow_style=False))
- print("------------------------------------")
- print()
- print("**INFO:**")
- print(yaml.dump(vars(self.info), default_flow_style=False))
- print("------------------------------------")
- print()
-
- # SETUP STUFF
- extras = self.setup_extras_pre()
- assert extras is not None, "setup_extras_pre() must return a DTO"
-
-
-
- data = self.setup_data(extras)
- assert data is not None, "setup_data() must return a DTO"
- if self.is_main_node:
- print("**DATA:**")
- print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- models = self.setup_models(extras)
- assert models is not None, "setup_models() must return a DTO"
- if self.is_main_node:
- print("**MODELS:**")
- print(yaml.dump({
- k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
- }, default_flow_style=False))
- print("------------------------------------")
- print()
-
-
-
- optimizers = self.setup_optimizers(extras, models)
- assert optimizers is not None, "setup_optimizers() must return a DTO"
- if self.is_main_node:
- print("**OPTIMIZERS:**")
- print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- schedulers = self.setup_schedulers(extras, models, optimizers)
- assert schedulers is not None, "setup_schedulers() must return a DTO"
- if self.is_main_node:
- print("**SCHEDULERS:**")
- print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
- assert post_extras is not None, "setup_extras_post() must return a DTO"
- extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
- if self.is_main_node:
- print("**EXTRAS:**")
- print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
- # -------
-
- # TRAIN
- if self.is_main_node:
- print("**TRAINING STARTING...**")
- self.train(data, extras, models, optimizers, schedulers)
-
- if single_gpu is False:
- barrier()
- destroy_process_group()
- if self.is_main_node:
- print()
- print("------------------------------------")
- print()
- print("**TRAINING COMPLETE**")
-
-
-
- def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers,
- schedulers: WarpCore.Schedulers):
- start_iter = self.info.iter + 1
- max_iters = self.config.updates * self.config.grad_accum_steps
- if self.is_main_node:
- print(f"STARTING AT STEP: {start_iter}/{max_iters}")
-
-
- if self.is_main_node:
- create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
-
- models.generator.train()
-
- iter_cnt = 0
- epoch_cnt = 0
- models.train_norm.train()
- while True:
- epoch_cnt += 1
- if self.world_size > 1:
-
- data.sampler.set_epoch(epoch_cnt)
- for ggg in range(len(data.dataloader)):
- iter_cnt += 1
- loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models)
- grad_norm = self.backward_pass(
- iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted,
- models, optimizers, schedulers
- )
-
- self.info.iter = iter_cnt
-
-
- # UPDATE LOSS METRICS
- self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
-
- #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss)
- if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
- print(f" NaN value encountered in training run {self.info.wandb_run_id}", \
- f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
-
- if self.is_main_node:
- logs = {
- 'loss': self.info.ema_loss,
- 'backward_loss': loss_adjusted.mean().item(),
- 'ema_loss': self.info.ema_loss,
- 'raw_ori_loss': loss.mean().item(),
- 'grad_norm': grad_norm.item(),
- 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0,
- 'total_steps': self.info.total_steps,
- }
- if iter_cnt % (self.config.save_every) == 0:
-
- print(iter_cnt, max_iters, logs, epoch_cnt, )
-
-
-
- if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters:
-
- # SAVE AND CHECKPOINT STUFF
- if np.isnan(loss.mean().item()):
- if self.is_main_node and self.config.wandb_project is not None:
- print(f"NaN value encountered in training run {self.info.wandb_run_id}", \
- f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
-
- else:
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- self.info.adaptive_loss = {
- 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
- 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
- }
-
-
-
- if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0:
- print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps )
- torch.save(models.train_norm.state_dict(), \
- f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors')
-
- torch.save(models.train_norm.state_dict(), \
- f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors')
-
-
- if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters:
-
- if self.is_main_node:
-
- self.sample(models, data, extras)
-
-
- if self.info.iter >= max_iters:
- break
-
- def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
-
-
- models.generator.eval()
- models.train_norm.eval()
- with torch.no_grad():
- batch = next(data.iterator)
- ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
-
- shape_lr = self.get_target_lr_size(ratio)
- conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
- unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
-
- latents = self.encode_latents(batch, models, extras)
- latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr)
-
-
- if self.is_main_node:
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
-
- *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
- models.generator, conditions,
- latents.shape, latents_lr.shape,
- unconditions, device=self.device, **extras.sampling_configs
- )
-
-
-
-
- if self.is_main_node:
- print('sampling results hr latent shape', latents.shape, 'lr latent shape', latents_lr.shape, )
- noised_images = torch.cat(
- [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0)
-
- sampled_images = torch.cat(
- [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0)
-
-
- noised_images_lr = torch.cat(
- [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0)
-
- sampled_images_lr = torch.cat(
- [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0)
-
- images = batch['images']
- if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2):
- images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic')
- images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic')
-
- collage_img = torch.cat([
- torch.cat([i for i in images.cpu()], dim=-1),
- torch.cat([i for i in noised_images.cpu()], dim=-1),
- torch.cat([i for i in sampled_images.cpu()], dim=-1),
- ], dim=-2)
-
- collage_img_lr = torch.cat([
- torch.cat([i for i in images_lr.cpu()], dim=-1),
- torch.cat([i for i in noised_images_lr.cpu()], dim=-1),
- torch.cat([i for i in sampled_images_lr.cpu()], dim=-1),
- ], dim=-2)
-
- torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg')
- torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg')
-
-
- models.generator.train()
- models.train_norm.train()
- print('finish sampling')
-
-
-
- def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False):
-
-
- models.generator.eval()
-
- with torch.no_grad():
-
- if self.is_main_node:
- conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds)
- unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
-
- *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
- models.generator, conditions,
- hr_shape, lr_shape,
- unconditions, device=self.device, **extras.sampling_configs
- )
-
- if models.generator_ema is not None:
-
- *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample(
- models.generator_ema, conditions,
- latents.shape, latents_lr.shape,
- unconditions, device=self.device, **extras.sampling_configs
- )
-
- else:
- sampled_ema = sampled
- sampled_ema_lr = sampled_lr
-
- return sampled, sampled_lr
-def main_worker(rank, cfg):
- print("Launching Script in main worker")
-
- warpcore = WurstCore(
- config_file_path=cfg, rank=rank, world_size = get_world_size()
- )
- # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
-
- # RUN TRAINING
- warpcore(get_world_size()==1)
-
-if __name__ == '__main__':
- print('launch multi process')
- # os.environ["OMP_NUM_THREADS"] = "1"
- # os.environ["MKL_NUM_THREADS"] = "1"
- #dist.init_process_group(backend="nccl")
- #torch.backends.cudnn.benchmark = True
-#train/train_c_my.py
- #mp.set_sharing_strategy('file_system')
-
- if get_master_ip() == "127.0.0.1":
- # manually launch distributed processes
- mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, ))
- else:
- main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )
diff --git a/train/train_ultrapixel_control.py b/train/train_ultrapixel_control.py
deleted file mode 100644
index cd67965973a85ed1d72c164dd0e8970f8b5ce277..0000000000000000000000000000000000000000
--- a/train/train_ultrapixel_control.py
+++ /dev/null
@@ -1,928 +0,0 @@
-import torch
-import json
-import yaml
-import torchvision
-from torch import nn, optim
-from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
-from warmup_scheduler import GradualWarmupScheduler
-import torch.multiprocessing as mp
-import numpy as np
-import sys
-
-import os
-from dataclasses import dataclass
-from torch.distributed import init_process_group, destroy_process_group, barrier
-from gdf import GDF_dual_fixlrt as GDF
-from gdf import EpsilonTarget, CosineSchedule
-from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
-from torchtools.transforms import SmartCrop
-from fractions import Fraction
-from modules.effnet import EfficientNetEncoder
-
-from modules.model_4stage_lite import StageC
-
-from modules.model_4stage_lite import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
-from modules.common_ckpt import GlobalResponseNorm
-from modules.previewer import Previewer
-from core.data import Bucketeer
-from train.base import DataCore, TrainingCore
-from tqdm import tqdm
-from core import WarpCore
-from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
-from torch.distributed.fsdp.wrap import ModuleWrapPolicy, size_based_auto_wrap_policy
-from accelerate import init_empty_weights
-from accelerate.utils import set_module_tensor_to_device
-from contextlib import contextmanager
-from train.dist_core import *
-import glob
-from torch.utils.data import DataLoader, Dataset
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.utils.data.distributed import DistributedSampler
-from PIL import Image
-from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
-from core.utils import Base
-from modules.common import LayerNorm2d
-import torch.nn.functional as F
-import functools
-import math
-import copy
-import random
-from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
-from modules import ControlNet, ControlNetDeliverer
-from modules import controlnet_filters
-
-Image.MAX_IMAGE_PIXELS = None
-torch.manual_seed(8432)
-random.seed(8432)
-np.random.seed(8432)
-#7978026
-
-class Null_Model(torch.nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, x):
- pass
-
-
-def identity(x):
- if isinstance(x, bytes):
- x = x.decode('utf-8')
- return x
-def check_nan_inmodel(model, meta=''):
- for name, param in model.named_parameters():
- if torch.isnan(param).any():
- print(f"nan detected in {name}", meta)
- return True
- print('no nan', meta)
- return False
-
-
-class WurstCore(TrainingCore, DataCore, WarpCore):
- @dataclass(frozen=True)
- class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
- # TRAINING PARAMS
- lr: float = EXPECTED_TRAIN
- warmup_updates: int = EXPECTED_TRAIN
- dtype: str = None
-
- # MODEL VERSION
- model_version: str = EXPECTED # 3.6B or 1B
- clip_image_model_name: str = 'openai/clip-vit-large-patch14'
- clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
-
- # CHECKPOINT PATHS
- effnet_checkpoint_path: str = EXPECTED
- previewer_checkpoint_path: str = EXPECTED
- #trans_inr_ckpt: str = EXPECTED
- generator_checkpoint_path: str = None
- controlnet_checkpoint_path: str = EXPECTED
-
- # controlnet settings
- controlnet_blocks: list = EXPECTED
- controlnet_filter: str = EXPECTED
- controlnet_filter_params: dict = None
- controlnet_bottleneck_mode: str = None
-
-
- # gdf customization
- adaptive_loss_weight: str = None
-
- #module_filters: list = EXPECTED
- #rank: int = EXPECTED
- @dataclass(frozen=True)
- class Data(Base):
- dataset: Dataset = EXPECTED
- dataloader: DataLoader = EXPECTED
- iterator: any = EXPECTED
- sampler: DistributedSampler = EXPECTED
-
- @dataclass(frozen=True)
- class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
- effnet: nn.Module = EXPECTED
- previewer: nn.Module = EXPECTED
- train_norm: nn.Module = EXPECTED
- train_norm_ema: nn.Module = EXPECTED
- controlnet: nn.Module = EXPECTED
-
- @dataclass(frozen=True)
- class Schedulers(WarpCore.Schedulers):
- generator: any = None
-
- @dataclass(frozen=True)
- class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
- gdf: GDF = EXPECTED
- sampling_configs: dict = EXPECTED
- effnet_preprocess: torchvision.transforms.Compose = EXPECTED
- controlnet_filter: controlnet_filters.BaseFilter = EXPECTED
-
- info: TrainingCore.Info
- config: Config
-
- def setup_extras_pre(self) -> Extras:
- gdf = GDF(
- schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
- input_scaler=VPScaler(), target=EpsilonTarget(),
- noise_cond=CosineTNoiseCond(),
- loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
- )
- sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
-
- if self.info.adaptive_loss is not None:
- gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
- gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
-
- effnet_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Normalize(
- mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
- )
- ])
-
- clip_preprocess = torchvision.transforms.Compose([
- torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
- torchvision.transforms.CenterCrop(224),
- torchvision.transforms.Normalize(
- mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
- )
- ])
-
- if self.config.training:
- transforms = torchvision.transforms.Compose([
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
- SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
- ])
- else:
- transforms = None
- controlnet_filter = getattr(controlnet_filters, self.config.controlnet_filter)(
- self.device,
- **(self.config.controlnet_filter_params if self.config.controlnet_filter_params is not None else {})
- )
-
- return self.Extras(
- gdf=gdf,
- sampling_configs=sampling_configs,
- transforms=transforms,
- effnet_preprocess=effnet_preprocess,
- clip_preprocess=clip_preprocess,
- controlnet_filter=controlnet_filter
- )
- def get_cnet(self, batch: dict, models: Models, extras: Extras, cnet_input=None, target_size=None, **kwargs):
- images = batch['images']
- if target_size is not None:
- images = Image.resize(images, target_size)
- with torch.no_grad():
- if cnet_input is None:
- cnet_input = extras.controlnet_filter(images, **kwargs)
- if isinstance(cnet_input, tuple):
- cnet_input, cnet_input_preview = cnet_input
- else:
- cnet_input_preview = cnet_input
- cnet_input, cnet_input_preview = cnet_input.to(self.device), cnet_input_preview.to(self.device)
- cnet = models.controlnet(cnet_input)
- return cnet, cnet_input_preview
-
- def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
- eval_image_embeds=False, return_fields=None):
- conditions = super().get_conditions(
- batch, models, extras, is_eval, is_unconditional,
- eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
- )
- return conditions
-
- def setup_models(self, extras: Extras) -> Models: # configure model
-
-
- dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16
-
- # EfficientNet encoderin
- effnet = EfficientNetEncoder()
- effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
- effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
- effnet.eval().requires_grad_(False).to(self.device)
- del effnet_checkpoint
-
- # Previewer
- previewer = Previewer()
- previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
- previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
- previewer.eval().requires_grad_(False).to(self.device)
- del previewer_checkpoint
-
- @contextmanager
- def dummy_context():
- yield None
-
- loading_context = dummy_context if self.config.training else init_empty_weights
-
- # Diffusion models
- with loading_context():
- generator_ema = None
- if self.config.model_version == '3.6B':
- generator = StageC()
- if self.config.ema_start_iters is not None: # default setting
- generator_ema = StageC()
- elif self.config.model_version == '1B':
-
- generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
-
- if self.config.ema_start_iters is not None and self.config.training:
- generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
- else:
- raise ValueError(f"Unknown model version {self.config.model_version}")
-
-
-
- if loading_context is dummy_context:
- generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path))
- else:
- for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
- set_module_tensor_to_device(generator, param_name, "cpu", value=param)
-
- generator._init_extra_parameter()
-
-
-
-
- generator = generator.to(torch.bfloat16).to(self.device)
-
- train_norm = nn.ModuleList()
-
-
- cnt_norm = 0
- for mm in generator.modules():
- if isinstance(mm, GlobalResponseNorm):
-
- train_norm.append(Null_Model())
- cnt_norm += 1
-
-
-
-
- train_norm.append(generator.agg_net)
- train_norm.append(generator.agg_net_up)
-
-
-
-
- if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')):
- sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu')
- collect_sd = {}
- for k, v in sdd.items():
- collect_sd[k[7:]] = v
- train_norm.load_state_dict(collect_sd, strict=True)
-
-
- train_norm.to(self.device).train().requires_grad_(True)
- train_norm_ema = copy.deepcopy(train_norm)
- train_norm_ema.to(self.device).eval().requires_grad_(False)
- if generator_ema is not None:
-
- generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
- generator_ema._init_extra_parameter()
-
- pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors')
- if os.path.exists(pretrained_pth):
- print(pretrained_pth, 'exists')
- generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu'))
-
- generator_ema.eval().requires_grad_(False)
-
- check_nan_inmodel(generator, 'generator')
-
-
-
- if self.config.use_fsdp and self.config.training:
- train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True)
-
-
- # CLIP encoders
- tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
- text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
- image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
-
- controlnet = ControlNet(
- c_in=extras.controlnet_filter.num_channels(),
- proj_blocks=self.config.controlnet_blocks,
- bottleneck_mode=self.config.controlnet_bottleneck_mode
- )
- controlnet = controlnet.to(dtype).to(self.device)
- controlnet = self.load_model(controlnet, 'controlnet')
- controlnet.backbone.eval().requires_grad_(True)
-
-
- return self.Models(
- effnet=effnet, previewer=previewer, train_norm = train_norm,
- generator=generator, generator_ema=generator_ema,
- tokenizer=tokenizer, text_model=text_model, image_model=image_model,
- train_norm_ema=train_norm_ema, controlnet =controlnet
- )
-
- def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
-
-#
-
- params = []
- params += list(models.train_norm.module.parameters())
-
- optimizer = optim.AdamW(params, lr=self.config.lr)
-
- return self.Optimizers(generator=optimizer)
-
- def ema_update(self, ema_model, source_model, beta):
- for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()):
- param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta)
-
- def sync_ema(self, ema_model):
- print('sync ema', torch.distributed.get_world_size())
- for param in ema_model.parameters():
- torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM)
- param.data /= torch.distributed.get_world_size()
- def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
-
-
- optimizer = optim.AdamW(
- models.generator.up_blocks.parameters() ,
- lr=self.config.lr)
- optimizer = self.load_optimizer(optimizer, 'generator_optim',
- fsdp_model=models.generator if self.config.use_fsdp else None)
- return self.Optimizers(generator=optimizer)
-
- def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers:
- scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
- scheduler.last_epoch = self.info.total_steps
- return self.Schedulers(generator=scheduler)
-
- def setup_data(self, extras: Extras) -> WarpCore.Data:
- # SETUP DATASET
- dataset_path = self.config.webdataset_path
- print('in line 96', dataset_path, type(dataset_path))
-
- dataset = mydist_dataset(dataset_path, \
- torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \
- else extras.transforms)
-
- # SETUP DATALOADER
- real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps)
- print('in line 119', self.process_id, real_batch_size)
- sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True)
- dataloader = DataLoader(
- dataset, batch_size=real_batch_size, num_workers=4, pin_memory=True,
- collate_fn=identity if self.config.multi_aspect_ratio is not None else None,
- sampler = sampler
- )
- if self.is_main_node:
- print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)")
-
- if self.config.multi_aspect_ratio is not None:
- aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio]
- dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32,
- ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio,
- interpolate_nearest=False) # , use_smartcrop=True)
- else:
-
- dataloader_iterator = iter(dataloader)
-
- return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler)
-
-
-
-
-
- def setup_ddp(self, experiment_id, single_gpu=False, rank=0):
-
- if not single_gpu:
- local_rank = rank
- process_id = rank
- world_size = get_world_size()
-
- self.process_id = process_id
- self.is_main_node = process_id == 0
- self.device = torch.device(local_rank)
- self.world_size = world_size
-
-
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '41443'
- torch.cuda.set_device(local_rank)
- init_process_group(
- backend="nccl",
- rank=local_rank,
- world_size=world_size,
- # init_method=init_method,
- )
- print(f"[GPU {process_id}] READY")
- else:
- self.is_main_node = rank == 0
- self.process_id = rank
- self.device = torch.device('cuda:0')
- self.world_size = 1
- print("Running in single thread, DDP not enabled.")
- # Training loop --------------------------------
- def get_target_lr_size(self, ratio, std_size=24):
- w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio))
- return (h * 32 , w * 32)
- def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
- #batch = next(data.iterator)
- batch = data
- ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
- shape_lr = self.get_target_lr_size(ratio)
-
- with torch.no_grad():
- conditions = self.get_conditions(batch, models, extras)
-
- latents = self.encode_latents(batch, models, extras)
- latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr)
-
- noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
- noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, )
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
-
- require_cond = True
-
- with torch.no_grad():
- _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions)
-
-
- pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions)
- loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
-
- loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps
- #
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- extras.gdf.loss_weight.update_buckets(logSNR, loss)
-
- return loss, loss_adjusted
-
- def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
-
- if update:
-
- torch.distributed.barrier()
- loss_adjusted.backward()
-
-
- grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0)
-
- optimizers_dict = optimizers.to_dict()
- for k in optimizers_dict:
- if k != 'training':
- optimizers_dict[k].step()
- schedulers_dict = schedulers.to_dict()
- for k in schedulers_dict:
- if k != 'training':
- schedulers_dict[k].step()
- for k in optimizers_dict:
- if k != 'training':
- optimizers_dict[k].zero_grad(set_to_none=True)
- self.info.total_steps += 1
- else:
- #print('in line 457', loss_adjusted)
- loss_adjusted.backward()
- #torch.distributed.barrier()
- grad_norm = torch.tensor(0.0).to(self.device)
-
- return grad_norm
-
- def models_to_save(self):
- return ['generator', 'generator_ema', 'trans_inr', 'trans_inr_ema']
-
- def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor:
-
- images = batch['images'].to(self.device)
- if target_size is not None:
- images = F.interpolate(images, target_size)
- #images = apply_degradations(images)
- return models.effnet(extras.effnet_preprocess(images))
-
- def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
- return models.previewer(latents)
-
- def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ):
- # Temporary setup, will be overriden by setup_ddp if required
- # self.device = device
- # self.process_id = 0
- # self.is_main_node = True
- # self.world_size = 1
- # ----
- # self.world_size = world_size
- # self.process_id = rank
- # self.device=device
- self.is_main_node = (rank == 0)
- self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
- self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank)
- self.info: self.Info = self.setup_info()
- print('in line 292', self.config.experiment_id, rank, world_size <= 1)
- p = [i for i in range( 2 * 768 // 32)]
- p = [num / sum(p) for num in p]
- self.rand_pro = p
- self.res_list = [o for o in range(800, 2336, 32)]
-
- #[32, 40, 48]
- #in line 292 stage_c_3b_finetuning False
-
- def __call__(self, single_gpu=False):
- # this will change the device to the CUDA rank
- #self.setup_wandb()
- if self.config.allow_tf32:
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
-
- if self.is_main_node:
- print()
- print("**STARTIG JOB WITH CONFIG:**")
- print(yaml.dump(self.config.to_dict(), default_flow_style=False))
- print("------------------------------------")
- print()
- print("**INFO:**")
- print(yaml.dump(vars(self.info), default_flow_style=False))
- print("------------------------------------")
- print()
- print('in line 308', self.is_main_node, self.is_main_node, self.process_id, self.device )
- # SETUP STUFF
- extras = self.setup_extras_pre()
- assert extras is not None, "setup_extras_pre() must return a DTO"
-
-
-
- data = self.setup_data(extras)
- assert data is not None, "setup_data() must return a DTO"
- if self.is_main_node:
- print("**DATA:**")
- print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- models = self.setup_models(extras)
- assert models is not None, "setup_models() must return a DTO"
- if self.is_main_node:
- print("**MODELS:**")
- print(yaml.dump({
- k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
- }, default_flow_style=False))
- print("------------------------------------")
- print()
-
-
-
- optimizers = self.setup_optimizers(extras, models)
- assert optimizers is not None, "setup_optimizers() must return a DTO"
- if self.is_main_node:
- print("**OPTIMIZERS:**")
- print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- schedulers = self.setup_schedulers(extras, models, optimizers)
- assert schedulers is not None, "setup_schedulers() must return a DTO"
- if self.is_main_node:
- print("**SCHEDULERS:**")
- print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
-
- post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
- assert post_extras is not None, "setup_extras_post() must return a DTO"
- extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
- if self.is_main_node:
- print("**EXTRAS:**")
- print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
- print("------------------------------------")
- print()
- # -------
-
- # TRAIN
- if self.is_main_node:
- print("**TRAINING STARTING...**")
- self.train(data, extras, models, optimizers, schedulers)
-
- if single_gpu is False:
- barrier()
- destroy_process_group()
- if self.is_main_node:
- print()
- print("------------------------------------")
- print()
- print("**TRAINING COMPLETE**")
- if self.config.wandb_project is not None:
- wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
-
-
- def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers,
- schedulers: WarpCore.Schedulers):
- start_iter = self.info.iter + 1
- max_iters = self.config.updates * self.config.grad_accum_steps
- if self.is_main_node:
- print(f"STARTING AT STEP: {start_iter}/{max_iters}")
-
-
- if self.is_main_node:
- create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
- if 'generator' in self.models_to_save():
- models.generator.train()
- #initial_params = {name: param.clone() for name, param in models.train_norm.named_parameters()}
- iter_cnt = 0
- epoch_cnt = 0
- models.train_norm.train()
- while True:
- epoch_cnt += 1
- if self.world_size > 1:
- print('sampler set epoch', epoch_cnt)
- data.sampler.set_epoch(epoch_cnt)
- for ggg in range(len(data.dataloader)):
- iter_cnt += 1
- # FORWARD PASS
- #print('in line 414 before forward', iter_cnt, batch['captions'][0], self.process_id)
- #loss, loss_adjusted, loss_extra = self.forward_pass(batch, extras, models)
- loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models)
-
- #print('in line 416', loss, iter_cnt)
- # # BACKWARD PASS
-
- grad_norm = self.backward_pass(
- iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted,
- models, optimizers, schedulers
- )
-
-
-
- self.info.iter = iter_cnt
-
- # UPDATE EMA
- if iter_cnt % self.config.ema_iters == 0:
-
- with torch.no_grad():
- print('in line 890 ema update', self.config.ema_iters, iter_cnt)
- self.ema_update(models.train_norm_ema, models.train_norm, self.config.ema_beta)
- #generator.module.agg_net.
- #self.ema_update(models.generator_ema.agg_net, models.generator.module.agg_net, self.config.ema_beta)
- #self.ema_update(models.generator_ema.agg_net_up, models.generator.module.agg_net_up, self.config.ema_beta)
-
- # UPDATE LOSS METRICS
- self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
-
- #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss)
- if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
- print(f"gggg NaN value encountered in training run {self.info.wandb_run_id}", \
- f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
-
- if self.is_main_node:
- logs = {
- 'loss': self.info.ema_loss,
- 'backward_loss': loss_adjusted.mean().item(),
- #'raw_extra_loss': loss_extra.mean().item(),
- 'ema_loss': self.info.ema_loss,
- 'raw_ori_loss': loss.mean().item(),
- #'raw_rec_loss': loss_rec.mean().item(),
- #'raw_lr_loss': loss_lr.mean().item(),
- #'reg_loss':loss_reg.item(),
- 'grad_norm': grad_norm.item(),
- 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0,
- 'total_steps': self.info.total_steps,
- }
- if iter_cnt % (self.config.save_every) == 0:
-
- print(iter_cnt, max_iters, logs, epoch_cnt, )
- #pbar.set_postfix(logs)
-
-
- #if iter_cnt % 10 == 0:
-
-
- if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters:
- #if True:
- # SAVE AND CHECKPOINT STUFF
- if np.isnan(loss.mean().item()):
- if self.is_main_node and self.config.wandb_project is not None:
- print(f"NaN value encountered in training run {self.info.wandb_run_id}", \
- f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
-
- else:
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- self.info.adaptive_loss = {
- 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
- 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
- }
- #self.save_checkpoints(models, optimizers)
-
- #torch.save(models.trans_inr.module.state_dict(), \
- #f'{self.config.output_path}/{self.config.experiment_id}/trans_inr.safetensors')
- #torch.save(models.trans_inr_ema.state_dict(), \
- #f'{self.config.output_path}/{self.config.experiment_id}/trans_inr_ema.safetensors')
-
-
- if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0:
- print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps )
- torch.save(models.train_norm.state_dict(), \
- f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors')
-
- #self.sync_ema(models.train_norm_ema)
- torch.save(models.train_norm_ema.state_dict(), \
- f'{self.config.output_path}/{self.config.experiment_id}/train_norm_ema.safetensors')
- #if self.is_main_node and iter_cnt % (4 * self.config.save_every * self.config.grad_accum_steps) == 0:
- torch.save(models.train_norm.state_dict(), \
- f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors')
-
-
- if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters:
-
- if self.is_main_node:
- #check_nan_inmodel(models.generator, 'generator')
- #check_nan_inmodel(models.generator_ema, 'generator_ema')
- self.sample(models, data, extras)
- if False:
- param_changes = {name: (param - initial_params[name]).norm().item() for name, param in models.train_norm.named_parameters()}
- threshold = sorted(param_changes.values(), reverse=True)[int(len(param_changes) * 0.1)] # top 10%
- important_params = [name for name, change in param_changes.items() if change > threshold]
- print(important_params, threshold, len(param_changes), self.process_id)
- json.dump(important_params, open(f'{self.config.output_path}/{self.config.experiment_id}/param.json', 'w'), indent=4)
-
-
- if self.info.iter >= max_iters:
- break
-
- def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
-
- #if 'generator' in self.models_to_save():
- models.generator.eval()
- models.train_norm.eval()
- with torch.no_grad():
- batch = next(data.iterator)
- ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
- #batch['images'] = batch['images'].to(torch.float16)
- shape_lr = self.get_target_lr_size(ratio)
- conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
- unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
- cnet, cnet_input = self.get_cnet(batch, models, extras)
- conditions, unconditions = {**conditions, 'cnet': cnet}, {**unconditions, 'cnet': cnet}
-
- latents = self.encode_latents(batch, models, extras)
- latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr)
-
- if self.is_main_node:
-
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
- #print('in line 366 on v100 switch to tf16')
- *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
- models.generator, models.trans_inr, conditions,
- latents.shape, latents_lr.shape,
- unconditions, device=self.device, **extras.sampling_configs
- )
-
-
-
- #else:
- sampled_ema = sampled
- sampled_ema_lr = sampled_lr
-
-
- if self.is_main_node:
- print('sampling results', latents.shape, latents_lr.shape, )
- noised_images = torch.cat(
- [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0)
-
- sampled_images = torch.cat(
- [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0)
- sampled_images_ema = torch.cat(
- [self.decode_latents(sampled_ema[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema))],
- dim=0)
-
- noised_images_lr = torch.cat(
- [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0)
-
- sampled_images_lr = torch.cat(
- [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0)
- sampled_images_ema_lr = torch.cat(
- [self.decode_latents(sampled_ema_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema_lr))],
- dim=0)
-
- images = batch['images']
- if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2):
- images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic')
- images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic')
-
- collage_img = torch.cat([
- torch.cat([i for i in images.cpu()], dim=-1),
- torch.cat([i for i in noised_images.cpu()], dim=-1),
- torch.cat([i for i in sampled_images.cpu()], dim=-1),
- torch.cat([i for i in sampled_images_ema.cpu()], dim=-1),
- ], dim=-2)
-
- collage_img_lr = torch.cat([
- torch.cat([i for i in images_lr.cpu()], dim=-1),
- torch.cat([i for i in noised_images_lr.cpu()], dim=-1),
- torch.cat([i for i in sampled_images_lr.cpu()], dim=-1),
- torch.cat([i for i in sampled_images_ema_lr.cpu()], dim=-1),
- ], dim=-2)
-
- torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg')
- torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg')
- #torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg')
-
- captions = batch['captions']
- if self.config.wandb_project is not None:
- log_data = [
- [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [
- wandb.Image(images[i])] for i in range(len(images))]
- log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"])
- wandb.log({"Log": log_table})
-
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
- plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1])
- plt.ylabel('Raw Loss')
- plt.ylabel('LogSNR')
- wandb.log({"Loss/LogSRN": plt})
-
- #if 'generator' in self.models_to_save():
- models.generator.train()
- models.train_norm.train()
- print('finishe sampling in line 901')
-
-
-
- def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False):
-
- #if 'generator' in self.models_to_save():
- models.generator.eval()
- models.trans_inr.eval()
- models.controlnet.eval()
- with torch.no_grad():
-
- if self.is_main_node:
- conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds)
- unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
- cnet, cnet_input = self.get_cnet(batch, models, extras, target_size = lr_shape)
- conditions, unconditions = {**conditions, 'cnet': cnet}, {**unconditions, 'cnet': cnet}
-
- #print('in line 885', self.is_main_node)
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
- #print('in line 366 on v100 switch to tf16')
- *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
- models.generator, models.trans_inr, conditions,
- hr_shape, lr_shape,
- unconditions, device=self.device, **extras.sampling_configs
- )
-
- if models.generator_ema is not None:
-
- *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample(
- models.generator_ema, models.trans_inr_ema, conditions,
- latents.shape, latents_lr.shape,
- unconditions, device=self.device, **extras.sampling_configs
- )
-
- else:
- sampled_ema = sampled
- sampled_ema_lr = sampled_lr
- #x0, x, epsilon, x0_lr, x_lr, pred_lr)
- #sampled, _ = models.trans_inr(sampled, None, sampled)
- #sampled_lr, _ = models.trans_inr(sampled, None, sampled_lr)
-
- return sampled, sampled_lr
-def main_worker(rank, cfg):
- print("Launching Script in main worker")
- print('in line 467', rank)
- warpcore = WurstCore(
- config_file_path=cfg, rank=rank, world_size = get_world_size()
- )
- # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
-
- # RUN TRAINING
- warpcore(get_world_size()==1)
-
-if __name__ == '__main__':
- print('launch multi process')
- # os.environ["OMP_NUM_THREADS"] = "1"
- # os.environ["MKL_NUM_THREADS"] = "1"
- #dist.init_process_group(backend="nccl")
- #torch.backends.cudnn.benchmark = True
-#train/train_c_my.py
- #mp.set_sharing_strategy('file_system')
- print('in line 481', sys.argv[1] if len(sys.argv) > 1 else None)
- print('in line 481',get_master_ip(), get_world_size() )
- print('in line 484', get_world_size())
- if get_master_ip() == "127.0.0.1":
- # manually launch distributed processes
- mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, ))
- else:
- main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )