import torch from transformers import PreTrainedModel from .configuration_x3d import X3DConfig from .x3d import build_model class X3DModel(PreTrainedModel): config_class = X3DConfig def __init__(self, config, **kwargs): super().__init__(config) self.model = build_model(config.cfg) checkpoint = kwargs.get("checkpoint", None) if checkpoint: checkpoint = torch.load( checkpoint, weights_only=True, map_location=torch.device("cpu")) self.model.load_state_dict(checkpoint["model_state"]) def forward(self, input_video): outputs = self.model(input_video) return outputs