Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
from model.model import PointSemSeg, Find3D | |
import numpy as np | |
import random | |
from transformers import AutoTokenizer, AutoModel | |
DEVICE = "cuda:0" | |
#if torch.cuda.is_available(): | |
#DEVICE = "cuda:0" | |
def get_seg_color(labels): | |
part_num = labels.max() | |
cmap_matrix = torch.tensor([[1,1,1], [1,0,0], [0,1,0], [0,0,1], [1,1,0], [1,0,1], | |
[0,1,1], [0.5,0.5,0.5], [0.5,0.5,0], [0.5,0,0.5],[0,0.5,0.5], | |
[0.1,0.2,0.3],[0.2,0.5,0.3], [0.6,0.3,0.2], [0.5,0.3,0.5], | |
[0.6,0.7,0.2],[0.5,0.8,0.3]])[:part_num+1,:] | |
onehot = F.one_hot(labels.long(), num_classes=part_num+1) * 1.0 # n_pts, part_num+1, each row 00.010.0, first place is unlabeled (0 originally) | |
pts_rgb = torch.matmul(onehot, cmap_matrix) | |
return pts_rgb | |
def get_legend(parts): | |
colors = ["white", "red", "green", "blue", "yellow", "magenta", "cyan","grey", "olive", | |
"purple", "teal", "navy", "darkgreen", "brown", "pinkpurple", "yellowgreen", "limegreen"] | |
legends = [] | |
i = 1 | |
for part in parts: | |
cur_color = colors[i] | |
legends.append(f"{cur_color}:{part}") | |
i += 1 | |
legend = " ".join(legends) | |
return legend | |
def load_model(): | |
model = Find3D.from_pretrained("ziqima/find3d-checkpt0", dim_output=768) | |
#model.load_state_dict(torch.load("find3d_checkpoint.pth")["model_state_dict"]) | |
model.eval() | |
model = model.to(DEVICE) | |
return model | |
def fnv_hash_vec(arr): | |
""" | |
FNV64-1A | |
""" | |
assert arr.ndim == 2 | |
# Floor first for negative coordinates | |
arr = arr.copy() | |
arr = arr.astype(np.uint64, copy=False) | |
hashed_arr = np.uint64(14695981039346656037) * np.ones( | |
arr.shape[0], dtype=np.uint64 | |
) | |
for j in range(arr.shape[1]): | |
hashed_arr *= np.uint64(1099511628211) | |
hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j]) | |
return hashed_arr | |
def grid_sample_numpy(xyz, rgb, normal, grid_size): # this should hopefully be 5000 or close | |
xyz = xyz.cpu().numpy() | |
rgb = rgb.cpu().numpy() | |
normal = normal.cpu().numpy() | |
scaled_coord = xyz / np.array(grid_size) | |
grid_coord = np.floor(scaled_coord).astype(int) | |
min_coord = grid_coord.min(0) | |
grid_coord -= min_coord | |
scaled_coord -= min_coord | |
min_coord = min_coord * np.array(grid_size) | |
key = fnv_hash_vec(grid_coord) | |
idx_sort = np.argsort(key) | |
key_sort = key[idx_sort] | |
_, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True) | |
idx_select = ( | |
np.cumsum(np.insert(count, 0, 0)[0:-1]) | |
+ np.random.randint(0, count.max(), count.size) % count | |
) | |
idx_unique = idx_sort[idx_select] | |
grid_coord = grid_coord[idx_unique] | |
xyz = torch.tensor(xyz[idx_unique]).to(DEVICE) | |
rgb = torch.tensor(rgb[idx_unique]).to(DEVICE) | |
normal = torch.tensor(normal[idx_unique]).to(DEVICE) | |
grid_coord = torch.tensor(grid_coord).to(DEVICE) | |
return xyz, rgb, normal, grid_coord | |
def encode_text(texts): | |
siglip = AutoModel.from_pretrained("google/siglip-base-patch16-224") # dim 768 #"google/siglip-so400m-patch14-384") | |
tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")#"google/siglip-so400m-patch14-384") | |
inputs = tokenizer(texts, padding="max_length", return_tensors="pt") | |
for key in inputs: | |
inputs[key] = inputs[key].to(DEVICE) | |
with torch.no_grad(): | |
text_feat = siglip.to(DEVICE).get_text_features(**inputs) | |
text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-12) | |
return text_feat | |
def preprocess_pcd(xyz, rgb, normal): # rgb should be 0-1 | |
assert rgb.max() <=1 | |
# normalize | |
# this is the same preprocessing I do before training | |
center = xyz.mean(0) | |
scale = max((xyz - center).abs().max(0)[0]) | |
xyz -= center | |
xyz *= (0.75 / float(scale)) # put in 0.75-size box | |
# axis swap | |
xyz = torch.cat([-xyz[:,0].reshape(-1,1), xyz[:,2].reshape(-1,1), xyz[:,1].reshape(-1,1)], dim=1) | |
# center shift | |
xyz_min = xyz.min(dim=0)[0] | |
xyz_max = xyz.max(dim=0)[0] | |
xyz_max[2] = 0 | |
shift = (xyz_min+xyz_max)/2 | |
xyz -= shift | |
# subsample/upsample to 5000 pts for grid sampling | |
if xyz.shape[0] != 5000: | |
random_indices = torch.randint(0, xyz.shape[0], (5000,)) | |
pts_xyz_subsampled = xyz[random_indices] | |
pts_rgb_subsampled = rgb[random_indices] | |
normal_subsampled = normal[random_indices] | |
else: | |
pts_xyz_subsampled = xyz | |
pts_rgb_subsampled = rgb | |
normal_subsampled = normal | |
# grid sampling | |
pts_xyz_gridsampled, pts_rgb_gridsampled, normal_gridsampled, grid_coord = grid_sample_numpy(pts_xyz_subsampled, pts_rgb_subsampled, normal_subsampled, 0.02) | |
# another center shift, z=false | |
xyz_min = pts_xyz_gridsampled.min(dim=0)[0] | |
xyz_min[2] = 0 | |
xyz_max = pts_xyz_gridsampled.max(dim=0)[0] | |
xyz_max[2] = 0 | |
shift = (xyz_min+xyz_max)/2 | |
pts_xyz_gridsampled -= shift | |
xyz -= shift | |
# normalize color | |
pts_rgb_gridsampled = pts_rgb_gridsampled / 0.5 - 1 | |
# combine color and normal as feat | |
feat = torch.cat([pts_rgb_gridsampled, normal_gridsampled], dim=1) | |
data_dict = {} | |
data_dict["coord"] = pts_xyz_gridsampled | |
data_dict["feat"] = feat | |
data_dict["grid_coord"] = grid_coord | |
data_dict["xyz_full"] = xyz | |
data_dict["offset"] = torch.tensor([pts_xyz_gridsampled.shape[0]]) | |
return data_dict |