Update README.md
Browse files
README.md
CHANGED
@@ -55,29 +55,34 @@ The HuggingFace script run_clm.py can be found here: https://github.com/huggingf
|
|
55 |
|
56 |
### **How to select the best sequences**
|
57 |
We've observed that perplexity values correlate with AlphaFold2's plddt.
|
58 |
-
We recommend
|
59 |
|
60 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def calculatePerplexity(sequence, model, tokenizer):
|
|
|
|
|
62 |
with torch.no_grad():
|
63 |
-
outputs = model(
|
64 |
-
|
65 |
return math.exp(loss)
|
66 |
-
|
67 |
-
#
|
68 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
69 |
-
tokenizer = AutoTokenizer.from_pretrained('/path/to/tokenizer') # replace with the actual path
|
70 |
-
model = GPT2LMHeadModel.from_pretrained('/path/to/output').to(device)
|
71 |
-
output = model.generate("<|endoftext|>", max_length=400, do_sample=True, top_k=950, repetition_penalty=1.2, num_return_sequences=10, eos_token_id=0)
|
72 |
-
|
73 |
-
# Take (for example) the first sequence
|
74 |
-
sequence = output[0]
|
75 |
ppl = calculatePerplexity(sequence, model, tokenizer)
|
|
|
76 |
```
|
77 |
|
78 |
Where `ppl` is a value with the perplexity for that sequence.
|
79 |
-
We do not yet have a threshold as
|
80 |
|
81 |
|
82 |
### **Training specs**
|
83 |
-
The model was trained on 128 NVIDIA A100 GPUs for 50 epochs, using a block size of 512
|
|
|
55 |
|
56 |
### **How to select the best sequences**
|
57 |
We've observed that perplexity values correlate with AlphaFold2's plddt.
|
58 |
+
We recommend computing perplexity for each sequence as follows:
|
59 |
|
60 |
```
|
61 |
+
sequence='MGEAMGLTQPAVSRAVARLEERVGIRIFNRTARAITLTDEGRRFYEAVAPLLAGIEMHGYR\nVNVEGVAQLLELYARDILAEGRLVQLLPEWAD'
|
62 |
+
|
63 |
+
#Convert the sequence to a string like this
|
64 |
+
#(note we have to introduce new line characters every 60 amino acids,
|
65 |
+
#following the FASTA file format).
|
66 |
+
|
67 |
+
sequence = "<|endoftext|>MGEAMGLTQPAVSRAVARLEERVGIRIFNRTARAITLTDEGRRFYEAVAPLLAGIEMHGY\nRVNVEGVAQLLELYARDILAEGRLVQLLPEWAD<|endoftext|>"
|
68 |
+
|
69 |
+
# ppl function
|
70 |
def calculatePerplexity(sequence, model, tokenizer):
|
71 |
+
input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)
|
72 |
+
input_ids = input_ids.to(device)
|
73 |
with torch.no_grad():
|
74 |
+
outputs = model(input_ids, labels=input_ids)
|
75 |
+
loss, logits = outputs[:2]
|
76 |
return math.exp(loss)
|
77 |
+
|
78 |
+
#And hence:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
ppl = calculatePerplexity(sequence, model, tokenizer)
|
80 |
+
|
81 |
```
|
82 |
|
83 |
Where `ppl` is a value with the perplexity for that sequence.
|
84 |
+
We do not yet have a threshold as to what perplexity value gives a 'good' or 'bad' sequence, but given the fast inference times, the best is to sample many sequences, order them by perplexity, and select those with the lower values (the lower the better).
|
85 |
|
86 |
|
87 |
### **Training specs**
|
88 |
+
The model was trained on 128 NVIDIA A100 GPUs for 50 epochs, using a block size of 512 and a total batch size of 1024 (65,536 tokens per batch). The optimizer used was Adam (beta1 = 0.9, beta2 = 0.999) with a learning rate of 1e-3.
|