Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from functools import partial | |
from typing import Any, Optional | |
import torch | |
import videosys | |
from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port | |
class VideoSysEngine: | |
""" | |
this is partly inspired by vllm | |
""" | |
def __init__(self, config): | |
self.config = config | |
self.parallel_worker_tasks = None | |
self._init_worker(config.pipeline_cls) | |
def _init_worker(self, pipeline_cls): | |
world_size = self.config.num_gpus | |
if "CUDA_VISIBLE_DEVICES" not in os.environ: | |
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size)) | |
# Disable torch async compiling which won't work with daemonic processes | |
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" | |
# Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU | |
# contention amongst the shards | |
if "OMP_NUM_THREADS" not in os.environ: | |
os.environ["OMP_NUM_THREADS"] = "1" | |
# NOTE: The two following lines need adaption for multi-node | |
assert world_size <= torch.cuda.device_count() | |
# change addr for multi-node | |
distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port()) | |
if world_size == 1: | |
self.workers = [] | |
self.worker_monitor = None | |
else: | |
result_handler = ResultHandler() | |
self.workers = [ | |
ProcessWorkerWrapper( | |
result_handler, | |
partial( | |
self._create_pipeline, | |
pipeline_cls=pipeline_cls, | |
rank=rank, | |
local_rank=rank, | |
distributed_init_method=distributed_init_method, | |
), | |
) | |
for rank in range(1, world_size) | |
] | |
self.worker_monitor = WorkerMonitor(self.workers, result_handler) | |
result_handler.start() | |
self.worker_monitor.start() | |
self.driver_worker = self._create_pipeline( | |
pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method | |
) | |
# TODO: add more options here for pipeline, or wrap all options into config | |
def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None): | |
videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method, seed=42) | |
pipeline = pipeline_cls(self.config) | |
return pipeline | |
def _run_workers( | |
self, | |
method: str, | |
*args, | |
async_run_tensor_parallel_workers_only: bool = False, | |
max_concurrent_workers: Optional[int] = None, | |
**kwargs, | |
) -> Any: | |
"""Runs the given method on all workers.""" | |
# Start the workers first. | |
worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers] | |
if async_run_tensor_parallel_workers_only: | |
# Just return futures | |
return worker_outputs | |
driver_worker_method = getattr(self.driver_worker, method) | |
driver_worker_output = driver_worker_method(*args, **kwargs) | |
# Get the results of the workers. | |
return [driver_worker_output] + [output.get() for output in worker_outputs] | |
def _driver_execute_model(self, *args, **kwargs): | |
return self.driver_worker.generate(*args, **kwargs) | |
def generate(self, *args, **kwargs): | |
return self._run_workers("generate", *args, **kwargs)[0] | |
def stop_remote_worker_execution_loop(self) -> None: | |
if self.parallel_worker_tasks is None: | |
return | |
parallel_worker_tasks = self.parallel_worker_tasks | |
self.parallel_worker_tasks = None | |
# Ensure that workers exit model loop cleanly | |
# (this will raise otherwise) | |
self._wait_for_tasks_completion(parallel_worker_tasks) | |
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: | |
"""Wait for futures returned from _run_workers() with | |
async_run_remote_workers_only to complete.""" | |
for result in parallel_worker_tasks: | |
result.get() | |
def save_video(self, video, output_path): | |
return self.driver_worker.save_video(video, output_path) | |
def shutdown(self): | |
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None: | |
worker_monitor.close() | |
torch.distributed.destroy_process_group() | |
def __del__(self): | |
self.shutdown() | |