|
from transformers import PreTrainedModel |
|
from .config import MonoSceneConfig |
|
from monoscene.monoscene import MonoScene |
|
|
|
|
|
class MonoSceneModel(PreTrainedModel): |
|
config_class = MonoSceneConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = MonoScene( |
|
dataset=config.dataset, |
|
n_classes=config.n_classes, |
|
feature=config.feature, |
|
project_res=['1', '2', '4', '8'], |
|
project_scale=config.project_scale, |
|
full_scene_size=config.full_scene_size |
|
) |
|
|
|
|
|
def forward(self, tensor): |
|
return self.model.forward(tensor) |