winglian commited on
Commit
3355706
·
unverified ·
1 Parent(s): daa4fac

Add support for GPTQ using native transformers/peft (#468)

Browse files

* auto gptq support

* more tweaks and add yml

* remove old gptq docker

* don't need explicit peft install for tests

* fix setup.py to use extra index url

install torch for tests
fix cuda version for autogptq index
set torch in requirements so that it installs properly
move gptq install around to work with github cicd

* gptq doesn't play well with sample packing

* address pr feedback

* remove torch install for now

* set quantization_config from model config

* Fix the implementation for getting quant config from model config

.github/workflows/main.yml CHANGED
@@ -23,11 +23,6 @@ jobs:
23
  python_version: "3.10"
24
  pytorch: 2.0.1
25
  axolotl_extras:
26
- - cuda: 118
27
- cuda_version: 11.8.0
28
- python_version: "3.9"
29
- pytorch: 2.0.1
30
- axolotl_extras: gptq
31
  runs-on: self-hosted
32
  steps:
33
  - name: Checkout
@@ -73,11 +68,6 @@ jobs:
73
  pytorch: 2.0.1
74
  axolotl_extras:
75
  is_latest: true
76
- - cuda: 118
77
- cuda_version: 11.8.0
78
- python_version: "3.9"
79
- pytorch: 2.0.1
80
- axolotl_extras: gptq
81
  runs-on: self-hosted
82
  steps:
83
  - name: Checkout
 
23
  python_version: "3.10"
24
  pytorch: 2.0.1
25
  axolotl_extras:
 
 
 
 
 
26
  runs-on: self-hosted
27
  steps:
28
  - name: Checkout
 
68
  pytorch: 2.0.1
69
  axolotl_extras:
70
  is_latest: true
 
 
 
 
 
71
  runs-on: self-hosted
72
  steps:
73
  - name: Checkout
.github/workflows/tests.yml CHANGED
@@ -24,7 +24,7 @@ jobs:
24
 
25
  - name: Install dependencies
26
  run: |
27
- pip install -e .[peft]
28
  pip install -r requirements-tests.txt
29
 
30
  - name: Run tests
 
24
 
25
  - name: Install dependencies
26
  run: |
27
+ pip install -e .
28
  pip install -r requirements-tests.txt
29
 
30
  - name: Run tests
docker/Dockerfile CHANGED
@@ -11,14 +11,13 @@ RUN apt-get update && \
11
 
12
  WORKDIR /workspace
13
 
14
- RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main"
15
  RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
16
  # If AXOLOTL_EXTRAS is set, append it in brackets
17
  RUN cd axolotl && \
18
  if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
19
- pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
20
  else \
21
- pip install -e .[flash-attn]; \
22
  fi
23
 
24
  # fix so that git fetch/pull from remote works
 
11
 
12
  WORKDIR /workspace
13
 
 
14
  RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
15
  # If AXOLOTL_EXTRAS is set, append it in brackets
16
  RUN cd axolotl && \
17
  if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
18
+ pip install -e .[flash-attn,gptq,$AXOLOTL_EXTRAS]; \
19
  else \
20
+ pip install -e .[flash-attn,gptq]; \
21
  fi
22
 
23
  # fix so that git fetch/pull from remote works
examples/gptq-lora-7b/README.md DELETED
@@ -1,8 +0,0 @@
1
- # LLaMa 7B using LoRA
2
-
3
- This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
4
-
5
- ```shell
6
- accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
7
-
8
- ```
 
 
 
 
 
 
 
 
 
examples/{gptq-lora-7b/config.yml → llama-2/gptq-lora.yml} RENAMED
@@ -1,63 +1,76 @@
1
- base_model: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
2
- base_model_config: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
3
- model_type: LlamaForCausalLM
4
- tokenizer_type: LlamaTokenizer
5
- trust_remote_code:
6
- load_in_8bit: true
7
  gptq: true
 
 
 
 
 
 
 
 
 
 
8
  datasets:
9
- - path: vicgalle/alpaca-gpt4
10
  type: alpaca
11
  dataset_prepared_path: last_run_prepared
12
- val_set_size: 0.02
13
- adapter:
14
  lora_model_dir:
15
- sequence_len: 2048
16
- max_packed_sequence_len:
17
  lora_r: 8
18
- lora_alpha: 16
19
  lora_dropout: 0.05
20
  lora_target_modules:
 
 
21
  - q_proj
22
  - v_proj
23
- lora_fan_in_fan_out: false
24
- wandb_project: llama-7b-lora-int4
25
- wandb_entity:
26
  wandb_watch:
27
  wandb_run_id:
28
  wandb_log_model:
29
- output_dir: ./llama-7b-lora-int4
30
  gradient_accumulation_steps: 1
31
  micro_batch_size: 1
32
  num_epochs: 3
33
- optimizer: adamw_bnb_8bit
 
 
 
34
  torchdistx_path:
35
  lr_scheduler: cosine
36
- learning_rate: 0.0000002
 
37
  train_on_inputs: false
38
  group_by_length: false
39
- fp16: true
40
  bf16: false
 
 
41
  tf32: true
 
42
  early_stopping_patience:
43
  resume_from_checkpoint:
44
  local_rank:
45
- logging_steps: 5
46
  xformers_attention:
47
  flash_attention:
48
- gradient_checkpointing: true
49
- gptq_groupsize: 128
50
- gptq_model_v1: false
51
- warmup_steps: 20
52
- eval_steps: 110
53
- save_steps: 660
 
54
  debug:
55
  deepspeed:
56
- weight_decay: 0.0001
57
- fsdp:
58
- fsdp_config:
59
- tokens:
60
- pad_token: "<pad>"
61
  bos_token: "<s>"
62
  eos_token: "</s>"
63
  unk_token: "<unk>"
 
1
+ base_model: TheBloke/Llama-2-7B-GPTQ
2
+ base_model_config: TheBloke/Llama-2-7B-GPTQ
3
+ is_llama_derived_model: false
 
 
 
4
  gptq: true
5
+ gptq_bits: 4
6
+ model_type: AutoModelForCausalLM
7
+ tokenizer_type: LlamaTokenizer
8
+ tokenizer_use_fast: true
9
+ tokenizer_legacy: true
10
+ load_in_8bit: false
11
+ load_in_4bit: false
12
+ strict: false
13
+ push_dataset_to_hub:
14
+ hf_use_auth_token: true
15
  datasets:
16
+ - path: mhenrichsen/alpaca_2k_test
17
  type: alpaca
18
  dataset_prepared_path: last_run_prepared
19
+ val_set_size: 0.01
20
+ adapter: lora
21
  lora_model_dir:
22
+ sequence_len: 4096
23
+ sample_packing:
24
  lora_r: 8
25
+ lora_alpha: 32
26
  lora_dropout: 0.05
27
  lora_target_modules:
28
+ - k_proj
29
+ - o_proj
30
  - q_proj
31
  - v_proj
32
+ lora_target_linear:
33
+ lora_fan_in_fan_out:
34
+ wandb_project:
35
  wandb_watch:
36
  wandb_run_id:
37
  wandb_log_model:
38
+ output_dir: ./model-out
39
  gradient_accumulation_steps: 1
40
  micro_batch_size: 1
41
  num_epochs: 3
42
+ optimizer: adamw_torch
43
+ adam_beta2: 0.95
44
+ adam_eps: 0.00001
45
+ max_grad_norm: 1.0
46
  torchdistx_path:
47
  lr_scheduler: cosine
48
+ lr_quadratic_warmup: true
49
+ learning_rate: 0.000017
50
  train_on_inputs: false
51
  group_by_length: false
 
52
  bf16: false
53
+ fp16: false
54
+ float16: true
55
  tf32: true
56
+ gradient_checkpointing: true
57
  early_stopping_patience:
58
  resume_from_checkpoint:
59
  local_rank:
60
+ logging_steps: 1
61
  xformers_attention:
62
  flash_attention:
63
+ sdp_attention:
64
+ flash_optimum:
65
+ gptq_groupsize:
66
+ gptq_model_v1:
67
+ warmup_steps: 100
68
+ eval_steps:
69
+ save_steps:
70
  debug:
71
  deepspeed:
72
+ weight_decay: 0.1
73
+ special_tokens:
 
 
 
74
  bos_token: "<s>"
75
  eos_token: "</s>"
76
  unk_token: "<unk>"
requirements.txt CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  packaging
2
  peft @ git+https://github.com/huggingface/peft.git
3
  transformers @ git+https://github.com/huggingface/transformers.git
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
3
+ torch==2.0.1
4
+ auto-gptq
5
  packaging
6
  peft @ git+https://github.com/huggingface/peft.git
7
  transformers @ git+https://github.com/huggingface/transformers.git
setup.py CHANGED
@@ -2,15 +2,27 @@
2
 
3
  from setuptools import find_packages, setup
4
 
5
- install_requires = []
6
- with open("./requirements.txt", encoding="utf-8") as requirements_file:
7
- # don't include peft yet until we check the int4
8
- # need to manually install peft for now...
9
- reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
10
- reqs = [r for r in reqs if "flash-attn" not in r]
11
- reqs = [r for r in reqs if r and r[0] != "#"]
12
- for r in reqs:
13
- install_requires.append(r)
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  setup(
16
  name="axolotl",
@@ -19,12 +31,10 @@ setup(
19
  package_dir={"": "src"},
20
  packages=find_packages(),
21
  install_requires=install_requires,
 
22
  extras_require={
23
  "gptq": [
24
- "alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
25
- ],
26
- "gptq_triton": [
27
- "alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
28
  ],
29
  "flash-attn": [
30
  "flash-attn==2.0.8",
@@ -32,8 +42,5 @@ setup(
32
  "extras": [
33
  "deepspeed",
34
  ],
35
- "peft": [
36
- "peft @ git+https://github.com/huggingface/peft.git",
37
- ],
38
  },
39
  )
 
2
 
3
  from setuptools import find_packages, setup
4
 
5
+
6
+ def parse_requirements():
7
+ _install_requires = []
8
+ _dependency_links = []
9
+ with open("./requirements.txt", encoding="utf-8") as requirements_file:
10
+ lines = [
11
+ r.strip() for r in requirements_file.readlines() if "auto-gptq" not in r
12
+ ]
13
+ for line in lines:
14
+ if line.startswith("--extra-index-url"):
15
+ # Handle custom index URLs
16
+ _, url = line.split()
17
+ _dependency_links.append(url)
18
+ elif "flash-attn" not in line and line and line[0] != "#":
19
+ # Handle standard packages
20
+ _install_requires.append(line)
21
+ return _install_requires, _dependency_links
22
+
23
+
24
+ install_requires, dependency_links = parse_requirements()
25
+
26
 
27
  setup(
28
  name="axolotl",
 
31
  package_dir={"": "src"},
32
  packages=find_packages(),
33
  install_requires=install_requires,
34
+ dependency_links=dependency_links,
35
  extras_require={
36
  "gptq": [
37
+ "auto-gptq",
 
 
 
38
  ],
39
  "flash-attn": [
40
  "flash-attn==2.0.8",
 
42
  "extras": [
43
  "deepspeed",
44
  ],
 
 
 
45
  },
46
  )
src/axolotl/utils/config.py CHANGED
@@ -108,9 +108,7 @@ def validate_config(cfg):
108
  "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
109
  )
110
  if cfg.load_4bit:
111
- raise ValueError(
112
- "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
113
- )
114
 
115
  if cfg.adapter == "qlora":
116
  if cfg.merge_lora:
 
108
  "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
109
  )
110
  if cfg.load_4bit:
111
+ raise ValueError("cfg.load_4bit parameter has been deprecated")
 
 
112
 
113
  if cfg.adapter == "qlora":
114
  if cfg.merge_lora:
src/axolotl/utils/models.py CHANGED
@@ -4,19 +4,19 @@
4
  import logging
5
  import math
6
  import os
7
- from pathlib import Path
8
  from typing import Optional, Tuple # noqa: F401
9
 
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
  from optimum.bettertransformer import BetterTransformer
14
- from peft import PeftConfig
15
  from transformers import ( # noqa: F401
16
  AutoConfig,
17
  AutoModelForCausalLM,
18
  AutoTokenizer,
19
  BitsAndBytesConfig,
 
20
  LlamaConfig,
21
  PreTrainedModel,
22
  PreTrainedTokenizerBase,
@@ -155,32 +155,17 @@ def load_model(
155
  LOG.info("patching _expand_mask")
156
  hijack_expand_mask()
157
 
158
- try:
159
- if cfg.gptq:
160
- from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
161
- replace_peft_model_with_int4_lora_model,
162
- )
163
-
164
- replace_peft_model_with_int4_lora_model()
165
- except Exception as err:
166
- LOG.exception(err)
167
- raise err
168
-
169
- if not cfg.gptq and (
170
- (cfg.adapter == "lora" and load_in_8bit)
171
- or (cfg.adapter == "qlora" and cfg.load_in_4bit)
172
- ):
173
- try:
174
- from peft import prepare_model_for_kbit_training
175
- except ImportError:
176
- # For backward compatibility
177
- from peft import (
178
- prepare_model_for_int8_training as prepare_model_for_kbit_training,
179
- )
180
-
181
  model_kwargs = {}
182
  if cfg.model_revision:
183
  model_kwargs["revision"] = cfg.model_revision
 
 
 
 
 
 
 
 
184
  if cfg.adapter == "qlora" and cfg.load_in_4bit:
185
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
186
  load_in_4bit=True,
@@ -191,45 +176,7 @@ def load_model(
191
  bnb_4bit_quant_type="nf4",
192
  )
193
  try:
194
- if cfg.gptq and cfg.is_llama_derived_model:
195
- from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
196
- from huggingface_hub import snapshot_download
197
-
198
- try:
199
- snapshot_download_kwargs = {}
200
- if cfg.base_model_ignore_patterns:
201
- snapshot_download_kwargs[
202
- "ignore_patterns"
203
- ] = cfg.base_model_ignore_patterns
204
- cache_model_path = Path(
205
- snapshot_download(base_model, **snapshot_download_kwargs)
206
- )
207
- files = (
208
- list(cache_model_path.glob("*.pt"))
209
- + list(cache_model_path.glob("*.safetensors"))
210
- + list(cache_model_path.glob("*.bin"))
211
- )
212
- if len(files) > 0:
213
- model_path = str(files[0])
214
- else:
215
- LOG.warning(
216
- "unable to find a cached model file, this will likely fail..."
217
- )
218
- model_path = str(cache_model_path)
219
- except Exception: # pylint: disable=broad-exception-caught
220
- model_path = cfg.base_model
221
- model, _ = load_llama_model_4bit_low_ram(
222
- base_model_config if base_model_config else base_model,
223
- model_path,
224
- device_map=cfg.device_map,
225
- half=cfg.fp16,
226
- groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
227
- is_v1_model=cfg.gptq_model_v1
228
- if cfg.gptq_model_v1 is not None
229
- else True,
230
- )
231
- load_in_8bit = False
232
- elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
233
  from transformers import LlamaForCausalLM
234
 
235
  config_kwargs = {}
@@ -275,15 +222,24 @@ def load_model(
275
  # )
276
  # model.train() # sets to train instead of eval mode
277
  elif model_type and not cfg.trust_remote_code:
278
- model = getattr(transformers, model_type).from_pretrained(
279
- base_model,
280
- device_map=cfg.device_map,
281
- load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
282
- load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
283
- torch_dtype=cfg.torch_dtype,
284
- trust_remote_code=cfg.trust_remote_code or False,
285
- **model_kwargs,
286
- )
 
 
 
 
 
 
 
 
 
287
  else:
288
  config = AutoConfig.from_pretrained(
289
  base_model,
@@ -359,11 +315,12 @@ def load_model(
359
  module.to(torch.float32)
360
 
361
  needs_fa2_dtype = cfg.adapter or cfg.fsdp
362
- if not cfg.gptq and (
363
- (cfg.adapter == "lora" and load_in_8bit)
364
- or (cfg.adapter == "qlora" and cfg.load_in_4bit)
365
  ):
366
  LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
 
 
367
  model = prepare_model_for_kbit_training(
368
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
369
  )
@@ -385,22 +342,10 @@ def load_model(
385
  if cfg.ddp and not load_in_8bit:
386
  model.to(f"cuda:{cfg.local_rank}")
387
 
388
- if cfg.gptq:
389
- # Scales to half
390
- LOG.info("Fitting 4bit scales and zeros to half")
391
- for _, module in model.named_modules():
392
- if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
393
- type(module)
394
- ):
395
- if hasattr(module, "is_v1_model") and module.is_v1_model:
396
- module.zeros = module.zeros.half()
397
- module.scales = module.scales.half()
398
- module.bias = module.bias.half()
399
-
400
  if (
401
  torch.cuda.device_count() > 1
402
  and int(os.getenv("WORLD_SIZE", "1")) > 1
403
- and (cfg.gptq or cfg.load_in_4bit)
404
  ):
405
  # llama is PROBABLY model parallelizable, but the default isn't that it is
406
  # so let's only set it for the 4bit, see
 
4
  import logging
5
  import math
6
  import os
 
7
  from typing import Optional, Tuple # noqa: F401
8
 
9
  import bitsandbytes as bnb
10
  import torch
11
  import transformers
12
  from optimum.bettertransformer import BetterTransformer
13
+ from peft import PeftConfig, prepare_model_for_kbit_training
14
  from transformers import ( # noqa: F401
15
  AutoConfig,
16
  AutoModelForCausalLM,
17
  AutoTokenizer,
18
  BitsAndBytesConfig,
19
+ GPTQConfig,
20
  LlamaConfig,
21
  PreTrainedModel,
22
  PreTrainedTokenizerBase,
 
155
  LOG.info("patching _expand_mask")
156
  hijack_expand_mask()
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  model_kwargs = {}
159
  if cfg.model_revision:
160
  model_kwargs["revision"] = cfg.model_revision
161
+ if cfg.gptq:
162
+ model_config = load_model_config(cfg)
163
+ if hasattr(model_config, "quantization_config"):
164
+ LOG.warning("model config does not contain quantization_config information")
165
+ else:
166
+ model_kwargs["quantization_config"] = GPTQConfig(
167
+ **model_config.quantization_config
168
+ )
169
  if cfg.adapter == "qlora" and cfg.load_in_4bit:
170
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
171
  load_in_4bit=True,
 
176
  bnb_4bit_quant_type="nf4",
177
  )
178
  try:
179
+ if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  from transformers import LlamaForCausalLM
181
 
182
  config_kwargs = {}
 
222
  # )
223
  # model.train() # sets to train instead of eval mode
224
  elif model_type and not cfg.trust_remote_code:
225
+ if cfg.gptq:
226
+ model = AutoModelForCausalLM.from_pretrained(
227
+ base_model,
228
+ device_map=cfg.device_map,
229
+ torch_dtype=cfg.torch_dtype,
230
+ trust_remote_code=cfg.trust_remote_code or False,
231
+ **model_kwargs,
232
+ )
233
+ else:
234
+ model = getattr(transformers, model_type).from_pretrained(
235
+ base_model,
236
+ device_map=cfg.device_map,
237
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
238
+ load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
239
+ torch_dtype=cfg.torch_dtype,
240
+ trust_remote_code=cfg.trust_remote_code or False,
241
+ **model_kwargs,
242
+ )
243
  else:
244
  config = AutoConfig.from_pretrained(
245
  base_model,
 
315
  module.to(torch.float32)
316
 
317
  needs_fa2_dtype = cfg.adapter or cfg.fsdp
318
+ if (cfg.adapter == "lora" and load_in_8bit) or (
319
+ cfg.adapter == "qlora" and cfg.load_in_4bit
 
320
  ):
321
  LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
322
+ if cfg.gradient_checkpointing:
323
+ model.gradient_checkpointing_enable()
324
  model = prepare_model_for_kbit_training(
325
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
326
  )
 
342
  if cfg.ddp and not load_in_8bit:
343
  model.to(f"cuda:{cfg.local_rank}")
344
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  if (
346
  torch.cuda.device_count() > 1
347
  and int(os.getenv("WORLD_SIZE", "1")) > 1
348
+ and (cfg.load_in_4bit)
349
  ):
350
  # llama is PROBABLY model parallelizable, but the default isn't that it is
351
  # so let's only set it for the 4bit, see
src/axolotl/utils/trainer.py CHANGED
@@ -514,23 +514,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
514
  training_arguments_kwargs["seed"] = cfg.seed
515
 
516
  if cfg.gradient_checkpointing:
517
- if cfg.gptq:
518
- from alpaca_lora_4bit.gradient_checkpointing import (
519
- apply_gradient_checkpointing,
520
- )
521
-
522
- gradient_checkpointing_ratio = (
523
- cfg.gradient_checkpointing_ratio
524
- if cfg.gradient_checkpointing_ratio
525
- else 1.0
526
- )
527
- apply_gradient_checkpointing(
528
- model, checkpoint_ratio=gradient_checkpointing_ratio
529
- )
530
- else:
531
- training_arguments_kwargs[
532
- "gradient_checkpointing"
533
- ] = cfg.gradient_checkpointing
534
  if cfg.fsdp:
535
  training_arguments_kwargs["fsdp"] = cfg.fsdp
536
  if cfg.fsdp_config:
 
514
  training_arguments_kwargs["seed"] = cfg.seed
515
 
516
  if cfg.gradient_checkpointing:
517
+ training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  if cfg.fsdp:
519
  training_arguments_kwargs["fsdp"] = cfg.fsdp
520
  if cfg.fsdp_config: