jadechoghari
commited on
Update patch.py
Browse files
patch.py
CHANGED
@@ -7,7 +7,7 @@ from einops import rearrange
|
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
|
10 |
-
from . import
|
11 |
from .utils import isinstance_str, init_generator, join_frame, split_frame, func_warper, join_warper, split_warper
|
12 |
|
13 |
|
@@ -42,7 +42,7 @@ def compute_merge(module: torch.nn.Module, x: torch.Tensor, tome_info: Dict[str,
|
|
42 |
|
43 |
# Recursive merge multi-frame tokens into one set. Such as 4->1 for 4 frames and 8->2->1 for 8 frames when target stride is 4.
|
44 |
while curF > 1:
|
45 |
-
m, u, ret_dict =
|
46 |
local_tokens, curF, args["local_merge_ratio"], unm, generator, args["target_stride"], args["align_batch"])
|
47 |
unm += ret_dict["unm_num"]
|
48 |
m_ls.append(m)
|
@@ -70,7 +70,7 @@ def compute_merge(module: torch.nn.Module, x: torch.Tensor, tome_info: Dict[str,
|
|
70 |
[module.global_tokens.to(local_tokens), local_tokens], dim=1)
|
71 |
local_chunk = 1
|
72 |
|
73 |
-
m, u, _ =
|
74 |
tokens, src_len, args["global_merge_ratio"], args["align_batch"], unmerge_chunk=local_chunk)
|
75 |
merged_tokens = m(tokens)
|
76 |
m_ls.append(m)
|
@@ -84,7 +84,7 @@ def compute_merge(module: torch.nn.Module, x: torch.Tensor, tome_info: Dict[str,
|
|
84 |
m = func_warper(m_ls)
|
85 |
u = func_warper(u_ls[::-1])
|
86 |
else:
|
87 |
-
m, u = (
|
88 |
merged_tokens = x
|
89 |
|
90 |
# Return merge op, unmerge op, and merged tokens.
|
|
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
|
10 |
+
from .merge import bipartite_soft_matching_randframe, bipartite_soft_matching_2s, do_nothing
|
11 |
from .utils import isinstance_str, init_generator, join_frame, split_frame, func_warper, join_warper, split_warper
|
12 |
|
13 |
|
|
|
42 |
|
43 |
# Recursive merge multi-frame tokens into one set. Such as 4->1 for 4 frames and 8->2->1 for 8 frames when target stride is 4.
|
44 |
while curF > 1:
|
45 |
+
m, u, ret_dict = bipartite_soft_matching_randframe(
|
46 |
local_tokens, curF, args["local_merge_ratio"], unm, generator, args["target_stride"], args["align_batch"])
|
47 |
unm += ret_dict["unm_num"]
|
48 |
m_ls.append(m)
|
|
|
70 |
[module.global_tokens.to(local_tokens), local_tokens], dim=1)
|
71 |
local_chunk = 1
|
72 |
|
73 |
+
m, u, _ = bipartite_soft_matching_2s(
|
74 |
tokens, src_len, args["global_merge_ratio"], args["align_batch"], unmerge_chunk=local_chunk)
|
75 |
merged_tokens = m(tokens)
|
76 |
m_ls.append(m)
|
|
|
84 |
m = func_warper(m_ls)
|
85 |
u = func_warper(u_ls[::-1])
|
86 |
else:
|
87 |
+
m, u = (do_nothing, do_nothing)
|
88 |
merged_tokens = x
|
89 |
|
90 |
# Return merge op, unmerge op, and merged tokens.
|