Sapir commited on
Commit
4f52f00
·
1 Parent(s): 86b1a7e

Added tpu flash attention.

Browse files
xora/models/transformers/attention.py CHANGED
@@ -20,6 +20,13 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
20
  from einops import rearrange
21
  from torch import nn
22
 
 
 
 
 
 
 
 
23
  # code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
24
 
25
  logger = logging.get_logger(__name__)
@@ -162,6 +169,15 @@ class BasicTransformerBlock(nn.Module):
162
  self._chunk_size = None
163
  self._chunk_dim = 0
164
 
 
 
 
 
 
 
 
 
 
165
 
166
  def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
167
  # Sets chunk feed-forward
@@ -461,6 +477,13 @@ class Attention(nn.Module):
461
  processor = AttnProcessor2_0()
462
  self.set_processor(processor)
463
 
 
 
 
 
 
 
 
464
  def set_processor(self, processor: "AttnProcessor") -> None:
465
  r"""
466
  Set the attention processor to use.
 
20
  from einops import rearrange
21
  from torch import nn
22
 
23
+ try:
24
+ from torch_xla.experimental.custom_kernel import flash_attention
25
+ except ImportError:
26
+ # workaround for automatic tests. Currently this function is manually patched
27
+ # to the torch_xla lib on setup of container
28
+ pass
29
+
30
  # code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
31
 
32
  logger = logging.get_logger(__name__)
 
169
  self._chunk_size = None
170
  self._chunk_dim = 0
171
 
172
+ def set_use_tpu_flash_attention(self, device):
173
+ r"""
174
+ Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
175
+ attention kernel.
176
+ """
177
+ if device == "xla":
178
+ self.use_tpu_flash_attention = True
179
+ self.attn1.set_use_tpu_flash_attention(device)
180
+ self.attn2.set_use_tpu_flash_attention(device)
181
 
182
  def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
183
  # Sets chunk feed-forward
 
477
  processor = AttnProcessor2_0()
478
  self.set_processor(processor)
479
 
480
+ def set_use_tpu_flash_attention(self, device_type):
481
+ r"""
482
+ Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
483
+ """
484
+ if device_type == "xla":
485
+ self.use_tpu_flash_attention = True
486
+
487
  def set_processor(self, processor: "AttnProcessor") -> None:
488
  r"""
489
  Set the attention processor to use.
xora/models/transformers/transformer3d.py CHANGED
@@ -153,11 +153,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
153
  """
154
  logger.info(" ENABLE TPU FLASH ATTENTION -> TRUE")
155
  # if using TPU -> configure components to use TPU flash attention
156
- if dist_util.acceleration_type() == dist_util.AccelerationType.TPU:
157
  self.use_tpu_flash_attention = True
158
  # push config down to the attention modules
159
  for block in self.transformer_blocks:
160
- block.set_use_tpu_flash_attention()
161
 
162
  def initialize(self, embedding_std: float, mode: Literal["xora", "pixart"]):
163
  def _basic_init(module):
 
153
  """
154
  logger.info(" ENABLE TPU FLASH ATTENTION -> TRUE")
155
  # if using TPU -> configure components to use TPU flash attention
156
+ if self.device.type == "xla":
157
  self.use_tpu_flash_attention = True
158
  # push config down to the attention modules
159
  for block in self.transformer_blocks:
160
+ block.set_use_tpu_flash_attention(self.device.type)
161
 
162
  def initialize(self, embedding_std: float, mode: Literal["xora", "pixart"]):
163
  def _basic_init(module):
xora/utils/dist_util.py CHANGED
@@ -1,11 +1,5 @@
1
  from enum import Enum
2
 
3
- class AccelerationType(Enum):
4
- CPU = "cpu"
5
- GPU = "gpu"
6
- TPU = "tpu"
7
- MPS = "mps"
8
-
9
  def execute_graph() -> None:
10
  if _acceleration_type == AccelerationType.TPU:
11
  xm.mark_step()
 
1
  from enum import Enum
2
 
 
 
 
 
 
 
3
  def execute_graph() -> None:
4
  if _acceleration_type == AccelerationType.TPU:
5
  xm.mark_step()