File size: 672 Bytes
ba918ab 2c26ac8 ba918ab 2c26ac8 ba918ab 2c26ac8 ba918ab 2c26ac8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
from transformers import PreTrainedModel
from .configuration_x3d import X3DConfig
from .x3d import build_model
class X3DModel(PreTrainedModel):
config_class = X3DConfig
def __init__(self, config, **kwargs):
super().__init__(config)
self.model = build_model(config.cfg)
checkpoint = kwargs.get("checkpoint", None)
if checkpoint:
checkpoint = torch.load(
checkpoint, weights_only=True, map_location=torch.device("cpu"))
self.model.load_state_dict(checkpoint["model_state"])
def forward(self, input_video):
outputs = self.model(input_video)
return outputs
|