|
import pathlib |
|
import tempfile |
|
import logging |
|
import os |
|
import copy |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from timm.models.layers import trunc_normal_ |
|
|
|
from .ImageEncoder import build_image_encoder |
|
from .LangEncoder import build_lang_encoder |
|
from .LangEncoder import build_tokenizer |
|
|
|
import mup.init |
|
from mup import set_base_shapes |
|
|
|
from safetensors.torch import load_file |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class UniCLModel(nn.Module): |
|
def __init__(self, config: dict): |
|
super().__init__() |
|
|
|
self.conf_lang_encoder = config['LANG_ENCODER'] |
|
self.tokenizer = build_tokenizer(self.conf_lang_encoder) |
|
|
|
self.lang_encoder = build_lang_encoder(self.conf_lang_encoder, self.tokenizer, config['VERBOSE']) |
|
|
|
dim_projection = config['UNICL_MODEL']['DIM_PROJECTION'] |
|
if hasattr(self.lang_encoder, 'dim_out'): |
|
dim_out = self.lang_encoder.dim_out |
|
else: |
|
with torch.no_grad(): |
|
dim_out = self.lang_encoder( |
|
torch.zeros(1,1).type(torch.LongTensor) |
|
)['last_hidden_state'].size(2) |
|
|
|
self.lang_projection = nn.Parameter(torch.empty(dim_out, dim_projection)) |
|
|
|
self.conf_image_encoder = config['IMAGE_ENCODER'] |
|
self.image_encoder = build_image_encoder(self.conf_image_encoder, config['VERBOSE']) |
|
|
|
self.image_projection = nn.Parameter( |
|
torch.empty(self.image_encoder.dim_out, dim_projection) |
|
) |
|
|
|
self.logit_scale = nn.Parameter(torch.ones([])) |
|
|
|
if torch.cuda.is_available(): |
|
self.device = torch.device(type="cuda", index=0) |
|
else: |
|
self.device = torch.device(type="cpu") |
|
|
|
def custom_init_weights(self, use_original_init=True): |
|
self.use_original_init = use_original_init |
|
logger.info('Custom init: {}'.format('original init' if self.use_original_init else 'muP init')) |
|
|
|
if self.use_original_init: |
|
|
|
|
|
custom_trunc_normal_ = trunc_normal_ |
|
else: |
|
|
|
custom_trunc_normal_ = mup.init.trunc_normal_ |
|
|
|
custom_trunc_normal_(self.lang_projection, std=.02) |
|
custom_trunc_normal_(self.image_projection, std=.02) |
|
|
|
def _convert_old_weights(self, model_dict): |
|
model_dict_updated = {} |
|
for k, v in model_dict.items(): |
|
if k.startswith('visual.'): |
|
model_dict_updated['image_encoder.'+k[7:]] = v |
|
elif k.startswith('text.'): |
|
model_dict_updated['lang_encoder.'+k[5:]] = v |
|
elif k == 'vision_projection': |
|
model_dict_updated['image_projection'] = v |
|
elif k == 'text_projection': |
|
model_dict_updated['lang_projection'] = v |
|
else: |
|
model_dict_updated[k] = v |
|
|
|
return model_dict_updated |
|
|
|
def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): |
|
if not os.path.isfile(pretrained): |
|
logger.warning(f'=> Pretrained model ({pretrained}) is not a file, skip init weight') |
|
return |
|
|
|
|
|
pretrained_dict = load_file(pretrained) |
|
logger.info(f'=> Loading pretrained model {pretrained}') |
|
model_dict = self.state_dict() |
|
pretrained_dict = self._convert_old_weights(pretrained_dict) |
|
|
|
pretrained_dict = { |
|
k: v.to(self.device) for k, v in pretrained_dict.items() |
|
} |
|
need_init_state_dict = {} |
|
image_encoder_state_dict = {} |
|
for k, v in pretrained_dict.items(): |
|
need_init = ( |
|
k.split('.')[0] in pretrained_layers |
|
or pretrained_layers[0] == '*' |
|
) |
|
|
|
if need_init: |
|
if k.startswith('image_encoder.'): |
|
image_encoder_state_dict[k] = v.to(self.device) |
|
else: |
|
if verbose: |
|
logger.info(f'=> init {k} from {pretrained}') |
|
|
|
if 'positional_embedding' in k and v.size() != model_dict[k].size(): |
|
positional_embedding_pretrained = v |
|
positional_embedding_current = model_dict[k] |
|
L1, nH1 = positional_embedding_pretrained.size() |
|
L2, nH2 = positional_embedding_current.size() |
|
if nH1 != nH2: |
|
logger.info(f"Error in loading {k}, passing") |
|
else: |
|
if L1 != L2: |
|
logger.info( |
|
'=> load_pretrained: resized variant: {} to {}' |
|
.format((L1, nH1), (L2, nH2)) |
|
) |
|
|
|
posemb = positional_embedding_pretrained.float() |
|
posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1) |
|
posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear') |
|
posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0) |
|
v = posemb_grid |
|
|
|
need_init_state_dict[k] = v.to(self.device) |
|
self.image_encoder.from_state_dict(image_encoder_state_dict, ['*'], verbose) |
|
self.load_state_dict(need_init_state_dict, strict=False) |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
no_weight_decay = {'logit_scale'} |
|
if hasattr(self.lang_encoder, 'no_weight_decay'): |
|
for k in self.lang_encoder.no_weight_decay(): |
|
no_weight_decay.add('lang_encoder.'+k) |
|
|
|
if hasattr(self.image_encoder, 'no_weight_decay'): |
|
for k in self.visual.no_weight_decay(): |
|
no_weight_decay.add('image_encoder.'+k) |
|
|
|
return no_weight_decay |
|
|
|
@property |
|
def dtype(self): |
|
return self.logit_scale.dtype |
|
|
|
def encode_image(self, image, norm=True): |
|
x = self.image_encoder.forward_features(image) |
|
x = x @ self.image_projection |
|
|
|
if norm: |
|
x = x / x.norm(dim=-1, keepdim=True) |
|
|
|
return x |
|
|
|
def encode_text(self, text, norm=True): |
|
x = self.lang_encoder(**text) |
|
x = x['last_hidden_state'] |
|
|
|
if self.conf_lang_encoder['TOKENIZER'] == 'clip': |
|
x = x[torch.arange(x.size(0)), text['input_ids'].argmax(dim=-1)] |
|
else: |
|
x = x[:, 0] |
|
|
|
x = x @ self.lang_projection |
|
|
|
if norm: |
|
x = x / x.norm(dim=-1, keepdim=True) |
|
|
|
return x |
|
|
|
def forward(self, image, text): |
|
features_image = self.encode_image(image) |
|
features_text = self.encode_text(text) |
|
|
|
|
|
T = self.logit_scale.exp() |
|
|
|
return features_image, features_text, T |
|
|
|
|
|
def create_model(config): |
|
model = UniCLModel(config) |
|
return model |
|
|
|
|
|
def create_mup_model(config): |
|
def gen_config(config, wm): |
|
|
|
|
|
assert (not config['UNICL_MODEL']['STANDPARAM']) and \ |
|
(not config['LANG_ENCODER']['STANDPARAM']) and \ |
|
(not config['IMAGE_ENCODER']['SPEC']['STANDPARAM']) |
|
new_config = copy.deepcopy(config) |
|
logger.info(f'Generate config with width mult = {wm}:') |
|
|
|
|
|
new_config_section = new_config['UNICL_MODEL'] |
|
new_config_section['STANDPARAM'] = True |
|
for name in ['DIM_PROJECTION']: |
|
base_name = 'BASE_' + name |
|
new_values = round(new_config_section[base_name] * wm) |
|
logger.info(f'config["UNICL_MODEL"]["{name}"]: {new_config_section[name]} -> {new_values}') |
|
new_config_section[name] = new_values |
|
|
|
|
|
new_config_section = new_config['LANG_ENCODER'] |
|
new_config_section['STANDPARAM'] = True |
|
for name in ['WIDTH', 'HEADS']: |
|
base_name = 'BASE_' + name |
|
new_values = round(new_config_section[base_name] * wm) |
|
logger.info(f'config["LANG_ENCODER"]["{name}"]: {new_config_section[name]} -> {new_values}') |
|
new_config_section[name] = new_values |
|
|
|
|
|
new_config_section = new_config['IMAGE_ENCODER']['SPEC'] |
|
new_config_section['STANDPARAM'] = True |
|
for name in ['DIM_EMBED', 'NUM_HEADS', 'NUM_GROUPS']: |
|
base_name = 'BASE_' + name |
|
new_values = [round(base_value * wm) for base_value in new_config_section[base_name]] |
|
logger.info(f'config["IMAGE_ENCODER"]["SPEC"]["{name}"]: {new_config_section[name]} -> {new_values}') |
|
new_config_section[name] = new_values |
|
|
|
return new_config |
|
|
|
logger.info('muP: Create models and set base shapes') |
|
logger.info('=> Create model') |
|
model = create_model(config) |
|
|
|
|
|
lang_encoder, image_encoder = model.lang_encoder, model.image_encoder |
|
model.lang_encoder, model.image_encoder = None, None |
|
|
|
logger.info('=> Create base model') |
|
base_config = gen_config(config, wm=1.0) |
|
base_model = create_model(base_config) |
|
del base_model.lang_encoder, base_model.image_encoder |
|
|
|
logger.info('=> Create delta model') |
|
delta_config = gen_config(config, wm=2.0) |
|
delta_model = create_model(delta_config) |
|
del delta_model.lang_encoder, delta_model.image_encoder |
|
|
|
logger.info('=> Set base shapes in model for training') |
|
set_base_shapes(model, base=base_model, delta=delta_model) |
|
|
|
|
|
model.lang_encoder, model.image_encoder = lang_encoder, image_encoder |
|
|
|
return model |
|
|
|
|
|
def build_unicl_model(config, **kwargs): |
|
standparam = config['UNICL_MODEL'].get('STANDPARAM', True) |
|
|
|
if standparam: |
|
logger.info('Create model with standard parameterization') |
|
model = create_model(config) |
|
|
|
use_original_init = True |
|
else: |
|
logger.info('Create model with mu parameterization') |
|
model = create_mup_model(config) |
|
use_original_init = False |
|
|
|
|
|
model.custom_init_weights(use_original_init=use_original_init) |
|
|
|
if config['UNICL_MODEL']['LOAD_PRETRAINED']: |
|
pretrained_path = config['UNICL_MODEL']['PRETRAINED'] |
|
from .Distributed.Utils import is_valid_url, download_file |
|
if is_valid_url(pretrained_path): |
|
with tempfile.TemporaryDirectory() as tmp_path: |
|
file_local_path = pathlib.Path(tmp_path) / 'base_model.pt' |
|
download_file(pretrained_path, file_local_path) |
|
model.from_pretrained(str(file_local_path), config['UNICL_MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) |
|
else: |
|
model.from_pretrained(pretrained_path, config['UNICL_MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) |
|
|
|
return model |
|
|