File size: 359 Bytes
2c26ac8
ba918ab
 
df825a0
2c26ac8
 
 
ba918ab
2c26ac8
ba918ab
 
 
 
2c26ac8
ba918ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import PretrainedConfig
from .cfg import load_config


class X3DConfig(PretrainedConfig):
    model_type = "x3d"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        path = kwargs.get("path", None)
        gpu_num = kwargs.get("gpu_num", 0)

        self.cfg = load_config(path)
        self.cfg.NUM_GPUS = gpu_num