jadechoghari commited on
Commit
7682d1f
·
verified ·
1 Parent(s): 4bf2a26

Update patch.py

Browse files
Files changed (1) hide show
  1. patch.py +4 -4
patch.py CHANGED
@@ -7,7 +7,7 @@ from einops import rearrange
7
  import torch
8
  import torch.nn.functional as F
9
 
10
- from . import merge
11
  from .utils import isinstance_str, init_generator, join_frame, split_frame, func_warper, join_warper, split_warper
12
 
13
 
@@ -42,7 +42,7 @@ def compute_merge(module: torch.nn.Module, x: torch.Tensor, tome_info: Dict[str,
42
 
43
  # Recursive merge multi-frame tokens into one set. Such as 4->1 for 4 frames and 8->2->1 for 8 frames when target stride is 4.
44
  while curF > 1:
45
- m, u, ret_dict = merge.bipartite_soft_matching_randframe(
46
  local_tokens, curF, args["local_merge_ratio"], unm, generator, args["target_stride"], args["align_batch"])
47
  unm += ret_dict["unm_num"]
48
  m_ls.append(m)
@@ -70,7 +70,7 @@ def compute_merge(module: torch.nn.Module, x: torch.Tensor, tome_info: Dict[str,
70
  [module.global_tokens.to(local_tokens), local_tokens], dim=1)
71
  local_chunk = 1
72
 
73
- m, u, _ = merge.bipartite_soft_matching_2s(
74
  tokens, src_len, args["global_merge_ratio"], args["align_batch"], unmerge_chunk=local_chunk)
75
  merged_tokens = m(tokens)
76
  m_ls.append(m)
@@ -84,7 +84,7 @@ def compute_merge(module: torch.nn.Module, x: torch.Tensor, tome_info: Dict[str,
84
  m = func_warper(m_ls)
85
  u = func_warper(u_ls[::-1])
86
  else:
87
- m, u = (merge.do_nothing, merge.do_nothing)
88
  merged_tokens = x
89
 
90
  # Return merge op, unmerge op, and merged tokens.
 
7
  import torch
8
  import torch.nn.functional as F
9
 
10
+ from .merge import bipartite_soft_matching_randframe, bipartite_soft_matching_2s, do_nothing
11
  from .utils import isinstance_str, init_generator, join_frame, split_frame, func_warper, join_warper, split_warper
12
 
13
 
 
42
 
43
  # Recursive merge multi-frame tokens into one set. Such as 4->1 for 4 frames and 8->2->1 for 8 frames when target stride is 4.
44
  while curF > 1:
45
+ m, u, ret_dict = bipartite_soft_matching_randframe(
46
  local_tokens, curF, args["local_merge_ratio"], unm, generator, args["target_stride"], args["align_batch"])
47
  unm += ret_dict["unm_num"]
48
  m_ls.append(m)
 
70
  [module.global_tokens.to(local_tokens), local_tokens], dim=1)
71
  local_chunk = 1
72
 
73
+ m, u, _ = bipartite_soft_matching_2s(
74
  tokens, src_len, args["global_merge_ratio"], args["align_batch"], unmerge_chunk=local_chunk)
75
  merged_tokens = m(tokens)
76
  m_ls.append(m)
 
84
  m = func_warper(m_ls)
85
  u = func_warper(u_ls[::-1])
86
  else:
87
+ m, u = (do_nothing, do_nothing)
88
  merged_tokens = x
89
 
90
  # Return merge op, unmerge op, and merged tokens.