Optimization
The optimum.fx.optimization
module provides a set of torch.fx graph transformations, along with classes and functions to write your own transformations and compose them.
The transformation guide
In π€ Optimum, there are two kinds of transformations: reversible and non-reversible transformations.
Write a non-reversible transformation
The most basic case of transformations is non-reversible transformations. Those transformations cannot be reversed, meaning that after applying them to a graph module, there is no way to get the original model back. To implement such transformations in π€ Optimum, it is very easy: you just need to subclass Transformation and implement the transform() method.
For instance, the following transformation changes all the multiplications to additions:
>>> from optimum.fx.optimization import Transformation
>>> class ChangeMulToAdd(Transformation):
>>> def transform(self, graph_module):
>>> for node in graph_module.graph.nodes:
>>> if node.op == "call_function" and node.target == operator.mul:
>>> node.target = operator.add
>>> return graph_module
After implementing it, your transformation can be used as a regular function:
>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>> model,
>>> input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )
>>> transformation = ChangeMulToAdd()
>>> transformed_model = transformation(traced)
Write a reversible transformation
A reversible transformation implements both the transformation and its reverse, allowing to retrieve the original model from the transformed one. To implement such transformation, you need to subclass ReversibleTransformation and implement the transform() and reverse() methods.
For instance, the following transformation is reversible:
>>> from optimum.fx.optimization import ReversibleTransformation
>>> class MulToMulTimesTwo(ReversibleTransformation):
>>> def transform(self, graph_module):
>>> for node in graph_module.graph.nodes:
>>> if node.op == "call_function" and node.target == operator.mul:
>>> x, y = node.args
>>> node.args = (2 * x, y)
>>> return graph_module
>>> def reverse(self, graph_module):
>>> for node in graph_module.graph.nodes:
>>> if node.op == "call_function" and node.target == operator.mul:
>>> x, y = node.args
>>> node.args = (x / 2, y)
>>> return graph_module
Composing transformations together
As applying mutilple transformations in chain is needed more often that not, compose() is provided. It is an utility function that allows you to create a transformation by chaining multiple other transformations.
>>> from optimum.fx.optimization import compose
>>> composition = compose(MulToMulTimesTwo(), ChangeMulToAdd())
The Optimization API
Main classes and functions
class optimum.fx.optimization.Transformation
< source >( )
A torch.fx graph transformation.
It must implemement the transform() method, and be used as a callable.
__call__
< source >(
graph_module: GraphModule
lint_and_recompile: bool = True
)
β
torch.fx.GraphModule
Parameters
-
graph_module (
torch.fx.GraphModule
) — The module to transform. -
lint_and_recompile (
bool
, defaults toTrue
) — Whether the transformed module should be linted and recompiled. This can be set toFalse
when chaining transformations together to perform this operation only once.
Returns
torch.fx.GraphModule
The transformed module.
get_transformed_nodes
< source >(
graph_module: GraphModule
)
β
List[torch.fx.Node]
mark_as_transformed
< source >( node: Node )
Marks a node as transformed by this transformation.
transform
< source >(
graph_module: GraphModule
)
β
torch.fx.GraphModule
transformed
< source >(
node: Node
)
β
bool
class optimum.fx.optimization.ReversibleTransformation
< source >( )
A torch.fx graph transformation that is reversible.
It must implemement the transform() and reverse() methods, and be used as a callable.
__call__
< source >(
graph_module: GraphModule
lint_and_recompile: bool = True
reverse: bool = False
)
β
torch.fx.GraphModule
Parameters
-
graph_module (
torch.fx.GraphModule
) — The module to transform. -
lint_and_recompile (
bool
, defaults toTrue
) — Whether the transformed module should be linted and recompiled. This can be set toFalse
when chaining transformations together to perform this operation only once. -
reverse (
bool
, defaults toFalse
) — IfTrue
, the reverse transformation is performed.
Returns
torch.fx.GraphModule
The transformed module.
mark_as_restored
< source >( node: Node )
Marks a node as restored back to its original state.
reverse
< source >(
graph_module: GraphModule
)
β
torch.fx.GraphModule
optimum.fx.optimization.compose
< source >( *args: Transformation inplace: bool = True )
Parameters
- args (Transformation) — The transformations to compose together.
-
inplace (
bool
, defaults toTrue
) — Whether the resulting transformation should be inplace, or create a new graph module.
Composes a list of transformations together.
Example:
>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>> model,
>>> input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )
>>> composition = compose(ChangeTrueDivToMulByInverse(), MergeLinears())
>>> transformed_model = composition(traced)
Transformations
class optimum.fx.optimization.MergeLinears
< source >( )
Transformation that merges linear layers that take the same input into one big linear layer.
Example:
>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import MergeLinears
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>> model,
>>> input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )
>>> transformation = MergeLinears()
>>> transformed_model = transformation(traced)
>>> restored_model = transformation(transformed_model, reverse=True)
class optimum.fx.optimization.ChangeTrueDivToMulByInverse
< source >( )
Transformation that changes truediv nodes to multiplication by the inverse nodes when the denominator is static. For example, that is sometimes the case for the scaling factor in attention layers.
Example:
>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import ChangeTrueDivToMulByInverse
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>> model,
>>> input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )
>>> transformation = ChangeTrueDivToMulByInverse()
>>> transformed_model = transformation(traced)
>>> restored_model = transformation(transformed_model, reverse=True)