|
import torch |
|
from transformers import PreTrainedModel |
|
from .configuration_x3d import X3DConfig |
|
from .x3d import build_model |
|
|
|
|
|
class X3DModel(PreTrainedModel): |
|
config_class = X3DConfig |
|
|
|
def __init__(self, config, **kwargs): |
|
super().__init__(config) |
|
self.model = build_model(config.cfg) |
|
|
|
checkpoint = kwargs.get("checkpoint", None) |
|
|
|
if checkpoint: |
|
checkpoint = torch.load( |
|
checkpoint, weights_only=True, map_location=torch.device("cpu")) |
|
self.model.load_state_dict(checkpoint["model_state"]) |
|
|
|
def forward(self, input_video): |
|
outputs = self.model(input_video) |
|
return outputs |
|
|