File size: 4,521 Bytes
43a369c
 
 
 
 
 
 
 
 
 
b03b419
738bdfa
 
43a369c
 
adf34e4
43a369c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74503df
43a369c
 
 
 
 
 
 
 
 
 
 
74503df
6965bae
43a369c
 
 
6965bae
 
 
0366edb
43a369c
6965bae
0366edb
6965bae
 
 
 
 
74503df
 
 
 
 
 
 
 
 
6965bae
 
 
 
 
43a369c
 
 
 
 
 
 
 
 
 
6965bae
0366edb
6965bae
0366edb
 
6965bae
0366edb
 
43a369c
 
 
738bdfa
0f72f6a
 
738bdfa
 
43a369c
 
74503df
43a369c
 
 
b03b419
43a369c
b03b419
6965bae
 
43a369c
 
 
 
 
 
6965bae
 
43a369c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
from paths import *
import numpy as np
from vision_tower import DINOv2_MLP
from transformers import AutoImageProcessor
import torch
import os
import matplotlib.pyplot as plt
import io
from PIL import Image

import torch.nn.functional as F
from utils import *

from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True)
print(ckpt_path)

save_path = './'
device = 'cpu'
dino = DINOv2_MLP(
                    dino_mode   = 'large',
                    in_dim      = 1024,
                    out_dim     = 360+180+60+2,
                    evaluate    = True,
                    mask_dino   = False,
                    frozen_back = False
                ).to(device)

dino.eval()
print('model create')
dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
print('weight loaded')
val_preprocess   = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')


def get_3angle(image):
    
    # image = Image.open(image_path).convert('RGB')
    image_inputs = val_preprocess(images = image)
    image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
    with torch.no_grad():
        dino_pred = dino(image_inputs)

    gaus_ax_pred   = torch.argmax(dino_pred[:, 0:360], dim=-1)
    gaus_pl_pred   = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
    gaus_ro_pred   = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
    confidence     = F.softmax(dino_pred[:, -2:], dim=-1)[0][0]
    angles = torch.zeros(4)
    angles[0]  = gaus_ax_pred
    angles[1]  = gaus_pl_pred - 90
    angles[2]  = gaus_ro_pred - 30
    angles[3]  = confidence
    return angles

def get_3angle_infer_aug(origin_img, rm_bkg_img):
    
    # image = Image.open(image_path).convert('RGB')
    image = get_crop_images(origin_img, num=3) + get_crop_images(rm_bkg_img, num=3)
    image_inputs = val_preprocess(images = image)
    image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
    with torch.no_grad():
        dino_pred = dino(image_inputs)

    gaus_ax_pred   = torch.argmax(dino_pred[:, 0:360], dim=-1).to(torch.float32)
    gaus_pl_pred   = torch.argmax(dino_pred[:, 360:360+180], dim=-1).to(torch.float32)
    gaus_ro_pred   = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1).to(torch.float32)
    
    gaus_ax_pred   = remove_outliers_and_average_circular(gaus_ax_pred)
    gaus_pl_pred   = remove_outliers_and_average(gaus_pl_pred)
    gaus_ro_pred   = remove_outliers_and_average(gaus_ro_pred)
    
    confidence     = torch.mean(F.softmax(dino_pred[:, -2:], dim=-1), dim=0)[0]
    angles = torch.zeros(4)
    angles[0]  = gaus_ax_pred
    angles[1]  = gaus_pl_pred - 90
    angles[2]  = gaus_ro_pred - 30
    angles[3]  = confidence
    return angles


def figure_to_img(fig):
    with io.BytesIO() as buf:
        fig.savefig(buf, format='JPG', bbox_inches='tight')
        buf.seek(0)
        image = Image.open(buf).copy()
    return image

def infer_func(img, do_rm_bkg, do_infer_aug):
    origin_img = Image.fromarray(img)
    if do_infer_aug:
        rm_bkg_img = background_preprocess(origin_img, True)
        angles = get_3angle_infer_aug(origin_img, rm_bkg_img)
    else:
        rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
        angles = get_3angle(rm_bkg_img)
    
    phi   = np.radians(angles[0])
    theta = np.radians(angles[1])
    gamma = angles[2]
    
    
    render_axis = render_3D_axis(phi, theta, gamma)
    res_img = overlay_images_with_scaling(render_axis, rm_bkg_img)
    
    # axis_model = "axis.obj"
    return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]

server = gr.Interface(
    flagging_mode='never',
    fn=infer_func, 
    inputs=[
        gr.Image(height=512, width=512, label="upload your image"),
        gr.Checkbox(label="Remove Background", value=True),
        gr.Checkbox(label="Inference time augmentation", value=False)
    ], 
    outputs=[
        gr.Image(height=512, width=512, label="result image"),
        # gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0],  label="3D Model"),
        gr.Textbox(lines=1, label='Azimuth(0~360°)'),
        gr.Textbox(lines=1, label='Polar(-90~90°)'),
        gr.Textbox(lines=1, label='Rotation(-90~90°)'),
        gr.Textbox(lines=1, label='Confidence(0~1)')
    ]
)

server.launch()