jbilcke-hf HF staff commited on
Commit
2557c6e
·
verified ·
1 Parent(s): 3006814

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +119 -0
handler.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import os
3
+ from pathlib import Path
4
+ import time
5
+ from datetime import datetime
6
+ import torch
7
+ import base64
8
+ from io import BytesIO
9
+
10
+ from hyvideo.utils.file_utils import save_videos_grid
11
+ from hyvideo.config import parse_args
12
+ from hyvideo.inference import HunyuanVideoSampler
13
+
14
+ class EndpointHandler:
15
+ def __init__(self, path: str = ""):
16
+ """Initialize the handler with the model path.
17
+
18
+ Args:
19
+ path: Path to the model weights directory
20
+ """
21
+ self.args = parse_args()
22
+ models_root_path = Path(path)
23
+ if not models_root_path.exists():
24
+ raise ValueError(f"`models_root` not exists: {models_root_path}")
25
+
26
+ # Initialize model
27
+ self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
28
+
29
+ # Default parameters
30
+ self.default_params = {
31
+ "num_inference_steps": 50,
32
+ "guidance_scale": 1.0,
33
+ "flow_shift": 7.0,
34
+ "embedded_guidance_scale": 6.0,
35
+ "video_length": 129, # 5s
36
+ "resolution": "1280x720"
37
+ }
38
+
39
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
40
+ """Process the input data and generate video.
41
+
42
+ Args:
43
+ data: Dictionary containing the input parameters
44
+ Required:
45
+ - inputs (str): The prompt text
46
+ Optional:
47
+ - resolution (str): Video resolution like "1280x720"
48
+ - video_length (int): Number of frames
49
+ - seed (int): Random seed (-1 for random)
50
+ - num_inference_steps (int): Number of inference steps
51
+ - guidance_scale (float): Guidance scale value
52
+ - flow_shift (float): Flow shift value
53
+ - embedded_guidance_scale (float): Embedded guidance scale value
54
+
55
+ Returns:
56
+ Dictionary containing the base64 encoded video
57
+ """
58
+ # Get prompt
59
+ prompt = data.pop("inputs", None)
60
+ if prompt is None:
61
+ raise ValueError("No prompt provided in the 'inputs' field")
62
+
63
+ # Get optional parameters with defaults
64
+ resolution = data.pop("resolution", self.default_params["resolution"])
65
+ video_length = int(data.pop("video_length", self.default_params["video_length"]))
66
+ seed = int(data.pop("seed", -1))
67
+ num_inference_steps = int(data.pop("num_inference_steps", self.default_params["num_inference_steps"]))
68
+ guidance_scale = float(data.pop("guidance_scale", self.default_params["guidance_scale"]))
69
+ flow_shift = float(data.pop("flow_shift", self.default_params["flow_shift"]))
70
+ embedded_guidance_scale = float(data.pop("embedded_guidance_scale", self.default_params["embedded_guidance_scale"]))
71
+
72
+ # Process resolution
73
+ width, height = resolution.split("x")
74
+ width, height = int(width), int(height)
75
+
76
+ # Set seed
77
+ seed = None if seed == -1 else seed
78
+
79
+ # Generate video
80
+ outputs = self.model.predict(
81
+ prompt=prompt,
82
+ height=height,
83
+ width=width,
84
+ video_length=video_length,
85
+ seed=seed,
86
+ negative_prompt="", # not applicable in inference
87
+ infer_steps=num_inference_steps,
88
+ guidance_scale=guidance_scale,
89
+ num_videos_per_prompt=1,
90
+ flow_shift=flow_shift,
91
+ batch_size=1,
92
+ embedded_guidance_scale=embedded_guidance_scale
93
+ )
94
+
95
+ # Process output video
96
+ samples = outputs['samples']
97
+ sample = samples[0].unsqueeze(0)
98
+
99
+ # Save video to temporary file
100
+ temp_dir = "/tmp/video_output"
101
+ os.makedirs(temp_dir, exist_ok=True)
102
+
103
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
104
+ video_path = f"{temp_dir}/{time_flag}_seed{outputs['seeds'][0]}.mp4"
105
+ save_videos_grid(sample, video_path, fps=24)
106
+
107
+ # Read video file and convert to base64
108
+ with open(video_path, "rb") as f:
109
+ video_bytes = f.read()
110
+ video_base64 = base64.b64encode(video_bytes).decode()
111
+
112
+ # Clean up
113
+ os.remove(video_path)
114
+
115
+ return {
116
+ "video_base64": video_base64,
117
+ "seed": outputs['seeds'][0],
118
+ "prompt": outputs['prompts'][0]
119
+ }