from transformers import PretrainedConfig | |
from .cfg import load_config | |
class X3DConfig(PretrainedConfig): | |
model_type = "x3d" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
path = kwargs.get("path", None) | |
gpu_num = kwargs.get("gpu_num", 0) | |
self.cfg = load_config(path) | |
self.cfg.NUM_GPUS = gpu_num | |