Spaces:
Build error
Build error
import numpy as np | |
import gradio as gr | |
import cv2 | |
import os | |
import shutil | |
import re | |
import torch | |
import csv | |
import time | |
from src.sts.demo.sts import handle_sts | |
from src.ir.ir import handle_ir | |
from src.ir.src.models.tc_classifier import TCClassifier | |
from src.tracker.signboard_track import SignboardTracker | |
from omegaconf import DictConfig | |
from hydra import compose, initialize | |
signboardTracker = SignboardTracker() | |
tracking_result_dir = "" | |
output_track_format = "mp4v" | |
output_track = "" | |
output_sts = "" | |
video_dir = "" | |
vd_dir = "" | |
labeling_dir = "" | |
frame_out = {} | |
rs = {} | |
results = [] | |
# with initialize(version_base=None, config_path="src/ir/configs", job_name="ir"): | |
# config = compose(config_name="test") | |
# config: DictConfig | |
# model_ir = TCClassifier(config.model.train.model_name, | |
# config.model.train.n_classes, | |
# config.model.train.lr, | |
# config.model.train.scheduler_type, | |
# config.model.train.max_steps, | |
# config.model.train.weight_decay, | |
# config.model.train.classifier_dropout, | |
# config.model.train.mixout, | |
# config.model.train.freeze_encoder) | |
# model_ir = model_ir.load_from_checkpoint(checkpoint_path=config.ckpt_path, map_location=torch.device("cuda")) | |
def create_dir(list_dir_path): | |
for dir_path in list_dir_path: | |
if not os.path.isdir(dir_path): | |
os.makedirs(dir_path) | |
def get_meta_from_video(input_video): | |
if input_video is not None: | |
video_name = os.path.basename(input_video).split('.')[0] | |
global video_dir | |
video_dir = os.path.join("static/videos/", f"{video_name}") | |
global vd_dir | |
vd_dir = os.path.join(video_dir, os.path.basename(input_video)) | |
global output_track | |
output_track = os.path.join(video_dir,"original") | |
global tracking_result_dir | |
tracking_result_dir = os.path.join(video_dir,"track/cropped") | |
global output_sts | |
output_sts = os.path.join(video_dir,"track/sts") | |
global labeling_dir | |
labeling_dir = os.path.join(video_dir,"track/labeling") | |
if os.path.isdir(video_dir): | |
return None | |
else: | |
create_dir([output_track, video_dir, os.path.join(video_dir, "track/segment"), output_sts, tracking_result_dir, labeling_dir]) | |
# initialize the video stream | |
video_cap = cv2.VideoCapture(input_video) | |
# grab the width, height, and fps of the frames in the video stream. | |
frame_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(video_cap.get(cv2.CAP_PROP_FPS)) | |
#tổng Fps | |
# total_frames = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# print(total_frames) | |
# # Tính tổng số giây trong video | |
# total_seconds = total_frames / video_cap.get(cv2.CAP_PROP_FPS) | |
# print(total_seconds) | |
# initialize the FourCC and a video writer object | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
output = cv2.VideoWriter(vd_dir, fourcc, fps, (frame_width, frame_height)) | |
while True: | |
success, frame = video_cap.read() | |
# write the frame to the output file | |
if success == True: | |
output.write(frame) | |
else: | |
break | |
# print(fps) | |
# return gr.Slider(1, fps, value=4, label="FPS",step=1, info="Choose between 1 and {fps}", interactive=True) | |
return gr.Textbox(value=fps) | |
def get_signboard(evt: gr.SelectData): | |
name_fr = int(evt.index) + 1 | |
ids_dir = tracking_result_dir | |
all_ids = os.listdir(ids_dir) | |
gallery=[] | |
for i in all_ids: | |
fr_id = str(name_fr) | |
al = re.search("[\d]*_"+fr_id+".png", i) | |
if al: | |
id_dir = os.path.join(ids_dir, i) | |
gallery.append(id_dir) | |
gallery = sorted(gallery) | |
return gallery, name_fr | |
def tracking(fps_target): | |
start = time.time() | |
fps_target = int(fps_target) | |
global results | |
results = signboardTracker.inference_signboard(fps_target, vd_dir, output_track, output_track_format, tracking_result_dir)[0] | |
# print("result", results) | |
fd = [] | |
global frame_out | |
list_id = [] | |
with open(os.path.join(video_dir, "track/label.csv"), 'w', newline='') as file: | |
writer = csv.writer(file) | |
writer.writerow(["Signboard", "Frame", "Text"]) | |
for frame, values in results.items(): | |
frame_dir = os.path.join(output_track, f"{frame}.jpg") | |
# segment = os.path.join(video_dir,"segment/" + f"{frame}.jpg") | |
list_boxs = [] | |
full = [] | |
list_id_tmp = [] | |
# print("values", values) | |
for value in values: | |
list_boxs.append(value['box']) | |
list_id_tmp.append(value['id']) | |
_, dict_rec_sign_out = handle_sts(frame_dir, labeling_dir, list_boxs, list_id_tmp) | |
# predicted = handle_ir(frame_dir, dict_rec_sign_out, os.path.join(video_dir, "ir")) | |
# print(predicted) | |
# fd.append(frame_dir) | |
# frame_out[frame] = full | |
list_id.extend(list_id_tmp) | |
list_id = list(set(list_id)) | |
# print(list_id) | |
print(time.time()-start) | |
return gr.Dropdown(label="signboard",choices=list_id, interactive=True) | |
def get_select_index(img_id, evt: gr.SelectData): | |
ids_dir = tracking_result_dir | |
# print(ids_dir) | |
all_ids = os.listdir(ids_dir) | |
gallery = [] | |
for i in all_ids: | |
fr_id = str(img_id) | |
al = re.search("[\d]*_"+fr_id+".png", i) | |
if al: | |
id_dir = os.path.join(ids_dir, i) | |
gallery.append(id_dir) | |
gallery = sorted(gallery) | |
gallery_id=[] | |
id_name = gallery[evt.index] | |
id = os.path.basename(id_name).split(".")[0].split("_")[0] | |
for i in all_ids: | |
al = re.search("^" +id + "_[\d]*.png", i) | |
if al: | |
id_dir = os.path.join(ids_dir, i) | |
gallery_id.append(id_dir) | |
gallery_id = sorted(gallery_id) | |
return gallery_id | |
id_glb = None | |
def select_id(evt: gr.SelectData): | |
choice=[] | |
global id_glb | |
id_glb = evt.value | |
for key, values in results.items(): | |
for value in values: | |
if value['id'] == evt.value: | |
choice.append(int(key)) | |
return gr.Dropdown(label="frame", choices=choice, interactive=True) | |
import pandas as pd | |
frame_glb = None | |
def select_frame(evt: gr.SelectData): | |
full_img = os.path.join(output_track, str(evt.value) + ".jpg") | |
crop_img = os.path.join(tracking_result_dir, str(id_glb) + "_" + str(evt.value) + ".png") | |
global frame_glb | |
frame_glb = evt.value | |
data = pd.read_csv(os.path.join(labeling_dir, str(id_glb) + "_" + str(frame_glb) + '.csv'), header=0) | |
return full_img, crop_img, data | |
def get_data(dtfr): | |
print(dtfr) | |
# df = pd.read_csv(os.path.join(video_dir, "track/label.csv")) | |
# for i, row in df.iterrows(): | |
# if str(row["Signboard"]) == str(id_tmp) and str(row["Frame"]) == str(frame_tmp): | |
# # print(row["Text"]) | |
# df_new = df.replace(str(row["Text"]), str(labeling)) | |
# print(df_new) | |
dtfr.to_csv(os.path.join(labeling_dir, str(id_glb) + "_" + str(frame_glb) + '.csv'), index=False, header=True) | |
return | |
def seg_track_app(): | |
########################################################## | |
###################### Front-end ######################## | |
########################################################## | |
with gr.Blocks(css=".gradio-container {background-color: white}") as demo: | |
gr.Markdown( | |
''' | |
<div style="text-align:center;"> | |
<span style="font-size:3em; font-weight:bold;">POI Engineeing</span> | |
</div> | |
''' | |
) | |
with gr.Row(): | |
# video input | |
with gr.Column(scale=0.2): | |
tab_video_input = gr.Row(label="Video type input") | |
with tab_video_input: | |
input_video = gr.Video(label='Input video') | |
tab_everything = gr.Row(label="Tracking") | |
with tab_everything: | |
with gr.Row(): | |
seg_signboard = gr.Button(value="Tracking", interactive=True) | |
all_info = gr.Row(label="Information about video") | |
with all_info: | |
with gr.Row(): | |
text = gr.Textbox(label="Fps") | |
check_fps = gr.Textbox(label="Choose fps for output", interactive=True) | |
with gr.Column(scale=1): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
id_drop = gr.Dropdown(label="Signboards",choices=[]) | |
with gr.Column(scale=1): | |
fr_drop = gr.Dropdown(label="Frames",choices=[]) | |
full_img = gr.Image(label="Full Image") | |
with gr.Column(scale=1): | |
crop_img = gr.Image(label="Cropped Image") | |
with gr.Row(): | |
dtfr = gr.Dataframe(headers=["Tag", "Value"], datatype=["str", "str"], interactive=True) | |
with gr.Row(): | |
submit = gr.Button(value="Submit", interactive=True) | |
########################################################## | |
###################### back-end ######################### | |
########################################################## | |
input_video.change( | |
fn=get_meta_from_video, | |
inputs=input_video, | |
outputs=text | |
) | |
seg_signboard.click( | |
fn=tracking, | |
inputs=check_fps, | |
outputs=id_drop | |
) | |
id_drop.select(select_id, None, fr_drop) | |
fr_drop.select(select_frame, None, [full_img,crop_img, dtfr]) | |
submit.click(get_data, dtfr, None) | |
demo.queue(concurrency_count=1) | |
demo.launch(debug=True, enable_queue=True, share=True) | |
if __name__ == "__main__": | |
seg_track_app() | |