-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevaluation_all.py
94 lines (72 loc) · 3.24 KB
/
evaluation_all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import random
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader
from core.models.FDENet.FDENet import FDENet, FDENetPlus
from core.dataset.dataset import BaseDataset
from torch.nn import MSELoss
from core.models.HRNet.hrnet_infer import HRNetInfer
# from sklearn.metrics.pairwise import cosine_similarity
torch.manual_seed(123)
torch.cuda.random.manual_seed(123)
random.seed(123)
class Evaluator(object):
def __init__(self,is_init=True):
if is_init:
self.initConfig()
self.init()
def initConfig(self):
self.datasetDir = "./dataset/10k/val"
self.batch_size = 8
# self.num_dim = 59 -18
self.num_dim = 54
self.hidden_dim = 128
self.device = torch.device("cuda")
self.model_load_path = "checkpoints/FDE_all/epoch 1.pth"
self.im_size = (256, 256)
# os.remove(self.tb_log_save_path)
self.hrnet_weight_path:str= "pretrained_models/HR18-WFLW.pth"
def init(self):
self.hrnet_infer = HRNetInfer(self.hrnet_weight_path,self.device)
self.dataset = BaseDataset(self.datasetDir,False,False,False,self.im_size)
self.dataLoader = DataLoader(self.dataset,batch_size=self.batch_size,shuffle=False)
self.model = FDENetPlus({"face":19,"eyes":13,"nose":15,"mouth":7},self.hidden_dim)
if self.model_load_path is not None:
self.load_model(self.model_load_path,self.model)
self.model=self.model.to(self.device)
def evaluate(self,):
self.lossfunc = MSELoss().to(self.device)
self.model.eval()
counter = 0
total_similarity = 0.0
total_distance = 0.0
total_loss = 0.0
for idx, data in enumerate(self.dataLoader):
counter+=1
imgs:torch.Tensor = data["img"]
labels:torch.Tensor = data["label"][:, :-5]
imgs=imgs.to(self.device)
# labels:torch.Tensor = torch.cat([data["label"][:, :19],data["label"][:, 32:-5]],dim=1)
labels = labels.to(self.device)
heatmap:torch.Tensor=self.hrnet_infer.get_heatmap(imgs)
heatmap =heatmap.detach()
with torch.no_grad():
output:torch.Tensor=self.model(imgs,heatmap)
loss=self.lossfunc(output,labels)
total_loss+=loss
similarity=torch.nn.functional.cosine_similarity(output.detach(),labels.detach(),dim=1)
similarity = torch.mean(similarity).cpu().numpy()
distance = torch.nn.functional.pairwise_distance(output.detach(),labels.detach(),p=2).mean()
total_similarity += similarity
total_distance+=distance
print(f"batch: {idx+1} | distance: {distance:.3f} | cosine similarity: {similarity:.3f}")
print(f"avg loss: {total_loss/counter} | avg similarity: {total_similarity/counter} | avg distance: {total_distance/counter}")
return total_loss/counter,total_similarity/counter,total_distance/counter
def load_model(self, load_path, model, strict=True):
load_net = torch.load(load_path)
model.load_state_dict(load_net, strict=strict)
model.eval()
if __name__ =="__main__":
evaluator = Evaluator()
evaluator.evaluate()