Optimum documentation


You are viewing v1.3.0 version. A newer version v1.24.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started


The optimum.fx.optimization module provides a set of torch.fx graph transformations, along with classes and functions to write your own transformations and compose them.

The transformation guide

In πŸ€— Optimum, there are two kinds of transformations: reversible and non-reversible transformations.

Write a non-reversible transformation

The most basic case of transformations is non-reversible transformations. Those transformations cannot be reversed, meaning that after applying them to a graph module, there is no way to get the original model back. To implement such transformations in πŸ€— Optimum, it is very easy: you just need to subclass Transformation and implement the transform() method.

For instance, the following transformation changes all the multiplications to additions:

>>> from optimum.fx.optimization import Transformation

>>> class ChangeMulToAdd(Transformation):
>>>     def transform(self, graph_module):
>>>         for node in graph_module.graph.nodes:
>>>             if node.op == "call_function" and node.target == operator.mul:
>>>                 node.target = operator.add
>>>         return graph_module

After implementing it, your transformation can be used as a regular function:

>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace

>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>>     model,
>>>     input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )

>>> transformation = ChangeMulToAdd()
>>> transformed_model = transformation(traced)

Write a reversible transformation

A reversible transformation implements both the transformation and its reverse, allowing to retrieve the original model from the transformed one. To implement such transformation, you need to subclass ReversibleTransformation and implement the transform() and reverse() methods.

For instance, the following transformation is reversible:

>>> from optimum.fx.optimization import ReversibleTransformation

>>> class MulToMulTimesTwo(ReversibleTransformation):
>>>     def transform(self, graph_module):
>>>         for node in graph_module.graph.nodes:
>>>             if node.op == "call_function" and node.target == operator.mul:
>>>                 x, y = node.args
>>>                 node.args = (2 * x, y)
>>>         return graph_module

>>>     def reverse(self, graph_module):
>>>         for node in graph_module.graph.nodes:
>>>             if node.op == "call_function" and node.target == operator.mul:
>>>                 x, y = node.args
>>>                 node.args = (x / 2, y)
>>>         return graph_module

Composing transformations together

As applying mutilple transformations in chain is needed more often that not, compose() is provided. It is an utility function that allows you to create a transformation by chaining multiple other transformations.

>>> from optimum.fx.optimization import compose
>>> composition = compose(MulToMulTimesTwo(), ChangeMulToAdd())

The Optimization API

Main classes and functions

class optimum.fx.optimization.Transformation

< >

( )


  • preserves_computation (bool, defaults to False) — Whether the transformation preserves the graph computation or not. If True, the original and the transformed graph should produce the same outputs.

A torch.fx graph transformation.

It must implemement the transform() method, and be used as a callable.


< >

( graph_module: GraphModule lint_and_recompile: bool = True ) β†’ torch.fx.GraphModule


  • graph_module (torch.fx.GraphModule) — The module to transform.
  • lint_and_recompile (bool, defaults to True) — Whether the transformed module should be linted and recompiled. This can be set to False when chaining transformations together to perform this operation only once.



The transformed module.


< >

( graph_module: GraphModule ) β†’ List[torch.fx.Node]


  • graph_module (torch.fx.GraphModule) — The graph_module to get the nodes from.



Gives the list of nodes that were transformed by the transformation.


< >

( node: Node )


  • node (torch.fx.Node) — The node to mark as transformed.

Marks a node as transformed by this transformation.


< >

( graph_module: GraphModule ) β†’ torch.fx.GraphModule


  • graph_module (torch.fx.GraphModule) — The module to transform.



The transformed module.


< >

( node: Node ) β†’ bool


  • node (torch.fx.Node) — The node to check.



Specifies whether the node was transformed by this transformation or not.

class optimum.fx.optimization.ReversibleTransformation

< >

( )


  • preserves_computation (bool, defaults to False) — Whether the transformation preserves the graph computation or not. If True, the original and the transformed graph should produce the same outputs.

A torch.fx graph transformation that is reversible.

It must implemement the transform() and reverse() methods, and be used as a callable.


< >

( graph_module: GraphModule lint_and_recompile: bool = True reverse: bool = False ) β†’ torch.fx.GraphModule


  • graph_module (torch.fx.GraphModule) — The module to transform.
  • lint_and_recompile (bool, defaults to True) — Whether the transformed module should be linted and recompiled. This can be set to False when chaining transformations together to perform this operation only once.
  • reverse (bool, defaults to False) — If True, the reverse transformation is performed.



The transformed module.


< >

( node: Node )


  • node (torch.fx.Node) — The node to mark as restored.

Marks a node as restored back to its original state.


< >

( graph_module: GraphModule ) β†’ torch.fx.GraphModule


  • graph_module (torch.fx.GraphModule) — The module to transform.



The reverse transformed module.


< >

( *args: Transformation inplace: bool = True )


  • args (Transformation) — The transformations to compose together.
  • inplace (bool, defaults to True) — Whether the resulting transformation should be inplace, or create a new graph module.

Composes a list of transformations together.


>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose

>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>>     model,
>>>     input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )
>>> composition = compose(ChangeTrueDivToMulByInverse(), MergeLinears())
>>> transformed_model = composition(traced)


class optimum.fx.optimization.MergeLinears

< >

( )


  • preserves_computation (bool, defaults to False) — Whether the transformation preserves the graph computation or not. If True, the original and the transformed graph should produce the same outputs.

Transformation that merges linear layers that take the same input into one big linear layer.


>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import MergeLinears

>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>>     model,
>>>     input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )
>>> transformation = MergeLinears()
>>> transformed_model = transformation(traced)
>>> restored_model = transformation(transformed_model, reverse=True)

class optimum.fx.optimization.ChangeTrueDivToMulByInverse

< >

( )


  • preserves_computation (bool, defaults to False) — Whether the transformation preserves the graph computation or not. If True, the original and the transformed graph should produce the same outputs.

Transformation that changes truediv nodes to multiplication by the inverse nodes when the denominator is static. For example, that is sometimes the case for the scaling factor in attention layers.


>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> from optimum.fx.optimization import ChangeTrueDivToMulByInverse

>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
>>>     model,
>>>     input_names=["input_ids", "attention_mask", "token_type_ids"],
>>> )
>>> transformation = ChangeTrueDivToMulByInverse()
>>> transformed_model = transformation(traced)
>>> restored_model = transformation(transformed_model, reverse=True)