seungduk commited on
Commit
05bcc9e
·
unverified ·
1 Parent(s): 3bd8203

Train parameters exclusively in specific ranges (#1390)

Browse files

* Train parameters exclusively in specific ranges

* Fix the style and update docs

* Update yaml example

examples/mistral/mixtral.yml CHANGED
@@ -16,12 +16,12 @@ output_dir: ./qlora-out
16
 
17
  ## You can optionally freeze the entire model and unfreeze a subset of parameters
18
  unfrozen_parameters:
19
- # - lm_head.*
20
- # - model.embed_tokens.*
21
- # - model.layers.2[0-9]+.block_sparse_moe.gate.*
22
- # - model.layers.2[0-9]+.block_sparse_moe.experts.*
23
- # - model.layers.3[0-9]+.block_sparse_moe.gate.*
24
- # - model.layers.3[0-9]+.block_sparse_moe.experts.*
25
 
26
  model_config:
27
  output_router_logits: true
 
16
 
17
  ## You can optionally freeze the entire model and unfreeze a subset of parameters
18
  unfrozen_parameters:
19
+ # - ^lm_head.weight$
20
+ # - ^model.embed_tokens.weight$[:32000]
21
+ # - model.layers.2[0-9]+.block_sparse_moe.gate
22
+ # - model.layers.2[0-9]+.block_sparse_moe.experts
23
+ # - model.layers.3[0-9]+.block_sparse_moe.gate
24
+ # - model.layers.3[0-9]+.block_sparse_moe.experts
25
 
26
  model_config:
27
  output_router_logits: true
src/axolotl/train.py CHANGED
@@ -19,7 +19,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
19
  from axolotl.common.cli import TrainerCliArgs
20
  from axolotl.logging_config import configure_logging
21
  from axolotl.utils.dict import DictDefault
22
- from axolotl.utils.freeze import freeze_parameters_except
23
  from axolotl.utils.models import load_model, load_tokenizer
24
  from axolotl.utils.trainer import setup_trainer
25
 
@@ -99,7 +99,7 @@ def train(
99
  safe_serialization = cfg.save_safetensors is True
100
 
101
  if cfg.unfrozen_parameters:
102
- freeze_parameters_except(model, cfg.unfrozen_parameters)
103
 
104
  trainer = setup_trainer(
105
  cfg,
 
19
  from axolotl.common.cli import TrainerCliArgs
20
  from axolotl.logging_config import configure_logging
21
  from axolotl.utils.dict import DictDefault
22
+ from axolotl.utils.freeze import freeze_layers_except
23
  from axolotl.utils.models import load_model, load_tokenizer
24
  from axolotl.utils.trainer import setup_trainer
25
 
 
99
  safe_serialization = cfg.save_safetensors is True
100
 
101
  if cfg.unfrozen_parameters:
102
+ freeze_layers_except(model, cfg.unfrozen_parameters)
103
 
104
  trainer = setup_trainer(
105
  cfg,
src/axolotl/utils/freeze.py CHANGED
@@ -3,13 +3,14 @@ module to freeze/unfreeze parameters by name
3
  """
4
  import logging
5
  import re
 
6
 
7
  from axolotl.utils.distributed import is_main_process
8
 
9
  LOG = logging.getLogger("axolotl.utils.freeze")
10
 
11
 
12
- def freeze_parameters_except(model, regex_patterns):
13
  """
14
  Freezes all layers of the given model except for the layers that match given regex patterns.
15
  Periods in the patterns are treated as literal periods, not as wildcard characters.
@@ -17,22 +18,209 @@ def freeze_parameters_except(model, regex_patterns):
17
  Parameters:
18
  - model (nn.Module): The PyTorch model to be modified.
19
  - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
 
 
 
 
20
 
21
  Returns:
22
  None; the model is modified in place.
23
  """
24
- # Escape periods and compile the regex patterns
25
- compiled_patterns = [
26
- re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
27
- ]
28
 
29
- # First, freeze all parameters in the model
30
- for param in model.parameters():
31
- param.requires_grad = False
32
 
33
  # Unfreeze layers that match the regex patterns
34
  for name, param in model.named_parameters():
35
- if any(pattern.match(name) for pattern in compiled_patterns):
36
- if is_main_process():
37
- LOG.debug(f"unfreezing {name}")
 
 
 
38
  param.requires_grad = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  """
4
  import logging
5
  import re
6
+ from typing import Callable, List, Tuple
7
 
8
  from axolotl.utils.distributed import is_main_process
9
 
10
  LOG = logging.getLogger("axolotl.utils.freeze")
11
 
12
 
13
+ def freeze_layers_except(model, regex_patterns):
14
  """
15
  Freezes all layers of the given model except for the layers that match given regex patterns.
16
  Periods in the patterns are treated as literal periods, not as wildcard characters.
 
18
  Parameters:
19
  - model (nn.Module): The PyTorch model to be modified.
20
  - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
21
+ Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
22
+ Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
23
+ The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
24
+ E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]
25
 
26
  Returns:
27
  None; the model is modified in place.
28
  """
29
+ if isinstance(regex_patterns, str):
30
+ regex_patterns = [regex_patterns]
 
 
31
 
32
+ patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]
 
 
33
 
34
  # Unfreeze layers that match the regex patterns
35
  for name, param in model.named_parameters():
36
+ param.requires_grad = False
37
+ unfrozen_ranges = []
38
+ for pattern in patterns:
39
+ if not pattern.match(name):
40
+ continue
41
+
42
  param.requires_grad = True
43
+
44
+ if pattern.range is not None:
45
+ unfrozen_ranges.append(pattern.range)
46
+
47
+ merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))
48
+
49
+ if param.requires_grad and is_main_process():
50
+ unfrozen_ranges = (
51
+ f" with ranges {merged_unfrozen_ranges}"
52
+ if merged_unfrozen_ranges
53
+ else ""
54
+ )
55
+ LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")
56
+
57
+ if not merged_unfrozen_ranges:
58
+ continue
59
+
60
+ # The range list we need is actually the inverted of the merged ranges
61
+ ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))
62
+
63
+ param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))
64
+
65
+ if is_main_process() and all(
66
+ not param.requires_grad for param in model.parameters()
67
+ ):
68
+ LOG.warning("All parameters are frozen. Model will not be trained.")
69
+
70
+
71
+ def _invert_ranges(
72
+ given_ranges: List[Tuple[int, int]], layer_size: int
73
+ ) -> List[Tuple[int, int]]:
74
+ """
75
+ Inverts a list of ranges to obtain the ranges not covered by the given ranges.
76
+
77
+ Parameters:
78
+ - given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
79
+ - layer_size (int): The length of the layer. E.g., len(model.layer.weight)
80
+ Returns:
81
+ - List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
82
+ """
83
+ if not given_ranges:
84
+ return [(0, layer_size)]
85
+
86
+ inverted_ranges = []
87
+ current_start = 0
88
+
89
+ for start, end in sorted(given_ranges):
90
+ if start > current_start:
91
+ inverted_ranges.append((current_start, start))
92
+ current_start = max(current_start, end)
93
+
94
+ # Handle the case where the last given range does not reach the end of the total_size
95
+ if current_start < layer_size:
96
+ inverted_ranges.append((current_start, layer_size))
97
+
98
+ return inverted_ranges
99
+
100
+
101
+ def _merge_ranges(
102
+ given_ranges: List[Tuple[int, int | None]], layer_size: int
103
+ ) -> List[Tuple[int, int]]:
104
+ """
105
+ Merges overlapping ranges and sorts the given ranges.
106
+
107
+ This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
108
+ as tuples, where the first element is the start index (inclusive) and the second element is the end
109
+ index (exclusive). The end index can be None, indicating that the range extends to the end of the
110
+ sequence.
111
+
112
+ Parameters:
113
+ - given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
114
+ - layer_size (int): The length of the layer. E.g., len(model.layer.weight)
115
+
116
+ Returns:
117
+ - List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
118
+ """
119
+ # End of each range can be determined now since we have the total size
120
+ processed_ranges = [
121
+ (start, end if end is not None else layer_size) for start, end in given_ranges
122
+ ]
123
+
124
+ # No need to merge if there's only one or no ranges
125
+ if len(processed_ranges) <= 1:
126
+ return processed_ranges
127
+
128
+ sorted_ranges = sorted(processed_ranges)
129
+
130
+ merged_ranges = [sorted_ranges[0]]
131
+ for start, end in sorted_ranges[1:]:
132
+ prev_start, prev_end = merged_ranges[-1]
133
+ if start <= prev_end:
134
+ merged_ranges[-1] = (prev_start, max(prev_end, end))
135
+ else:
136
+ merged_ranges.append((start, end))
137
+
138
+ return merged_ranges
139
+
140
+
141
+ def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
142
+ """
143
+ Create a hook to freeze parameters in specified ranges by setting their gradients to zero.
144
+
145
+ This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
146
+ two integers representing the start and end indices of the range.
147
+
148
+ Parameters:
149
+ - ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.
150
+
151
+ Returns:
152
+ - Callable: A hook function to be used with `register_hook` on parameters.
153
+
154
+ Example usage:
155
+ ```
156
+ ranges_to_freeze = [(0, 10), (20, 30)]
157
+ hook = _create_freeze_parameters_hook(ranges_to_freeze)
158
+ model.register_hook(hook)
159
+ ```
160
+ """
161
+
162
+ def freeze_parameters_hook(gradients):
163
+ for start, end in ranges_to_freeze:
164
+ gradients[start:end].zero_()
165
+
166
+ return freeze_parameters_hook
167
+
168
+
169
+ class LayerNamePattern:
170
+ """
171
+ Represents a regex pattern for layer names, potentially including a parameter index range.
172
+ """
173
+
174
+ def __init__(self, pattern: str):
175
+ """
176
+ Initializes a new instance of the LayerNamePattern class.
177
+
178
+ Parameters:
179
+ - pattern (str): The regex pattern for layer names, potentially including a parameter index range.
180
+ """
181
+ self.raw_pattern = pattern
182
+ name_pattern, self.range = self._parse_pattern(pattern)
183
+ self.name_regex = re.compile(name_pattern.replace(".", "\\."))
184
+
185
+ def match(self, name: str) -> bool:
186
+ """
187
+ Checks if the given layer name matches the regex pattern.
188
+
189
+ Parameters:
190
+ - name (str): The layer name to check.
191
+
192
+ Returns:
193
+ - bool: True if the layer name matches the pattern, False otherwise.
194
+ """
195
+ return self.name_regex.match(name) is not None
196
+
197
+ def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
198
+ """
199
+ Extracts the range pattern from the given pattern.
200
+
201
+ Parameters:
202
+ - pattern (str): The pattern to extract the range from.
203
+
204
+ Returns:
205
+ - Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
206
+ """
207
+ match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
208
+ if not match:
209
+ return pattern, None
210
+
211
+ base_pattern, start_part, end_part = match.groups()
212
+
213
+ if end_part is None and start_part.isdecimal():
214
+ index = int(start_part)
215
+ return base_pattern, (index, index + 1)
216
+
217
+ # [:end] or [start:] or [start:end]
218
+ start = int(start_part) if start_part else 0
219
+ end = int(end_part) if end_part else None
220
+
221
+ if end is not None and start >= end:
222
+ raise ValueError(
223
+ f"Invalid range in layer name pattern: {pattern}."
224
+ "End of range must be greater than start."
225
+ )
226
+ return base_pattern, (start, end)
tests/test_freeze.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains unit tests for the `freeze_layers_except` function.
3
+
4
+ The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers.
5
+ The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios.
6
+ """
7
+
8
+ import unittest
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+ from axolotl.utils.freeze import freeze_layers_except
14
+
15
+ ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
16
+ ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
17
+
18
+
19
+ class TestFreezeLayersExcept(unittest.TestCase):
20
+ """
21
+ A test case class for the `freeze_layers_except` function.
22
+ """
23
+
24
+ def setUp(self):
25
+ self.model = _TestModel()
26
+
27
+ def test_freeze_layers_with_dots_in_name(self):
28
+ freeze_layers_except(self.model, ["features.layer"])
29
+ self.assertTrue(
30
+ self.model.features.layer.weight.requires_grad,
31
+ "model.features.layer should be trainable.",
32
+ )
33
+ self.assertFalse(
34
+ self.model.classifier.weight.requires_grad,
35
+ "model.classifier should be frozen.",
36
+ )
37
+
38
+ def test_freeze_layers_without_dots_in_name(self):
39
+ freeze_layers_except(self.model, ["classifier"])
40
+ self.assertFalse(
41
+ self.model.features.layer.weight.requires_grad,
42
+ "model.features.layer should be trainable.",
43
+ )
44
+ self.assertTrue(
45
+ self.model.classifier.weight.requires_grad,
46
+ "model.classifier should be frozen.",
47
+ )
48
+
49
+ def test_freeze_layers_regex_patterns(self):
50
+ # The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'.
51
+ freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"])
52
+ self.assertTrue(
53
+ self.model.features.layer.weight.requires_grad,
54
+ "model.features.layer should be trainable.",
55
+ )
56
+ self.assertFalse(
57
+ self.model.classifier.weight.requires_grad,
58
+ "model.classifier should be frozen.",
59
+ )
60
+
61
+ def test_all_layers_frozen(self):
62
+ freeze_layers_except(self.model, [])
63
+ self.assertFalse(
64
+ self.model.features.layer.weight.requires_grad,
65
+ "model.features.layer should be frozen.",
66
+ )
67
+ self.assertFalse(
68
+ self.model.classifier.weight.requires_grad,
69
+ "model.classifier should be frozen.",
70
+ )
71
+
72
+ def test_all_layers_unfrozen(self):
73
+ freeze_layers_except(self.model, ["features.layer", "classifier"])
74
+ self.assertTrue(
75
+ self.model.features.layer.weight.requires_grad,
76
+ "model.features.layer should be trainable.",
77
+ )
78
+ self.assertTrue(
79
+ self.model.classifier.weight.requires_grad,
80
+ "model.classifier should be trainable.",
81
+ )
82
+
83
+ def test_freeze_layers_with_range_pattern_start_end(self):
84
+ freeze_layers_except(self.model, ["features.layer[1:5]"])
85
+ self.assertTrue(
86
+ self.model.features.layer.weight.requires_grad,
87
+ "model.features.layer should be trainable.",
88
+ )
89
+ self.assertFalse(
90
+ self.model.classifier.weight.requires_grad,
91
+ "model.classifier should be frozen.",
92
+ )
93
+
94
+ self._assert_gradient_output(
95
+ [
96
+ ZERO,
97
+ ONE_TO_TEN,
98
+ ONE_TO_TEN,
99
+ ONE_TO_TEN,
100
+ ONE_TO_TEN,
101
+ ZERO,
102
+ ZERO,
103
+ ZERO,
104
+ ZERO,
105
+ ZERO,
106
+ ]
107
+ )
108
+
109
+ def test_freeze_layers_with_range_pattern_single_index(self):
110
+ freeze_layers_except(self.model, ["features.layer[5]"])
111
+ self.assertTrue(
112
+ self.model.features.layer.weight.requires_grad,
113
+ "model.features.layer should be trainable.",
114
+ )
115
+ self.assertFalse(
116
+ self.model.classifier.weight.requires_grad,
117
+ "model.classifier should be frozen.",
118
+ )
119
+
120
+ self._assert_gradient_output(
121
+ [ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO]
122
+ )
123
+
124
+ def test_freeze_layers_with_range_pattern_start_omitted(self):
125
+ freeze_layers_except(self.model, ["features.layer[:5]"])
126
+ self.assertTrue(
127
+ self.model.features.layer.weight.requires_grad,
128
+ "model.features.layer should be trainable.",
129
+ )
130
+ self.assertFalse(
131
+ self.model.classifier.weight.requires_grad,
132
+ "model.classifier should be frozen.",
133
+ )
134
+
135
+ self._assert_gradient_output(
136
+ [
137
+ ONE_TO_TEN,
138
+ ONE_TO_TEN,
139
+ ONE_TO_TEN,
140
+ ONE_TO_TEN,
141
+ ONE_TO_TEN,
142
+ ZERO,
143
+ ZERO,
144
+ ZERO,
145
+ ZERO,
146
+ ZERO,
147
+ ]
148
+ )
149
+
150
+ def test_freeze_layers_with_range_pattern_end_omitted(self):
151
+ freeze_layers_except(self.model, ["features.layer[4:]"])
152
+ self.assertTrue(
153
+ self.model.features.layer.weight.requires_grad,
154
+ "model.features.layer should be trainable.",
155
+ )
156
+ self.assertFalse(
157
+ self.model.classifier.weight.requires_grad,
158
+ "model.classifier should be frozen.",
159
+ )
160
+
161
+ self._assert_gradient_output(
162
+ [
163
+ ZERO,
164
+ ZERO,
165
+ ZERO,
166
+ ZERO,
167
+ ONE_TO_TEN,
168
+ ONE_TO_TEN,
169
+ ONE_TO_TEN,
170
+ ONE_TO_TEN,
171
+ ONE_TO_TEN,
172
+ ONE_TO_TEN,
173
+ ]
174
+ )
175
+
176
+ def test_freeze_layers_with_range_pattern_merge_included(self):
177
+ freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"])
178
+ self.assertTrue(
179
+ self.model.features.layer.weight.requires_grad,
180
+ "model.features.layer should be trainable.",
181
+ )
182
+ self.assertFalse(
183
+ self.model.classifier.weight.requires_grad,
184
+ "model.classifier should be frozen.",
185
+ )
186
+
187
+ self._assert_gradient_output(
188
+ [
189
+ ZERO,
190
+ ZERO,
191
+ ZERO,
192
+ ZERO,
193
+ ONE_TO_TEN,
194
+ ONE_TO_TEN,
195
+ ONE_TO_TEN,
196
+ ONE_TO_TEN,
197
+ ONE_TO_TEN,
198
+ ONE_TO_TEN,
199
+ ]
200
+ )
201
+
202
+ def test_freeze_layers_with_range_pattern_merge_intersect(self):
203
+ freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"])
204
+ self.assertTrue(
205
+ self.model.features.layer.weight.requires_grad,
206
+ "model.features.layer should be trainable.",
207
+ )
208
+ self.assertFalse(
209
+ self.model.classifier.weight.requires_grad,
210
+ "model.classifier should be frozen.",
211
+ )
212
+
213
+ self._assert_gradient_output(
214
+ [
215
+ ZERO,
216
+ ZERO,
217
+ ZERO,
218
+ ZERO,
219
+ ONE_TO_TEN,
220
+ ONE_TO_TEN,
221
+ ONE_TO_TEN,
222
+ ONE_TO_TEN,
223
+ ZERO,
224
+ ZERO,
225
+ ]
226
+ )
227
+
228
+ def test_freeze_layers_with_range_pattern_merge_separate(self):
229
+ freeze_layers_except(
230
+ self.model,
231
+ ["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"],
232
+ )
233
+ self.assertTrue(
234
+ self.model.features.layer.weight.requires_grad,
235
+ "model.features.layer should be trainable.",
236
+ )
237
+ self.assertFalse(
238
+ self.model.classifier.weight.requires_grad,
239
+ "model.classifier should be frozen.",
240
+ )
241
+
242
+ self._assert_gradient_output(
243
+ [
244
+ ZERO,
245
+ ONE_TO_TEN,
246
+ ZERO,
247
+ ONE_TO_TEN,
248
+ ZERO,
249
+ ONE_TO_TEN,
250
+ ZERO,
251
+ ZERO,
252
+ ZERO,
253
+ ZERO,
254
+ ]
255
+ )
256
+
257
+ def _assert_gradient_output(self, expected):
258
+ input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32)
259
+
260
+ self.model.features.layer.weight.grad = None # Reset gradients
261
+ output = self.model.features.layer(input_tensor)
262
+ loss = output.sum()
263
+ loss.backward()
264
+
265
+ expected_grads = torch.tensor(expected)
266
+ torch.testing.assert_close(
267
+ self.model.features.layer.weight.grad, expected_grads
268
+ )
269
+
270
+
271
+ class _SubLayerModule(nn.Module):
272
+ def __init__(self):
273
+ super().__init__()
274
+ self.layer = nn.Linear(10, 10)
275
+
276
+
277
+ class _TestModel(nn.Module):
278
+ def __init__(self):
279
+ super().__init__()
280
+ self.features = _SubLayerModule()
281
+ self.classifier = nn.Linear(10, 2)
282
+
283
+
284
+ if __name__ == "__main__":
285
+ unittest.main()