from transformers import PretrainedConfig from typing import List class UNet3DConfig(PretrainedConfig): model_type = "UNet" def __init__( self, in_ch=1, out_ch=1, init_features=64, **kwargs): self.in_ch = in_ch self.out_ch = out_ch self.init_features = init_features super().__init__(**kwargs) class UNetMSS3DConfig(PretrainedConfig): model_type = "UNetMSS" def __init__( self, in_ch=1, out_ch=1, output_dir=None, init_features=64, **kwargs): self.in_ch = in_ch self.out_ch = out_ch self.output_dir = output_dir self.init_features = init_features super().__init__(**kwargs)