Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,651 Bytes
07c6a04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
class T5EncoderPolicy(Policy):
def config_sanity_check(self):
assert not self.shard_config.enable_tensor_parallelism
assert not self.shard_config.enable_flash_attention
def preprocess(self):
return self.model
def module_policy(self):
from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
policy = {}
# check whether apex is installed
try:
from apex.normalization import FusedRMSNorm # noqa
from videosys.core.shardformer.t5.modeling import T5LayerNorm
# recover hf from fused rms norm to T5 norm which is faster
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=T5LayerNorm,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm),
policy=policy,
target_key=T5LayerSelfAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm),
policy=policy,
target_key=T5Stack,
)
except (ImportError, ModuleNotFoundError):
pass
# use jit operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_T5_layer_ff_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_method_replacement(
description={
"forward": get_T5_layer_self_attention_forward(),
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=T5LayerSelfAttention,
)
return policy
def postprocess(self):
return self.model
|