Sapir commited on
Commit
bebbcd0
·
1 Parent(s): fc02e02

Ckpt conversion: script + usage examples updated.

Browse files
scripts/to_safetensors.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ from typing import Any, Dict
4
+ import safetensors.torch
5
+ import torch
6
+ import json
7
+ import shutil
8
+
9
+
10
+ def load_text_encoder(index_path: Path) -> Dict:
11
+ with open(index_path, 'r') as f:
12
+ index: Dict = json.load(f)
13
+
14
+ loaded_tensors = {}
15
+ for part_file in set(index.get("weight_map", {}).values()):
16
+ tensors = safetensors.torch.load_file(index_path.parent / part_file, device='cpu')
17
+ for tensor_name in tensors:
18
+ loaded_tensors[tensor_name] = tensors[tensor_name]
19
+
20
+ return loaded_tensors
21
+
22
+
23
+ def convert_unet(unet: Dict, add_prefix=True) -> Dict:
24
+ if add_prefix:
25
+ return {"model.diffusion_model." + key: value for key, value in unet.items()}
26
+ return unet
27
+
28
+
29
+ def convert_vae(vae_path: Path, add_prefix=True) -> Dict:
30
+ state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
31
+ stats_path = vae_path / "per_channel_statistics.json"
32
+ if stats_path.exists():
33
+ with open(stats_path, 'r') as f:
34
+ data = json.load(f)
35
+ transposed_data = list(zip(*data["data"]))
36
+ data_dict = {
37
+ f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(vals)
38
+ for col, vals in zip(data["columns"], transposed_data)
39
+ }
40
+ else:
41
+ data_dict = {}
42
+
43
+ result = {("vae." if add_prefix else "") + key: value for key, value in state_dict.items()}
44
+ result.update(data_dict)
45
+ return result
46
+
47
+
48
+ def convert_encoder(encoder: Dict) -> Dict:
49
+ return {"text_encoders.t5xxl.transformer." + key: value for key, value in encoder.items()}
50
+
51
+
52
+ def save_config(config_src: str, config_dst: str):
53
+ shutil.copy(config_src, config_dst)
54
+
55
+
56
+ def load_vae_config(vae_path: Path) -> str:
57
+ config_path = vae_path / "config.json"
58
+ if not config_path.exists():
59
+ raise FileNotFoundError(f"VAE config file {config_path} not found.")
60
+ return str(config_path)
61
+
62
+
63
+ def main(unet_path: str, vae_path: str, t5_path: str, out_path: str, mode: str,
64
+ unet_config_path: str = None, scheduler_config_path: str = None) -> None:
65
+ unet = convert_unet(torch.load(unet_path, weights_only=True), add_prefix=(mode == 'single'))
66
+
67
+ # Load VAE from directory and config
68
+ vae = convert_vae(Path(vae_path), add_prefix=(mode == 'single'))
69
+ vae_config_path = load_vae_config(Path(vae_path))
70
+
71
+ if mode == 'single':
72
+ result = {**unet, **vae}
73
+ safetensors.torch.save_file(result, out_path)
74
+ elif mode == 'separate':
75
+ # Create directories for unet, vae, and scheduler
76
+ unet_dir = Path(out_path) / 'unet'
77
+ vae_dir = Path(out_path) / 'vae'
78
+ scheduler_dir = Path(out_path) / 'scheduler'
79
+
80
+ unet_dir.mkdir(parents=True, exist_ok=True)
81
+ vae_dir.mkdir(parents=True, exist_ok=True)
82
+ scheduler_dir.mkdir(parents=True, exist_ok=True)
83
+
84
+ # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
85
+ safetensors.torch.save_file(unet, unet_dir / 'diffusion_pytorch_model.safetensors')
86
+ safetensors.torch.save_file(vae, vae_dir / 'diffusion_pytorch_model.safetensors')
87
+
88
+ # Save config files for unet, vae, and scheduler
89
+ if unet_config_path:
90
+ save_config(unet_config_path, unet_dir / 'config.json')
91
+ if vae_config_path:
92
+ save_config(vae_config_path, vae_dir / 'config.json')
93
+ if scheduler_config_path:
94
+ save_config(scheduler_config_path, scheduler_dir / 'scheduler_config.json')
95
+
96
+
97
+ if __name__ == '__main__':
98
+ parser = argparse.ArgumentParser()
99
+ parser.add_argument('--unet_path', '-u', type=str, default='unet/ema-002.pt')
100
+ parser.add_argument('--vae_path', '-v', type=str, default='vae/')
101
+ parser.add_argument('--t5_path', '-t', type=str, default='t5/PixArt-XL-2-1024-MS/')
102
+ parser.add_argument('--out_path', '-o', type=str, default='xora.safetensors')
103
+ parser.add_argument('--mode', '-m', type=str, choices=['single', 'separate'], default='single',
104
+ help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.")
105
+ parser.add_argument('--unet_config_path', type=str, help="Path to the UNet config file (for separate mode)")
106
+ parser.add_argument('--scheduler_config_path', type=str,
107
+ help="Path to the Scheduler config file (for separate mode)")
108
+
109
+ args = parser.parse_args()
110
+ main(**args.__dict__)
xora/examples/image_to_video.py CHANGED
@@ -5,32 +5,46 @@ from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
5
  from xora.schedulers.rf import RectifiedFlowScheduler
6
  from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
  from pathlib import Path
8
- from transformers import T5EncoderModel
 
9
 
 
 
 
 
 
10
 
11
- model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
12
- vae_local_path = Path("/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000")
13
- dtype = torch.float32
14
- vae = CausalVideoAutoencoder.from_pretrained(
15
- pretrained_model_name_or_path=vae_local_path,
16
- revision=False,
17
- torch_dtype=torch.bfloat16,
18
- load_in_8bit=False,
 
 
19
  ).cuda()
20
- transformer_config_path = Path("/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json")
21
- transformer_config = Transformer3DModel.load_config(transformer_config_path)
 
 
 
22
  transformer = Transformer3DModel.from_config(transformer_config)
23
- transformer_local_path = Path("/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-first-frame-cond-4k-seq/ckpt/01822000/model.pt")
24
- transformer_ckpt_state_dict = torch.load(transformer_local_path)
25
- transformer.load_state_dict(transformer_ckpt_state_dict, True)
26
  transformer = transformer.cuda()
27
  unet = transformer
28
- scheduler_config_path = Path("/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json")
 
 
29
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
30
  scheduler = RectifiedFlowScheduler.from_config(scheduler_config)
 
 
31
  patchifier = SymmetricPatchifier(patch_size=1)
32
- # text_encoder = T5EncoderModel.from_pretrained("t5-v1_1-xxl")
33
 
 
34
  submodel_dict = {
35
  "unet": unet,
36
  "transformer": transformer,
@@ -38,36 +52,33 @@ submodel_dict = {
38
  "text_encoder": None,
39
  "scheduler": scheduler,
40
  "vae": vae,
41
-
42
  }
43
 
 
44
  pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
45
  safety_checker=None,
46
- revision=None,
47
- torch_dtype=dtype,
48
- **submodel_dict,
49
- )
50
 
51
- num_inference_steps=20
52
- num_images_per_prompt=2
53
- guidance_scale=3
54
- height=512
55
- width=768
56
- num_frames=57
57
- frame_rate=25
58
- # sample = {
59
- # "prompt": "A cat", # (B, L, E)
60
- # 'prompt_attention_mask': None, # (B , L)
61
- # 'negative_prompt': "Ugly deformed",
62
- # 'negative_prompt_attention_mask': None # (B , L)
63
- # }
64
 
 
65
  sample = torch.load("/opt/sample.pt")
66
- for _, item in sample.items():
67
  if item is not None:
68
- item = item.cuda()
 
69
  media_items = torch.load("/opt/sample_media.pt")
70
 
 
71
  images = pipeline(
72
  num_inference_steps=num_inference_steps,
73
  num_images_per_prompt=num_images_per_prompt,
@@ -84,4 +95,4 @@ images = pipeline(
84
  vae_per_channel_normalize=True,
85
  ).images
86
 
87
- print()
 
5
  from xora.schedulers.rf import RectifiedFlowScheduler
6
  from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
  from pathlib import Path
8
+ import safetensors.torch
9
+ import json
10
 
11
+ # Paths for the separate mode directories
12
+ separate_dir = Path("/opt/models/xora-img2video")
13
+ unet_dir = separate_dir / 'unet'
14
+ vae_dir = separate_dir / 'vae'
15
+ scheduler_dir = separate_dir / 'scheduler'
16
 
17
+ # Load VAE from separate mode
18
+ vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
19
+ vae_config_path = vae_dir / "config.json"
20
+ with open(vae_config_path, 'r') as f:
21
+ vae_config = json.load(f)
22
+ vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
23
+ vae = CausalVideoAutoencoder.from_pretrained_conf(
24
+ config=vae_config,
25
+ state_dict=vae_state_dict,
26
+ torch_dtype=torch.bfloat16
27
  ).cuda()
28
+
29
+ # Load UNet (Transformer) from separate mode
30
+ unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
31
+ unet_config_path = unet_dir / "config.json"
32
+ transformer_config = Transformer3DModel.load_config(unet_config_path)
33
  transformer = Transformer3DModel.from_config(transformer_config)
34
+ unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
35
+ transformer.load_state_dict(unet_state_dict, strict=True)
 
36
  transformer = transformer.cuda()
37
  unet = transformer
38
+
39
+ # Load Scheduler from separate mode
40
+ scheduler_config_path = scheduler_dir / "scheduler_config.json"
41
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
42
  scheduler = RectifiedFlowScheduler.from_config(scheduler_config)
43
+
44
+ # Patchifier (remains the same)
45
  patchifier = SymmetricPatchifier(patch_size=1)
 
46
 
47
+ # Use submodels for the pipeline
48
  submodel_dict = {
49
  "unet": unet,
50
  "transformer": transformer,
 
52
  "text_encoder": None,
53
  "scheduler": scheduler,
54
  "vae": vae,
 
55
  }
56
 
57
+ model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
58
  pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
59
  safety_checker=None,
60
+ revision=None,
61
+ torch_dtype=torch.float32, # dtype adjusted
62
+ **submodel_dict,
63
+ ).to("cuda")
64
 
65
+ num_inference_steps = 20
66
+ num_images_per_prompt = 2
67
+ guidance_scale = 3
68
+ height = 512
69
+ width = 768
70
+ num_frames = 57
71
+ frame_rate = 25
 
 
 
 
 
 
72
 
73
+ # Assuming sample is a dict loaded from a .pt file
74
  sample = torch.load("/opt/sample.pt")
75
+ for key, item in sample.items():
76
  if item is not None:
77
+ sample[key] = item.cuda()
78
+
79
  media_items = torch.load("/opt/sample_media.pt")
80
 
81
+ # Generate images (video frames)
82
  images = pipeline(
83
  num_inference_steps=num_inference_steps,
84
  num_images_per_prompt=num_images_per_prompt,
 
95
  vae_per_channel_normalize=True,
96
  ).images
97
 
98
+ print("Generated video frames.")
xora/examples/text_to_video.py CHANGED
@@ -6,69 +6,78 @@ from xora.schedulers.rf import RectifiedFlowScheduler
6
  from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
  from pathlib import Path
8
  from transformers import T5EncoderModel
 
 
9
 
 
 
 
 
 
10
 
11
- model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
12
- vae_local_path = Path("/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000")
13
- dtype = torch.float32
14
- vae = CausalVideoAutoencoder.from_pretrained(
15
- pretrained_model_name_or_path=vae_local_path,
16
- revision=False,
17
- torch_dtype=torch.bfloat16,
18
- load_in_8bit=False,
 
 
19
  ).cuda()
20
- transformer_config_path = Path("/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json")
21
- transformer_config = Transformer3DModel.load_config(transformer_config_path)
 
 
 
22
  transformer = Transformer3DModel.from_config(transformer_config)
23
- transformer_local_path = Path("/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-nl/ckpt/01760000/model.pt")
24
- transformer_ckpt_state_dict = torch.load(transformer_local_path)
25
- transformer.load_state_dict(transformer_ckpt_state_dict, True)
26
  transformer = transformer.cuda()
27
  unet = transformer
28
- scheduler_config_path = Path("/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json")
 
 
29
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
30
  scheduler = RectifiedFlowScheduler.from_config(scheduler_config)
 
 
31
  patchifier = SymmetricPatchifier(patch_size=1)
32
- # text_encoder = T5EncoderModel.from_pretrained("t5-v1_1-xxl")
33
 
 
34
  submodel_dict = {
35
  "unet": unet,
36
  "transformer": transformer,
37
  "patchifier": patchifier,
38
- "text_encoder": None,
39
  "scheduler": scheduler,
40
  "vae": vae,
41
-
42
  }
43
-
44
  pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
45
  safety_checker=None,
46
  revision=None,
47
- torch_dtype=dtype,
48
  **submodel_dict,
49
- )
50
-
51
- num_inference_steps=20
52
- num_images_per_prompt=2
53
- guidance_scale=3
54
- height=512
55
- width=768
56
- num_frames=57
57
- frame_rate=25
58
- # sample = {
59
- # "prompt": "A cat", # (B, L, E)
60
- # 'prompt_attention_mask': None, # (B , L)
61
- # 'negative_prompt': "Ugly deformed",
62
- # 'negative_prompt_attention_mask': None # (B , L)
63
- # }
64
-
65
- sample = torch.load("/opt/sample.pt")
66
- for _, item in sample.items():
67
- if item is not None:
68
- item = item.cuda()
69
-
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
 
72
  images = pipeline(
73
  num_inference_steps=num_inference_steps,
74
  num_images_per_prompt=num_images_per_prompt,
@@ -85,4 +94,4 @@ images = pipeline(
85
  vae_per_channel_normalize=True,
86
  ).images
87
 
88
- print()
 
6
  from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
  from pathlib import Path
8
  from transformers import T5EncoderModel
9
+ import safetensors.torch
10
+ import json
11
 
12
+ # Paths for the separate mode directories
13
+ separate_dir = Path("/opt/models/xora-txt2video")
14
+ unet_dir = separate_dir / 'unet'
15
+ vae_dir = separate_dir / 'vae'
16
+ scheduler_dir = separate_dir / 'scheduler'
17
 
18
+ # Load VAE from separate mode
19
+ vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
20
+ vae_config_path = vae_dir / "config.json"
21
+ with open(vae_config_path, 'r') as f:
22
+ vae_config = json.load(f)
23
+ vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
24
+ vae = CausalVideoAutoencoder.from_pretrained_conf(
25
+ config=vae_config,
26
+ state_dict=vae_state_dict,
27
+ torch_dtype=torch.bfloat16
28
  ).cuda()
29
+
30
+ # Load UNet (Transformer) from separate mode
31
+ unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
32
+ unet_config_path = unet_dir / "config.json"
33
+ transformer_config = Transformer3DModel.load_config(unet_config_path)
34
  transformer = Transformer3DModel.from_config(transformer_config)
35
+ unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
36
+ transformer.load_state_dict(unet_state_dict, strict=True)
 
37
  transformer = transformer.cuda()
38
  unet = transformer
39
+
40
+ # Load Scheduler from separate mode
41
+ scheduler_config_path = scheduler_dir / "scheduler_config.json"
42
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
43
  scheduler = RectifiedFlowScheduler.from_config(scheduler_config)
44
+
45
+ # Patchifier (remains the same)
46
  patchifier = SymmetricPatchifier(patch_size=1)
 
47
 
48
+ # Use submodels for the pipeline
49
  submodel_dict = {
50
  "unet": unet,
51
  "transformer": transformer,
52
  "patchifier": patchifier,
 
53
  "scheduler": scheduler,
54
  "vae": vae,
 
55
  }
56
+ model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
57
  pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
58
  safety_checker=None,
59
  revision=None,
60
+ torch_dtype=torch.float32,
61
  **submodel_dict,
62
+ ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # Sample input
65
+ num_inference_steps = 20
66
+ num_images_per_prompt = 2
67
+ guidance_scale = 3
68
+ height = 512
69
+ width = 768
70
+ num_frames = 57
71
+ frame_rate = 25
72
+ sample = {
73
+ "prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
74
+ "The man is wearing a dark blue zip-up jacket and a light blue collared shirt. He is sitting in the driver's seat of a car with a black interior. The car is moving on a road with trees and bushes on either side. The man has a serious expression on his face and is looking straight ahead.",
75
+ 'prompt_attention_mask': None, # Adjust attention masks as needed
76
+ 'negative_prompt': "Ugly deformed",
77
+ 'negative_prompt_attention_mask': None
78
+ }
79
 
80
+ # Generate images (video frames)
81
  images = pipeline(
82
  num_inference_steps=num_inference_steps,
83
  num_images_per_prompt=num_images_per_prompt,
 
94
  vae_per_channel_normalize=True,
95
  ).images
96
 
97
+ print("Generated images (video frames).")
xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -41,6 +41,35 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
41
 
42
  return video_vae
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  @staticmethod
45
  def from_config(config):
46
  assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder"
 
41
 
42
  return video_vae
43
 
44
+ @classmethod
45
+ def from_pretrained_conf(cls, config, state_dict, torch_dtype=torch.float32):
46
+ video_vae = cls.from_config(config)
47
+ video_vae.to(torch_dtype)
48
+
49
+ per_channel_statistics_prefix = "per_channel_statistics."
50
+ ckpt_state_dict = {
51
+ key: value
52
+ for key, value in state_dict.items()
53
+ if not key.startswith(per_channel_statistics_prefix)
54
+ }
55
+ video_vae.load_state_dict(ckpt_state_dict)
56
+
57
+ data_dict = {
58
+ key.removeprefix(per_channel_statistics_prefix): value
59
+ for key, value in state_dict.items()
60
+ if key.startswith(per_channel_statistics_prefix)
61
+ }
62
+ if len(data_dict) > 0:
63
+ video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
64
+ video_vae.register_buffer(
65
+ "mean_of_means",
66
+ data_dict.get(
67
+ "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
68
+ ),
69
+ )
70
+
71
+ return video_vae
72
+
73
  @staticmethod
74
  def from_config(config):
75
  assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder"