x3d / modeling_x3d.py
zhong-al
Revert changes
ba918ab
raw
history blame
672 Bytes
import torch
from transformers import PreTrainedModel
from .configuration_x3d import X3DConfig
from .x3d import build_model
class X3DModel(PreTrainedModel):
config_class = X3DConfig
def __init__(self, config, **kwargs):
super().__init__(config)
self.model = build_model(config.cfg)
checkpoint = kwargs.get("checkpoint", None)
if checkpoint:
checkpoint = torch.load(
checkpoint, weights_only=True, map_location=torch.device("cpu"))
self.model.load_state_dict(checkpoint["model_state"])
def forward(self, input_video):
outputs = self.model(input_video)
return outputs