testen van het bloem model 'bloem_model.h5' in een confusion matrix. terug naar de inleiding
De schets is een python 3.7.6 schets. De pakketten zijn in anaconda binnen een virtuele omgeving geinstalleerd.
De schets maakt gebruik van keras versie 2.3.1, numpy versie 1.18.1, tensorflow versie 2.1.0
De schets maakt gebruik van keras versie 2.3.1, numpy versie 1.18.1, tensorflow versie 2.1.0
voor het pakket sklearn, scikit-learn versie 0.22.1 en matplotlib versie 3.1.3
Stap 1, De benodigde pakketten importeren
Stap 2, variabele met pad naar testmap
Stap 1, De benodigde pakketten importeren
import numpy as np
import keras
from keras.models import load_model
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing import image
from keras.applications import imagenet_utils
from sklearn.metrics import confusion_matrix
import itertools
import matplotlib.pyplot as plt
%matplotlib inline
Stap 2, variabele met pad naar testmap
Stap 3, ImageDataGenererator maakt batches van de afbeeldingen om het model te trainen en te valideren, zie stap 6
zie https://keras.io/preprocessing/image/
test_path = 'bloemen/test'
test_batches = ImageDataGenerator(preprocessing_function=keras.applications.mobilenet.preprocess_input).flow_from_directory(test_path, target_size=(224,224),batch_size=10,shuffle=False)
Stap 4, test_batches bevat een numpy ndarray met de testlabels
test_batches.class_indices # is een python dict deze printen geeft: {'pasiebloem': 0, 'waterlelie': 1}
test_labels = test_batches.classes # is eem numpy.ndarray
test_batches.class_indices # is een python dict
Stap 5, Het model laden
model = load_model('bloem_model.h5')
model.summary()
Stap 6, De test batch door het model sturen
predictions = model.predict_generator(test_batches, steps=1, verbose=0)
Stap 7, De functie plot_confusion_matrix
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix',cmap=plt.cm.Blues):
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float')/ cm.sum(axis=1)[:, np.newaxis]
print("Genormaliseerde confusion matrix")
else:
print("confusion matrix zonder normalisatie")
print(cm)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i,j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('de werkelijke labels')
plt.xlabel('de voorspelde labels')
Stap 8, De variabele cm aanmaken, cm is een argument in de plot_confusion_matrix functie
testlabels en predictions worden zo met elkaar vergeleken
cm = confusion_matrix(test_labels, predictions.argmax(axis=1))
Stap 9, De confusion matrix plotten
cm_plot_labels = ['passiebloem', 'waterlelie']
plot_confusion_matrix(cm, cm_plot_labels, title=('Confusion Matrix'))