Spaces:
Running
on
Zero
Running
on
Zero
from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func | |
from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward | |
from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription | |
class T5EncoderPolicy(Policy): | |
def config_sanity_check(self): | |
assert not self.shard_config.enable_tensor_parallelism | |
assert not self.shard_config.enable_flash_attention | |
def preprocess(self): | |
return self.model | |
def module_policy(self): | |
from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack | |
policy = {} | |
# check whether apex is installed | |
try: | |
from apex.normalization import FusedRMSNorm # noqa | |
from videosys.core.shardformer.t5.modeling import T5LayerNorm | |
# recover hf from fused rms norm to T5 norm which is faster | |
self.append_or_create_submodule_replacement( | |
description=SubModuleReplacementDescription( | |
suffix="layer_norm", | |
target_module=T5LayerNorm, | |
), | |
policy=policy, | |
target_key=T5LayerFF, | |
) | |
self.append_or_create_submodule_replacement( | |
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm), | |
policy=policy, | |
target_key=T5LayerSelfAttention, | |
) | |
self.append_or_create_submodule_replacement( | |
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm), | |
policy=policy, | |
target_key=T5Stack, | |
) | |
except (ImportError, ModuleNotFoundError): | |
pass | |
# use jit operator | |
if self.shard_config.enable_jit_fused: | |
self.append_or_create_method_replacement( | |
description={ | |
"forward": get_jit_fused_T5_layer_ff_forward(), | |
"dropout_add": get_jit_fused_dropout_add_func(), | |
}, | |
policy=policy, | |
target_key=T5LayerFF, | |
) | |
self.append_or_create_method_replacement( | |
description={ | |
"forward": get_T5_layer_self_attention_forward(), | |
"dropout_add": get_jit_fused_dropout_add_func(), | |
}, | |
policy=policy, | |
target_key=T5LayerSelfAttention, | |
) | |
return policy | |
def postprocess(self): | |
return self.model | |