File size: 8,325 Bytes
5ea3aa3
 
 
 
 
2ea70eb
5ea3aa3
 
 
 
 
 
05bcc9e
5ea3aa3
 
 
 
 
 
 
05bcc9e
 
 
 
5ea3aa3
 
 
 
05bcc9e
 
5ea3aa3
05bcc9e
5ea3aa3
 
 
05bcc9e
 
 
 
 
 
5ea3aa3
05bcc9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ea70eb
05bcc9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ea70eb
 
 
05bcc9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""
module to freeze/unfreeze parameters by name
"""
import logging
import re
from typing import Callable, List, Tuple, Union

from axolotl.utils.distributed import is_main_process

LOG = logging.getLogger("axolotl.utils.freeze")


def freeze_layers_except(model, regex_patterns):
    """
    Freezes all layers of the given model except for the layers that match given regex patterns.
    Periods in the patterns are treated as literal periods, not as wildcard characters.

    Parameters:
    - model (nn.Module): The PyTorch model to be modified.
    - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
      Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
      Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
      The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
      E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]

    Returns:
    None; the model is modified in place.
    """
    if isinstance(regex_patterns, str):
        regex_patterns = [regex_patterns]

    patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]

    # Unfreeze layers that match the regex patterns
    for name, param in model.named_parameters():
        param.requires_grad = False
        unfrozen_ranges = []
        for pattern in patterns:
            if not pattern.match(name):
                continue

            param.requires_grad = True

            if pattern.range is not None:
                unfrozen_ranges.append(pattern.range)

        merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))

        if param.requires_grad and is_main_process():
            unfrozen_ranges = (
                f" with ranges {merged_unfrozen_ranges}"
                if merged_unfrozen_ranges
                else ""
            )
            LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")

        if not merged_unfrozen_ranges:
            continue

        # The range list we need is actually the inverted of the merged ranges
        ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))

        param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))

    if is_main_process() and all(
        not param.requires_grad for param in model.parameters()
    ):
        LOG.warning("All parameters are frozen. Model will not be trained.")


def _invert_ranges(
    given_ranges: List[Tuple[int, int]], layer_size: int
) -> List[Tuple[int, int]]:
    """
    Inverts a list of ranges to obtain the ranges not covered by the given ranges.

    Parameters:
    - given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
    - layer_size (int): The length of the layer. E.g., len(model.layer.weight)
    Returns:
    - List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
    """
    if not given_ranges:
        return [(0, layer_size)]

    inverted_ranges = []
    current_start = 0

    for start, end in sorted(given_ranges):
        if start > current_start:
            inverted_ranges.append((current_start, start))
        current_start = max(current_start, end)

    # Handle the case where the last given range does not reach the end of the total_size
    if current_start < layer_size:
        inverted_ranges.append((current_start, layer_size))

    return inverted_ranges


def _merge_ranges(
    given_ranges: List[Tuple[int, Union[int, None]]], layer_size: int
) -> List[Tuple[int, int]]:
    """
    Merges overlapping ranges and sorts the given ranges.

    This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
    as tuples, where the first element is the start index (inclusive) and the second element is the end
    index (exclusive). The end index can be None, indicating that the range extends to the end of the
    sequence.

    Parameters:
    - given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
    - layer_size (int): The length of the layer. E.g., len(model.layer.weight)

    Returns:
    - List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
    """
    # End of each range can be determined now since we have the total size
    processed_ranges = [
        (start, end if end is not None else layer_size) for start, end in given_ranges
    ]

    # No need to merge if there's only one or no ranges
    if len(processed_ranges) <= 1:
        return processed_ranges

    sorted_ranges = sorted(processed_ranges)

    merged_ranges = [sorted_ranges[0]]
    for start, end in sorted_ranges[1:]:
        prev_start, prev_end = merged_ranges[-1]
        if start <= prev_end:
            merged_ranges[-1] = (prev_start, max(prev_end, end))
        else:
            merged_ranges.append((start, end))

    return merged_ranges


def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
    """
    Create a hook to freeze parameters in specified ranges by setting their gradients to zero.

    This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
    two integers representing the start and end indices of the range.

    Parameters:
    - ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.

    Returns:
    - Callable: A hook function to be used with `register_hook` on parameters.

    Example usage:
    ```
    ranges_to_freeze = [(0, 10), (20, 30)]
    hook = _create_freeze_parameters_hook(ranges_to_freeze)
    model.register_hook(hook)
    ```
    """

    def freeze_parameters_hook(gradients):
        for start, end in ranges_to_freeze:
            gradients[start:end].zero_()

    return freeze_parameters_hook


class LayerNamePattern:
    """
    Represents a regex pattern for layer names, potentially including a parameter index range.
    """

    def __init__(self, pattern: str):
        """
        Initializes a new instance of the LayerNamePattern class.

        Parameters:
        - pattern (str): The regex pattern for layer names, potentially including a parameter index range.
        """
        self.raw_pattern = pattern
        name_pattern, self.range = self._parse_pattern(pattern)
        self.name_regex = re.compile(name_pattern.replace(".", "\\."))

    def match(self, name: str) -> bool:
        """
        Checks if the given layer name matches the regex pattern.

        Parameters:
        - name (str): The layer name to check.

        Returns:
        - bool: True if the layer name matches the pattern, False otherwise.
        """
        return self.name_regex.match(name) is not None

    def _parse_pattern(
        self, pattern: str
    ) -> Tuple[str, Union[Tuple[int, Union[int, None]], None]]:
        """
        Extracts the range pattern from the given pattern.

        Parameters:
        - pattern (str): The pattern to extract the range from.

        Returns:
        - Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
        """
        match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
        if not match:
            return pattern, None

        base_pattern, start_part, end_part = match.groups()

        if end_part is None and start_part.isdecimal():
            index = int(start_part)
            return base_pattern, (index, index + 1)

        # [:end] or [start:] or [start:end]
        start = int(start_part) if start_part else 0
        end = int(end_part) if end_part else None

        if end is not None and start >= end:
            raise ValueError(
                f"Invalid range in layer name pattern: {pattern}."
                "End of range must be greater than start."
            )
        return base_pattern, (start, end)