File size: 4,550 Bytes
1504958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
from pathlib import Path
from typing import Dict
import safetensors.torch
import torch
import json
import shutil


def load_text_encoder(index_path: Path) -> Dict:
    with open(index_path, "r") as f:
        index: Dict = json.load(f)

    loaded_tensors = {}
    for part_file in set(index.get("weight_map", {}).values()):
        tensors = safetensors.torch.load_file(
            index_path.parent / part_file, device="cpu"
        )
        for tensor_name in tensors:
            loaded_tensors[tensor_name] = tensors[tensor_name]

    return loaded_tensors


def convert_unet(unet: Dict, add_prefix=True) -> Dict:
    if add_prefix:
        return {"model.diffusion_model." + key: value for key, value in unet.items()}
    return unet


def convert_vae(vae_path: Path, add_prefix=True) -> Dict:
    state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
    stats_path = vae_path / "per_channel_statistics.json"
    if stats_path.exists():
        with open(stats_path, "r") as f:
            data = json.load(f)
        transposed_data = list(zip(*data["data"]))
        data_dict = {
            f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(
                vals
            )
            for col, vals in zip(data["columns"], transposed_data)
        }
    else:
        data_dict = {}

    result = {
        ("vae." if add_prefix else "") + key: value for key, value in state_dict.items()
    }
    result.update(data_dict)
    return result


def convert_encoder(encoder: Dict) -> Dict:
    return {
        "text_encoders.t5xxl.transformer." + key: value
        for key, value in encoder.items()
    }


def save_config(config_src: str, config_dst: str):
    shutil.copy(config_src, config_dst)


def load_vae_config(vae_path: Path) -> str:
    config_path = vae_path / "config.json"
    if not config_path.exists():
        raise FileNotFoundError(f"VAE config file {config_path} not found.")
    return str(config_path)


def main(
    unet_path: str,
    vae_path: str,
    out_path: str,
    mode: str,
    unet_config_path: str = None,
    scheduler_config_path: str = None,
) -> None:
    unet = convert_unet(
        torch.load(unet_path, weights_only=True), add_prefix=(mode == "single")
    )

    # Load VAE from directory and config
    vae = convert_vae(Path(vae_path), add_prefix=(mode == "single"))
    vae_config_path = load_vae_config(Path(vae_path))

    if mode == "single":
        result = {**unet, **vae}
        safetensors.torch.save_file(result, out_path)
    elif mode == "separate":
        # Create directories for unet, vae, and scheduler
        unet_dir = Path(out_path) / "unet"
        vae_dir = Path(out_path) / "vae"
        scheduler_dir = Path(out_path) / "scheduler"

        unet_dir.mkdir(parents=True, exist_ok=True)
        vae_dir.mkdir(parents=True, exist_ok=True)
        scheduler_dir.mkdir(parents=True, exist_ok=True)

        # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
        safetensors.torch.save_file(
            unet, unet_dir / "unet_diffusion_pytorch_model.safetensors"
        )
        safetensors.torch.save_file(
            vae, vae_dir / "vae_diffusion_pytorch_model.safetensors"
        )

        # Save config files for unet, vae, and scheduler
        if unet_config_path:
            save_config(unet_config_path, unet_dir / "config.json")
        if vae_config_path:
            save_config(vae_config_path, vae_dir / "config.json")
        if scheduler_config_path:
            save_config(scheduler_config_path, scheduler_dir / "scheduler_config.json")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--unet_path", "-u", type=str, default="unet/ema-002.pt")
    parser.add_argument("--vae_path", "-v", type=str, default="vae/")
    parser.add_argument("--out_path", "-o", type=str, default="xora.safetensors")
    parser.add_argument(
        "--mode",
        "-m",
        type=str,
        choices=["single", "separate"],
        default="single",
        help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.",
    )
    parser.add_argument(
        "--unet_config_path",
        type=str,
        help="Path to the UNet config file (for separate mode)",
    )
    parser.add_argument(
        "--scheduler_config_path",
        type=str,
        help="Path to the Scheduler config file (for separate mode)",
    )

    args = parser.parse_args()
    main(**args.__dict__)