FSFM-3C
init_test
2fa1887
raw
history blame
13.6 kB
# -*- coding: utf-8 -*-
# Author: Gaojian Wang@ZJUICSR
# --------------------------------------------------------
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
# You can find the license in the LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
import math
import sys
from typing import Iterable, Optional
import numpy as np
import torch
from timm.data import Mixup
from timm.utils import accuracy
import util.misc as misc
import util.lr_sched as lr_sched
from util.metrics import *
import torch.nn.functional as F
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
mixup_fn: Optional[Mixup] = None, log_writer=None,
args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
accum_iter = args.accum_iter
optimizer.zero_grad()
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
with torch.cuda.amp.autocast():
# outputs = model(samples)
outputs = model(samples).to(device, non_blocking=True) # modified
loss = criterion(outputs, targets)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=False,
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', max_lr, epoch_1000x)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
for batch in metric_logger.log_every(data_loader, 10, header):
images = batch[0]
target = batch[-1]
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
with torch.cuda.amp.autocast():
# output = model(images)
output = model(images).to(device, non_blocking=True) # modified
loss = criterion(output, target)
# acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc = float(accuracy(output, target, topk=(1,))[0])
preds = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
trues = (target.detach().cpu().numpy())
auc_score = roc_auc_score(trues, preds) * 100.
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
# metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
# metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['acc'].update(acc, n=batch_size)
metric_logger.meters['auc'].update(auc_score, n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
# print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
# .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
print('* Acc {acc.global_avg:.3f} Auc {auc.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(acc=metric_logger.acc, auc=metric_logger.auc, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def test_ori(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
labels = np.array([])
preds = np.array([])
for batch in metric_logger.log_every(data_loader, 10, header):
images = batch[0]
target = batch[-1]
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
with torch.cuda.amp.autocast():
# output = model(images)
output = model(images).to(device, non_blocking=True) # modified
loss = criterion(output, target)
# acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc = float(accuracy(output, target, topk=(1,))[0])
pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
preds = np.append(preds, pred)
label = (target.detach().cpu().numpy())
labels = np.append(labels, label)
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
# metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
# metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['acc'].update(acc, n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
# print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
# .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
auc_score = roc_auc_score(labels, preds) * 100.
metric_logger.meters['auc'].update(auc_score)
print('* Acc {acc.global_avg:.3f} Auc {auc.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(acc=metric_logger.acc, auc=metric_logger.auc, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def test(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
frame_labels = np.array([]) # int label
frame_preds = np.array([]) # pred logit
frame_y_preds = np.array([]) # pred int
# for batch in metric_logger.log_every(data_loader, print_freq=len(data_loader), header=header):
for batch in data_loader:
images = batch[0] # torch.Size([BS, C, H, W])
target = batch[1] # torch.Size([BS])
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
with torch.cuda.amp.autocast():
# output = model(images)
output = model(images).to(device, non_blocking=True) # modified
loss = criterion(output, target)
frame_pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
frame_preds = np.append(frame_preds, frame_pred)
frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
frame_y_preds = np.append(frame_y_preds, frame_y_pred)
frame_label = (target.detach().cpu().numpy())
frame_labels = np.append(frame_labels, frame_label)
metric_logger.update(loss=loss.item())
# gather the stats from all processes
metric_logger.synchronize_between_processes()
metric_logger.meters['frame_acc'].update(frame_level_acc(frame_labels, frame_y_preds))
metric_logger.meters['frame_balanced_acc'].update(frame_level_balanced_acc(frame_labels, frame_y_preds))
metric_logger.meters['frame_auc'].update(frame_level_auc(frame_labels, frame_preds))
metric_logger.meters['frame_eer'].update(frame_level_eer(frame_labels, frame_preds))
print('*[------FRAME-LEVEL------] \n'
'Acc {frame_acc.global_avg:.3f} Balanced_Acc {frame_balanced_acc.global_avg:.3f} '
'Auc {frame_auc.global_avg:.3f} EER {frame_eer.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(frame_acc=metric_logger.frame_acc, frame_balanced_acc=metric_logger.frame_balanced_acc,
frame_auc=metric_logger.frame_auc, frame_eer=metric_logger.frame_eer, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def test_all(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
frame_labels = np.array([]) # int label
frame_preds = np.array([]) # pred logit
frame_y_preds = np.array([]) # pred int
video_names_list = list()
# for batch in metric_logger.log_every(data_loader, print_freq=len(data_loader), header=header):
for batch in data_loader:
images = batch[0] # torch.Size([BS, C, H, W])
target = batch[1] # torch.Size([BS])
video_name = batch[-1] # list[BS]
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
# with torch.cuda.amp.autocast():
# output = model(images)
output = model(images).to(device, non_blocking=True) # modified
loss = criterion(output, target)
frame_pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
frame_preds = np.append(frame_preds, frame_pred)
frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
frame_y_preds = np.append(frame_y_preds, frame_y_pred)
frame_label = (target.detach().cpu().numpy())
frame_labels = np.append(frame_labels, frame_label)
video_names_list.extend(list(video_name))
metric_logger.update(loss=loss.item())
# gather the stats from all processes
# metric_logger.synchronize_between_processes()
# metric_logger.meters['frame_acc'].update(frame_level_acc(frame_labels, frame_y_preds))
# metric_logger.meters['frame_balanced_acc'].update(frame_level_balanced_acc(frame_labels, frame_y_preds))
# metric_logger.meters['frame_auc'].update(frame_level_auc(frame_labels, frame_preds))
# metric_logger.meters['frame_eer'].update(frame_level_eer(frame_labels, frame_preds))
# print('*[------FRAME-LEVEL------] \n'
# 'Acc {frame_acc.global_avg:.3f} Balanced_Acc {frame_balanced_acc.global_avg:.3f} '
# 'Auc {frame_auc.global_avg:.3f} EER {frame_eer.global_avg:.3f} loss {losses.global_avg:.3f}'
# .format(frame_acc=metric_logger.frame_acc, frame_balanced_acc=metric_logger.frame_balanced_acc,
# frame_auc=metric_logger.frame_auc, frame_eer=metric_logger.frame_eer, losses=metric_logger.loss))
# video-level metrics:
frame_labels_list = frame_labels.tolist()
frame_preds_list = frame_preds.tolist()
video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred(frame_labels_list, video_names_list, frame_preds_list)
# print(len(video_label_list), len(video_pred_list), len(video_y_pred_list))
# metric_logger.meters['video_acc'].update(video_level_acc(video_label_list, video_y_pred_list))
# metric_logger.meters['video_balanced_acc'].update(video_level_balanced_acc(video_label_list, video_y_pred_list))
# metric_logger.meters['video_auc'].update(video_level_auc(video_label_list, video_pred_list))
# metric_logger.meters['video_eer'].update(frame_level_eer(video_label_list, video_pred_list))
# print('*[------VIDEO-LEVEL------] \n'
# 'Acc {video_acc.global_avg:.3f} Balanced_Acc {video_balanced_acc.global_avg:.3f} '
# 'Auc {video_auc.global_avg:.3f} EER {video_eer.global_avg:.3f}'
# .format(video_acc=metric_logger.video_acc, video_balanced_acc=metric_logger.video_balanced_acc,
# video_auc=metric_logger.video_auc, video_eer=metric_logger.video_eer))
# return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
return frame_preds_list, video_pred_list