Spaces:
Running
Running
File size: 4,521 Bytes
43a369c b03b419 738bdfa 43a369c adf34e4 43a369c 74503df 43a369c 74503df 6965bae 43a369c 6965bae 0366edb 43a369c 6965bae 0366edb 6965bae 74503df 6965bae 43a369c 6965bae 0366edb 6965bae 0366edb 6965bae 0366edb 43a369c 738bdfa 0f72f6a 738bdfa 43a369c 74503df 43a369c b03b419 43a369c b03b419 6965bae 43a369c 6965bae 43a369c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
from paths import *
import numpy as np
from vision_tower import DINOv2_MLP
from transformers import AutoImageProcessor
import torch
import os
import matplotlib.pyplot as plt
import io
from PIL import Image
import torch.nn.functional as F
from utils import *
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True)
print(ckpt_path)
save_path = './'
device = 'cpu'
dino = DINOv2_MLP(
dino_mode = 'large',
in_dim = 1024,
out_dim = 360+180+60+2,
evaluate = True,
mask_dino = False,
frozen_back = False
).to(device)
dino.eval()
print('model create')
dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
print('weight loaded')
val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
def get_3angle(image):
# image = Image.open(image_path).convert('RGB')
image_inputs = val_preprocess(images = image)
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
with torch.no_grad():
dino_pred = dino(image_inputs)
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0][0]
angles = torch.zeros(4)
angles[0] = gaus_ax_pred
angles[1] = gaus_pl_pred - 90
angles[2] = gaus_ro_pred - 30
angles[3] = confidence
return angles
def get_3angle_infer_aug(origin_img, rm_bkg_img):
# image = Image.open(image_path).convert('RGB')
image = get_crop_images(origin_img, num=3) + get_crop_images(rm_bkg_img, num=3)
image_inputs = val_preprocess(images = image)
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
with torch.no_grad():
dino_pred = dino(image_inputs)
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1).to(torch.float32)
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1).to(torch.float32)
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1).to(torch.float32)
gaus_ax_pred = remove_outliers_and_average_circular(gaus_ax_pred)
gaus_pl_pred = remove_outliers_and_average(gaus_pl_pred)
gaus_ro_pred = remove_outliers_and_average(gaus_ro_pred)
confidence = torch.mean(F.softmax(dino_pred[:, -2:], dim=-1), dim=0)[0]
angles = torch.zeros(4)
angles[0] = gaus_ax_pred
angles[1] = gaus_pl_pred - 90
angles[2] = gaus_ro_pred - 30
angles[3] = confidence
return angles
def figure_to_img(fig):
with io.BytesIO() as buf:
fig.savefig(buf, format='JPG', bbox_inches='tight')
buf.seek(0)
image = Image.open(buf).copy()
return image
def infer_func(img, do_rm_bkg, do_infer_aug):
origin_img = Image.fromarray(img)
if do_infer_aug:
rm_bkg_img = background_preprocess(origin_img, True)
angles = get_3angle_infer_aug(origin_img, rm_bkg_img)
else:
rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
angles = get_3angle(rm_bkg_img)
phi = np.radians(angles[0])
theta = np.radians(angles[1])
gamma = angles[2]
render_axis = render_3D_axis(phi, theta, gamma)
res_img = overlay_images_with_scaling(render_axis, rm_bkg_img)
# axis_model = "axis.obj"
return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]
server = gr.Interface(
flagging_mode='never',
fn=infer_func,
inputs=[
gr.Image(height=512, width=512, label="upload your image"),
gr.Checkbox(label="Remove Background", value=True),
gr.Checkbox(label="Inference time augmentation", value=False)
],
outputs=[
gr.Image(height=512, width=512, label="result image"),
# gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
gr.Textbox(lines=1, label='Azimuth(0~360°)'),
gr.Textbox(lines=1, label='Polar(-90~90°)'),
gr.Textbox(lines=1, label='Rotation(-90~90°)'),
gr.Textbox(lines=1, label='Confidence(0~1)')
]
)
server.launch()
|