""" © Battelle Memorial Institute 2023 Made available under the GNU General Public License v 2.0 BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. """ import torch import torch.nn as nn from .positional_encoding import PositionalEncoding class FupBERTModel(nn.Module): """ A class that extends torch.nn.Module that implements a custom Transformer encoder model to create a single embedding for Fup prediction. """ def __init__( self, ntoken, ninp, nhead, nhid, nlayers, token_reduction, padding_idx, cls_idx, edge_idx, num_out, dropout=0.1, ): """ Initializes a FubBERT object. Parameters ---------- ntoken : int The maximum number of tokens the embedding layer should expect. This is the same as the size of the vocabulary. ninp : int The hidden dimension that should be used for embedding and input to the Transformer encoder. nhead : int The number of heads to use in the Transformer encoder. nhid : int The size of the hidden dimension to use throughout the Transformer encoder. nlayers : int The number of layers to use in a single head of the Transformer encoder. token_reduction : str The type of token reduction to use. This can be either 'mean' or 'cls'. padding_idx : int The index used as padding for the input sequences. cls_idx : int The index used as the cls token for the input sequences. edge_idx : int The index used as the edge token for the input sequences. num_out : int The number of outputs to predict with the model. dropout : float, optional The fractional dropout to apply to the model. The default is 0.1. Returns ------- None. """ super(FupBERTModel, self).__init__() # Store the input parameters self.ntoken = ntoken self.ninp = ninp self.nhead = nhead self.nhid = nhid self.nlayers = nlayers self.token_reduction = token_reduction self.padding_idx = padding_idx self.cls_idx = cls_idx self.edge_idx = edge_idx self.num_out = num_out self.dropout = dropout # Set the model parameters self.model_type = "Transformer Encoder" self.embedding = nn.Embedding( self.ntoken, self.ninp, padding_idx=self.padding_idx ) self.pos_encoder = PositionalEncoding(self.ninp, self.dropout) encoder_layers = nn.TransformerEncoderLayer( self.ninp, self.nhead, self.nhid, self.dropout, activation="gelu", batch_first=True, ) self.transformer_encoder = nn.TransformerEncoder(encoder_layers, self.nlayers) self.pred_head = nn.Linear(self.ninp, self.num_out) def _generate_src_key_mask(self, src): mask = src == self.padding_idx mask = mask.type(torch.bool) return mask def forward(self, src): """ Perform a forward pass of the module. Parameters ---------- src : tensor The input tensor. The shape should be (batch size, sequence length). Returns ------- output : tensor The output tensor. The shape will be (batch size, num_out). """ src = self.get_embeddings(src) output = self.pred_head(src) return output def get_embeddings(self, src): """ Perform a forward pass of the module excluding the classification layers. This will return the embeddings from the encoder. Parameters ---------- src : tensor The input tensor. The shape should be (batch size, sequence length). Returns ------- embeds : tensor The output tensor of sequence embeddings. The shape should be (batch size, self.ninp) """ src_mask = self._generate_src_key_mask(src) x = self.embedding(src) x = self.pos_encoder(x) x = self.transformer_encoder(x, src_key_padding_mask=src_mask) # Mask the data based on the token reduction strategy if self.token_reduction == "mean": pad_mask = src == self.padding_idx cls_mask = src == self.cls_idx edge_mask = src == self.edge_idx mask = torch.logical_or(pad_mask, cls_mask) mask = torch.logical_or(mask, edge_mask) # Apply the mask x[mask[:, : x.shape[1]]] = torch.nan # Take the mean of the embeddings embeds = torch.nanmean(x, dim=1) elif self.token_reduction == "cls": embeds = x[:, 0, :] else: raise ValueError( "Token reduction must be mean or cls. " "Recieved {}".format(self.token_reduction) ) return embeds