Baraaqasem commited on
Commit
413d4d0
·
verified ·
1 Parent(s): 14ee1a9

Upload 49 files

Browse files
Files changed (49) hide show
  1. src/videogen_hub/__init__.py +12 -0
  2. src/videogen_hub/_version.py +1 -0
  3. src/videogen_hub/base/__init__.py +0 -0
  4. src/videogen_hub/common/__init__.py +0 -0
  5. src/videogen_hub/common/lvdm/__init__.py +0 -0
  6. src/videogen_hub/common/lvdm/models/__init__.py +0 -0
  7. src/videogen_hub/common/lvdm/models/samplers/__init__.py +0 -0
  8. src/videogen_hub/common/lvdm/modules/__init__.py +0 -0
  9. src/videogen_hub/common/lvdm/modules/encoders/__init__.py +0 -0
  10. src/videogen_hub/depend/__init__.py +0 -0
  11. src/videogen_hub/depend/icetk/__init__.py +5 -0
  12. src/videogen_hub/depend/icetk/ice_tokenizer.py +116 -0
  13. src/videogen_hub/depend/icetk/image_tokenizer.py +77 -0
  14. src/videogen_hub/depend/icetk/sentencepiece_model_pb2.py +722 -0
  15. src/videogen_hub/depend/icetk/text_tokenizer.py +77 -0
  16. src/videogen_hub/depend/icetk/utils.py +46 -0
  17. src/videogen_hub/depend/icetk/vqvae/__init__.py +5 -0
  18. src/videogen_hub/depend/icetk/vqvae/api.py +93 -0
  19. src/videogen_hub/depend/icetk/vqvae/enc_dec.py +386 -0
  20. src/videogen_hub/depend/icetk/vqvae/quantize.py +156 -0
  21. src/videogen_hub/depend/icetk/vqvae/vqvae_hierarchical.py +97 -0
  22. src/videogen_hub/infermodels/__init__.py +59 -0
  23. src/videogen_hub/infermodels/cogvideo.py +54 -0
  24. src/videogen_hub/infermodels/cogvideox.py +48 -0
  25. src/videogen_hub/infermodels/consisti2v.py +116 -0
  26. src/videogen_hub/infermodels/dynamicrafter.py +104 -0
  27. src/videogen_hub/infermodels/i2vgen_xl.py +57 -0
  28. src/videogen_hub/infermodels/lavie.py +103 -0
  29. src/videogen_hub/infermodels/modelscope.py +62 -0
  30. src/videogen_hub/infermodels/opensora.py +134 -0
  31. src/videogen_hub/infermodels/opensora_12.py +139 -0
  32. src/videogen_hub/infermodels/opensora_plan.py +73 -0
  33. src/videogen_hub/infermodels/seine.py +52 -0
  34. src/videogen_hub/infermodels/show_one.py +79 -0
  35. src/videogen_hub/infermodels/streamingt2v.py +49 -0
  36. src/videogen_hub/infermodels/t2v_turbo.py +147 -0
  37. src/videogen_hub/infermodels/videocrafter.py +63 -0
  38. src/videogen_hub/metrics/__init__.py +0 -0
  39. src/videogen_hub/metrics/brisque_metric.py +47 -0
  40. src/videogen_hub/metrics/clip-sim_metric.py +63 -0
  41. src/videogen_hub/metrics/clipscore_metric.py +65 -0
  42. src/videogen_hub/metrics/dino-sim_metric.py +71 -0
  43. src/videogen_hub/metrics/mse-dyn_metric.py +59 -0
  44. src/videogen_hub/metrics/piqe_metric.py +49 -0
  45. src/videogen_hub/metrics/ssim-dyn_metric.py +60 -0
  46. src/videogen_hub/metrics/ssim-sim_metric.py +59 -0
  47. src/videogen_hub/metrics/xclipscore_metric.py +72 -0
  48. src/videogen_hub/utils/__init__.py +17 -0
  49. src/videogen_hub/utils/file_helper.py +24 -0
src/videogen_hub/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from ._version import __version__
4
+ MODEL_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "checkpoints"))
5
+ if os.environ.get("VIDEO_MODEL_PATH"):
6
+ MODEL_PATH = os.environ.get("VIDEO_MODEL_PATH")
7
+
8
+ # (cogVideo) Set the SAT_HOME env variable to MODEL_PATH if not set
9
+ if not os.environ.get("SAT_HOME"):
10
+ os.environ["SAT_HOME"] = MODEL_PATH
11
+
12
+ from .infermodels import load, get_model, load_model
src/videogen_hub/_version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.4a0"
src/videogen_hub/base/__init__.py ADDED
File without changes
src/videogen_hub/common/__init__.py ADDED
File without changes
src/videogen_hub/common/lvdm/__init__.py ADDED
File without changes
src/videogen_hub/common/lvdm/models/__init__.py ADDED
File without changes
src/videogen_hub/common/lvdm/models/samplers/__init__.py ADDED
File without changes
src/videogen_hub/common/lvdm/modules/__init__.py ADDED
File without changes
src/videogen_hub/common/lvdm/modules/encoders/__init__.py ADDED
File without changes
src/videogen_hub/depend/__init__.py ADDED
File without changes
src/videogen_hub/depend/icetk/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .ice_tokenizer import IceTokenizer
2
+
3
+ icetk = IceTokenizer()
4
+
5
+ __all__ = ['icetk']
src/videogen_hub/depend/icetk/ice_tokenizer.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+
3
+ import os
4
+ import sys
5
+ import math
6
+ import random
7
+ from typing import List, Tuple, Union
8
+
9
+ import torch
10
+ from PIL import Image
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import pil_to_tensor
13
+
14
+ from .text_tokenizer import TextTokenizer
15
+ from .image_tokenizer import ImageTokenizer
16
+ from .utils import auto_create
17
+
18
+ class IceTokenizer:
19
+ def __init__(self, path='~/.icetk_models', device='cuda', fp16=True):
20
+ self.configure(path, device, fp16)
21
+
22
+ def configure(self, path=None, device=None, fp16=None):
23
+ if path is not None:
24
+ self.path = os.path.expanduser(path)
25
+ if device is not None:
26
+ self.device = device
27
+ if fp16 is not None:
28
+ self.fp16 = fp16
29
+
30
+ @property
31
+ def text_tokenizer(self):
32
+ if not hasattr(self, '_text_tokenizer'):
33
+ fp = os.path.join(self.path, 'ice_text.model')
34
+ auto_create(fp)
35
+ self._text_tokenizer = TextTokenizer(fp)
36
+ return self._text_tokenizer
37
+
38
+ @property
39
+ def image_tokenizer(self):
40
+ if not hasattr(self, '_image_tokenizer'):
41
+ fp = os.path.join(self.path, 'ice_image.pt')
42
+ auto_create(fp)
43
+ self._image_tokenizer = ImageTokenizer(fp, device=self.device, fp16=self.fp16)
44
+ return self._image_tokenizer
45
+
46
+ @property
47
+ def num_image_tokens(self):
48
+ return 20000 # self.image_tokenizer.num_tokens # allow not load
49
+
50
+ @property
51
+ def num_text_tokens(self):
52
+ return self.text_tokenizer.num_tokens
53
+ @property
54
+ def num_tokens(self):
55
+ return self.num_image_tokens + self.num_text_tokens
56
+
57
+ def add_special_tokens(self, special_tokens: List[str]):
58
+ self.text_tokenizer.add_special_tokens(special_tokens)
59
+
60
+ def encode(self, text=None,
61
+ image_path=None, image_pil=None, image_torch=None,
62
+ image_size: int=None, compress_rate=8, ignore_linebreak=True):
63
+ assert (text is None) + (image_path is None) + (image_pil is None) + (image_torch is None) == 3
64
+ assert int(compress_rate) in [4, 8, 16]
65
+ if text is not None:
66
+ if not ignore_linebreak:
67
+ text = text.replace('\n', '<n>')
68
+ tmp = self.text_tokenizer.encode(text)
69
+ return [x + self.num_image_tokens for x in tmp]
70
+ else:
71
+ need_norm_to_1 = False
72
+ if image_path is not None:
73
+ image_pil = Image.open(image_path)
74
+ if image_torch is None:
75
+ image_torch = pil_to_tensor(image_pil)
76
+ need_norm_to_1 = True
77
+ if image_size is not None:
78
+ # for speed in large-scale preprocessing, set this to None and transform in Dataloader.
79
+ # TODO: test speed
80
+ tr = transforms.Compose([
81
+ transforms.Resize(image_size),
82
+ transforms.CenterCrop(image_size),
83
+ ])
84
+ image_torch = tr(image_torch)
85
+ image_torch = image_torch.to(self.image_tokenizer.device).float()
86
+ if need_norm_to_1:
87
+ image_torch /= 255.
88
+ return self.image_tokenizer.encode(image_torch, l=int(math.log2(compress_rate))-2)
89
+
90
+
91
+ def decode(self, text_ids: List[int]=None, image_ids: Union[List[int], torch.LongTensor]=None, compress_rate=8):
92
+ assert (text_ids is None) + (image_ids is None) == 1
93
+ if text_ids is not None:
94
+ ids = [int(_id) - self.num_image_tokens for _id in text_ids]
95
+ return self.text_tokenizer.decode(ids).replace('<n>', '\n')
96
+ else:
97
+ return self.image_tokenizer.decode(image_ids, l=int(math.log2(compress_rate))-2)
98
+
99
+ def tokenize(self, text):
100
+ return self.text_tokenizer.tokenize(text)
101
+
102
+ def __getitem__(self, x):
103
+ if isinstance(x, int):
104
+ if x < self.num_image_tokens:
105
+ return '<image_{}>'.format(x)
106
+ else:
107
+ return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
108
+ elif isinstance(x, str):
109
+ if x.startswith('<image_') and x.endswith('>') and x[7:-1].isdigit():
110
+ return int(x[7:-1])
111
+ else:
112
+ return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
113
+ else:
114
+ raise ValueError('The key should be str or int.')
115
+
116
+
src/videogen_hub/depend/icetk/image_tokenizer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : image_tokenizer.py
4
+ @Time : 2021/12/20 14:19:49
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torchvision import transforms
18
+
19
+ from .vqvae import load_default_HVQVAE, load_ckpt
20
+
21
+ class ImageTokenizer(object):
22
+ def __init__(self,
23
+ model_path,
24
+ device='cuda',
25
+ fp16=True):
26
+ model = load_default_HVQVAE()
27
+ model = load_ckpt(model, model_path)
28
+ model = model.to(device)
29
+ model.eval()
30
+
31
+ self.tr_normalize = transforms.Normalize(
32
+ [0.79093, 0.76271, 0.75340],
33
+ [0.30379, 0.32279, 0.32800]
34
+ )
35
+
36
+ self.model = model
37
+ self.device = device
38
+ self.fp16 = fp16
39
+ self.num_tokens = model.quantize.n_embed
40
+
41
+ if fp16:
42
+ model = model.half()
43
+
44
+ def __len__(self):
45
+ return self.num_tokens
46
+
47
+ def encode(self, image_torch, l=1):
48
+ '''Convert a batch of img to code
49
+ Args:
50
+ model: The tokenizer model.
51
+ img: [b, c, h, w]
52
+ '''
53
+ if len(image_torch.shape) == 3:
54
+ image_torch = image_torch.unsqueeze(0)
55
+ img = self.tr_normalize(image_torch).to(self.device)
56
+ if self.fp16:
57
+ img = img.half()
58
+ with torch.no_grad():
59
+ quant, diff, id = self.model.single_encode(img, l)
60
+ return id.view(img.shape[0], -1)
61
+
62
+ def decode(self, codes, l=1):
63
+ '''Convert a batch of code to imgs
64
+ Args:
65
+ codes : [b, h, w] or [b, h*w] or [h*w] LongTensor / list
66
+ '''
67
+ if isinstance(codes, list):
68
+ codes = torch.tensor(codes, dtype=torch.long, device=self.device)
69
+ if len(codes.shape) == 1:
70
+ codes = codes.unsqueeze(0)
71
+ if len(codes.shape) == 2:
72
+ s = int(math.sqrt(len(codes.view(-1))) + 1e-5)
73
+ codes = codes.view(codes.shape[0], s, s)
74
+ with torch.no_grad():
75
+ out = self.model.single_decode_code(codes, l)
76
+ out = out * torch.tensor([0.30379, 0.32279, 0.32800], device=out.device).view(1, -1, 1, 1) + torch.tensor([0.79093, 0.76271, 0.75340], device=out.device).view(1, -1, 1, 1)
77
+ return out
src/videogen_hub/depend/icetk/sentencepiece_model_pb2.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: sentencepiece_model.proto
4
+ """Generated protocol buffer code."""
5
+ from google.protobuf import descriptor as _descriptor
6
+ from google.protobuf import message as _message
7
+ from google.protobuf import reflection as _reflection
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ # @@protoc_insertion_point(imports)
10
+
11
+ _sym_db = _symbol_database.Default()
12
+
13
+
14
+
15
+
16
+ DESCRIPTOR = _descriptor.FileDescriptor(
17
+ name='sentencepiece_model.proto',
18
+ package='sentencepiece',
19
+ syntax='proto2',
20
+ serialized_options=b'H\003',
21
+ create_key=_descriptor._internal_create_key,
22
+ serialized_pb=b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\xa1\n\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
23
+ )
24
+
25
+
26
+
27
+ _TRAINERSPEC_MODELTYPE = _descriptor.EnumDescriptor(
28
+ name='ModelType',
29
+ full_name='sentencepiece.TrainerSpec.ModelType',
30
+ filename=None,
31
+ file=DESCRIPTOR,
32
+ create_key=_descriptor._internal_create_key,
33
+ values=[
34
+ _descriptor.EnumValueDescriptor(
35
+ name='UNIGRAM', index=0, number=1,
36
+ serialized_options=None,
37
+ type=None,
38
+ create_key=_descriptor._internal_create_key),
39
+ _descriptor.EnumValueDescriptor(
40
+ name='BPE', index=1, number=2,
41
+ serialized_options=None,
42
+ type=None,
43
+ create_key=_descriptor._internal_create_key),
44
+ _descriptor.EnumValueDescriptor(
45
+ name='WORD', index=2, number=3,
46
+ serialized_options=None,
47
+ type=None,
48
+ create_key=_descriptor._internal_create_key),
49
+ _descriptor.EnumValueDescriptor(
50
+ name='CHAR', index=3, number=4,
51
+ serialized_options=None,
52
+ type=None,
53
+ create_key=_descriptor._internal_create_key),
54
+ ],
55
+ containing_type=None,
56
+ serialized_options=None,
57
+ serialized_start=1294,
58
+ serialized_end=1347,
59
+ )
60
+ _sym_db.RegisterEnumDescriptor(_TRAINERSPEC_MODELTYPE)
61
+
62
+ _MODELPROTO_SENTENCEPIECE_TYPE = _descriptor.EnumDescriptor(
63
+ name='Type',
64
+ full_name='sentencepiece.ModelProto.SentencePiece.Type',
65
+ filename=None,
66
+ file=DESCRIPTOR,
67
+ create_key=_descriptor._internal_create_key,
68
+ values=[
69
+ _descriptor.EnumValueDescriptor(
70
+ name='NORMAL', index=0, number=1,
71
+ serialized_options=None,
72
+ type=None,
73
+ create_key=_descriptor._internal_create_key),
74
+ _descriptor.EnumValueDescriptor(
75
+ name='UNKNOWN', index=1, number=2,
76
+ serialized_options=None,
77
+ type=None,
78
+ create_key=_descriptor._internal_create_key),
79
+ _descriptor.EnumValueDescriptor(
80
+ name='CONTROL', index=2, number=3,
81
+ serialized_options=None,
82
+ type=None,
83
+ create_key=_descriptor._internal_create_key),
84
+ _descriptor.EnumValueDescriptor(
85
+ name='USER_DEFINED', index=3, number=4,
86
+ serialized_options=None,
87
+ type=None,
88
+ create_key=_descriptor._internal_create_key),
89
+ _descriptor.EnumValueDescriptor(
90
+ name='BYTE', index=4, number=6,
91
+ serialized_options=None,
92
+ type=None,
93
+ create_key=_descriptor._internal_create_key),
94
+ _descriptor.EnumValueDescriptor(
95
+ name='UNUSED', index=5, number=5,
96
+ serialized_options=None,
97
+ type=None,
98
+ create_key=_descriptor._internal_create_key),
99
+ ],
100
+ containing_type=None,
101
+ serialized_options=None,
102
+ serialized_start=2100,
103
+ serialized_end=2184,
104
+ )
105
+ _sym_db.RegisterEnumDescriptor(_MODELPROTO_SENTENCEPIECE_TYPE)
106
+
107
+
108
+ _TRAINERSPEC = _descriptor.Descriptor(
109
+ name='TrainerSpec',
110
+ full_name='sentencepiece.TrainerSpec',
111
+ filename=None,
112
+ file=DESCRIPTOR,
113
+ containing_type=None,
114
+ create_key=_descriptor._internal_create_key,
115
+ fields=[
116
+ _descriptor.FieldDescriptor(
117
+ name='input', full_name='sentencepiece.TrainerSpec.input', index=0,
118
+ number=1, type=9, cpp_type=9, label=3,
119
+ has_default_value=False, default_value=[],
120
+ message_type=None, enum_type=None, containing_type=None,
121
+ is_extension=False, extension_scope=None,
122
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
123
+ _descriptor.FieldDescriptor(
124
+ name='input_format', full_name='sentencepiece.TrainerSpec.input_format', index=1,
125
+ number=7, type=9, cpp_type=9, label=1,
126
+ has_default_value=False, default_value=b"".decode('utf-8'),
127
+ message_type=None, enum_type=None, containing_type=None,
128
+ is_extension=False, extension_scope=None,
129
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
130
+ _descriptor.FieldDescriptor(
131
+ name='model_prefix', full_name='sentencepiece.TrainerSpec.model_prefix', index=2,
132
+ number=2, type=9, cpp_type=9, label=1,
133
+ has_default_value=False, default_value=b"".decode('utf-8'),
134
+ message_type=None, enum_type=None, containing_type=None,
135
+ is_extension=False, extension_scope=None,
136
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
137
+ _descriptor.FieldDescriptor(
138
+ name='model_type', full_name='sentencepiece.TrainerSpec.model_type', index=3,
139
+ number=3, type=14, cpp_type=8, label=1,
140
+ has_default_value=True, default_value=1,
141
+ message_type=None, enum_type=None, containing_type=None,
142
+ is_extension=False, extension_scope=None,
143
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
144
+ _descriptor.FieldDescriptor(
145
+ name='vocab_size', full_name='sentencepiece.TrainerSpec.vocab_size', index=4,
146
+ number=4, type=5, cpp_type=1, label=1,
147
+ has_default_value=True, default_value=8000,
148
+ message_type=None, enum_type=None, containing_type=None,
149
+ is_extension=False, extension_scope=None,
150
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
151
+ _descriptor.FieldDescriptor(
152
+ name='accept_language', full_name='sentencepiece.TrainerSpec.accept_language', index=5,
153
+ number=5, type=9, cpp_type=9, label=3,
154
+ has_default_value=False, default_value=[],
155
+ message_type=None, enum_type=None, containing_type=None,
156
+ is_extension=False, extension_scope=None,
157
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
158
+ _descriptor.FieldDescriptor(
159
+ name='self_test_sample_size', full_name='sentencepiece.TrainerSpec.self_test_sample_size', index=6,
160
+ number=6, type=5, cpp_type=1, label=1,
161
+ has_default_value=True, default_value=0,
162
+ message_type=None, enum_type=None, containing_type=None,
163
+ is_extension=False, extension_scope=None,
164
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
165
+ _descriptor.FieldDescriptor(
166
+ name='character_coverage', full_name='sentencepiece.TrainerSpec.character_coverage', index=7,
167
+ number=10, type=2, cpp_type=6, label=1,
168
+ has_default_value=True, default_value=float(0.9995),
169
+ message_type=None, enum_type=None, containing_type=None,
170
+ is_extension=False, extension_scope=None,
171
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
172
+ _descriptor.FieldDescriptor(
173
+ name='input_sentence_size', full_name='sentencepiece.TrainerSpec.input_sentence_size', index=8,
174
+ number=11, type=4, cpp_type=4, label=1,
175
+ has_default_value=True, default_value=0,
176
+ message_type=None, enum_type=None, containing_type=None,
177
+ is_extension=False, extension_scope=None,
178
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
179
+ _descriptor.FieldDescriptor(
180
+ name='shuffle_input_sentence', full_name='sentencepiece.TrainerSpec.shuffle_input_sentence', index=9,
181
+ number=19, type=8, cpp_type=7, label=1,
182
+ has_default_value=True, default_value=True,
183
+ message_type=None, enum_type=None, containing_type=None,
184
+ is_extension=False, extension_scope=None,
185
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
186
+ _descriptor.FieldDescriptor(
187
+ name='mining_sentence_size', full_name='sentencepiece.TrainerSpec.mining_sentence_size', index=10,
188
+ number=12, type=5, cpp_type=1, label=1,
189
+ has_default_value=False, default_value=0,
190
+ message_type=None, enum_type=None, containing_type=None,
191
+ is_extension=False, extension_scope=None,
192
+ serialized_options=b'\030\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
193
+ _descriptor.FieldDescriptor(
194
+ name='training_sentence_size', full_name='sentencepiece.TrainerSpec.training_sentence_size', index=11,
195
+ number=13, type=5, cpp_type=1, label=1,
196
+ has_default_value=False, default_value=0,
197
+ message_type=None, enum_type=None, containing_type=None,
198
+ is_extension=False, extension_scope=None,
199
+ serialized_options=b'\030\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
200
+ _descriptor.FieldDescriptor(
201
+ name='seed_sentencepiece_size', full_name='sentencepiece.TrainerSpec.seed_sentencepiece_size', index=12,
202
+ number=14, type=5, cpp_type=1, label=1,
203
+ has_default_value=True, default_value=1000000,
204
+ message_type=None, enum_type=None, containing_type=None,
205
+ is_extension=False, extension_scope=None,
206
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
207
+ _descriptor.FieldDescriptor(
208
+ name='shrinking_factor', full_name='sentencepiece.TrainerSpec.shrinking_factor', index=13,
209
+ number=15, type=2, cpp_type=6, label=1,
210
+ has_default_value=True, default_value=float(0.75),
211
+ message_type=None, enum_type=None, containing_type=None,
212
+ is_extension=False, extension_scope=None,
213
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
214
+ _descriptor.FieldDescriptor(
215
+ name='max_sentence_length', full_name='sentencepiece.TrainerSpec.max_sentence_length', index=14,
216
+ number=18, type=5, cpp_type=1, label=1,
217
+ has_default_value=True, default_value=4192,
218
+ message_type=None, enum_type=None, containing_type=None,
219
+ is_extension=False, extension_scope=None,
220
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
221
+ _descriptor.FieldDescriptor(
222
+ name='num_threads', full_name='sentencepiece.TrainerSpec.num_threads', index=15,
223
+ number=16, type=5, cpp_type=1, label=1,
224
+ has_default_value=True, default_value=16,
225
+ message_type=None, enum_type=None, containing_type=None,
226
+ is_extension=False, extension_scope=None,
227
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
228
+ _descriptor.FieldDescriptor(
229
+ name='num_sub_iterations', full_name='sentencepiece.TrainerSpec.num_sub_iterations', index=16,
230
+ number=17, type=5, cpp_type=1, label=1,
231
+ has_default_value=True, default_value=2,
232
+ message_type=None, enum_type=None, containing_type=None,
233
+ is_extension=False, extension_scope=None,
234
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
235
+ _descriptor.FieldDescriptor(
236
+ name='max_sentencepiece_length', full_name='sentencepiece.TrainerSpec.max_sentencepiece_length', index=17,
237
+ number=20, type=5, cpp_type=1, label=1,
238
+ has_default_value=True, default_value=16,
239
+ message_type=None, enum_type=None, containing_type=None,
240
+ is_extension=False, extension_scope=None,
241
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
242
+ _descriptor.FieldDescriptor(
243
+ name='split_by_unicode_script', full_name='sentencepiece.TrainerSpec.split_by_unicode_script', index=18,
244
+ number=21, type=8, cpp_type=7, label=1,
245
+ has_default_value=True, default_value=True,
246
+ message_type=None, enum_type=None, containing_type=None,
247
+ is_extension=False, extension_scope=None,
248
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
249
+ _descriptor.FieldDescriptor(
250
+ name='split_by_number', full_name='sentencepiece.TrainerSpec.split_by_number', index=19,
251
+ number=23, type=8, cpp_type=7, label=1,
252
+ has_default_value=True, default_value=True,
253
+ message_type=None, enum_type=None, containing_type=None,
254
+ is_extension=False, extension_scope=None,
255
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
256
+ _descriptor.FieldDescriptor(
257
+ name='split_by_whitespace', full_name='sentencepiece.TrainerSpec.split_by_whitespace', index=20,
258
+ number=22, type=8, cpp_type=7, label=1,
259
+ has_default_value=True, default_value=True,
260
+ message_type=None, enum_type=None, containing_type=None,
261
+ is_extension=False, extension_scope=None,
262
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
263
+ _descriptor.FieldDescriptor(
264
+ name='treat_whitespace_as_suffix', full_name='sentencepiece.TrainerSpec.treat_whitespace_as_suffix', index=21,
265
+ number=24, type=8, cpp_type=7, label=1,
266
+ has_default_value=True, default_value=False,
267
+ message_type=None, enum_type=None, containing_type=None,
268
+ is_extension=False, extension_scope=None,
269
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
270
+ _descriptor.FieldDescriptor(
271
+ name='split_digits', full_name='sentencepiece.TrainerSpec.split_digits', index=22,
272
+ number=25, type=8, cpp_type=7, label=1,
273
+ has_default_value=True, default_value=False,
274
+ message_type=None, enum_type=None, containing_type=None,
275
+ is_extension=False, extension_scope=None,
276
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
277
+ _descriptor.FieldDescriptor(
278
+ name='control_symbols', full_name='sentencepiece.TrainerSpec.control_symbols', index=23,
279
+ number=30, type=9, cpp_type=9, label=3,
280
+ has_default_value=False, default_value=[],
281
+ message_type=None, enum_type=None, containing_type=None,
282
+ is_extension=False, extension_scope=None,
283
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
284
+ _descriptor.FieldDescriptor(
285
+ name='user_defined_symbols', full_name='sentencepiece.TrainerSpec.user_defined_symbols', index=24,
286
+ number=31, type=9, cpp_type=9, label=3,
287
+ has_default_value=False, default_value=[],
288
+ message_type=None, enum_type=None, containing_type=None,
289
+ is_extension=False, extension_scope=None,
290
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
291
+ _descriptor.FieldDescriptor(
292
+ name='required_chars', full_name='sentencepiece.TrainerSpec.required_chars', index=25,
293
+ number=36, type=9, cpp_type=9, label=1,
294
+ has_default_value=False, default_value=b"".decode('utf-8'),
295
+ message_type=None, enum_type=None, containing_type=None,
296
+ is_extension=False, extension_scope=None,
297
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
298
+ _descriptor.FieldDescriptor(
299
+ name='byte_fallback', full_name='sentencepiece.TrainerSpec.byte_fallback', index=26,
300
+ number=35, type=8, cpp_type=7, label=1,
301
+ has_default_value=True, default_value=False,
302
+ message_type=None, enum_type=None, containing_type=None,
303
+ is_extension=False, extension_scope=None,
304
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
305
+ _descriptor.FieldDescriptor(
306
+ name='vocabulary_output_piece_score', full_name='sentencepiece.TrainerSpec.vocabulary_output_piece_score', index=27,
307
+ number=32, type=8, cpp_type=7, label=1,
308
+ has_default_value=True, default_value=True,
309
+ message_type=None, enum_type=None, containing_type=None,
310
+ is_extension=False, extension_scope=None,
311
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
312
+ _descriptor.FieldDescriptor(
313
+ name='hard_vocab_limit', full_name='sentencepiece.TrainerSpec.hard_vocab_limit', index=28,
314
+ number=33, type=8, cpp_type=7, label=1,
315
+ has_default_value=True, default_value=True,
316
+ message_type=None, enum_type=None, containing_type=None,
317
+ is_extension=False, extension_scope=None,
318
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
319
+ _descriptor.FieldDescriptor(
320
+ name='use_all_vocab', full_name='sentencepiece.TrainerSpec.use_all_vocab', index=29,
321
+ number=34, type=8, cpp_type=7, label=1,
322
+ has_default_value=True, default_value=False,
323
+ message_type=None, enum_type=None, containing_type=None,
324
+ is_extension=False, extension_scope=None,
325
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
326
+ _descriptor.FieldDescriptor(
327
+ name='unk_id', full_name='sentencepiece.TrainerSpec.unk_id', index=30,
328
+ number=40, type=5, cpp_type=1, label=1,
329
+ has_default_value=True, default_value=0,
330
+ message_type=None, enum_type=None, containing_type=None,
331
+ is_extension=False, extension_scope=None,
332
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
333
+ _descriptor.FieldDescriptor(
334
+ name='bos_id', full_name='sentencepiece.TrainerSpec.bos_id', index=31,
335
+ number=41, type=5, cpp_type=1, label=1,
336
+ has_default_value=True, default_value=1,
337
+ message_type=None, enum_type=None, containing_type=None,
338
+ is_extension=False, extension_scope=None,
339
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
340
+ _descriptor.FieldDescriptor(
341
+ name='eos_id', full_name='sentencepiece.TrainerSpec.eos_id', index=32,
342
+ number=42, type=5, cpp_type=1, label=1,
343
+ has_default_value=True, default_value=2,
344
+ message_type=None, enum_type=None, containing_type=None,
345
+ is_extension=False, extension_scope=None,
346
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
347
+ _descriptor.FieldDescriptor(
348
+ name='pad_id', full_name='sentencepiece.TrainerSpec.pad_id', index=33,
349
+ number=43, type=5, cpp_type=1, label=1,
350
+ has_default_value=True, default_value=-1,
351
+ message_type=None, enum_type=None, containing_type=None,
352
+ is_extension=False, extension_scope=None,
353
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
354
+ _descriptor.FieldDescriptor(
355
+ name='unk_piece', full_name='sentencepiece.TrainerSpec.unk_piece', index=34,
356
+ number=45, type=9, cpp_type=9, label=1,
357
+ has_default_value=True, default_value=b"<unk>".decode('utf-8'),
358
+ message_type=None, enum_type=None, containing_type=None,
359
+ is_extension=False, extension_scope=None,
360
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
361
+ _descriptor.FieldDescriptor(
362
+ name='bos_piece', full_name='sentencepiece.TrainerSpec.bos_piece', index=35,
363
+ number=46, type=9, cpp_type=9, label=1,
364
+ has_default_value=True, default_value=b"<s>".decode('utf-8'),
365
+ message_type=None, enum_type=None, containing_type=None,
366
+ is_extension=False, extension_scope=None,
367
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
368
+ _descriptor.FieldDescriptor(
369
+ name='eos_piece', full_name='sentencepiece.TrainerSpec.eos_piece', index=36,
370
+ number=47, type=9, cpp_type=9, label=1,
371
+ has_default_value=True, default_value=b"</s>".decode('utf-8'),
372
+ message_type=None, enum_type=None, containing_type=None,
373
+ is_extension=False, extension_scope=None,
374
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
375
+ _descriptor.FieldDescriptor(
376
+ name='pad_piece', full_name='sentencepiece.TrainerSpec.pad_piece', index=37,
377
+ number=48, type=9, cpp_type=9, label=1,
378
+ has_default_value=True, default_value=b"<pad>".decode('utf-8'),
379
+ message_type=None, enum_type=None, containing_type=None,
380
+ is_extension=False, extension_scope=None,
381
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
382
+ _descriptor.FieldDescriptor(
383
+ name='unk_surface', full_name='sentencepiece.TrainerSpec.unk_surface', index=38,
384
+ number=44, type=9, cpp_type=9, label=1,
385
+ has_default_value=True, default_value=b" \342\201\207 ".decode('utf-8'),
386
+ message_type=None, enum_type=None, containing_type=None,
387
+ is_extension=False, extension_scope=None,
388
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
389
+ _descriptor.FieldDescriptor(
390
+ name='train_extremely_large_corpus', full_name='sentencepiece.TrainerSpec.train_extremely_large_corpus', index=39,
391
+ number=49, type=8, cpp_type=7, label=1,
392
+ has_default_value=True, default_value=False,
393
+ message_type=None, enum_type=None, containing_type=None,
394
+ is_extension=False, extension_scope=None,
395
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
396
+ ],
397
+ extensions=[
398
+ ],
399
+ nested_types=[],
400
+ enum_types=[
401
+ _TRAINERSPEC_MODELTYPE,
402
+ ],
403
+ serialized_options=None,
404
+ is_extendable=True,
405
+ syntax='proto2',
406
+ extension_ranges=[(200, 536870912), ],
407
+ oneofs=[
408
+ ],
409
+ serialized_start=45,
410
+ serialized_end=1358,
411
+ )
412
+
413
+
414
+ _NORMALIZERSPEC = _descriptor.Descriptor(
415
+ name='NormalizerSpec',
416
+ full_name='sentencepiece.NormalizerSpec',
417
+ filename=None,
418
+ file=DESCRIPTOR,
419
+ containing_type=None,
420
+ create_key=_descriptor._internal_create_key,
421
+ fields=[
422
+ _descriptor.FieldDescriptor(
423
+ name='name', full_name='sentencepiece.NormalizerSpec.name', index=0,
424
+ number=1, type=9, cpp_type=9, label=1,
425
+ has_default_value=False, default_value=b"".decode('utf-8'),
426
+ message_type=None, enum_type=None, containing_type=None,
427
+ is_extension=False, extension_scope=None,
428
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
429
+ _descriptor.FieldDescriptor(
430
+ name='precompiled_charsmap', full_name='sentencepiece.NormalizerSpec.precompiled_charsmap', index=1,
431
+ number=2, type=12, cpp_type=9, label=1,
432
+ has_default_value=False, default_value=b"",
433
+ message_type=None, enum_type=None, containing_type=None,
434
+ is_extension=False, extension_scope=None,
435
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
436
+ _descriptor.FieldDescriptor(
437
+ name='add_dummy_prefix', full_name='sentencepiece.NormalizerSpec.add_dummy_prefix', index=2,
438
+ number=3, type=8, cpp_type=7, label=1,
439
+ has_default_value=True, default_value=True,
440
+ message_type=None, enum_type=None, containing_type=None,
441
+ is_extension=False, extension_scope=None,
442
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
443
+ _descriptor.FieldDescriptor(
444
+ name='remove_extra_whitespaces', full_name='sentencepiece.NormalizerSpec.remove_extra_whitespaces', index=3,
445
+ number=4, type=8, cpp_type=7, label=1,
446
+ has_default_value=True, default_value=True,
447
+ message_type=None, enum_type=None, containing_type=None,
448
+ is_extension=False, extension_scope=None,
449
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
450
+ _descriptor.FieldDescriptor(
451
+ name='escape_whitespaces', full_name='sentencepiece.NormalizerSpec.escape_whitespaces', index=4,
452
+ number=5, type=8, cpp_type=7, label=1,
453
+ has_default_value=True, default_value=True,
454
+ message_type=None, enum_type=None, containing_type=None,
455
+ is_extension=False, extension_scope=None,
456
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
457
+ _descriptor.FieldDescriptor(
458
+ name='normalization_rule_tsv', full_name='sentencepiece.NormalizerSpec.normalization_rule_tsv', index=5,
459
+ number=6, type=9, cpp_type=9, label=1,
460
+ has_default_value=False, default_value=b"".decode('utf-8'),
461
+ message_type=None, enum_type=None, containing_type=None,
462
+ is_extension=False, extension_scope=None,
463
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
464
+ ],
465
+ extensions=[
466
+ ],
467
+ nested_types=[],
468
+ enum_types=[
469
+ ],
470
+ serialized_options=None,
471
+ is_extendable=True,
472
+ syntax='proto2',
473
+ extension_ranges=[(200, 536870912), ],
474
+ oneofs=[
475
+ ],
476
+ serialized_start=1361,
477
+ serialized_end=1570,
478
+ )
479
+
480
+
481
+ _SELFTESTDATA_SAMPLE = _descriptor.Descriptor(
482
+ name='Sample',
483
+ full_name='sentencepiece.SelfTestData.Sample',
484
+ filename=None,
485
+ file=DESCRIPTOR,
486
+ containing_type=None,
487
+ create_key=_descriptor._internal_create_key,
488
+ fields=[
489
+ _descriptor.FieldDescriptor(
490
+ name='input', full_name='sentencepiece.SelfTestData.Sample.input', index=0,
491
+ number=1, type=9, cpp_type=9, label=1,
492
+ has_default_value=False, default_value=b"".decode('utf-8'),
493
+ message_type=None, enum_type=None, containing_type=None,
494
+ is_extension=False, extension_scope=None,
495
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
496
+ _descriptor.FieldDescriptor(
497
+ name='expected', full_name='sentencepiece.SelfTestData.Sample.expected', index=1,
498
+ number=2, type=9, cpp_type=9, label=1,
499
+ has_default_value=False, default_value=b"".decode('utf-8'),
500
+ message_type=None, enum_type=None, containing_type=None,
501
+ is_extension=False, extension_scope=None,
502
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
503
+ ],
504
+ extensions=[
505
+ ],
506
+ nested_types=[],
507
+ enum_types=[
508
+ ],
509
+ serialized_options=None,
510
+ is_extendable=False,
511
+ syntax='proto2',
512
+ extension_ranges=[],
513
+ oneofs=[
514
+ ],
515
+ serialized_start=1641,
516
+ serialized_end=1682,
517
+ )
518
+
519
+ _SELFTESTDATA = _descriptor.Descriptor(
520
+ name='SelfTestData',
521
+ full_name='sentencepiece.SelfTestData',
522
+ filename=None,
523
+ file=DESCRIPTOR,
524
+ containing_type=None,
525
+ create_key=_descriptor._internal_create_key,
526
+ fields=[
527
+ _descriptor.FieldDescriptor(
528
+ name='samples', full_name='sentencepiece.SelfTestData.samples', index=0,
529
+ number=1, type=11, cpp_type=10, label=3,
530
+ has_default_value=False, default_value=[],
531
+ message_type=None, enum_type=None, containing_type=None,
532
+ is_extension=False, extension_scope=None,
533
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
534
+ ],
535
+ extensions=[
536
+ ],
537
+ nested_types=[_SELFTESTDATA_SAMPLE, ],
538
+ enum_types=[
539
+ ],
540
+ serialized_options=None,
541
+ is_extendable=True,
542
+ syntax='proto2',
543
+ extension_ranges=[(200, 536870912), ],
544
+ oneofs=[
545
+ ],
546
+ serialized_start=1572,
547
+ serialized_end=1693,
548
+ )
549
+
550
+
551
+ _MODELPROTO_SENTENCEPIECE = _descriptor.Descriptor(
552
+ name='SentencePiece',
553
+ full_name='sentencepiece.ModelProto.SentencePiece',
554
+ filename=None,
555
+ file=DESCRIPTOR,
556
+ containing_type=None,
557
+ create_key=_descriptor._internal_create_key,
558
+ fields=[
559
+ _descriptor.FieldDescriptor(
560
+ name='piece', full_name='sentencepiece.ModelProto.SentencePiece.piece', index=0,
561
+ number=1, type=9, cpp_type=9, label=1,
562
+ has_default_value=False, default_value=b"".decode('utf-8'),
563
+ message_type=None, enum_type=None, containing_type=None,
564
+ is_extension=False, extension_scope=None,
565
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
566
+ _descriptor.FieldDescriptor(
567
+ name='score', full_name='sentencepiece.ModelProto.SentencePiece.score', index=1,
568
+ number=2, type=2, cpp_type=6, label=1,
569
+ has_default_value=False, default_value=float(0),
570
+ message_type=None, enum_type=None, containing_type=None,
571
+ is_extension=False, extension_scope=None,
572
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
573
+ _descriptor.FieldDescriptor(
574
+ name='type', full_name='sentencepiece.ModelProto.SentencePiece.type', index=2,
575
+ number=3, type=14, cpp_type=8, label=1,
576
+ has_default_value=True, default_value=1,
577
+ message_type=None, enum_type=None, containing_type=None,
578
+ is_extension=False, extension_scope=None,
579
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
580
+ ],
581
+ extensions=[
582
+ ],
583
+ nested_types=[],
584
+ enum_types=[
585
+ _MODELPROTO_SENTENCEPIECE_TYPE,
586
+ ],
587
+ serialized_options=None,
588
+ is_extendable=True,
589
+ syntax='proto2',
590
+ extension_ranges=[(200, 536870912), ],
591
+ oneofs=[
592
+ ],
593
+ serialized_start=1985,
594
+ serialized_end=2195,
595
+ )
596
+
597
+ _MODELPROTO = _descriptor.Descriptor(
598
+ name='ModelProto',
599
+ full_name='sentencepiece.ModelProto',
600
+ filename=None,
601
+ file=DESCRIPTOR,
602
+ containing_type=None,
603
+ create_key=_descriptor._internal_create_key,
604
+ fields=[
605
+ _descriptor.FieldDescriptor(
606
+ name='pieces', full_name='sentencepiece.ModelProto.pieces', index=0,
607
+ number=1, type=11, cpp_type=10, label=3,
608
+ has_default_value=False, default_value=[],
609
+ message_type=None, enum_type=None, containing_type=None,
610
+ is_extension=False, extension_scope=None,
611
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
612
+ _descriptor.FieldDescriptor(
613
+ name='trainer_spec', full_name='sentencepiece.ModelProto.trainer_spec', index=1,
614
+ number=2, type=11, cpp_type=10, label=1,
615
+ has_default_value=False, default_value=None,
616
+ message_type=None, enum_type=None, containing_type=None,
617
+ is_extension=False, extension_scope=None,
618
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
619
+ _descriptor.FieldDescriptor(
620
+ name='normalizer_spec', full_name='sentencepiece.ModelProto.normalizer_spec', index=2,
621
+ number=3, type=11, cpp_type=10, label=1,
622
+ has_default_value=False, default_value=None,
623
+ message_type=None, enum_type=None, containing_type=None,
624
+ is_extension=False, extension_scope=None,
625
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
626
+ _descriptor.FieldDescriptor(
627
+ name='self_test_data', full_name='sentencepiece.ModelProto.self_test_data', index=3,
628
+ number=4, type=11, cpp_type=10, label=1,
629
+ has_default_value=False, default_value=None,
630
+ message_type=None, enum_type=None, containing_type=None,
631
+ is_extension=False, extension_scope=None,
632
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
633
+ _descriptor.FieldDescriptor(
634
+ name='denormalizer_spec', full_name='sentencepiece.ModelProto.denormalizer_spec', index=4,
635
+ number=5, type=11, cpp_type=10, label=1,
636
+ has_default_value=False, default_value=None,
637
+ message_type=None, enum_type=None, containing_type=None,
638
+ is_extension=False, extension_scope=None,
639
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
640
+ ],
641
+ extensions=[
642
+ ],
643
+ nested_types=[_MODELPROTO_SENTENCEPIECE, ],
644
+ enum_types=[
645
+ ],
646
+ serialized_options=None,
647
+ is_extendable=True,
648
+ syntax='proto2',
649
+ extension_ranges=[(200, 536870912), ],
650
+ oneofs=[
651
+ ],
652
+ serialized_start=1696,
653
+ serialized_end=2206,
654
+ )
655
+
656
+ _TRAINERSPEC.fields_by_name['model_type'].enum_type = _TRAINERSPEC_MODELTYPE
657
+ _TRAINERSPEC_MODELTYPE.containing_type = _TRAINERSPEC
658
+ _SELFTESTDATA_SAMPLE.containing_type = _SELFTESTDATA
659
+ _SELFTESTDATA.fields_by_name['samples'].message_type = _SELFTESTDATA_SAMPLE
660
+ _MODELPROTO_SENTENCEPIECE.fields_by_name['type'].enum_type = _MODELPROTO_SENTENCEPIECE_TYPE
661
+ _MODELPROTO_SENTENCEPIECE.containing_type = _MODELPROTO
662
+ _MODELPROTO_SENTENCEPIECE_TYPE.containing_type = _MODELPROTO_SENTENCEPIECE
663
+ _MODELPROTO.fields_by_name['pieces'].message_type = _MODELPROTO_SENTENCEPIECE
664
+ _MODELPROTO.fields_by_name['trainer_spec'].message_type = _TRAINERSPEC
665
+ _MODELPROTO.fields_by_name['normalizer_spec'].message_type = _NORMALIZERSPEC
666
+ _MODELPROTO.fields_by_name['self_test_data'].message_type = _SELFTESTDATA
667
+ _MODELPROTO.fields_by_name['denormalizer_spec'].message_type = _NORMALIZERSPEC
668
+ DESCRIPTOR.message_types_by_name['TrainerSpec'] = _TRAINERSPEC
669
+ DESCRIPTOR.message_types_by_name['NormalizerSpec'] = _NORMALIZERSPEC
670
+ DESCRIPTOR.message_types_by_name['SelfTestData'] = _SELFTESTDATA
671
+ DESCRIPTOR.message_types_by_name['ModelProto'] = _MODELPROTO
672
+ _sym_db.RegisterFileDescriptor(DESCRIPTOR)
673
+
674
+ TrainerSpec = _reflection.GeneratedProtocolMessageType('TrainerSpec', (_message.Message,), {
675
+ 'DESCRIPTOR' : _TRAINERSPEC,
676
+ '__module__' : 'sentencepiece_model_pb2'
677
+ # @@protoc_insertion_point(class_scope:sentencepiece.TrainerSpec)
678
+ })
679
+ _sym_db.RegisterMessage(TrainerSpec)
680
+
681
+ NormalizerSpec = _reflection.GeneratedProtocolMessageType('NormalizerSpec', (_message.Message,), {
682
+ 'DESCRIPTOR' : _NORMALIZERSPEC,
683
+ '__module__' : 'sentencepiece_model_pb2'
684
+ # @@protoc_insertion_point(class_scope:sentencepiece.NormalizerSpec)
685
+ })
686
+ _sym_db.RegisterMessage(NormalizerSpec)
687
+
688
+ SelfTestData = _reflection.GeneratedProtocolMessageType('SelfTestData', (_message.Message,), {
689
+
690
+ 'Sample' : _reflection.GeneratedProtocolMessageType('Sample', (_message.Message,), {
691
+ 'DESCRIPTOR' : _SELFTESTDATA_SAMPLE,
692
+ '__module__' : 'sentencepiece_model_pb2'
693
+ # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData.Sample)
694
+ })
695
+ ,
696
+ 'DESCRIPTOR' : _SELFTESTDATA,
697
+ '__module__' : 'sentencepiece_model_pb2'
698
+ # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData)
699
+ })
700
+ _sym_db.RegisterMessage(SelfTestData)
701
+ _sym_db.RegisterMessage(SelfTestData.Sample)
702
+
703
+ ModelProto = _reflection.GeneratedProtocolMessageType('ModelProto', (_message.Message,), {
704
+
705
+ 'SentencePiece' : _reflection.GeneratedProtocolMessageType('SentencePiece', (_message.Message,), {
706
+ 'DESCRIPTOR' : _MODELPROTO_SENTENCEPIECE,
707
+ '__module__' : 'sentencepiece_model_pb2'
708
+ # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto.SentencePiece)
709
+ })
710
+ ,
711
+ 'DESCRIPTOR' : _MODELPROTO,
712
+ '__module__' : 'sentencepiece_model_pb2'
713
+ # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto)
714
+ })
715
+ _sym_db.RegisterMessage(ModelProto)
716
+ _sym_db.RegisterMessage(ModelProto.SentencePiece)
717
+
718
+
719
+ DESCRIPTOR._options = None
720
+ _TRAINERSPEC.fields_by_name['mining_sentence_size']._options = None
721
+ _TRAINERSPEC.fields_by_name['training_sentence_size']._options = None
722
+ # @@protoc_insertion_point(module_scope)
src/videogen_hub/depend/icetk/text_tokenizer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : text_tokenizer.py
4
+ @Time : 2021/12/20 01:26:12
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ from copy import copy
15
+ from typing import List
16
+
17
+ import sentencepiece as spm
18
+ from . import sentencepiece_model_pb2 as model
19
+
20
+
21
+ class TextTokenizer:
22
+ def __init__(self, model_path):
23
+ self.proto = model.ModelProto()
24
+ with open(model_path, 'rb') as fin:
25
+ proto_str = fin.read()
26
+ self.proto.ParseFromString(proto_str)
27
+ self.refresh()
28
+
29
+ def refresh(self):
30
+ self.sp = spm.SentencePieceProcessor()
31
+ self.sp.Load(model_proto=self.proto.SerializeToString())
32
+ self.num_tokens = self.sp.vocab_size()
33
+
34
+ def add_special_tokens(self, tokens):
35
+ for token in tokens:
36
+ new_token = model.ModelProto().SentencePiece()
37
+ new_token.piece = token
38
+ new_token.score = 0
39
+ self.proto.pieces.append(new_token)
40
+ self.refresh()
41
+
42
+ def discourage_tokens(self, tokens):
43
+ if isinstance(tokens, str): # single token
44
+ tokens = [tokens]
45
+ for token in tokens:
46
+ for piece in self.proto.pieces:
47
+ if piece.piece == token:
48
+ piece.score = -100
49
+ self.refresh()
50
+
51
+ def discourage_ids(self, ids):
52
+ if isinstance(ids, int):
53
+ ids = [ids]
54
+ for idx in ids:
55
+ self.proto.pieces[idx].score = -100
56
+ self.refresh()
57
+
58
+ def encode(self, text):
59
+ return self.sp.EncodeAsIds(text)
60
+
61
+ def decode(self, ids: List[int]):
62
+ return self.sp.DecodeIds(ids)
63
+
64
+ def tokenize(self, text):
65
+ return self.sp.EncodeAsPieces(text)
66
+
67
+ def convert_tokens_to_ids(self, tokens):
68
+ return [self.sp.PieceToId(token) for token in tokens]
69
+
70
+ def convert_token_to_id(self, token):
71
+ return self.sp.PieceToId(token)
72
+
73
+ def convert_id_to_token(self, idx):
74
+ return self.sp.IdToPiece(idx)
75
+
76
+ def __len__(self):
77
+ return self.num_tokens
src/videogen_hub/depend/icetk/utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : utils.py
4
+ @Time : 2021/12/22 23:00:33
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import requests
15
+
16
+ from tqdm import tqdm
17
+ import requests
18
+ from filelock import FileLock
19
+
20
+ def download_with_progress_bar(save_path, url):
21
+ with requests.get(url, stream=True) as r:
22
+ r.raise_for_status()
23
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
24
+ with open(save_path, 'wb') as f:
25
+ pbar = tqdm(total=int(r.headers['Content-Length']), unit_scale=True)
26
+ for chunk in r.iter_content(chunk_size=32 * 1024):
27
+ if chunk: # filter out keep-alive new chunks
28
+ f.write(chunk)
29
+ pbar.update(len(chunk))
30
+
31
+ MODEL_ULRS = {
32
+ 'ice_text.model': 'https://cloud.tsinghua.edu.cn/f/2c73ea6d3e7f4aed82ec/?dl=1',
33
+ 'ice_image.pt': 'https://cloud.tsinghua.edu.cn/f/ae2cd37af814429d875d/?dl=1'
34
+ }
35
+
36
+ def auto_create(file_path):
37
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
38
+ lock = FileLock(file_path + '.lock')
39
+ with lock:
40
+ if os.path.exists(file_path):
41
+ return False
42
+ else:
43
+ url = MODEL_ULRS[os.path.basename(file_path)]
44
+ print(f'Downloading tokenizer models {url} into {file_path} ...')
45
+ download_with_progress_bar(file_path, url)
46
+ return True
src/videogen_hub/depend/icetk/vqvae/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .vqvae_hierarchical import HVQVAE
2
+ from .enc_dec import Encoder, Decoder, ResidualDownSample
3
+ from .quantize import VectorQuantizeEMA
4
+
5
+ from .api import load_default_HVQVAE, load_ckpt
src/videogen_hub/depend/icetk/vqvae/api.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import torch
3
+ import json
4
+ import math
5
+ import os
6
+ import numpy as np
7
+
8
+ import torch.nn.functional as F
9
+
10
+ def new_module(config):
11
+ '''in config:
12
+ "target": module type
13
+ "params": dict of params'''
14
+ if type(config) == str:
15
+ with open(config, 'r') as file:
16
+ config = json.load(file)
17
+ assert type(config) == dict
18
+ if not "target" in config:
19
+ raise KeyError("Expected key `target` to instantiate.")
20
+ module, cls = config.get('target').rsplit(".", 1)
21
+ model = getattr(importlib.import_module(module, package=__package__), cls)(**config.get("params", dict()))
22
+
23
+ return model
24
+
25
+ def load_ckpt(model, path):
26
+ sd = torch.load(path, map_location="cpu")['module']
27
+ model.load_state_dict(sd, strict=False)
28
+ return model
29
+
30
+ def load_default_HVQVAE():
31
+ config = {
32
+ "target": "..vqvae.HVQVAE",
33
+ "params": {
34
+ "levels": 3,
35
+ "embedding_dim": 256,
36
+ "codebook_scale": 1,
37
+ "down_sampler_configs": [
38
+ {
39
+ "target": "..vqvae.ResidualDownSample",
40
+ "params": {
41
+ "in_channels": 256
42
+ }
43
+ },
44
+ {
45
+ "target": "..vqvae.ResidualDownSample",
46
+ "params": {
47
+ "in_channels": 256
48
+ }
49
+ }
50
+ ],
51
+ "enc_config": {
52
+ "target": "..vqvae.Encoder",
53
+ "params": {
54
+ "num_res_blocks": 2,
55
+ "channels_mult": [1,2,4]
56
+ }
57
+ },
58
+ "quantize_config": {
59
+ "target": "..vqvae.VectorQuantizeEMA",
60
+ "params": {
61
+ "hidden_dim": 256,
62
+ "embedding_dim": 256,
63
+ "n_embed": 20000,
64
+ "training_loc": False
65
+ }
66
+ },
67
+ "dec_configs": [
68
+ {
69
+ "target": "..vqvae.Decoder",
70
+ "params": {
71
+ "channels_mult": [1,1,1,2,4]
72
+ }
73
+ },
74
+ {
75
+ "target": "..vqvae.Decoder",
76
+ "params": {
77
+ "channels_mult": [1,1,2,4]
78
+ }
79
+ },
80
+ {
81
+ "target": "..vqvae.Decoder",
82
+ "params": {
83
+ "channels_mult": [1,2,4]
84
+ }
85
+ }
86
+ ]
87
+ }
88
+ }
89
+ return new_module(config)
90
+
91
+
92
+ if __name__ == '__main__':
93
+ pass
src/videogen_hub/depend/icetk/vqvae/enc_dec.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ def nonlinearity(x):
8
+ return x * torch.sigmoid(x)
9
+
10
+ def Normalize(in_channels):
11
+ return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
12
+
13
+ class Upsample(nn.Module):
14
+ def __init__(self,
15
+ in_channels,
16
+ with_conv):
17
+ super().__init__()
18
+ self.with_conv = with_conv
19
+ if with_conv:
20
+ self.conv = nn.Conv2d(in_channels,
21
+ in_channels,
22
+ kernel_size=3,
23
+ stride=1,
24
+ padding=1)
25
+
26
+ def forward(self, x):
27
+ x = F.interpolate(x, scale_factor=2., mode="nearest")
28
+ if self.with_conv:
29
+ x = self.conv(x)
30
+ return x
31
+
32
+ class DownSample(nn.Module):
33
+ def __init__(self,
34
+ in_channels,
35
+ with_conv):
36
+ super().__init__()
37
+ self.with_conv = with_conv
38
+ if with_conv:
39
+ self.conv = nn.Conv2d(in_channels,
40
+ in_channels,
41
+ kernel_size=3,
42
+ stride=2,
43
+ padding=0)
44
+
45
+ def forward(self, x):
46
+ if self.with_conv:
47
+ pad = (0, 1, 0, 1)
48
+ x = F.pad(x, pad, mode='constant', value=0)
49
+ x = self.conv(x)
50
+ else:
51
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
52
+ return x
53
+
54
+ class ResidualDownSample(nn.Module):
55
+ def __init__(self, in_channels):
56
+ super().__init__()
57
+ self.in_channels = in_channels
58
+ self.pooling_down_sampler = DownSample(in_channels, with_conv=False)
59
+ self.conv_down_sampler = DownSample(in_channels, with_conv=True)
60
+
61
+ def forward(self, x):
62
+ return self.pooling_down_sampler(x) + self.conv_down_sampler(x)
63
+
64
+ class ResnetBlock(nn.Module):
65
+ def __init__(self,
66
+ in_channels,
67
+ dropout,
68
+ out_channels=None,
69
+ conv_shortcut=False):
70
+ super().__init__()
71
+ self.in_channels = in_channels
72
+ out_channels = in_channels if out_channels is None else out_channels
73
+ self.out_channels = out_channels
74
+ self.use_conv_shortcut = conv_shortcut
75
+
76
+ self.norm1 = Normalize(in_channels)
77
+ self.conv1 = nn.Conv2d(in_channels,
78
+ out_channels,
79
+ kernel_size=3,
80
+ stride=1,
81
+ padding=1)
82
+
83
+ self.norm2 = Normalize(out_channels)
84
+ self.dropout = nn.Dropout(dropout)
85
+ self.conv2 = nn.Conv2d(out_channels,
86
+ out_channels,
87
+ kernel_size=3,
88
+ stride=1,
89
+ padding=1)
90
+ if in_channels != out_channels:
91
+ if conv_shortcut:
92
+ self.conv_shortcut = nn.Conv2d(in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1)
97
+ else:
98
+ self.nin_shortcut = nn.Conv2d(in_channels,
99
+ out_channels,
100
+ kernel_size=1,
101
+ stride=1,
102
+ padding=0)
103
+
104
+ def forward(self, x):
105
+ h = x
106
+ h = self.norm1(h)
107
+ h = nonlinearity(h)
108
+ h = self.conv1(h)
109
+
110
+ h = self.norm2(h)
111
+ h = nonlinearity(h)
112
+ h = self.dropout(h)
113
+ h = self.conv2(h)
114
+
115
+ if self.in_channels != self.out_channels:
116
+ if self.use_conv_shortcut:
117
+ x = self.conv_shortcut(x)
118
+ else:
119
+ x = self.nin_shortcut(x)
120
+
121
+ return x + h
122
+
123
+ class AttnBlock(nn.Module):
124
+ def __init__(self, in_channels):
125
+ super().__init__()
126
+ self.in_channels = in_channels
127
+
128
+ self.norm = Normalize(in_channels)
129
+ self.q = nn.Conv2d(in_channels,
130
+ in_channels,
131
+ kernel_size=1,
132
+ stride=1,
133
+ padding=0)
134
+ self.k = nn.Conv2d(in_channels,
135
+ in_channels,
136
+ kernel_size=1,
137
+ stride=1,
138
+ padding=0)
139
+ self.v = nn.Conv2d(in_channels,
140
+ in_channels,
141
+ kernel_size=1,
142
+ stride=1,
143
+ padding=0)
144
+ self.proj_out = nn.Conv2d(in_channels,
145
+ in_channels,
146
+ kernel_size=1,
147
+ stride=1,
148
+ padding=0)
149
+
150
+ def forward(self, x):
151
+ h_ = x
152
+ h_ = self.norm(h_)
153
+ q = self.q(h_)
154
+ k = self.k(h_)
155
+ v = self.v(h_)
156
+
157
+ B, C, H, W = q.shape
158
+ q = q.reshape(B, C, -1)
159
+ q = q.permute(0, 2, 1) # (B, H*W, C)
160
+ k = k.reshape(B, C, -1) # (B, C, H*W)
161
+ w_ = torch.bmm(q, k) # (B, H*W, H*W)
162
+ w_ = w_ * C**(-0.5)
163
+ w_ = F.softmax(w_, dim=2)
164
+
165
+ v = v.reshape(B, C, -1) # (B, C, H*W)
166
+ w_ = w_.permute(0, 2, 1)
167
+ h_ = torch.bmm(v, w_)
168
+ h_ = h_.reshape(B, C, H, W)
169
+
170
+ h_ = self.proj_out(h_)
171
+
172
+ return x + h_
173
+
174
+ class Encoder(nn.Module):
175
+ def __init__(self,
176
+ in_channels=3,
177
+ out_channels=3,
178
+ z_channels=256,
179
+ channels=128,
180
+ num_res_blocks=0,
181
+ resolution=256,
182
+ attn_resolutions=[16],
183
+ resample_with_conv=True,
184
+ channels_mult=(1,2,4,8),
185
+ dropout=0.
186
+ ):
187
+ super().__init__()
188
+
189
+ self.in_channels = in_channels
190
+ self.out_channels = out_channels
191
+ self.z_channels = z_channels
192
+ self.channels = channels
193
+ self.num_resolutions = len(channels_mult)
194
+ self.num_res_blocks = num_res_blocks
195
+ self.resolution = resolution
196
+
197
+ self.conv_in = nn.Conv2d(in_channels,
198
+ channels,
199
+ kernel_size=3,
200
+ stride=1,
201
+ padding=1)
202
+
203
+ current_resolution = resolution
204
+ in_channels_mult = (1,) + tuple(channels_mult)
205
+ self.down = nn.ModuleList()
206
+ for i_level in range(self.num_resolutions):
207
+ block = nn.ModuleList()
208
+ attn = nn.ModuleList()
209
+ block_in = channels * in_channels_mult[i_level]
210
+ block_out = channels * channels_mult[i_level]
211
+ for i_block in range(self.num_res_blocks):
212
+ block.append(ResnetBlock(in_channels=block_in,
213
+ out_channels=block_out,
214
+ dropout=dropout))
215
+ block_in = block_out
216
+ if current_resolution in attn_resolutions:
217
+ attn.append(AttnBlock(block_in))
218
+ down = nn.Module()
219
+ down.block = block
220
+ down.attn = attn
221
+ if i_level != self.num_resolutions - 1:
222
+ down.downsample = DownSample(block_in,
223
+ resample_with_conv)
224
+ current_resolution = current_resolution // 2
225
+ self.down.append(down)
226
+
227
+ # middle
228
+ self.mid = nn.Module()
229
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
230
+ out_channels=block_in,
231
+ dropout=dropout)
232
+ self.mid.attn_1 = AttnBlock(block_in)
233
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
234
+ out_channels=block_in,
235
+ dropout=dropout)
236
+
237
+ # end
238
+ self.norm_out = Normalize(block_in)
239
+ self.conv_out = nn.Conv2d(block_in,
240
+ z_channels,
241
+ kernel_size=3,
242
+ stride=1,
243
+ padding=1)
244
+
245
+ def test_forward(self, x):
246
+ # downsample
247
+ import pdb
248
+ hs = [self.conv_in(x)]
249
+ for i_level in range(self.num_resolutions):
250
+ for i_block in range(self.num_res_blocks):
251
+ h = self.down[i_level].block[i_block](hs[-1])
252
+ if len(self.down[i_level].attn) > 0:
253
+ h = self.down[i_level].attn[i_block](h)
254
+ hs.append(h)
255
+ if i_level != self.num_resolutions - 1:
256
+ hs.append(self.down[i_level].downsample(hs[-1]))
257
+
258
+ return hs
259
+
260
+ def forward(self, x):
261
+ # downsample
262
+ hs = [self.conv_in(x)]
263
+ for i_level in range(self.num_resolutions):
264
+ for i_block in range(self.num_res_blocks):
265
+ h = self.down[i_level].block[i_block](hs[-1])
266
+ if len(self.down[i_level].attn) > 0:
267
+ h = self.down[i_level].attn[i_block](h)
268
+ hs.append(h)
269
+ if i_level != self.num_resolutions - 1:
270
+ hs.append(self.down[i_level].downsample(hs[-1]))
271
+
272
+ # middle
273
+ h = hs[-1]
274
+ h = self.mid.block_1(h)
275
+ h = self.mid.attn_1(h)
276
+ h = self.mid.block_2(h)
277
+
278
+ # end
279
+ h = self.norm_out(h)
280
+ h = nonlinearity(h)
281
+ h = self.conv_out(h)
282
+
283
+ return h
284
+
285
+ class Decoder(nn.Module):
286
+ def __init__(self,
287
+ in_channels=3,
288
+ out_channels=3,
289
+ z_channels=256,
290
+ channels=128,
291
+ num_res_blocks=0,
292
+ resolution=256,
293
+ attn_resolutions=[16],
294
+ channels_mult=(1,2,4,8),
295
+ resample_with_conv=True,
296
+ dropout=0.
297
+ ):
298
+ super().__init__()
299
+ self.in_channels = in_channels
300
+ self.out_channels = out_channels
301
+ self.z_channels = z_channels
302
+ self.channels = channels
303
+ self.num_resolutions = len(channels_mult)
304
+ self.num_res_blocks = num_res_blocks
305
+ self.resolution = resolution
306
+
307
+ in_channels_mult = (1,) + tuple(channels_mult)
308
+ block_in = channels * channels_mult[self.num_resolutions - 1]
309
+ current_resolution = resolution // 2**(self.num_resolutions - 1)
310
+ self.z_shape = (1, z_channels, current_resolution, current_resolution)
311
+
312
+ # z to block_in
313
+ self.conv_in = nn.Conv2d(z_channels,
314
+ block_in,
315
+ kernel_size=3,
316
+ stride=1,
317
+ padding=1)
318
+
319
+ # middle
320
+ self.mid = nn.Module()
321
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
322
+ out_channels=block_in,
323
+ dropout=dropout)
324
+ self.mid.attn_1 = AttnBlock(block_in)
325
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
326
+ out_channels=block_in,
327
+ dropout=dropout)
328
+
329
+ # upsampling
330
+ self.up = nn.ModuleList()
331
+ for i_level in reversed(range(self.num_resolutions)):
332
+ block = nn.ModuleList()
333
+ attn = nn.ModuleList()
334
+ block_out = channels * channels_mult[i_level]
335
+ for i_block in range(self.num_res_blocks + 1):
336
+ block.append(ResnetBlock(in_channels=block_in,
337
+ out_channels=block_out,
338
+ dropout=dropout))
339
+ block_in = block_out
340
+ if current_resolution in attn_resolutions:
341
+ attn.append(AttnBlock(block_in))
342
+ up = nn.Module()
343
+ up.block = block
344
+ up.attn = attn
345
+ if i_level != 0:
346
+ up.upsample = Upsample(block_in,
347
+ resample_with_conv)
348
+ current_resolution = current_resolution * 2
349
+ self.up.insert(0, up)
350
+
351
+ # end
352
+ self.norm_out = Normalize(block_in)
353
+ self.conv_out = nn.Conv2d(block_in,
354
+ out_channels,
355
+ kernel_size=3,
356
+ stride=1,
357
+ padding=1)
358
+
359
+ def forward(self, z):
360
+ self.last_z_shape = z.shape
361
+
362
+ # z to block_in
363
+ h = self.conv_in(z)
364
+
365
+ # middle
366
+ h = self.mid.block_1(h)
367
+ h = self.mid.attn_1(h)
368
+ h = self.mid.block_2(h)
369
+
370
+ # upsampling
371
+ for i_level in reversed(range(self.num_resolutions)):
372
+ for i_block in range(self.num_res_blocks + 1):
373
+ h = self.up[i_level].block[i_block](h)
374
+ if len(self.up[i_level].attn) > 0:
375
+ h = self.up[i_level].attn[i_block](h)
376
+ if i_level != 0:
377
+ h = self.up[i_level].upsample(h)
378
+
379
+ # end
380
+ h = self.norm_out(h)
381
+ h = nonlinearity(h)
382
+ h = self.conv_out(h)
383
+ return h
384
+
385
+ def get_last_layer(self):
386
+ return self.conv_out.weight
src/videogen_hub/depend/icetk/vqvae/quantize.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch import einsum
4
+ from torch.nn import functional as F
5
+
6
+ class VectorQuantize(nn.Module):
7
+ def __init__(self,
8
+ hidden_dim,
9
+ embedding_dim,
10
+ n_embed,
11
+ commitment_cost=1):
12
+ super().__init__()
13
+
14
+ self.hidden_dim = hidden_dim
15
+ self.embedding_dim = embedding_dim
16
+ self.n_embed = n_embed
17
+ self.commitment_cost = commitment_cost
18
+
19
+ self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1)
20
+ self.embed = nn.Embedding(n_embed, embedding_dim)
21
+ self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
22
+
23
+ def forward(self, z):
24
+ B, C, H, W = z.shape
25
+
26
+ z_e = self.proj(z)
27
+ z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C)
28
+ flatten = z_e.reshape(-1, self.embedding_dim)
29
+
30
+ dist = (
31
+ flatten.pow(2).sum(1, keepdim=True)
32
+ - 2 * flatten @ self.embed.weight.t()
33
+ + self.embed.weight.pow(2).sum(1, keepdim=True).t()
34
+ )
35
+ _, embed_ind = (-dist).max(1)
36
+ embed_ind = embed_ind.view(B, H, W)
37
+
38
+ z_q = self.embed_code(embed_ind)
39
+ diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean() \
40
+ + (z_q - z_e.detach()).pow(2).mean()
41
+
42
+ z_q = z_e + (z_q - z_e).detach()
43
+ return z_q, diff, embed_ind
44
+
45
+ def embed_code(self, embed_id):
46
+ return F.embedding(embed_id, self.embed.weight)
47
+
48
+
49
+ class VectorQuantizeEMA(nn.Module):
50
+ def __init__(self,
51
+ hidden_dim,
52
+ embedding_dim,
53
+ n_embed,
54
+ commitment_cost=1,
55
+ decay=0.99,
56
+ eps=1e-5,
57
+ pre_proj=True,
58
+ training_loc=True):
59
+ super().__init__()
60
+
61
+ self.hidden_dim = hidden_dim
62
+ self.embedding_dim = embedding_dim
63
+ self.n_embed = n_embed
64
+ self.commitment_cost = commitment_cost
65
+ self.training_loc = training_loc
66
+
67
+ self.pre_proj = pre_proj
68
+ if self.pre_proj:
69
+ self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1)
70
+ self.embed = nn.Embedding(n_embed, embedding_dim)
71
+ self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
72
+
73
+ self.register_buffer("cluster_size", torch.zeros(n_embed))
74
+ self.register_buffer("embed_avg", self.embed.weight.data.clone())
75
+
76
+ self.decay = decay
77
+ self.eps = eps
78
+
79
+ def forward(self, z):
80
+ B, C, H, W = z.shape
81
+
82
+ if self.pre_proj:
83
+ z_e = self.proj(z)
84
+ else:
85
+ z_e = z
86
+ z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C)
87
+ flatten = z_e.reshape(-1, self.embedding_dim)
88
+
89
+ dist = (
90
+ flatten.pow(2).sum(1, keepdim=True)
91
+ - 2 * flatten @ self.embed.weight.t()
92
+ + self.embed.weight.pow(2).sum(1, keepdim=True).t()
93
+ )
94
+ _, embed_ind = (-dist).max(1)
95
+ embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
96
+ embed_ind = embed_ind.view(B, H, W)
97
+
98
+ z_q = self.embed_code(embed_ind)
99
+
100
+ diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean()
101
+
102
+ z_q = z_e + (z_q - z_e).detach()
103
+ return z_q, diff, embed_ind
104
+
105
+ def embed_code(self, embed_id):
106
+ return F.embedding(embed_id, self.embed.weight)
107
+
108
+
109
+ class GumbelQuantize(nn.Module):
110
+ def __init__(self,
111
+ hidden_dim,
112
+ embedding_dim,
113
+ n_embed,
114
+ commitment_cost=1,
115
+ straight_through=True,
116
+ kl_weight=5e-4,
117
+ temp_init=1.,
118
+ eps=1e-5):
119
+ super().__init__()
120
+
121
+ self.hidden_dim = hidden_dim
122
+ self.embedding_dim = embedding_dim
123
+ self.n_embed = n_embed
124
+ self.commitment_cost = commitment_cost
125
+
126
+ self.kl_weight = kl_weight
127
+ self.temperature = temp_init
128
+ self.eps = eps
129
+
130
+ self.proj = nn.Conv2d(hidden_dim, n_embed, 1)
131
+ self.embed = nn.Embedding(n_embed, embedding_dim)
132
+ self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
133
+
134
+ self.straight_through = straight_through
135
+
136
+ def forward(self, z, temp=None):
137
+ hard = self.straight_through if self.training else True
138
+ temp = self.temperature if temp is None else temp
139
+
140
+ B, C, H, W = z.shape
141
+
142
+ z_e = self.proj(z)
143
+
144
+ soft_one_hot = F.gumbel_softmax(z_e, tau=temp, dim=1, hard=hard)
145
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
146
+
147
+ qy = F.softmax(z_e, dim=1)
148
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + self.eps), dim=1).mean()
149
+
150
+ embed_ind = soft_one_hot.argmax(dim=1)
151
+ z_q = z_q.permute(0, 2, 3, 1)
152
+ return z_q, diff, embed_ind
153
+
154
+ def embed_code(self, embed_id):
155
+ return F.embedding(embed_id, self.embed.weight)
156
+
src/videogen_hub/depend/icetk/vqvae/vqvae_hierarchical.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import json
4
+ import os
5
+
6
+ from .api import new_module
7
+
8
+ class HVQVAE(nn.Module):
9
+ def __init__(
10
+ self,
11
+ levels,
12
+ embedding_dim,
13
+ enc_config,
14
+ quantize_config,
15
+ down_sampler_configs,
16
+ dec_configs,
17
+ codebook_scale=1.
18
+ ):
19
+ super().__init__()
20
+
21
+ self.levels = levels
22
+
23
+ self.enc = new_module(enc_config)
24
+
25
+ self.decs = nn.ModuleList()
26
+ for i in range(levels):
27
+ self.decs.append(new_module(dec_configs[i]))
28
+
29
+ self.quantize = new_module(quantize_config)
30
+ self.down_samplers = nn.ModuleList()
31
+ for i in range(levels-1):
32
+ self.down_samplers.append(new_module(down_sampler_configs[i]))
33
+ self.codebook_scale = codebook_scale
34
+
35
+ def forward(self, input):
36
+ quants, diffs, ids = self.encode(input)
37
+ dec_outputs = self.decode(quants[::-1])
38
+
39
+ total_diff = diffs[0]
40
+ scale = 1.
41
+ for diff in diffs[1:]:
42
+ scale *= self.codebook_scale
43
+ total_diff = total_diff + diff * scale
44
+ return dec_outputs, total_diff
45
+
46
+ def encode(self, input):
47
+ enc_output = self.enc(input)
48
+ enc_outputs = [enc_output]
49
+ for l in range(self.levels-1):
50
+ enc_outputs.append(self.down_samplers[l](enc_outputs[-1]))
51
+
52
+ quants, diffs, ids = [], [], []
53
+ for enc_output in enc_outputs:
54
+ quant, diff, id = self.quantize(enc_output)
55
+ quants.append(quant.permute(0, 3, 1, 2))
56
+ diffs.append(diff)
57
+ ids.append(id)
58
+
59
+ return quants, diffs, ids
60
+
61
+ def decode(self, quants):
62
+ dec_outputs = []
63
+ for l in range(self.levels-1, -1, -1):
64
+ dec_outputs.append(self.decs[l](quants[l]))
65
+
66
+ return dec_outputs
67
+
68
+ def decode_code(self, codes):
69
+ quants = []
70
+ for l in range(self.levels):
71
+ quants.append(self.quantize.embed_code(codes[l]).permute(0, 3, 1, 2))
72
+ dec_outputs = self.decode(quants)
73
+
74
+ return dec_outputs
75
+
76
+ def single_encode(self, input, l):
77
+ assert l >= 0 and l <= 2
78
+ enc_output = self.enc(input)
79
+ for i in range(l):
80
+ enc_output = self.down_samplers[i](enc_output)
81
+
82
+ quant, diff, id = self.quantize(enc_output)
83
+
84
+ return quant, diff, id
85
+
86
+ def single_decode(self, quant, l):
87
+ assert l >= 0 and l <= 2
88
+ return self.decs[l](quant)
89
+
90
+ def single_decode_code(self, code, l):
91
+ assert l >= 0 and l <= 2
92
+ quant = self.quantize.embed_code(code).permute(0, 3, 1, 2)
93
+ return self.decs[2-l](quant)
94
+
95
+ def get_last_layer(self):
96
+ return self.decs[-1].get_last_layer()
97
+
src/videogen_hub/infermodels/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==========================================================
2
+ # Text-to-Video Generation
3
+ from .lavie import LaVie
4
+ from .videocrafter import VideoCrafter2
5
+ from .modelscope import ModelScope
6
+ from .streamingt2v import StreamingT2V
7
+ from .show_one import ShowOne
8
+ from .opensora import OpenSora
9
+ from .opensora_plan import OpenSoraPlan
10
+ from .t2v_turbo import T2VTurbo
11
+ from .opensora_12 import OpenSora12
12
+ from .cogvideox import CogVideoX
13
+
14
+ # from .cogvideo import CogVideo # Not supporting CogVideo ATM
15
+
16
+ # ==========================================================
17
+ # Image-to-Video Generation
18
+ from .seine import SEINE
19
+ from .consisti2v import ConsistI2V
20
+ from .dynamicrafter import DynamiCrafter
21
+ from .i2vgen_xl import I2VGenXL
22
+
23
+ # ==========================================================
24
+
25
+ import sys
26
+ from functools import partial
27
+
28
+
29
+ def get_model(model_name: str = None, init_with_default_params: bool = True):
30
+ """
31
+ Retrieves a model class or instance by its name.
32
+
33
+ Args:
34
+ model_name (str): Name of the model class. Triggers an error if the module name does not exist.
35
+ init_with_default_params (bool, optional): If True, returns an initialized model instance; otherwise, returns
36
+ the model class. Default is True. If set to True, be cautious of potential ``OutOfMemoryError`` with insufficient CUDA memory.
37
+
38
+ Returns:
39
+ model_class or model_instance: Depending on ``init_with_default_params``, either the model class or an instance of the model.
40
+
41
+ Examples::
42
+ initialized_model = infermodels.get_model(model_name='<Model>', init_with_default_params=True)
43
+
44
+ uninitialized_model = infermodels.get_model(model_name='<Model>', init_with_default_params=False)
45
+ initialized_model = uninitialized_model(device="cuda", <...>)
46
+ """
47
+
48
+ if not hasattr(sys.modules[__name__], model_name):
49
+ raise ValueError(f"No model named {model_name} found in infermodels.")
50
+
51
+ model_class = getattr(sys.modules[__name__], model_name)
52
+ if init_with_default_params:
53
+ model_instance = model_class()
54
+ return model_instance
55
+ return model_class
56
+
57
+
58
+ load_model = partial(get_model, init_with_default_params=True)
59
+ load = partial(get_model)
src/videogen_hub/infermodels/cogvideo.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class CogVideo:
2
+ def __init__(self, device="cuda"):
3
+ """
4
+ Initializes the CogVideo model with a specific device.
5
+
6
+ Args:
7
+ device (str, optional): The device to run the model on. Defaults to "cuda".
8
+ """
9
+
10
+ import argparse
11
+
12
+ # Manually creating an args object
13
+ self.args = argparse.Namespace(
14
+ generate_frame_num=5,
15
+ coglm_temperature2=0.89,
16
+ use_guidance_stage1=True,
17
+ use_guidance_stage2=False, # Assuming this is not set
18
+ guidance_alpha=3.0,
19
+ stage_1=False, # Assuming this is not set
20
+ stage_2=False, # Assuming this is not set
21
+ both_stages=True,
22
+ parallel_size=1,
23
+ stage1_max_inference_batch_size=-1,
24
+ multi_gpu=False, # Assuming this is not set
25
+ device=3,
26
+ )
27
+
28
+ def infer_one_video(
29
+ self,
30
+ prompt: str = None,
31
+ size: list = [320, 512],
32
+ seconds: int = 2,
33
+ fps: int = 8,
34
+ seed: int = 42,
35
+ ):
36
+ """
37
+ Generates a single video based on the provided prompt and parameters.
38
+
39
+ Args:
40
+ prompt (str, optional): The text prompt to generate the video from. Defaults to None.
41
+ size (list, optional): The size of the video as [height, width]. Defaults to [320, 512].
42
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
43
+ fps (int, optional): The frames per second of the video. Defaults to 8.
44
+ seed (int, optional): The seed for random number generation. Defaults to 42.
45
+
46
+ Returns:
47
+ torch.Tensor: The generated video as a tensor.
48
+ """
49
+
50
+ from videogen_hub.pipelines.cogvideo.cogvideo_pipeline import pipeline
51
+
52
+ return pipeline(
53
+ self.args, raw_text=prompt, height=size[0], width=size[1], duration=seconds
54
+ )
src/videogen_hub/infermodels/cogvideox.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class CogVideoX:
4
+ def __init__(self, weight="THUDM/CogVideoX-2b", device="cuda"):
5
+ """
6
+ Initializes the CogVideo model with a specific device.
7
+
8
+ Args:
9
+ device (str, optional): The device to run the model on. Defaults to "cuda".
10
+ """
11
+ from diffusers import CogVideoXPipeline
12
+
13
+ self.pipe = CogVideoXPipeline.from_pretrained(weight).to("cuda")
14
+
15
+ def infer_one_video(
16
+ self,
17
+ prompt: str = None,
18
+ size: list = [320, 512],
19
+ seconds: int = 2,
20
+ fps: int = 8,
21
+ seed: int = 42,
22
+ ):
23
+ """
24
+ Generates a single video based on the provided prompt and parameters.
25
+
26
+ Args:
27
+ prompt (str, optional): The text prompt to generate the video from. Defaults to None.
28
+ size (list, optional): The size of the video as [height, width]. Defaults to [320, 512].
29
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
30
+ fps (int, optional): The frames per second of the video. Defaults to 8.
31
+ seed (int, optional): The seed for random number generation. Defaults to 42.
32
+
33
+ Returns:
34
+ torch.Tensor: The generated video as a tensor.
35
+ """
36
+
37
+ video = self.pipe(prompt=prompt,
38
+ guidance_scale=6,
39
+ num_frames=seconds * fps,
40
+ #height=size[0],
41
+ #width=size[1],
42
+ num_inference_steps=50,
43
+ generator=torch.manual_seed(seed)).frames[0]
44
+ from videogen_hub.utils import images_to_tensor
45
+ video = video[:-1] # drop the last frame
46
+ video = images_to_tensor(video) # parse it back to tensor (T, C, H, W)
47
+
48
+ return video
src/videogen_hub/infermodels/consisti2v.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from PIL import Image
4
+ from huggingface_hub import snapshot_download
5
+
6
+ from videogen_hub import MODEL_PATH
7
+
8
+
9
+ class ConsistI2V:
10
+ def __init__(self, device="cuda"):
11
+
12
+ class Args:
13
+ def __init__(self):
14
+ self.inference_config = "configs/inference/inference.yaml"
15
+ self.prompt = None
16
+ self.n_prompt = ""
17
+ self.seed = "random"
18
+ self.path_to_first_frame = None
19
+ self.prompt_config = "configs/prompts/default.yaml"
20
+ self.format = "mp4"
21
+ self.save_model = False
22
+ self.optional_args = []
23
+
24
+ self.args = Args()
25
+ model_path = os.path.join(MODEL_PATH, "TIGER-Lab", "ConsistI2V").replace("\\", "\\\\")
26
+ yaml_config = f"""
27
+ output_dir: "samples/inference"
28
+ output_name: "i2v"
29
+ pretrained_model_path: "{model_path}"
30
+ unet_path: null
31
+ unet_ckpt_prefix: "module."
32
+ pipeline_pretrained_path: null
33
+
34
+ sampling_kwargs:
35
+ height: 256
36
+ width: 256
37
+ n_frames: 16
38
+ steps: 50
39
+ ddim_eta: 0.0
40
+ guidance_scale_txt: 7.5
41
+ guidance_scale_img: 1.0
42
+ guidance_rescale: 0.0
43
+ num_videos_per_prompt: 1
44
+ frame_stride: 3
45
+
46
+ unet_additional_kwargs:
47
+ variant: null
48
+ n_temp_heads: 8
49
+ augment_temporal_attention: true
50
+ temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
51
+ first_frame_condition_mode: "concat"
52
+ use_frame_stride_condition: true
53
+ noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
54
+ noise_alpha: 1.0
55
+
56
+ noise_scheduler_kwargs:
57
+ beta_start: 0.00085
58
+ beta_end: 0.012
59
+ beta_schedule: "linear"
60
+ steps_offset: 1
61
+ clip_sample: false
62
+ rescale_betas_zero_snr: false # true if using zero terminal snr
63
+ timestep_spacing: "leading" # "trailing" if using zero terminal snr
64
+ prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
65
+
66
+ frameinit_kwargs:
67
+ enable: true
68
+ camera_motion: null
69
+ noise_level: 850
70
+ filter_params:
71
+ method: 'gaussian'
72
+ d_s: 0.25
73
+ d_t: 0.25
74
+ """
75
+
76
+ from omegaconf import OmegaConf
77
+
78
+ self.config = OmegaConf.create(yaml_config)
79
+ model_path = os.path.join(MODEL_PATH, "ConsistI2V").replace("\\", "\\\\")
80
+ snapshot_download("TIGER-Lab/ConsistI2V", local_dir=model_path)
81
+ from videogen_hub.pipelines.consisti2v.scripts.animate import main
82
+
83
+ self.pipeline = main
84
+
85
+ def infer_one_video(
86
+ self,
87
+ input_image: Image.Image,
88
+ prompt: str = None,
89
+ size: list = [320, 512],
90
+ seconds: int = 2,
91
+ fps: int = 8,
92
+ seed: int = 42,
93
+ ):
94
+ """
95
+ Generates a single video based on a textual prompt and first frame image, using either a provided image or an image path as the starting point. The output is a tensor representing the video.
96
+
97
+ Args:
98
+ input_image (PIL.Image.Image): The input image to use as the basis for video generation.
99
+ prompt (str, optional): The text prompt that guides the video generation. If not specified, the video generation will rely solely on the input image. Defaults to None.
100
+ size (list, optional): Specifies the resolution of the output video as [height, width]. Defaults to [320, 512].
101
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
102
+ fps (int, optional): The number of frames per second in the generated video. This determines how smooth the video appears. Defaults to 8.
103
+ seed (int, optional): A seed value for random number generation, ensuring reproducibility of the video generation process. Defaults to 42.
104
+
105
+ Returns:
106
+ torch.Tensor: A tensor representing the generated video, structured as (time, channel, height, width).
107
+ """
108
+
109
+ self.args.prompt = prompt
110
+ self.args.path_to_first_frame = input_image
111
+ self.args.seed = str(seed)
112
+ self.config.sampling_kwargs.height = size[0]
113
+ self.config.sampling_kwargs.width = size[1]
114
+ self.config.sampling_kwargs.n_frames = seconds * fps
115
+
116
+ return self.pipeline(self.args, self.config)
src/videogen_hub/infermodels/dynamicrafter.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import hf_hub_download
4
+ from PIL import Image
5
+
6
+ from videogen_hub import MODEL_PATH
7
+
8
+
9
+ class DynamiCrafter:
10
+ def __init__(self, version: str = "256"):
11
+ """
12
+ Initializes the DynamiCrafter model using the Doubiiu/DynamiCrafter_{version} checkpoint from the Hugging Face Hub.
13
+ and load them to "MODEL_DIR/dynamicrafter_{version}_v1"
14
+
15
+ Args:
16
+ version (str, optional): The resolution of the video to generate. Choose from '256', '512', or '1024'. Defaults to '256'.
17
+ """
18
+ from videogen_hub.pipelines.dynamicrafter.inference import DynamiCrafterPipeline
19
+
20
+ if version == "256":
21
+ (self.height, self.width) = 256, 256
22
+ self.fs = 3
23
+ self.model_path = hf_hub_download(
24
+ repo_id="Doubiiu/DynamiCrafter",
25
+ filename="model.ckpt",
26
+ local_dir=os.path.join(MODEL_PATH, "dynamicrafter_256_v1"),
27
+ )
28
+
29
+ elif version == "512":
30
+ (self.height, self.width) = 320, 512
31
+ self.fs = 24
32
+ self.model_path = hf_hub_download(
33
+ repo_id="Doubiiu/DynamiCrafter_512",
34
+ filename="model.ckpt",
35
+ local_dir=os.path.join(MODEL_PATH, "dynamicrafter_512_v1"),
36
+ )
37
+
38
+ elif version == "1024":
39
+ (self.height, self.width) = 576, 1024
40
+ self.fs = 10
41
+ self.model_path = hf_hub_download(
42
+ repo_id="Doubiiu/DynamiCrafter_1024",
43
+ filename="model.ckpt",
44
+ local_dir=os.path.join(MODEL_PATH, "dynamicrafter_1024_v1"),
45
+ )
46
+ else:
47
+ raise ValueError("Invalid input. Please enter 256, 512, or 1024.")
48
+
49
+ self.arg_list = [
50
+ "--ckpt_path",
51
+ self.model_path,
52
+ "--config",
53
+ f"src/videogen_hub/pipelines/dynamicrafter/configs/inference_{version}_v1.0.yaml",
54
+ "--n_samples",
55
+ "1",
56
+ "--bs",
57
+ "1",
58
+ "--height",
59
+ str(self.height),
60
+ "--width",
61
+ str(self.width),
62
+ "--text_input",
63
+ "--unconditional_guidance_scale",
64
+ "7.5",
65
+ "--ddim_steps",
66
+ "50",
67
+ "--ddim_eta",
68
+ "1.0",
69
+ "--video_length",
70
+ "16",
71
+ "--frame_stride",
72
+ str(self.fs),
73
+ ]
74
+
75
+ self.pipeline = DynamiCrafterPipeline(self.arg_list)
76
+
77
+ def infer_one_video(
78
+ self,
79
+ input_image: Image.Image,
80
+ prompt: str = None,
81
+ seconds: int = 2,
82
+ fps: int = 8,
83
+ seed: int = 42,
84
+ ):
85
+ """
86
+ Generates a single video based on a textual prompt and first frame image, using either a provided image or an image path as the starting point. The output is a tensor representing the video.
87
+
88
+ Args:
89
+ input_image (PIL.Image.Image): The input image to use as the basis for video generation.
90
+ prompt (str, optional): The text prompt that guides the video generation. If not specified, the video generation will rely solely on the input image. Defaults to None.
91
+ size (list, optional): Specifies the resolution of the output video as [height, width]. Defaults to [320, 512].
92
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
93
+ fps (int, optional): The number of frames per second in the generated video. This determines how smooth the video appears. Defaults to 8.
94
+ seed (int, optional): A seed value for random number generation, ensuring reproducibility of the video generation process. Defaults to 42.
95
+
96
+ Returns:
97
+ torch.Tensor: A tensor representing the generated video, structured as (time, channel, height, width).
98
+ """
99
+ self.pipeline.args.seed = seed
100
+ self.pipeline.args.text_input = prompt
101
+ self.pipeline.args.video_length = fps * seconds
102
+ video = self.pipeline.run_inference(input_image)
103
+
104
+ return video
src/videogen_hub/infermodels/i2vgen_xl.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+ import torch
4
+ from huggingface_hub import snapshot_download, hf_hub_download
5
+ from PIL import Image
6
+
7
+ from videogen_hub import MODEL_PATH
8
+
9
+
10
+ class I2VGenXL:
11
+ def __init__(self):
12
+ """
13
+ Initializes the I2VGenXL model using the ali-vilab/i2vgen-xl checkpoint from the Hugging Face Hub.
14
+
15
+ Args:
16
+ None
17
+ """
18
+
19
+ from diffusers import I2VGenXLPipeline
20
+ model_path = os.path.join(MODEL_PATH, "i2vgen-xl")
21
+ model_path = snapshot_download("ali-vilab/i2vgen-xl", local_dir=model_path, ignore_patterns=["*fp16*", "*png"])
22
+ self.pipeline = I2VGenXLPipeline.from_pretrained(
23
+ model_path, torch_dtype=torch.float16, variant="fp16"
24
+ )
25
+
26
+ def infer_one_video(
27
+ self,
28
+ input_image: Image.Image,
29
+ prompt: str = None,
30
+ size: list = [320, 512],
31
+ seconds: int = 2,
32
+ fps: int = 8,
33
+ seed: int = 42,
34
+ ):
35
+ """
36
+ Generates a single video based on a textual prompt and first frame image, using either a provided image or an image path as the starting point. The output is a tensor representing the video.
37
+
38
+ Args:
39
+ input_image (Image.Image): The input image path or tensor to use as the basis for video generation.
40
+ prompt (str, optional): The text prompt that guides the video generation. If not specified, the video generation will rely solely on the input image. Defaults to None.
41
+ size (list, optional): Specifies the resolution of the output video as [height, width]. Defaults to [320, 512].
42
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
43
+ fps (int, optional): The number of frames per second in the generated video. This determines how smooth the video appears. Defaults to 8.
44
+ seed (int, optional): A seed value for random number generation, ensuring reproducibility of the video generation process. Defaults to 42.
45
+
46
+ Returns:
47
+ torch.Tensor: A tensor representing the generated video, structured as (time, channel, height, width).
48
+ """
49
+ return self.pipeline(
50
+ prompt=prompt,
51
+ image=input_image,
52
+ height=size[0],
53
+ width=size[1],
54
+ target_fps=fps,
55
+ num_frames=seconds * fps,
56
+ generator=torch.manual_seed(seed),
57
+ )
src/videogen_hub/infermodels/lavie.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import torch
3
+
4
+ from videogen_hub import MODEL_PATH
5
+
6
+
7
+ class LaVie():
8
+ def __init__(self, model_path=os.path.join(MODEL_PATH, "lavie"), device="cuda"):
9
+ """
10
+ 1. Download all necessary models from huggingface.
11
+ 2. Initializes the LaVie model with a specific model path and device.
12
+
13
+ Args:
14
+ model_path (str, optional): The path to the model checkpoints. Defaults to "MODEL_PATH/lavie".
15
+ device (str, optional): The device to run the model on. Defaults to "cuda".
16
+ """
17
+
18
+ # Put the source code imports here to avoid dependency version issues
19
+ from videogen_hub.pipelines.lavie.lavie_src.base.pipelines.pipeline_videogen import VideoGenPipeline
20
+ from videogen_hub.pipelines.lavie.lavie_src.base.download import find_model
21
+ from videogen_hub.pipelines.lavie.lavie_src.base.models.unet import UNet3DConditionModel
22
+ from diffusers.schedulers import DDPMScheduler
23
+ from diffusers.models import AutoencoderKL
24
+ from transformers import CLIPTokenizer, CLIPTextModel
25
+ from huggingface_hub import snapshot_download
26
+ from omegaconf import OmegaConf
27
+
28
+ snapshot_download(repo_id="Vchitect/LaVie", local_dir=model_path)
29
+ snapshot_download(repo_id="CompVis/stable-diffusion-v1-4", local_dir=os.path.join(model_path, "stable-diffusion-v1-4"))
30
+ snapshot_download(repo_id="stabilityai/stable-diffusion-x4-upscaler",
31
+ local_dir=os.path.join(model_path, "stable-diffusion-x4-upscaler"))
32
+
33
+ torch.set_grad_enabled(False)
34
+ self.device = device
35
+
36
+ config = {
37
+ "model_config": {
38
+ "use_compile": False,
39
+ "use_fp16": True,
40
+ "run_time": 0,
41
+ "guidance_scale": 7.5,
42
+ "num_sampling_steps": 50
43
+ },
44
+ "scheduler_config": {
45
+ "sample_method": "ddpm",
46
+ "beta_start": 0.0001,
47
+ "beta_end": 0.02,
48
+ "beta_schedule": "linear"
49
+ }
50
+ }
51
+ self.config = OmegaConf.create(config)
52
+
53
+ sd_path = os.path.join(model_path, "stable-diffusion-v1-4")
54
+ unet = UNet3DConditionModel.from_pretrained_2d(sd_path, subfolder="unet").to(device, dtype=torch.float16)
55
+ state_dict = find_model(os.path.join(model_path, "lavie_base.pt"))
56
+ unet.load_state_dict(state_dict)
57
+
58
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
59
+ tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
60
+ text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder",
61
+ torch_dtype=torch.float16).to(device) # huge
62
+
63
+ scheduler = DDPMScheduler.from_pretrained(sd_path,
64
+ subfolder="scheduler",
65
+ beta_start=self.config.scheduler_config.beta_start,
66
+ beta_end=self.config.scheduler_config.beta_end,
67
+ beta_schedule=self.config.scheduler_config.beta_schedule)
68
+
69
+ self.videogen_pipeline = VideoGenPipeline(vae=vae,
70
+ text_encoder=text_encoder_one,
71
+ tokenizer=tokenizer_one,
72
+ scheduler=scheduler,
73
+ unet=unet).to(device)
74
+ self.videogen_pipeline.enable_xformers_memory_efficient_attention()
75
+
76
+ def infer_one_video(self,
77
+ prompt: str = None,
78
+ size: list = [320, 512],
79
+ seconds: int = 2,
80
+ fps: int = 8,
81
+ seed: int = 42):
82
+ """
83
+ Generates a single video based on the provided prompt and parameters.
84
+
85
+ Args:
86
+ prompt (str, optional): The text prompt to generate the video from. Defaults to None.
87
+ size (list, optional): The size of the video as [height, width]. Defaults to [320, 512].
88
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
89
+ fps (int, optional): The frames per second of the video. Defaults to 8.
90
+ seed (int, optional): The seed for random number generation. Defaults to 42.
91
+
92
+ Returns:
93
+ torch.Tensor: The generated video as a tensor.
94
+ """
95
+ if seed is not None:
96
+ torch.manual_seed(seed)
97
+ videos = self.videogen_pipeline(prompt,
98
+ video_length=seconds * fps,
99
+ height=size[0],
100
+ width=size[1],
101
+ num_inference_steps=self.config.model_config.num_sampling_steps,
102
+ guidance_scale=self.config.model_config.guidance_scale).video
103
+ return videos[0]
src/videogen_hub/infermodels/modelscope.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from huggingface_hub import snapshot_download
5
+
6
+ from videogen_hub import MODEL_PATH
7
+
8
+
9
+ class ModelScope:
10
+ def __init__(self, device="gpu"):
11
+ """
12
+ 1. Download the pretrained model and put it inside checkpoints/modelscope
13
+ 2. Create Pipeline
14
+ Note: it seems that the model needed from model_dir cannot support cpu
15
+ Args:
16
+ device: 'gpu' or 'cpu' the device to use the model
17
+ """
18
+ from modelscope.pipelines import pipeline
19
+ from modelscope.models import Model
20
+
21
+ model_dir = snapshot_download(
22
+ repo_id="ali-vilab/modelscope-damo-text-to-video-synthesis",
23
+ local_dir=os.path.join(MODEL_PATH, "modelscope"),
24
+
25
+ )
26
+ model = Model.from_pretrained(model_dir)
27
+ self.pipeline = pipeline("text-to-video-synthesis", model=model, device=device)
28
+
29
+ def infer_one_video(
30
+ self, prompt: str = None, seconds: int = 2, fps: int = 8, seed: int = 42
31
+ ):
32
+ """
33
+ Generates a single video based on the provided prompt and parameters.
34
+ The generated video always has resolution 256x256
35
+
36
+ Args:
37
+ prompt (str, optional): The text prompt to generate the video from. Defaults to None.
38
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
39
+ fps (int, optional): The frames per second of the video. Defaults to 8.
40
+ seed (int, optional): The seed for random number generation. Defaults to 42.
41
+
42
+ Returns:
43
+ torch.Tensor: The generated video as a tensor.
44
+ """
45
+ from modelscope.outputs import OutputKeys
46
+ from decord import VideoReader
47
+ from decord import cpu, gpu
48
+ import io
49
+
50
+ torch.manual_seed(seed)
51
+ self.pipeline.model.config.model.model_args.max_frames = fps * seconds
52
+
53
+ test_text = {
54
+ "text": prompt,
55
+ }
56
+ output_video_path = self.pipeline(
57
+ test_text,
58
+ )[OutputKeys.OUTPUT_VIDEO]
59
+ result = io.BytesIO(output_video_path)
60
+ result = VideoReader(result, ctx=cpu(0))
61
+ result = torch.from_numpy(result.get_batch(range(len(result))).asnumpy())
62
+ return result
src/videogen_hub/infermodels/opensora.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import snapshot_download, hf_hub_download
4
+
5
+ from videogen_hub import MODEL_PATH
6
+
7
+
8
+ class OpenSora:
9
+ def __init__(self, device="gpu"):
10
+ """
11
+ 1. Download the pretrained model and put it inside MODEL_PATH/modelscope
12
+ 2. Create Pipeline
13
+ Note: it seems that the model needed from model_dir cannot support cpu
14
+ Args:
15
+ device: 'gpu' or 'cpu' the device to use the model
16
+ """
17
+
18
+ from mmengine import Config as mmengine_config
19
+ from videogen_hub.pipelines.opensora.scripts.inference import main
20
+
21
+ self.pipeline = main
22
+ self.config = {
23
+ # Basic video frame settings
24
+ "num_frames": 32, # Total number of frames in a clip
25
+ "frame_interval": 3, # Interval between frames
26
+ "fps": 24, # Frames per second
27
+ "image_size": [480, 854], # Resolution of each frame (height, width)
28
+ # Model configuration for multi-resolution and specific model parameters
29
+ "multi_resolution": "STDiT2", # Multi-resolution model type
30
+ "model": {
31
+ "type": "STDiT2-XL/2", # Model type and size
32
+ "from_pretrained": os.path.join(MODEL_PATH, "STDiT2-XL_2"), # Path to pretrained checkpoint
33
+ "file_name": "model.safetensors", # Name of the model file
34
+ "input_sq_size": 512, # Input square size for the model
35
+ "qk_norm": True, # Whether to normalize query-key in attention
36
+ "enable_flashattn": False, # Enable flash attention mechanism, require flash_attn package
37
+ "enable_layernorm_kernel": False, # Enable layer normalization in kernel, requires apex package
38
+ },
39
+ # Variational Autoencoder (VAE) specific settings
40
+ "vae": {
41
+ "type": "VideoAutoencoderKL", # Type of the autoencoder
42
+ "from_pretrained": "stabilityai/sd-vae-ft-ema", # Pretrained model from Hugging Face
43
+ "cache_dir": os.path.join(MODEL_PATH, "sd-vae-ft-ema"), # Local cache directory for model weights
44
+ "micro_batch_size": 4, # Batch size for processing
45
+ },
46
+ # Text encoder settings for embedding textual information
47
+ "text_encoder": {
48
+ "type": "t5", # Text encoder model type
49
+ "from_pretrained": "DeepFloyd/t5-v1_1-xxl", # Pretrained model
50
+ "cache_dir": os.path.join(MODEL_PATH, "t5-v1_1-xxl"), # Cache directory
51
+ "model_max_length": 200, # Max length of text inputs
52
+ },
53
+ # Scheduler settings for diffusion models
54
+ "scheduler": {
55
+ "type": "iddpm", # Type of scheduler for the diffusion process
56
+ "num_sampling_steps": 50, # Number of sampling steps in diffusion
57
+ "cfg_scale": 7.0, # Scale for classifier-free guidance
58
+ "cfg_channel": 3, # Number of channels for guidance
59
+ },
60
+ # Additional settings for processing and output
61
+ "dtype": "bf16", # Data type for computation (bfloat16)
62
+ # "prompt_path": "./assets/texts/t2v_samples.txt", # Path to text prompts
63
+ "prompt_path": None, # Path to text prompts
64
+ "prompt": [
65
+ "A beautiful sunset over the city"
66
+ ], # List of prompts for generation
67
+ "batch_size": 1, # Batch size for generation
68
+ "seed": 42, # Seed for random number generators
69
+ "save_dir": "./samples/samples/", # Directory to save generated samples
70
+ "config": "sample.py", # Path to this configuration file
71
+ "prompt_as_path": False, # Treat the prompt as a file path (True/False)
72
+ "reference_path": None, # Path to reference image/video for conditioning
73
+ "loop": 1, # Number of times to loop the processing
74
+ "sample_name": None, # Specific name for the generated sample
75
+ "num_sample": 1, # Number of samples to generate
76
+ }
77
+ self.config = mmengine_config(self.config)
78
+
79
+ hf_hub_download(
80
+ repo_id="hpcai-tech/OpenSora-STDiT-v2-stage2",
81
+ filename="model.safetensors",
82
+ local_dir=self.config.model.from_pretrained,
83
+ )
84
+
85
+ hf_hub_download(
86
+ repo_id="stabilityai/sd-vae-ft-ema",
87
+ filename="diffusion_pytorch_model.safetensors",
88
+ local_dir=self.config.vae.cache_dir,
89
+ )
90
+
91
+ hf_hub_download(
92
+ repo_id="DeepFloyd/t5-v1_1-xxl",
93
+ filename="pytorch_model-00001-of-00002.bin",
94
+ local_dir=self.config.text_encoder.cache_dir,
95
+ )
96
+
97
+ def infer_one_video(
98
+ self,
99
+ prompt: str = None,
100
+ size: list = [320, 512],
101
+ seconds: int = 2,
102
+ fps: int = 8,
103
+ seed: int = 42,
104
+ ):
105
+ """
106
+ Generates a single video based on the provided prompt and parameters.
107
+ The generated video always has resolution 256x256
108
+
109
+ Args:
110
+ prompt (str, optional): The text prompt to generate the video from. Defaults to None.
111
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
112
+ fps (int, optional): The frames per second of the video. Defaults to 8.
113
+ seed (int, optional): The seed for random number generation. Defaults to 42.
114
+
115
+ Returns:
116
+ torch.Tensor: The generated video as a tensor.
117
+ """
118
+
119
+ self.config.num_frames = fps * seconds
120
+ self.config.fps = fps
121
+ self.config.seed = seed
122
+ self.config.prompt = [prompt]
123
+ self.config.image_size = size
124
+
125
+ all_batch_samples = self.pipeline(self.config)
126
+
127
+ sample = all_batch_samples[0][0]
128
+ # sample is torch.Size([1, C, f, H, W])
129
+
130
+ output = sample.squeeze(0).permute(1, 2, 3, 0).cpu().float()
131
+ # torch.Size([1, C, f, H, W]) -> torch.Size([f, H, W, C])
132
+ # BFloat16 -> Float
133
+
134
+ return output
src/videogen_hub/infermodels/opensora_12.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import snapshot_download, hf_hub_download
4
+
5
+ from videogen_hub import MODEL_PATH
6
+
7
+
8
+ class OpenSora12:
9
+ def __init__(self, device="gpu"):
10
+ """
11
+ 1. Download the pretrained model and put it inside MODEL_PATH/modelscope
12
+ 2. Create Pipeline
13
+ Note: it seems that the model needed from model_dir cannot support cpu
14
+ Args:
15
+ device: 'gpu' or 'cpu' the device to use the model
16
+ """
17
+
18
+ from mmengine import Config as mmengine_config
19
+ from videogen_hub.pipelines.opensora.scripts.inference import main
20
+ model_path = snapshot_download("hpcai-tech/OpenSora-STDiT-v3",
21
+ local_dir=os.path.join(MODEL_PATH, 'OpenSora-STDiT-v3'))
22
+ self.pipeline = main
23
+ self.config = {
24
+ # Basic video frame settings
25
+ "num_frames": 51, # Total number of frames in a clip
26
+ "frame_interval": 1, # Interval between frames
27
+ "fps": 24, # Frames per second
28
+ "image_size": [480, 854], # Resolution of each frame (height, width)
29
+ # Model configuration for multi-resolution and specific model parameters
30
+ "multi_resolution": "STDiT2", # Multi-resolution model type
31
+ "model": {
32
+ "type": "STDiT3-XL/2", # Model type and size
33
+ "from_pretrained": os.path.join(MODEL_PATH, "STDiT3-XL_2"), # Path to pretrained checkpoint
34
+ "file_name": "model.safetensors", # Name of the model file
35
+ "input_sq_size": 512, # Input square size for the model
36
+ "qk_norm": True, # Whether to normalize query-key in attention
37
+ "enable_flashattn": False, # Enable flash attention mechanism, require flash_attn package
38
+ "enable_layernorm_kernel": False, # Enable layer normalization in kernel, requires apex package
39
+ },
40
+ # Variational Autoencoder (VAE) specific settings
41
+ "vae": {
42
+ "type": "OpenSoraVAE_V1_2", # Type of the autoencoder
43
+ "from_pretrained": "hpcai-tech/OpenSora-VAE-v1.2", # Pretrained model from Hugging Face
44
+ #"cache_dir": os.path.join(MODEL_PATH, "OpenSora-VAE-v1.2"), # Local cache directory for model weights
45
+ "micro_frame_size": 17,
46
+ "micro_batch_size": 4, # Batch size for processing
47
+ },
48
+ # Text encoder settings for embedding textual information
49
+ "text_encoder": {
50
+ "type": "t5", # Text encoder model type
51
+ "from_pretrained": "DeepFloyd/t5-v1_1-xxl", # Pretrained model
52
+ "cache_dir": os.path.join(MODEL_PATH, "t5-v1_1-xxl"), # Cache directory
53
+ "model_max_length": 300, # Max length of text inputs
54
+ },
55
+ # Scheduler settings for diffusion models
56
+ "scheduler": {
57
+ "type": "rflow", # Type of scheduler for the diffusion process
58
+ "num_sampling_steps": 30, # Number of sampling steps in diffusion
59
+ "cfg_scale": 7.0, # Scale for classifier-free guidance
60
+ # "cfg_channel": 3, # Number of channels for guidance
61
+ },
62
+ # Additional settings for processing and output
63
+ "dtype": "bf16", # Data type for computation (bfloat16)
64
+ # "prompt_path": "./assets/texts/t2v_samples.txt", # Path to text prompts
65
+ "prompt_path": None, # Path to text prompts
66
+ "prompt": [
67
+ "A beautiful sunset over the city"
68
+ ], # List of prompts for generation
69
+ "batch_size": 1, # Batch size for generation
70
+ "seed": 42, # Seed for random number generators
71
+ "save_dir": "./samples/samples/", # Directory to save generated samples
72
+ "config": "sample.py", # Path to this configuration file
73
+ "prompt_as_path": False, # Treat the prompt as a file path (True/False)
74
+ "reference_path": None, # Path to reference image/video for conditioning
75
+ "loop": 1, # Number of times to loop the processing
76
+ "sample_name": None, # Specific name for the generated sample
77
+ "num_sample": 1, # Number of samples to generate
78
+ "aes": 6.5,
79
+ "flow": None,
80
+ }
81
+ self.config = mmengine_config(self.config)
82
+
83
+ hf_hub_download(
84
+ repo_id="hpcai-tech/OpenSora-STDiT-v3",
85
+ filename="model.safetensors",
86
+ local_dir=self.config.model.from_pretrained,
87
+ )
88
+
89
+ hf_hub_download(
90
+ repo_id="hpcai-tech/OpenSora-VAE-v1.2",
91
+ filename="model.safetensors",
92
+ local_dir=os.path.join(MODEL_PATH, "OpenSora-VAE-v1.2"),
93
+ )
94
+
95
+ hf_hub_download(
96
+ repo_id="DeepFloyd/t5-v1_1-xxl",
97
+ filename="pytorch_model-00001-of-00002.bin",
98
+ local_dir=self.config.text_encoder.cache_dir,
99
+ )
100
+
101
+ def infer_one_video(
102
+ self,
103
+ prompt: str = None,
104
+ size: list = [320, 512],
105
+ seconds: int = 2,
106
+ fps: int = 8,
107
+ seed: int = 42,
108
+ ):
109
+ """
110
+ Generates a single video based on the provided prompt and parameters.
111
+ The generated video always has resolution 256x256
112
+
113
+ Args:
114
+ prompt (str, optional): The text prompt to generate the video from. Defaults to None.
115
+ size (list, optional): The resolution of the video. Defaults to [320, 512].
116
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
117
+ fps (int, optional): The frames per second of the video. Defaults to 8.
118
+ seed (int, optional): The seed for random number generation. Defaults to 42.
119
+
120
+ Returns:
121
+ torch.Tensor: The generated video as a tensor.
122
+ """
123
+
124
+ self.config.num_frames = fps * seconds
125
+ self.config.fps = fps
126
+ self.config.seed = seed
127
+ self.config.prompt = [prompt]
128
+ self.config.image_size = size
129
+
130
+ all_batch_samples = self.pipeline(self.config)
131
+
132
+ sample = all_batch_samples[0][0]
133
+ # sample is torch.Size([1, C, f, H, W])
134
+
135
+ output = sample.squeeze(0).permute(1, 2, 3, 0).cpu().float()
136
+ # torch.Size([1, C, f, H, W]) -> torch.Size([f, H, W, C])
137
+ # BFloat16 -> Float
138
+
139
+ return output
src/videogen_hub/infermodels/opensora_plan.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import snapshot_download, hf_hub_download
4
+ import torch
5
+
6
+ from videogen_hub import MODEL_PATH
7
+
8
+
9
+ class OpenSoraPlan():
10
+ def __init__(self, device="cuda"):
11
+ """
12
+ 1. Download the pretrained model and put it inside MODEL_PATH
13
+ 2. Create Pipeline
14
+ Note: it seems that the model needed from model_dir cannot support cpu
15
+ Args:
16
+ device: 'cuda' or 'cpu' the device to use the model
17
+ """
18
+ from videogen_hub.pipelines.opensora_plan.opensora.sample_t2v import OpenSoraPlanPipeline
19
+
20
+ model_path = snapshot_download('LanguageBind/Open-Sora-Plan-v1.1.0', local_dir = os.path.join(MODEL_PATH, 'Open-Sora-Plan-v1.1.0'))
21
+
22
+ arg_list = ['--model_path', model_path,
23
+ '--version', '65x512x512',
24
+ '--num_frames', '65',
25
+ '--height', '512',
26
+ '--width', '512',
27
+ '--cache_dir', MODEL_PATH,
28
+ '--text_encoder_name', 'DeepFloyd/t5-v1_1-xxl',
29
+ '--text_prompt', 'prompt_list_0.txt',
30
+ '--ae', 'CausalVAEModel_4x8x8',
31
+ '--ae_path', "/remote-home1/yeyang/CausalVAEModel_4x8x8",
32
+ '--save_img_path', "./sample_video_65x512x512",
33
+ '--fps', '24',
34
+ '--guidance_scale', '7.5',
35
+ '--num_sampling_steps', '150',
36
+ '--enable_tiling']
37
+ self.pipeline = OpenSoraPlanPipeline(arg_list, device)
38
+
39
+ def infer_one_video(
40
+ self,
41
+ prompt: str = None,
42
+ size: list = [320, 512],
43
+ seconds: int = 2,
44
+ fps: int = 8,
45
+ seed: int = 42,
46
+ ):
47
+ """
48
+ Generates a single video based on the provided prompt and parameters.
49
+ Note that there are only 3 available shapes: (1 or 65 or 221)xHxW
50
+ The output is of shape [frames, channels, height, width].
51
+ Args:
52
+ prompt (str, optional): The text prompt to generate the video from. Defaults to None.
53
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
54
+ fps (int, optional): The frames per second of the video. Defaults to 8.
55
+ seed (int, optional): The seed for random number generation. Defaults to 42.
56
+
57
+ Returns:
58
+ torch.Tensor: The generated video as a tensor.
59
+ """
60
+
61
+ torch.manual_seed(seed)
62
+
63
+ self.pipeline.args.text_prompt = prompt
64
+ self.pipeline.args.num_frames = fps * seconds
65
+ self.pipeline.args.fps = fps
66
+ self.pipeline.args.height = size[0]
67
+ self.pipeline.args.width = size[1]
68
+
69
+ samples = self.pipeline.inference(save_output=False)
70
+ # samples is torch.Size([B, T, H, W, C])
71
+
72
+ output = samples.squeeze(0).permute(0, 3, 1, 2).cpu().float()
73
+ return output
src/videogen_hub/infermodels/seine.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from huggingface_hub import snapshot_download, hf_hub_download
6
+
7
+ from videogen_hub import MODEL_PATH
8
+
9
+
10
+ class SEINE():
11
+ def __init__(self):
12
+ """
13
+ 1. Download the pretrained model and put it inside MODEL_PATH/SEINE
14
+ 2. Create Pipeline.
15
+ """
16
+ from videogen_hub.pipelines.seine.SEINEPipeline import SEINEPipeline
17
+
18
+ seine_path = hf_hub_download(repo_id="Vchitect/SEINE", filename="seine.pt", local_dir=os.path.join(MODEL_PATH, "SEINE"))
19
+ pretrained_model_path = snapshot_download(repo_id="CompVis/stable-diffusion-v1-4",
20
+ local_dir=os.path.join(MODEL_PATH, "SEINE", "stable-diffusion-v1-4"),
21
+ ignore_patterns=["*pytorch_model.bin", "*fp16*", "*non_ema*"])
22
+
23
+ self.pipeline = SEINEPipeline(seine_path, pretrained_model_path,
24
+ 'src/videogen_hub/pipelines/seine/sample_i2v.yaml')
25
+
26
+ def infer_one_video(self,
27
+ input_image: Image.Image,
28
+ prompt: str = None,
29
+ size: list = [320, 512],
30
+ seconds: int = 2,
31
+ fps: int = 8,
32
+ seed: int = 42):
33
+ """
34
+ Generates a single video based on a textual prompt and first frame image, using either a provided image or an image path as the starting point. The output is a tensor representing the video.
35
+
36
+ Args:
37
+ input_image (PIL.Image.Image): The input image to use as the basis for video generation.
38
+ prompt (str, optional): The text prompt that guides the video generation. If not specified, the video generation will rely solely on the input image. Defaults to None.
39
+ size (list, optional): Specifies the resolution of the output video as [height, width]. Defaults to [320, 512].
40
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
41
+ fps (int, optional): The number of frames per second in the generated video. This determines how smooth the video appears. Defaults to 8.
42
+ seed (int, optional): A seed value for random number generation, ensuring reproducibility of the video generation process. Defaults to 42.
43
+
44
+ Returns:
45
+ torch.Tensor: A tensor representing the generated video, structured as (time, channel, height, width).
46
+ """
47
+ video = self.pipeline.infer_one_video(input_image=input_image,
48
+ text_prompt=prompt,
49
+ output_size=size,
50
+ num_frames=seconds * fps,
51
+ seed=seed)
52
+ return video
src/videogen_hub/infermodels/show_one.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ from videogen_hub import MODEL_PATH
5
+
6
+
7
+ class ShowOne():
8
+ def __init__(self):
9
+ """
10
+ Initialize the Pipeline, which download all necessary models.
11
+ """
12
+ from videogen_hub.pipelines.show_1.run_inference import ShowOnePipeline
13
+ from huggingface_hub import snapshot_download
14
+
15
+ base_path = snapshot_download(
16
+ repo_id="showlab/show-1-base",
17
+ local_dir=os.path.join(MODEL_PATH, "showlab", "show-1-base"),
18
+ local_dir_use_symlinks = False
19
+ )
20
+
21
+ interp_path = snapshot_download(
22
+ repo_id="showlab/show-1-interpolation",
23
+ local_dir=os.path.join(MODEL_PATH, "showlab", "show-1-interpolation"),
24
+
25
+ )
26
+
27
+ deepfloyd_path = snapshot_download(
28
+ repo_id="DeepFloyd/IF-II-L-v1.0",
29
+ local_dir=os.path.join(MODEL_PATH, "DeepFloyd/IF-II-L-v1.0"),
30
+
31
+ )
32
+
33
+ sr1_path = snapshot_download(
34
+ repo_id="showlab/show-1-sr1",
35
+ local_dir=os.path.join(MODEL_PATH, "showlab", "show-1-sr1"),
36
+
37
+ )
38
+
39
+ sr2_path = snapshot_download(
40
+ repo_id="showlab/show-1-sr2",
41
+ local_dir=os.path.join(MODEL_PATH, "showlab", "show-1-sr2"),
42
+
43
+ )
44
+
45
+ self.pipeline = ShowOnePipeline(base_path, interp_path, deepfloyd_path, sr1_path, sr2_path)
46
+
47
+ def infer_one_video(self,
48
+ prompt: str = None,
49
+ size: list = [320, 512],
50
+ seconds: int = 2,
51
+ fps: int = 8,
52
+ seed: int = 42):
53
+ """
54
+ Generates a single video based on a textual prompt. The output is a tensor representing the video.
55
+ Since the initial_num_frames is set to be 8 as shown in paper in the pipeline,
56
+ we need the (number of frames - 1) divisible by 7 to manage interpolation.
57
+
58
+ Args:
59
+ prompt (str, optional): The text prompt that guides the video generation. If not specified, the video generation will rely solely on the input image. Defaults to None.
60
+ size (list, optional): Specifies the resolution of the output video as [height, width]. Defaults to [320, 512].
61
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
62
+ fps (int, optional): The number of frames per second in the generated video. This determines how smooth the video appears. Defaults to 8.
63
+ seed (int, optional): A seed value for random number generation, ensuring reproducibility of the video generation process. Defaults to 42.
64
+
65
+ Returns:
66
+ torch.Tensor: A tensor representing the generated video, structured as (time, channel, height, width).
67
+ """
68
+ num_frames = fps * seconds
69
+
70
+ assert (num_frames - 1) % 7 == 0
71
+ scaling_factor = (num_frames - 1) // 7
72
+ video = self.pipeline.inference(prompt=prompt,
73
+ negative_prompt="",
74
+ output_size=size,
75
+ initial_num_frames=8,
76
+ scaling_factor=scaling_factor,
77
+ seed=seed)
78
+
79
+ return video
src/videogen_hub/infermodels/streamingt2v.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ from videogen_hub import MODEL_PATH
6
+
7
+
8
+ class StreamingT2V:
9
+ def __init__(self, device="cuda"):
10
+ """
11
+ Initializes the StreamingT2V model.
12
+
13
+ Args:
14
+ device (str, optional): The device to run the model on. Defaults to "cuda".
15
+ """
16
+
17
+ from videogen_hub.pipelines.streamingt2v.streamingt2v_pipeline import pipeline
18
+ # https://huggingface.co/spaces/PAIR/StreamingT2V/resolve/main/t2v_enhanced/checkpoints/streaming_t2v.ckpt?download=true
19
+ model_url = "https://huggingface.co/spaces/PAIR/StreamingT2V/resolve/main/t2v_enhanced/checkpoints/streaming_t2v.ckpt?download=true"
20
+ # Download the file
21
+ ckpt_file_streaming_t2v = hf_hub_download(repo_id="PAIR/StreamingT2V",
22
+ filename="streaming_t2v.ckpt",
23
+ local_dir=os.path.join(MODEL_PATH, "streamingtv2"))
24
+
25
+ self.pipeline = pipeline
26
+
27
+ def infer_one_video(
28
+ self,
29
+ prompt: str = None,
30
+ size: list = [320, 512],
31
+ seconds: int = 2,
32
+ fps: int = 8,
33
+ seed: int = 42,
34
+ ):
35
+ """
36
+ Generates a single video based on the provided prompt and parameters.
37
+
38
+ Args:
39
+ prompt (str, optional): The text prompt to generate the video from. Defaults to None.
40
+ size (list, optional): The size of the video as [height, width]. Defaults to [320, 512].
41
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
42
+ fps (int, optional): The frames per second of the video. Defaults to 8.
43
+ seed (int, optional): The seed for random number generation. Defaults to 42.
44
+
45
+ Returns:
46
+ torch.Tensor: The generated video as a tensor.
47
+ """
48
+
49
+ return self.pipeline(prompt, size, seconds, fps, seed)
src/videogen_hub/infermodels/t2v_turbo.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import hf_hub_download, snapshot_download
4
+ import torch
5
+
6
+ from videogen_hub import MODEL_PATH
7
+
8
+
9
+ class T2VTurbo():
10
+ def __init__(self, base_model="vc2", merged=True, device="cuda"):
11
+ """
12
+ 1. Download the pretrained model and put it inside MODEL_PATH
13
+ 2. Create Pipeline
14
+ Args:
15
+ device: 'cuda' or 'cpu' the device to use the model
16
+ """
17
+ from videogen_hub.pipelines.t2v_turbo.inference_vc2 import T2VTurboVC2Pipeline1
18
+ from videogen_hub.pipelines.t2v_turbo.inference_ms import T2VTurboMSPipeline1
19
+
20
+ self.config = {
21
+ "model": {
22
+ "target": "lvdm.models.ddpm3d.LatentDiffusion",
23
+ "params": {
24
+ "linear_start": 0.00085,
25
+ "linear_end": 0.012,
26
+ "num_timesteps_cond": 1,
27
+ "timesteps": 1000,
28
+ "first_stage_key": "video",
29
+ "cond_stage_key": "caption",
30
+ "cond_stage_trainable": False,
31
+ "conditioning_key": "crossattn",
32
+ "image_size": [320, 512],
33
+ "channels": 4,
34
+ "scale_by_std": False,
35
+ "scale_factor": 0.18215,
36
+ "use_ema": False,
37
+ "uncond_type": "empty_seq",
38
+ "use_scale": True,
39
+ "scale_b": 0.7,
40
+ "unet_config": {
41
+ "target": "lvdm.modules.networks.openaimodel3d.UNetModel",
42
+ "params": {
43
+ "in_channels": 4,
44
+ "out_channels": 4,
45
+ "model_channels": 320,
46
+ "attention_resolutions": [4, 2, 1],
47
+ "num_res_blocks": 2,
48
+ "channel_mult": [1, 2, 4, 4],
49
+ "num_head_channels": 64,
50
+ "transformer_depth": 1,
51
+ "context_dim": 1024,
52
+ "use_linear": True,
53
+ "use_checkpoint": True,
54
+ "temporal_conv": True,
55
+ "temporal_attention": True,
56
+ "temporal_selfatt_only": True,
57
+ "use_relative_position": False,
58
+ "use_causal_attention": False,
59
+ "temporal_length": 16,
60
+ "addition_attention": True,
61
+ "fps_cond": True
62
+ }
63
+ },
64
+ "first_stage_config": {
65
+ "target": "lvdm.models.autoencoder.AutoencoderKL",
66
+ "params": {
67
+ "embed_dim": 4,
68
+ "monitor": "val / rec_loss",
69
+ "ddconfig": {
70
+ "double_z": True,
71
+ "z_channels": 4,
72
+ "resolution": 512,
73
+ "in_channels": 3,
74
+ "out_ch": 3,
75
+ "ch": 128,
76
+ "ch_mult": [1, 2, 4, 4],
77
+ "num_res_blocks": 2,
78
+ "attn_resolutions": [],
79
+ "dropout": 0.0
80
+ },
81
+ "lossconfig": {
82
+ "target": "torch.nn.Identity"
83
+ }
84
+ }
85
+ },
86
+ "cond_stage_config": {
87
+ "target": "lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder",
88
+ "params": {
89
+ "freeze": True,
90
+ "layer": "penultimate"
91
+ }
92
+ }
93
+ }
94
+ }
95
+ }
96
+ if base_model == "vc2" and merged:
97
+ merged_model_path = hf_hub_download(repo_id="jiachenli-ucsb/T2V-Turbo-VC2-Merged",
98
+ filename="t2v_turbo_vc2.pt",
99
+ local_dir=os.path.join(MODEL_PATH, "T2V-Turbo-VC2"))
100
+ self.pipeline = T2VTurboVC2Pipeline1(self.config, merged, device, None, merged_model_path)
101
+
102
+ elif base_model == "vc2":
103
+ base_model_path = hf_hub_download(repo_id="VideoCrafter/VideoCrafter2",
104
+ filename="model.ckpt",
105
+ local_dir=os.path.join(MODEL_PATH, "videocrafter2"))
106
+
107
+ unet_lora_path = hf_hub_download(repo_id="jiachenli-ucsb/T2V-Turbo-VC2",
108
+ filename="unet_lora.pt",
109
+ local_dir=os.path.join(MODEL_PATH, "T2V-Turbo-VC2"))
110
+ # It uses the config provided above.
111
+ self.pipeline = T2VTurboVC2Pipeline1(self.config, merged, device, unet_lora_path, base_model_path)
112
+ else:
113
+ base_model_path = snapshot_download(repo_id="ali-vilab/text-to-video-ms-1.7b",
114
+ local_dir=os.path.join(MODEL_PATH, "modelscope_1.7b"))
115
+
116
+ unet_lora_path = hf_hub_download(repo_id="jiachenli-ucsb/T2V-Turbo-MS",
117
+ filename="unet_lora.pt",
118
+ local_dir=os.path.join(MODEL_PATH, "T2V-Turbo-MS"))
119
+
120
+ # It uses the config provided by base_model.
121
+ self.pipeline = T2VTurboMSPipeline1(device, unet_lora_path, base_model_path)
122
+
123
+ def infer_one_video(
124
+ self,
125
+ prompt: str = None,
126
+ size: list = [320, 512],
127
+ seconds: int = 2,
128
+ fps: int = 8,
129
+ seed: int = 42,
130
+ ):
131
+ """
132
+ Generates a single video based on the provided prompt and parameters.
133
+ The output is of shape [frames, channels, height, width].
134
+ Args:
135
+ prompt (str, optional): The text prompt to generate the video from. Defaults to None.
136
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
137
+ fps (int, optional): The frames per second of the video. Defaults to 8.
138
+ seed (int, optional): The seed for random number generation. Defaults to 42.
139
+
140
+ Returns:
141
+ torch.Tensor: The generated video as a tensor.
142
+ """
143
+ output = self.pipeline.inference(prompt=prompt, height=size[0], width=size[1],
144
+ seed=seed, num_frames=seconds * fps, fps=fps, randomize_seed=False)
145
+ # [channels, frames, height, width] -> [frames, channels, height, width]
146
+ output = output.squeeze().permute(1, 0, 2, 3)
147
+ return output.cpu()
src/videogen_hub/infermodels/videocrafter.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from huggingface_hub import hf_hub_download
3
+ from pathlib import Path
4
+ import os
5
+
6
+ from videogen_hub import MODEL_PATH
7
+
8
+
9
+ class VideoCrafter2():
10
+ def __init__(self, device="cuda"):
11
+ """
12
+ 1. Download the pretrained model and put it inside MODEL_PATH/videocrafter2
13
+ 2. Create Pipeline
14
+ Args:
15
+ device: 'cuda' or 'cpu' the device to use the model
16
+ """
17
+ from videogen_hub.pipelines.videocrafter.inference import VideoCrafterPipeline
18
+
19
+ model_path = hf_hub_download(repo_id="VideoCrafter/VideoCrafter2",
20
+ filename="model.ckpt",
21
+ local_dir=os.path.join(MODEL_PATH, "videocrafter2"))
22
+ config_path = str(Path(__file__).parent.parent.absolute())
23
+ config_path = os.path.join(config_path, 'pipelines/videocrafter/inference_t2v_512_v2.0.yaml')
24
+
25
+ arg_list = ['--mode', 'base',
26
+ '--ckpt_path', model_path,
27
+ '--config', config_path,
28
+ '--n_samples', '1',
29
+ '--bs', '1',
30
+ '--unconditional_guidance_scale', '12.0',
31
+ '--ddim_steps', '50',
32
+ '--ddim_eta', '1.0',
33
+ '--fps', '8']
34
+
35
+ self.pipeline = VideoCrafterPipeline(arg_list, device, 0, 1)
36
+
37
+ def infer_one_video(self,
38
+ prompt: str = None,
39
+ size: list = [320, 512],
40
+ seconds: int = 2,
41
+ fps: int = 8,
42
+ seed: int = 42):
43
+ """
44
+ Generates a single video based on the provided prompt and parameters.
45
+
46
+ Args:
47
+ prompt (str, optional): The text prompt to generate the video from. Defaults to None.
48
+ size (list, optional): The size of the video as [height, width]. Defaults to [320, 512].
49
+ seconds (int, optional): The duration of the video in seconds. Defaults to 2.
50
+ fps (int, optional): The frames per second of the video. Defaults to 8.
51
+ seed (int, optional): The seed for random number generation. Defaults to 42.
52
+
53
+ Returns:
54
+ torch.Tensor: The generated video as a tensor, the shape being [num_frames, 3, height, width]
55
+
56
+ """
57
+ torch.manual_seed(seed)
58
+ video = self.pipeline.run_inference(prompt,
59
+ video_length=seconds * fps,
60
+ height=size[0],
61
+ width=size[1])
62
+
63
+ return video.squeeze(0, 1).cpu().permute(1, 0, 2, 3)
src/videogen_hub/metrics/__init__.py ADDED
File without changes
src/videogen_hub/metrics/brisque_metric.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from brisque import BRISQUE
2
+ from PIL import Image
3
+ import numpy as np
4
+ from typing import List
5
+
6
+ ROUND_DIGIT=3
7
+ NUM_ASPECT=5
8
+
9
+ BRISQUE_POINT_LOW=10
10
+ BRISQUE_POINT_MID=30
11
+ BRISQUE_POINT_HIGH=50
12
+
13
+ class MetricBRISQUE():
14
+ def __init__(self) -> None:
15
+ """
16
+ Initialize a class MetricBRISQUE for testing visual quality of a given video.
17
+
18
+ """
19
+ None
20
+
21
+ def evaluate(self,frame_list:List[Image.Image]):
22
+ """
23
+ Calculate BRISQUE for visual quality for each frame of the given video and take the average value,
24
+ then quantize the orginal output based on some predefined thresholds.
25
+
26
+ Args:
27
+ frame_list:List[Image.Image], frames of the video used in calculation
28
+
29
+ Returns:
30
+ piqe_avg: float, the computed average BRISQUE among the frames
31
+ quantized_ans: int, the quantized value of the above avg score based on pre-defined thresholds.
32
+ """
33
+ brisque_list=[]
34
+ for frame in frame_list:
35
+ brisque_score=BRISQUE().score(frame)
36
+ brisque_list.append(brisque_score)
37
+ brisque_avg=np.mean(brisque_list)
38
+ quantized_ans=0
39
+ if brisque_avg < BRISQUE_POINT_LOW:
40
+ quantized_ans=4
41
+ elif brisque_avg < BRISQUE_POINT_MID:
42
+ quantized_ans=3
43
+ elif brisque_avg < BRISQUE_POINT_HIGH:
44
+ quantized_ans=2
45
+ else:
46
+ quantized_ans=1
47
+ return brisque_avg, quantized_ans
src/videogen_hub/metrics/clip-sim_metric.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch.nn.functional as F
4
+ from typing import List
5
+ from transformers import CLIPProcessor, CLIPModel
6
+
7
+ ROUND_DIGIT=3
8
+ NUM_ASPECT=5
9
+
10
+ CLIP_POINT_HIGH=0.97
11
+ CLIP_POINT_MID=0.9
12
+ CLIP_POINT_LOW=0.8
13
+
14
+
15
+ class MetricCLIP_sim():
16
+ def __init__(self, device = "cuda") -> None:
17
+ """
18
+ Initialize a class MetricCLIP_sim with the specified device for testing temporal consistency of a given video.
19
+
20
+ Args:
21
+ device (str, optional): The device on which the model will run. Defaults to "cuda".
22
+ """
23
+ self.device = device
24
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
25
+ self.model.to(self.device)
26
+ self.tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
27
+
28
+ def evaluate(self,frame_list:List[Image.Image]):
29
+ """
30
+ Calculate the cosine similarity between the CLIP features of adjacent frames of a given video to test temporal consistency,
31
+ then quantize the orginal output based on some predefined thresholds.
32
+
33
+ Args:
34
+ frame_list:List[Image.Image], frames of the video used in calculation.
35
+
36
+ Returns:
37
+ clip_frame_score: float, the computed CLIP feature cosine similarity between each adjacent pair of frames and then averaged among all the pairs.
38
+ quantized_ans: int, the quantized value of the above avg CLIP-Sim scores based on pre-defined thresholds.
39
+ """
40
+
41
+ device=self.model.device
42
+ frame_sim_list=[]
43
+ for f_idx in range(len(frame_list)-1):
44
+ frame_1 = frame_list[f_idx]
45
+ frame_2 = frame_list[f_idx+1]
46
+ input_1 = self.tokenizer(images=frame_1, return_tensors="pt", padding=True).to(device)
47
+ input_2 = self.tokenizer(images=frame_2, return_tensors="pt", padding=True).to(device)
48
+ output_1 = self.model.get_image_features(**input_1).flatten()
49
+ output_2 = self.model.get_image_features(**input_2).flatten()
50
+ cos_sim = F.cosine_similarity(output_1, output_2, dim=0).item()
51
+ frame_sim_list.append(cos_sim)
52
+
53
+ clip_frame_score = np.mean(frame_sim_list)
54
+ quantized_ans=0
55
+ if clip_frame_score >= CLIP_POINT_HIGH:
56
+ quantized_ans=4
57
+ elif clip_frame_score < CLIP_POINT_HIGH and clip_frame_score >= CLIP_POINT_MID:
58
+ quantized_ans=3
59
+ elif clip_frame_score < CLIP_POINT_MID and clip_frame_score >= CLIP_POINT_LOW:
60
+ quantized_ans=2
61
+ else:
62
+ quantized_ans=1
63
+ return clip_frame_score, quantized_ans
src/videogen_hub/metrics/clipscore_metric.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch.nn.functional as F
4
+ from typing import List
5
+ from transformers import CLIPProcessor, CLIPModel
6
+
7
+ NUM_ASPECT=5
8
+ ROUND_DIGIT=3
9
+ MAX_LENGTH = 76
10
+
11
+ MAX_NUM_FRAMES=8
12
+
13
+ CLIP_POINT_LOW=0.27
14
+ CLIP_POINT_MID=0.31
15
+ CLIP_POINT_HIGH=0.35
16
+
17
+
18
+ class MetricCLIPScore():
19
+ def __init__(self, device="cuda") -> None:
20
+ """
21
+ Initialize a MetricCLIPScore object with the specified device.
22
+
23
+ Args:
24
+ device (str, optional): The device on which the model will run. Defaults to "cuda".
25
+ """
26
+ self.device = device
27
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
28
+ self.model.to(self.device)
29
+ self.tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
30
+
31
+ def evaluate(self, frame_list:List[Image.Image], text:str,):
32
+ """
33
+ Calculate the cosine similarity of between CLIP features of text prompt and each frame of a given video to test text-to-video alignment,
34
+ then quantize the orginal output based on some predefined thresholds.
35
+
36
+ Args:
37
+ frame_list:List[Image.Image], frames of the video used in calculation.
38
+ text:str, text prompt for generating the video.
39
+
40
+ Returns:
41
+ clip_score_avg: float, the computed average CLIP-Score between each frame and the text prompt.
42
+ quantized_ans: int, the quantized value of the above avg SSIM scores based on pre-defined thresholds.
43
+ """
44
+
45
+ device=self.model.device
46
+ input_t = self.tokenizer(text=text, max_length=MAX_LENGTH, truncation=True, return_tensors="pt", padding=True).to(device)
47
+ cos_sim_list=[]
48
+ for image in frame_list:
49
+ input_f = self.tokenizer(images=image, return_tensors="pt", padding=True).to(device)
50
+ output_t = self.model.get_text_features(**input_t).flatten()
51
+ output_f = self.model.get_image_features(**input_f).flatten()
52
+ cos_sim = F.cosine_similarity(output_t, output_f, dim=0).item()
53
+ cos_sim_list.append(cos_sim)
54
+ clip_score_avg=np.mean(cos_sim_list)
55
+ quantized_ans=0
56
+ if clip_score_avg < CLIP_POINT_LOW:
57
+ quantized_ans=1
58
+ elif clip_score_avg >= CLIP_POINT_LOW and clip_score_avg < CLIP_POINT_MID:
59
+ quantized_ans=2
60
+ elif clip_score_avg >= CLIP_POINT_MID and clip_score_avg < CLIP_POINT_HIGH:
61
+ quantized_ans=3
62
+ else:
63
+ quantized_ans=4
64
+ return clip_score_avg, quantized_ans
65
+
src/videogen_hub/metrics/dino-sim_metric.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from typing import List
6
+ from torchvision.models import vit_b_16
7
+ import torchvision.transforms as transforms
8
+
9
+ ROUND_DIGIT=3
10
+ NUM_ASPECT=5
11
+
12
+ DINO_POINT_HIGH=0.97
13
+ DINO_POINT_MID=0.9
14
+ DINO_POINT_LOW=0.8
15
+
16
+
17
+ class MetricDINO_sim():
18
+ def __init__(self, device="cuda") -> None:
19
+ """
20
+ Initialize a class MetricDINO_sim with the specified device for testing temporal consistency of a given video.
21
+
22
+ Args:
23
+ device (str, optional): The device on which the model will run. Defaults to "cuda".
24
+ """
25
+ self.device = device
26
+ self.model = vit_b_16(pretrained=True)
27
+ self.model.to(self.device).eval()
28
+ self.preprocess = transforms.Compose([
29
+ transforms.Resize(256),
30
+ transforms.CenterCrop(224),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
+ ])
34
+
35
+ def evaluate(self, frame_list:List[Image.Image]):
36
+ """
37
+ Calculate the cosine similarity between the DINO features of adjacent frames of a given video to test temporal consistency,
38
+ then quantize the orginal output based on some predefined thresholds.
39
+
40
+ Args:
41
+ frame_list:List[Image.Image], frames of the video used in calculation.
42
+
43
+ Returns:
44
+ dino_frame_score: float, the computed DINO feature cosine similarity between each adjacent pair of frames and then averaged among all the pairs.
45
+ quantized_ans: int, the quantized value of the above avg DINO-Sim scores based on pre-defined thresholds.
46
+ """
47
+
48
+ device = self.device
49
+ frame_sim_list=[]
50
+ for f_idx in range(len(frame_list)-1):
51
+ frame_1=frame_list[f_idx]
52
+ frame_2=frame_list[f_idx+1]
53
+ frame_tensor_1 = self.preprocess(frame_1).unsqueeze(0).to(device)
54
+ frame_tensor_2 = self.preprocess(frame_2).unsqueeze(0).to(device)
55
+ with torch.no_grad():
56
+ feat_1 = self.model(frame_tensor_1).flatten()
57
+ feat_2 = self.model(frame_tensor_2).flatten()
58
+ cos_sim=F.cosine_similarity(feat_1, feat_2, dim=0).item()
59
+ frame_sim_list.append(cos_sim)
60
+
61
+ dino_frame_score = np.mean(frame_sim_list)
62
+ quantized_ans=0
63
+ if dino_frame_score >= DINO_POINT_HIGH:
64
+ quantized_ans=4
65
+ elif dino_frame_score < DINO_POINT_HIGH and dino_frame_score >= DINO_POINT_MID:
66
+ quantized_ans=3
67
+ elif dino_frame_score < DINO_POINT_MID and dino_frame_score >= DINO_POINT_LOW:
68
+ quantized_ans=2
69
+ else:
70
+ quantized_ans=1
71
+ return dino_frame_score, quantized_ans
src/videogen_hub/metrics/mse-dyn_metric.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from PIL import Image
4
+ from typing import List
5
+ from skimage.metrics import structural_similarity as ssim
6
+ from skimage import io, color
7
+
8
+ ROUND_DIGIT=3
9
+ DYN_SAMPLE_STEP=4
10
+ NUM_ASPECT=5
11
+
12
+ MSE_POINT_HIGH=3000
13
+ MSE_POINT_MID=1000
14
+ MSE_POINT_LOW=100
15
+
16
+
17
+ class MetricMSE_dyn():
18
+ def __init__(self) -> None:
19
+ """
20
+ Initialize a class MetricMSE_dyn for testing dynamic degree of a given video.
21
+
22
+ """
23
+ None
24
+
25
+ def evaluate(self, frame_list:List[Image.Image]):
26
+ """
27
+ Calculate the MSE (Mean Squared Error) between frames sampled at regular intervals of a given video to test dynamic_degree,
28
+ then quantize the orginal output based on some predefined thresholds.
29
+
30
+ Args:
31
+ frame_list:List[Image.Image], frames of the video used in calculation.
32
+
33
+ Returns:
34
+ mse_avg: float, the computed MSE between frames sampled at regular intervals and then averaged among all the pairs.
35
+ quantized_ans: int, the quantized value of the above avg MSE scores based on pre-defined thresholds.
36
+ """
37
+
38
+ mse_list=[]
39
+ sampled_list = frame_list[::DYN_SAMPLE_STEP]
40
+ for f_idx in range(len(sampled_list)-1):
41
+ imageA = cv2.cvtColor(np.array(sampled_list[f_idx]), cv2.COLOR_RGB2BGR)
42
+ imageB = cv2.cvtColor(np.array(sampled_list[f_idx+1]), cv2.COLOR_RGB2BGR)
43
+
44
+ err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
45
+ err /= float(imageA.shape[0] * imageA.shape[1])
46
+ mse_value = err
47
+ mse_list.append(mse_value)
48
+ mse_avg=np.mean(mse_list)
49
+ quantized_ans=0
50
+ if mse_avg >= MSE_POINT_HIGH:
51
+ quantized_ans=4
52
+ elif mse_avg < MSE_POINT_HIGH and mse_avg >= MSE_POINT_MID:
53
+ quantized_ans=3
54
+ elif mse_avg < MSE_POINT_MID and mse_avg >= MSE_POINT_LOW:
55
+ quantized_ans=2
56
+ else:
57
+ quantized_ans=1
58
+
59
+ return mse_avg, quantized_ans
src/videogen_hub/metrics/piqe_metric.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pypiqe import piqe
2
+ from PIL import Image
3
+ import numpy as np
4
+ from typing import List
5
+
6
+ ROUND_DIGIT=3
7
+ NUM_ASPECT=5
8
+
9
+ PIQE_POINT_LOW=15
10
+ PIQE_POINT_MID=30
11
+ PIQE_POINT_HIGH=50
12
+
13
+ class MetricPIQE():
14
+ def __init__(self) -> None:
15
+ """
16
+ Initialize a class MetricPIQE for testing visual quality of a given video.
17
+
18
+ """
19
+ None
20
+
21
+ def evaluate(self,frame_list:List[Image.Image]):
22
+ """
23
+ Calculate PIQE for visual quality for each frame of the given video and take the average value,
24
+ then quantize the orginal output based on some predefined thresholds.
25
+
26
+ Args:
27
+ frame_list:List[Image.Image], frames of the video used in calculation.
28
+
29
+ Returns:
30
+ piqe_avg: float, the computed average PIQE among the frames.
31
+ quantized_ans: int, the quantized value of the above avg score based on pre-defined thresholds.
32
+ """
33
+ piqe_list=[]
34
+ for frame in frame_list:
35
+ frame=np.array(frame)
36
+ piqe_score, _,_,_ = piqe(frame)
37
+ piqe_list.append(piqe_score)
38
+ piqe_avg=np.mean(piqe_list)
39
+ quantized_ans=0
40
+ if piqe_avg < PIQE_POINT_LOW:
41
+ quantized_ans=4
42
+ elif piqe_avg < PIQE_POINT_MID:
43
+ quantized_ans=3
44
+ elif piqe_avg < PIQE_POINT_HIGH:
45
+ quantized_ans=2
46
+ else:
47
+ quantized_ans=1
48
+ return piqe_avg, quantized_ans
49
+
src/videogen_hub/metrics/ssim-dyn_metric.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from typing import List
4
+ from skimage.metrics import structural_similarity as ssim
5
+ from skimage import io, color
6
+
7
+ ROUND_DIGIT=3
8
+ DYN_SAMPLE_STEP=4
9
+ NUM_ASPECT=5
10
+
11
+ SSIM_POINT_HIGH=0.9
12
+ SSIM_POINT_MID=0.7
13
+ SSIM_POINT_LOW=0.5
14
+
15
+
16
+
17
+ class MetricSSIM_dyn():
18
+ def __init__(self) -> None:
19
+ """
20
+ Initialize a class MetricSSIM_dyn for testing dynamic degree of a given video.
21
+
22
+ """
23
+ None
24
+
25
+ def evaluate(self, frame_list:List[Image.Image]):
26
+ """
27
+ Calculate the MSE between frames sampled at regular intervals of a given video to test dynamic_degree,
28
+ then quantize the orginal output based on some predefined thresholds.
29
+
30
+ Args:
31
+ frame_list:List[Image.Image], frames of the video used in calculation.
32
+
33
+ Returns:
34
+ ssim_avg: float, the computed SSIM between frames sampled at regular intervals and then averaged among all the pairs.
35
+ quantized_ans: int, the quantized value of the above avg SSIM scores based on pre-defined thresholds.
36
+ """
37
+
38
+ ssim_list=[]
39
+ sampled_list = frame_list[::DYN_SAMPLE_STEP]
40
+ for f_idx in range(len(sampled_list)-1):
41
+ frame_1=sampled_list[f_idx]
42
+ frame_1_gray=color.rgb2gray(frame_1)
43
+ frame_2=sampled_list[f_idx+1]
44
+ frame_2_gray=color.rgb2gray(frame_2)
45
+
46
+ ssim_value, _ = ssim(frame_1_gray, frame_2_gray, full=True,\
47
+ data_range=frame_2_gray.max() - frame_2_gray.min())
48
+ ssim_list.append(ssim_value)
49
+ ssim_avg=np.mean(ssim_list)
50
+
51
+ quantized_ans=0
52
+ if ssim_avg >= SSIM_POINT_HIGH:
53
+ quantized_ans=1
54
+ elif ssim_avg <= SSIM_POINT_HIGH and ssim_avg > SSIM_POINT_MID:
55
+ quantized_ans=2
56
+ elif ssim_avg <= SSIM_POINT_MID and ssim_avg > SSIM_POINT_LOW:
57
+ quantized_ans=3
58
+ else:
59
+ quantized_ans=4
60
+ return ssim_avg, quantized_ans
src/videogen_hub/metrics/ssim-sim_metric.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from typing import List
6
+ from skimage.metrics import structural_similarity as ssim
7
+ from skimage import io, color
8
+
9
+ ROUND_DIGIT=3
10
+ NUM_ASPECT=5
11
+
12
+ TEM_SSIM_POINT_HIGH=0.9
13
+ TEM_SSIM_POINT_MID=0.75
14
+ TEM_SSIM_POINT_LOW=0.6
15
+
16
+
17
+ class MetricSSIM_sim():
18
+ def __init__(self) -> None:
19
+ """
20
+ Initialize a class MetricSSIM_sim for testing temporal consistency of a given video.
21
+
22
+ """
23
+ None
24
+
25
+ def evaluate(self, frame_list:List[Image.Image]):
26
+ """
27
+ Calculate the SSIM between adjacent frames of a given video to test temporal consistency,
28
+ then quantize the orginal output based on some predefined thresholds.
29
+
30
+ Args:
31
+ frame_list:List[Image.Image], frames of the video used in calculation.
32
+
33
+ Returns:
34
+ ssim_avg: float, the computed SSIM between each adjacent pair of frames and then averaged among all the pairs.
35
+ quantized_ans: int, the quantized value of the above avg SSIM scores based on pre-defined thresholds.
36
+ """
37
+
38
+ ssim_list=[]
39
+ for f_idx in range(len(frame_list)-1):
40
+ frame_1=frame_list[f_idx]
41
+ frame_1_gray=color.rgb2gray(frame_1)
42
+ frame_2=frame_list[f_idx+1]
43
+ frame_2_gray=color.rgb2gray(frame_2)
44
+
45
+ ssim_value, _ = ssim(frame_1_gray, frame_2_gray, full=True,\
46
+ data_range=frame_2_gray.max() - frame_2_gray.min())
47
+ ssim_list.append(ssim_value)
48
+ ssim_avg=np.mean(ssim_list)
49
+ quantized_ans=0
50
+ if ssim_avg >= TEM_SSIM_POINT_HIGH:
51
+ quantized_ans=4
52
+ elif ssim_avg < TEM_SSIM_POINT_HIGH and ssim_avg >= TEM_SSIM_POINT_MID:
53
+ quantized_ans=3
54
+ elif ssim_avg < TEM_SSIM_POINT_MID and ssim_avg >= TEM_SSIM_POINT_LOW:
55
+ quantized_ans=2
56
+ else:
57
+ quantized_ans=1
58
+ return ssim_avg, quantized_ans
59
+
src/videogen_hub/metrics/xclipscore_metric.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch.nn.functional as F
4
+ from typing import List
5
+ from transformers import AutoTokenizer, AutoModel, AutoProcessor
6
+
7
+ NUM_ASPECT=5
8
+ ROUND_DIGIT=3
9
+ MAX_LENGTH = 76
10
+
11
+ MAX_NUM_FRAMES=8
12
+
13
+ X_CLIP_POINT_LOW=0.15
14
+ X_CLIP_POINT_MID=0.225
15
+ X_CLIP_POINT_HIGH=0.30
16
+
17
+
18
+ def _read_video_frames(frames, max_frames):
19
+ total_frames = len(frames)
20
+ indices = np.linspace(0, total_frames - 1, num=max_frames).astype(int)
21
+
22
+ selected_frames = [np.array(frames[i]) for i in indices]
23
+ return np.stack(selected_frames)
24
+
25
+
26
+ class MetricXCLIPScore():
27
+ def __init__(self, device="cuda") -> None:
28
+ """
29
+ Initialize a MetricXCLIPScore object with the specified device.
30
+
31
+ Args:
32
+ device (str, optional): The device on which the model will run. Defaults to "cuda".
33
+ """
34
+
35
+ self.model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
36
+ self.processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
37
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/xclip-base-patch32")
38
+
39
+ def evaluate(self, frame_list:List[Image.Image], text:str,):
40
+ """
41
+ Calculate the cosine similarity of between X-CLIP features of text prompt and the given video to test text-to-video alignment,
42
+ then quantize the orginal output based on some predefined thresholds.
43
+
44
+ Args:
45
+ frame_list:List[Image.Image], frames of the video used in calculation.
46
+ text:str, text prompt for generating the video.
47
+
48
+ Returns:
49
+ xclip_score_avg: float, the computed X-CLIP-Score between video and its text prompt.
50
+ quantized_ans: int, the quantized value of the above avg SSIM scores based on pre-defined thresholds.
51
+ """
52
+
53
+ input_text = self.tokenizer([text], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors="pt")
54
+ text_feature = self.model.get_text_features(**input_text).flatten()
55
+
56
+ video=_read_video_frames(frame_list,MAX_NUM_FRAMES)
57
+
58
+ input_video = self.processor(videos=list(video), return_tensors="pt")
59
+ video_feature = self.model.get_video_features(**input_video).flatten()
60
+ cos_sim=F.cosine_similarity(text_feature, video_feature, dim=0).item()
61
+ quantized_ans=0
62
+ if cos_sim < X_CLIP_POINT_LOW:
63
+ quantized_ans=1
64
+ elif cos_sim >= X_CLIP_POINT_LOW and cos_sim < X_CLIP_POINT_MID:
65
+ quantized_ans=2
66
+ elif cos_sim >= X_CLIP_POINT_MID and cos_sim < X_CLIP_POINT_HIGH:
67
+ quantized_ans=3
68
+ else:
69
+ quantized_ans=4
70
+ return cos_sim, quantized_ans
71
+
72
+
src/videogen_hub/utils/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+
5
+ def images_to_tensor(image_list):
6
+ """
7
+ Parse a list of PIL images and convert them to a PyTorch tensor in shape (T, C, H, W).
8
+ """
9
+ transform = transforms.ToTensor()
10
+
11
+ # Convert each PIL image to tensor and store in a list
12
+ tensor_list = [transform(img) for img in image_list]
13
+
14
+ # Stack the list of tensors along a new dimension to create the final tensor
15
+ tensor = torch.stack(tensor_list, dim=0)
16
+
17
+ return tensor
src/videogen_hub/utils/file_helper.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, List, Optional
3
+ from urllib.parse import urlparse
4
+ import requests
5
+
6
+ def get_file_path(filename: Union[str, os.PathLike], search_from: Union[str, os.PathLike] = "."):
7
+ """
8
+ Search for a file across a directory and return its absolute path.
9
+
10
+ Args:
11
+ filename (Union[str, os.PathLike]): The name of the file to search for.
12
+ search_from (Union[str, os.PathLike], optional): The directory from which to start the search. Defaults to ".".
13
+
14
+ Returns:
15
+ str: Absolute path to the found file.
16
+
17
+ Raises:
18
+ FileNotFoundError: If the file is not found.
19
+ """
20
+ for root, dirs, files in os.walk(search_from):
21
+ for name in files:
22
+ if name == filename:
23
+ return os.path.abspath(os.path.join(root, name))
24
+ raise FileNotFoundError(filename, "not found.")