|
|
|
|
|
|
|
from transformers import EsmForSequenceClassification, AutoTokenizer |
|
import torch |
|
|
|
|
|
def model_fn(model_dir): |
|
model = EsmForSequenceClassification.from_pretrained(model_dir, device_map="auto") |
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
|
return model, tokenizer |
|
|
|
|
|
def predict_fn(data, model_and_tokenizer): |
|
model, tokenizer = model_and_tokenizer |
|
model.eval() |
|
inputs = data.pop("inputs", data) |
|
encoding = tokenizer(inputs, return_tensors="pt") |
|
encoding = {k: v.to(model.device) for k, v in encoding.items()} |
|
results = model(**encoding) |
|
sigmoid = torch.nn.Sigmoid() |
|
probs = sigmoid(results.logits) |
|
probs = probs.cpu() |
|
return {"membrane_probability": probs[0][1].item()} |
|
|