SX_ispymac_inDelphi / pipeline.py
ljw20180420's picture
Upload pipeline.py with huggingface_hub
66376de verified
from diffusers import DiffusionPipeline
import torch
from sklearn.neighbors import KNeighborsRegressor
import numpy as np
class inDelphiPipeline(DiffusionPipeline):
def __init__(self, inDelphi_model, onebp_features, insert_probabilities, m654):
super().__init__()
self.register_modules(inDelphi_model=inDelphi_model)
self.onebp_feature_mean = onebp_features.mean(axis=0)
self.onebp_feature_std = onebp_features.std(axis=0)
self.insertion_model = KNeighborsRegressor(weights='distance').fit((onebp_features - self.onebp_feature_mean) / self.onebp_feature_std, insert_probabilities)
self.m654 = m654 / np.maximum(np.linalg.norm(m654, ord=1, axis=1, keepdims=True), 1e-6)
self.m4 = m654.reshape(16, 4, 4).sum(axis=0)
self.m4 = self.m4 / np.maximum(np.linalg.norm(self.m4, ord=1, axis=1, keepdims=True), 1e-6)
@torch.no_grad()
def __call__(self, batch, use_m654=False):
mh_weights, mhless_weights, total_del_len_weights = self.inDelphi_model(
batch["mh_input"].to(self.inDelphi_model.device),
batch["mh_del_len"].to(self.inDelphi_model.device)
).values()
mX = self.m654 if use_m654 else self.m4
log_total_weights = total_del_len_weights.sum(dim=1, keepdim=True).log()
precisions = 1 - torch.distributions.Categorical(total_del_len_weights[:,:28]).entropy() / torch.log(torch.tensor(28))
onebp_features = torch.cat([
batch["onebp_feature"],
precisions[:, None].cpu(),
log_total_weights.cpu()
], dim=1).cpu().numpy()
pre_insert_probabilities = self.insertion_model.predict((onebp_features - self.onebp_feature_mean) / self.onebp_feature_std)
pre_insert_1bps = mX[batch['m654'] % 4] if mX.shape[0] == 4 else mX[batch['m654']]
return {
"mh_weight": [
mh_weights[i, batch["mh_del_len"][i] < self.inDelphi_model.config.DELLEN_LIMIT]
for i in range(len(batch["mh_del_len"]))
],
"mhless_weight": mhless_weights,
"total_del_len_weight": total_del_len_weights,
"pre_insert_probability": pre_insert_probabilities,
"pre_insert_1bp": pre_insert_1bps
}