File size: 2,648 Bytes
55890ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import random
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from .model_overrides import get_forward

# A custom encode function to override the forward of the model
def encode_custom(forward, encoder, sentence_feature):
    embed_mask = None
    if "embed_mask" in sentence_feature:
        embed_mask = sentence_feature.pop("embed_mask")
    out, reps = forward(encoder.model, **sentence_feature)
    sentence_feature["embed_mask"] = embed_mask

    return [encoder.get_pooling(sentence_feature, emb) for emb in reps]

def l3prune(encoder, dataset, loss_fn, batch_size=64, num_batches=100):
    dataset = [t for t in dataset]
    subset = random.sample(dataset, batch_size*num_batches)
    subset = [[encoder.prepare_for_tokenization(t) for t in s.texts] for s in subset]
    subset = [subset[i:i + batch_size] for i in range(0, len(subset), batch_size)]

    num_layers = encoder.model.config.num_hidden_layers
    loss = {i: [] for i in range(1, num_layers+1)}
    forward = get_forward(encoder.model)

    with torch.no_grad():
        # Override the forward of the model to get the intermediate representations in only one pass
        if forward:
            encode = partial(encode_custom, forward)
            for batch in tqdm(subset):
                features = []
                for j in range(3):
                    embs = [t[j] for t in batch]
                    embs = encoder.tokenize(embs).to(encoder.model.device)
                    embs = encode(encoder, embs)
                    features += [embs]
                q, d, d_neg = features
                for i in range(num_layers):
                    loss[i+1] += [loss_fn(q[i], d[i], d_neg[i])]
        else:
            # Without the override, we have to rerun the forward pass with each layer pruned
            for l in range(num_layers, 0, -1):
                encoder.prune(layer_prune=l)
                for batch in tqdm(subset):
                    features = []
                    for j in range(3):
                        embs = [t[j] for t in batch]
                        embs = encoder.tokenize(embs).to(encoder.model.device)
                        embs = encoder.forward(embs)
                        features += [embs]
                    q, d, d_neg = features
                    loss[l] += [loss_fn(q, d, d_neg)]

        loss = [torch.tensor(loss[i]).mean().float().detach() for i in range(1, num_layers+1)]
    
    # minima before and after midpoint
    midpoint = num_layers // 2
    small_p = np.argmin(loss[:midpoint]) + 1
    large_p = np.argmin(loss[midpoint:]) + midpoint + 1
    return small_p, large_p