-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
113 lines (87 loc) · 4.31 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import time
import numpy as np
import tensorflow as tf
import cv2
from operator import add
from metrics import intersection_over_union, dice_coefficient, weighted_f_score, s_score, e_score, max_e_score, mean_absolute_error
from utils import create_directory, load_model
from data import load_test_dataset
import argparse
def compute_metrics(y_true, y_pred):
y_pred = y_pred > 0.5
y_pred_flat = y_pred.reshape(-1).astype(np.uint8)
y_true = y_true > 0.5
y_true_flat = y_true.reshape(-1).astype(np.uint8)
iou_score = intersection_over_union(y_true_flat, y_pred_flat).numpy()
dice_score = dice_coefficient(y_true_flat, y_pred_flat).numpy()
f_score = weighted_f_score(y_true_flat, y_pred_flat)
s_measure_score = s_score(y_true_flat, y_pred_flat)
e_measure_score = e_score(y_true_flat, y_pred_flat)
max_e_measure_score = max_e_score(y_true_flat, y_pred_flat)
mae_score = mean_absolute_error(y_true_flat, y_pred_flat)
return [iou_score, dice_score, f_score, s_measure_score, e_measure_score, max_e_measure_score, mae_score]
def parse_mask(mask):
mask = np.squeeze(mask)
mask = np.stack([mask, mask, mask], axis=-1)
return mask
if __name__ == "__main__":
np.random.seed(42)
tf.random.set_seed(42)
parser = argparse.ArgumentParser(description='Test model on a dataset.')
parser.add_argument('--dataset', type=str, required=True, help='Name of the dataset directory in the "Dataset" folder.')
parser.add_argument('--fulltest', action='store_true', help='Test on the full dataset instead of loading from val.txt')
args = parser.parse_args()
dataset_name = args.dataset
fulltest = args.fulltest
print(f"Testing on {dataset_name}")
dataset_path = os.path.join("Dataset", dataset_name)
test_images, test_masks = load_test_dataset(dataset_path, fulltest)
image_size = (224, 224)
model_path = f"model/model.keras"
model = load_model(model_path)
dummy_image = np.zeros((1, 224, 224, 3))
_ = model.predict(dummy_image)
metrics_scores = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
inference_times = []
for i, (image_path, mask_path) in enumerate(zip(test_images, test_masks)):
name = os.path.basename(mask_path).split(".")[0]
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
if image is None:
print(f"Failed to load image: {image_path}. Skipping.")
continue
image = cv2.resize(image, image_size)
original_image = image.copy()
image = image / 255.0
image = np.expand_dims(image, axis=0).astype(np.float32)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
print(f"Failed to load mask: {mask_path}. Skipping.")
continue
mask = cv2.resize(mask, image_size)
original_mask = mask.copy()
mask = np.expand_dims(mask, axis=0) / 255.0
mask = mask.astype(np.float32)
start_time = time.time()
predicted_mask = model.predict(image)
inference_time = time.time() - start_time
inference_times.append(inference_time)
print(f"{name}: {inference_time:1.5f}")
scores = compute_metrics(mask, predicted_mask)
metrics_scores = list(map(add, metrics_scores, scores))
predicted_mask = (predicted_mask[0] > 0.5) * 255
predicted_mask = np.array(predicted_mask, dtype=np.uint8)
original_mask = parse_mask(original_mask)
predicted_mask = parse_mask(predicted_mask)
separator_line = np.ones((image_size[0], 10, 3)) * 255
concatenated_images = np.concatenate([original_image, separator_line, original_mask, separator_line, predicted_mask], axis=1)
cv2.imwrite(f"results/{name}.png", concatenated_images)
average_scores = [score_sum / len(test_images) for score_sum in metrics_scores]
metric_labels = ["mIoU", "mDice", "Fw", "Sm", "Em", "maxEm", "MAE"]
print("\nAverage Scores:")
for label, score in zip(metric_labels, average_scores):
print(f"{label}: {score:1.4f}")
mean_inference_time = np.mean(inference_times)
mean_fps = 1 / mean_inference_time
print(f"Mean FPS: {mean_fps}")