File size: 947 Bytes
d282272 0412820 d282272 0412820 d282272 0412820 d282272 0412820 d282272 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from transformers import PreTrainedModel
from typing import Optional
import torch
class MinerUModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self._setup_models()
def _setup_models(self):
from model_loader import MinerUModelLoader
self.models = MinerUModelLoader.load_models("./")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
model = cls(config)
model._setup_models()
return model
def forward(self, input_data):
# 实现前向传播逻辑
return self.models["layout"](input_data)
def load_model():
model = MinerUModel.from_pretrained("./")
return model
def inference(pdf_content):
model = load_model()
return model(pdf_content) |