ReactXT / model /blip2.py
SyrWin
init
95f97c5
raw
history blame
4.29 kB
"""
Copyright (c) 2023, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import contextlib
import logging
import os
import torch
import torch.nn as nn
from lavis.common.dist_utils import download_cached_file
from lavis.common.utils import is_url
from lavis.models.base_model import BaseModel
from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel
from transformers import BertTokenizer
from model.gin_model import GNN
class Blip2Base(BaseModel):
@classmethod
def init_tokenizer(cls):
if True:
bert_name = 'allenai/scibert_scivocab_uncased'
else:
bert_name = 'bert_pretrained/'
tokenizer = BertTokenizer.from_pretrained(bert_name)
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
return tokenizer
def maybe_autocast(self, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
@classmethod
def init_Qformer(cls, model_name, num_query_token, graph_width, cross_attention_freq=2):
assert model_name == 'scibert'
print("bert load scibert")
if True:
bert_name = 'allenai/scibert_scivocab_uncased'
else:
bert_name = 'bert_pretrained/'
encoder_config = BertConfig.from_pretrained(bert_name)
encoder_config.encoder_width = graph_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel.from_pretrained(
bert_name, config=encoder_config
)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return Qformer, query_tokens
@classmethod
def init_graph_encoder(
cls, gin_num_layers, gin_hidden_dim, gin_drop_ratio):
graph_encoder = GNN(
num_layer=gin_num_layers,
emb_dim=gin_hidden_dim,
gnn_type='gin',
drop_ratio=gin_drop_ratio,
JK='last',
)
ckpt = torch.load('gin_pretrained/graphcl_80.pth', map_location=torch.device('cpu'))
missing_keys, unexpected_keys = graph_encoder.load_state_dict(ckpt, strict=False)
if len(missing_keys) or len(unexpected_keys):
print(missing_keys)
print(unexpected_keys)
ln_graph = LayerNorm(graph_encoder.num_features)
return graph_encoder, ln_graph
def load_from_pretrained(self, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
msg = self.load_state_dict(state_dict, strict=False)
# logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor, mask=None):
orig_type = x.dtype
# ret = super().forward(x.type(torch.float32))
ret = super().forward(x)
return ret.type(orig_type)