Training an CNN to automatically classify storm morphology

Preprocessing

1) Set the number of classes and size of the image (including the '3rd' dimension, even though these are greyscale)

2) Read in the files using the get_example_data utility

3) Normalize the data by 80 dBZ

4) Verify the shape of the training, validation, and testing datasets

5) Transform the "single number" classifications into keras friendly arrays.

In [1]:
#Based on examples from the Keras documentation
import numpy as np
np.random.seed(42)
from tensorflow import keras
from tensorflow.keras import layers
import pickle
from svrimg.utils.get_images import get_example_data

num_classes = 6
input_shape = (136, 136, 1)

(x_train, y_train) = get_example_data('training', data_dir="../data/pkls/")
(x_val, y_val) = get_example_data('validation', data_dir="../data/pkls/")
(x_test, y_test) = get_example_data('testing', data_dir="../data/pkls/")

#Normalize by 80 dBZ
x_train = x_train.astype("float32") / 80
x_test = x_test.astype("float32") / 80
x_val = x_val.astype("float32") / 80

print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_val.shape[0], "validate samples")
print(x_test.shape[0], "test samples")

y_train = keras.utils.to_categorical(y_train, num_classes)
y_val = keras.utils.to_categorical(y_val, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
x_train shape: (1331, 136, 136, 1)
1331 train samples
110 validate samples
300 test samples

Create a simple 3 convolutional and 1 dense layer CNN

In [2]:
model = keras.Sequential(
    [
        keras.Input(shape=(136, 136, 1)),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.SpatialDropout2D(0.3),
        layers.MaxPooling2D(pool_size=(3, 3)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.SpatialDropout2D(0.3),
        layers.MaxPooling2D(pool_size=(3, 3)),
        layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
        layers.SpatialDropout2D(0.3),
        layers.MaxPooling2D(pool_size=(3, 3)),
        layers.Flatten(),
        layers.Dense(128, activation="relu"),
        layers.Dropout(0.6),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

keras.utils.plot_model(model, show_shapes=True)
WARNING:tensorflow:Large dropout rate: 0.6 (>0.5). In TensorFlow 2.x, dropout() uses dropout rate instead of keep_prob. Please ensure that this is intended.
Out[2]:

Use Data Augmentation to reduce overfitting

We can show how this works with one example.

In [3]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = 8, 8

from svrimg.utils.map_helper import radar_colormap, draw_box_plot
from matplotlib.colors import BoundaryNorm

cmap = radar_colormap()
classes = np.array(list(range(0, 85, 5)))
norm = BoundaryNorm(classes, ncolors=cmap.N)

sample = x_test[52]
ax = plt.subplot(1,1,1)
draw_box_plot(ax, sample.squeeze()*80)
Out[3]:
<matplotlib.axes._subplots.AxesSubplot at 0x1a7c960c240>

We should try to avoid shifting the image left and right, because the location of the storm report is right in the middle of each image. Instead, rotate the image and zoom in and out slightly. It is also important to ask yourself, does the image augmentation make sense?

We can visualize this with 9 randomly generated examples.

In [4]:
from keras.preprocessing.image import ImageDataGenerator
from numpy import expand_dims

plt.rcParams['figure.figsize'] = 20, 20

samples = expand_dims(sample, 0)

datagen = ImageDataGenerator(rotation_range=55, zoom_range=[0.9,1.0], fill_mode="reflect")

aug_imgs = datagen.flow(samples, batch_size=1)

for i in range(9):
   
    ax = plt.subplot(3,3,i+1)

    batch = aug_imgs.next()

    draw_box_plot(ax, batch[0].squeeze()*80)
Using TensorFlow backend.

Create an image generator for the training data and validation data and pass these values into model.fit(). Wait for the model to finish 100 epochs and test how it did!

Note: modify the "workers" argument depending on what kind of CPU you have. This was tested on a 20 core machine.

In [5]:
epochs = 100

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

history = model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
                              epochs=epochs, validation_data=(x_val, y_val), workers=8)
WARNING:tensorflow:From <ipython-input-5-e9a4ca861b60>:7: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.
WARNING:tensorflow:sample_weight modes were coerced from
  ...
    to  
  ['...']
Train for 42 steps, validate on 110 samples
Epoch 1/100
WARNING:tensorflow:Large dropout rate: 0.6 (>0.5). In TensorFlow 2.x, dropout() uses dropout rate instead of keep_prob. Please ensure that this is intended.
WARNING:tensorflow:Large dropout rate: 0.6 (>0.5). In TensorFlow 2.x, dropout() uses dropout rate instead of keep_prob. Please ensure that this is intended.
42/42 [==============================] - 4s 104ms/step - loss: 1.3260 - accuracy: 0.4771 - val_loss: 0.9010 - val_accuracy: 0.7273
Epoch 2/100
42/42 [==============================] - 1s 28ms/step - loss: 1.0469 - accuracy: 0.6281 - val_loss: 0.8009 - val_accuracy: 0.6818
Epoch 3/100
42/42 [==============================] - 1s 23ms/step - loss: 0.9465 - accuracy: 0.6619 - val_loss: 0.7362 - val_accuracy: 0.7636
Epoch 4/100
42/42 [==============================] - ETA: 0s - loss: 0.8757 - accuracy: 0.71 - 1s 24ms/step - loss: 0.8824 - accuracy: 0.7115 - val_loss: 0.6336 - val_accuracy: 0.7636
Epoch 5/100
42/42 [==============================] - 1s 24ms/step - loss: 0.8109 - accuracy: 0.7340 - val_loss: 0.5961 - val_accuracy: 0.8000
Epoch 6/100
42/42 [==============================] - 1s 25ms/step - loss: 0.7836 - accuracy: 0.7506 - val_loss: 0.5230 - val_accuracy: 0.8091
Epoch 7/100
42/42 [==============================] - 1s 25ms/step - loss: 0.7222 - accuracy: 0.7648 - val_loss: 0.5586 - val_accuracy: 0.8182
Epoch 8/100
42/42 [==============================] - 1s 24ms/step - loss: 0.7197 - accuracy: 0.7761 - val_loss: 0.5705 - val_accuracy: 0.7909
Epoch 9/100
42/42 [==============================] - 1s 24ms/step - loss: 0.7045 - accuracy: 0.7678 - val_loss: 0.4959 - val_accuracy: 0.8364
Epoch 10/100
42/42 [==============================] - 1s 24ms/step - loss: 0.6835 - accuracy: 0.7836 - val_loss: 0.4786 - val_accuracy: 0.8364
Epoch 11/100
42/42 [==============================] - 1s 24ms/step - loss: 0.6455 - accuracy: 0.8017 - val_loss: 0.4740 - val_accuracy: 0.8182
Epoch 12/100
42/42 [==============================] - 1s 24ms/step - loss: 0.6280 - accuracy: 0.8009 - val_loss: 0.4779 - val_accuracy: 0.8364
Epoch 13/100
42/42 [==============================] - 1s 23ms/step - loss: 0.6385 - accuracy: 0.7904 - val_loss: 0.4774 - val_accuracy: 0.8455
Epoch 14/100
42/42 [==============================] - 1s 24ms/step - loss: 0.5912 - accuracy: 0.8054 - val_loss: 0.4973 - val_accuracy: 0.8818
Epoch 15/100
42/42 [==============================] - 1s 23ms/step - loss: 0.5884 - accuracy: 0.8047 - val_loss: 0.4399 - val_accuracy: 0.8455
Epoch 16/100
42/42 [==============================] - 1s 24ms/step - loss: 0.6201 - accuracy: 0.8062 - val_loss: 0.5239 - val_accuracy: 0.8636
Epoch 17/100
42/42 [==============================] - 1s 23ms/step - loss: 0.5464 - accuracy: 0.8310 - val_loss: 0.4385 - val_accuracy: 0.8909
Epoch 18/100
42/42 [==============================] - 1s 23ms/step - loss: 0.5554 - accuracy: 0.8212 - val_loss: 0.4087 - val_accuracy: 0.8727
Epoch 19/100
42/42 [==============================] - 1s 23ms/step - loss: 0.5252 - accuracy: 0.8445 - val_loss: 0.3751 - val_accuracy: 0.8909
Epoch 20/100
42/42 [==============================] - 1s 22ms/step - loss: 0.5543 - accuracy: 0.8257 - val_loss: 0.4029 - val_accuracy: 0.8727
Epoch 21/100
42/42 [==============================] - 1s 24ms/step - loss: 0.5395 - accuracy: 0.8264 - val_loss: 0.4389 - val_accuracy: 0.8818
Epoch 22/100
42/42 [==============================] - 1s 25ms/step - loss: 0.5119 - accuracy: 0.8407 - val_loss: 0.3721 - val_accuracy: 0.8909
Epoch 23/100
42/42 [==============================] - 1s 24ms/step - loss: 0.4786 - accuracy: 0.8475 - val_loss: 0.3998 - val_accuracy: 0.8818
Epoch 24/100
42/42 [==============================] - 1s 25ms/step - loss: 0.4794 - accuracy: 0.8542 - val_loss: 0.3903 - val_accuracy: 0.8636
Epoch 25/100
42/42 [==============================] - 1s 25ms/step - loss: 0.5079 - accuracy: 0.8370 - val_loss: 0.3618 - val_accuracy: 0.8636
Epoch 26/100
42/42 [==============================] - 1s 24ms/step - loss: 0.4750 - accuracy: 0.8580 - val_loss: 0.3653 - val_accuracy: 0.9000
Epoch 27/100
42/42 [==============================] - 1s 24ms/step - loss: 0.4708 - accuracy: 0.8565 - val_loss: 0.3405 - val_accuracy: 0.8909
Epoch 28/100
42/42 [==============================] - 1s 24ms/step - loss: 0.4997 - accuracy: 0.8415 - val_loss: 0.3706 - val_accuracy: 0.8909
Epoch 29/100
42/42 [==============================] - 1s 25ms/step - loss: 0.4473 - accuracy: 0.8603 - val_loss: 0.3379 - val_accuracy: 0.8909
Epoch 30/100
42/42 [==============================] - 1s 25ms/step - loss: 0.4689 - accuracy: 0.8708 - val_loss: 0.3526 - val_accuracy: 0.8818
Epoch 31/100
42/42 [==============================] - 1s 24ms/step - loss: 0.4673 - accuracy: 0.8588 - val_loss: 0.3620 - val_accuracy: 0.8909
Epoch 32/100
42/42 [==============================] - 1s 23ms/step - loss: 0.4230 - accuracy: 0.8775 - val_loss: 0.3342 - val_accuracy: 0.8909
Epoch 33/100
42/42 [==============================] - 1s 25ms/step - loss: 0.3996 - accuracy: 0.8828 - val_loss: 0.3325 - val_accuracy: 0.9000
Epoch 34/100
42/42 [==============================] - 1s 24ms/step - loss: 0.4225 - accuracy: 0.8685 - val_loss: 0.3144 - val_accuracy: 0.9000
Epoch 35/100
42/42 [==============================] - 1s 25ms/step - loss: 0.4172 - accuracy: 0.8700 - val_loss: 0.3105 - val_accuracy: 0.8909
Epoch 36/100
42/42 [==============================] - 1s 24ms/step - loss: 0.4092 - accuracy: 0.8775 - val_loss: 0.3124 - val_accuracy: 0.9000
Epoch 37/100
42/42 [==============================] - 1s 25ms/step - loss: 0.3981 - accuracy: 0.8813 - val_loss: 0.3188 - val_accuracy: 0.9000
Epoch 38/100
42/42 [==============================] - 1s 25ms/step - loss: 0.4415 - accuracy: 0.8678 - val_loss: 0.3006 - val_accuracy: 0.9000
Epoch 39/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3947 - accuracy: 0.8835 - val_loss: 0.3491 - val_accuracy: 0.8727
Epoch 40/100
42/42 [==============================] - 1s 24ms/step - loss: 0.4210 - accuracy: 0.8685 - val_loss: 0.3191 - val_accuracy: 0.8818
Epoch 41/100
42/42 [==============================] - 1s 25ms/step - loss: 0.3683 - accuracy: 0.8820 - val_loss: 0.3338 - val_accuracy: 0.8818
Epoch 42/100
42/42 [==============================] - 1s 25ms/step - loss: 0.4201 - accuracy: 0.8768 - val_loss: 0.3252 - val_accuracy: 0.8909
Epoch 43/100
42/42 [==============================] - 1s 23ms/step - loss: 0.3991 - accuracy: 0.8805 - val_loss: 0.2955 - val_accuracy: 0.9000
Epoch 44/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3619 - accuracy: 0.8798 - val_loss: 0.3334 - val_accuracy: 0.8818
Epoch 45/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3747 - accuracy: 0.8881 - val_loss: 0.3104 - val_accuracy: 0.8909
Epoch 46/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3737 - accuracy: 0.8820 - val_loss: 0.3139 - val_accuracy: 0.9091
Epoch 47/100
42/42 [==============================] - 1s 23ms/step - loss: 0.3540 - accuracy: 0.8933 - val_loss: 0.2843 - val_accuracy: 0.9091
Epoch 48/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3439 - accuracy: 0.8948 - val_loss: 0.2857 - val_accuracy: 0.9000
Epoch 49/100
42/42 [==============================] - 1s 23ms/step - loss: 0.3845 - accuracy: 0.8888 - val_loss: 0.2884 - val_accuracy: 0.9000
Epoch 50/100
42/42 [==============================] - 1s 25ms/step - loss: 0.3509 - accuracy: 0.8971 - val_loss: 0.3214 - val_accuracy: 0.8818
Epoch 51/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3454 - accuracy: 0.8956 - val_loss: 0.3087 - val_accuracy: 0.8727
Epoch 52/100
42/42 [==============================] - 1s 25ms/step - loss: 0.3622 - accuracy: 0.8963 - val_loss: 0.3160 - val_accuracy: 0.8909
Epoch 53/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3676 - accuracy: 0.8820 - val_loss: 0.3363 - val_accuracy: 0.8818
Epoch 54/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3442 - accuracy: 0.9008 - val_loss: 0.3120 - val_accuracy: 0.8909
Epoch 55/100
42/42 [==============================] - 1s 22ms/step - loss: 0.3365 - accuracy: 0.8978 - val_loss: 0.3017 - val_accuracy: 0.9091
Epoch 56/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3352 - accuracy: 0.8941 - val_loss: 0.3035 - val_accuracy: 0.9091
Epoch 57/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3326 - accuracy: 0.8941 - val_loss: 0.2959 - val_accuracy: 0.9091
Epoch 58/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3387 - accuracy: 0.9001 - val_loss: 0.2976 - val_accuracy: 0.9000
Epoch 59/100
42/42 [==============================] - 1s 25ms/step - loss: 0.3503 - accuracy: 0.8948 - val_loss: 0.3034 - val_accuracy: 0.9091
Epoch 60/100
42/42 [==============================] - 1s 25ms/step - loss: 0.3222 - accuracy: 0.9068 - val_loss: 0.3041 - val_accuracy: 0.9000
Epoch 61/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3172 - accuracy: 0.8986 - val_loss: 0.3015 - val_accuracy: 0.9000
Epoch 62/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2705 - accuracy: 0.9113 - val_loss: 0.2917 - val_accuracy: 0.9091
Epoch 63/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3412 - accuracy: 0.9001 - val_loss: 0.3134 - val_accuracy: 0.8909
Epoch 64/100
42/42 [==============================] - ETA: 0s - loss: 0.3181 - accuracy: 0.88 - 1s 24ms/step - loss: 0.3146 - accuracy: 0.8896 - val_loss: 0.3121 - val_accuracy: 0.9000
Epoch 65/100
42/42 [==============================] - 1s 25ms/step - loss: 0.3347 - accuracy: 0.8888 - val_loss: 0.2941 - val_accuracy: 0.8909
Epoch 66/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3131 - accuracy: 0.9068 - val_loss: 0.3169 - val_accuracy: 0.8909
Epoch 67/100
42/42 [==============================] - 1s 24ms/step - loss: 0.2923 - accuracy: 0.9144 - val_loss: 0.2838 - val_accuracy: 0.9000
Epoch 68/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3081 - accuracy: 0.9083 - val_loss: 0.2941 - val_accuracy: 0.9000
Epoch 69/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3537 - accuracy: 0.8881 - val_loss: 0.2861 - val_accuracy: 0.9000
Epoch 70/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3045 - accuracy: 0.9106 - val_loss: 0.3177 - val_accuracy: 0.8818
Epoch 71/100
42/42 [==============================] - 1s 25ms/step - loss: 0.3067 - accuracy: 0.8971 - val_loss: 0.3283 - val_accuracy: 0.8455
Epoch 72/100
42/42 [==============================] - 1s 26ms/step - loss: 0.3120 - accuracy: 0.9098 - val_loss: 0.3125 - val_accuracy: 0.9000
Epoch 73/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2976 - accuracy: 0.9031 - val_loss: 0.3140 - val_accuracy: 0.9091
Epoch 74/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2816 - accuracy: 0.9136 - val_loss: 0.3461 - val_accuracy: 0.8909
Epoch 75/100
42/42 [==============================] - 1s 24ms/step - loss: 0.2899 - accuracy: 0.9091 - val_loss: 0.3498 - val_accuracy: 0.8818
Epoch 76/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2823 - accuracy: 0.9098 - val_loss: 0.3222 - val_accuracy: 0.8818
Epoch 77/100
42/42 [==============================] - 1s 24ms/step - loss: 0.3046 - accuracy: 0.9106 - val_loss: 0.3072 - val_accuracy: 0.9000
Epoch 78/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2600 - accuracy: 0.9181 - val_loss: 0.3482 - val_accuracy: 0.9000
Epoch 79/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2949 - accuracy: 0.9068 - val_loss: 0.2990 - val_accuracy: 0.8909
Epoch 80/100
42/42 [==============================] - 1s 26ms/step - loss: 0.2862 - accuracy: 0.9121 - val_loss: 0.3375 - val_accuracy: 0.8909
Epoch 81/100
42/42 [==============================] - 1s 24ms/step - loss: 0.2597 - accuracy: 0.9159 - val_loss: 0.3405 - val_accuracy: 0.8818
Epoch 82/100
42/42 [==============================] - 1s 26ms/step - loss: 0.2605 - accuracy: 0.9151 - val_loss: 0.3298 - val_accuracy: 0.9000
Epoch 83/100
42/42 [==============================] - 1s 26ms/step - loss: 0.2613 - accuracy: 0.9174 - val_loss: 0.3437 - val_accuracy: 0.8909
Epoch 84/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2682 - accuracy: 0.9241 - val_loss: 0.3326 - val_accuracy: 0.8909
Epoch 85/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2704 - accuracy: 0.9076 - val_loss: 0.3073 - val_accuracy: 0.8818
Epoch 86/100
42/42 [==============================] - 1s 24ms/step - loss: 0.2351 - accuracy: 0.9196 - val_loss: 0.3182 - val_accuracy: 0.9000
Epoch 87/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2586 - accuracy: 0.9136 - val_loss: 0.3560 - val_accuracy: 0.8909
Epoch 88/100
42/42 [==============================] - 1s 24ms/step - loss: 0.2669 - accuracy: 0.9121 - val_loss: 0.3410 - val_accuracy: 0.9091
Epoch 89/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2609 - accuracy: 0.9204 - val_loss: 0.3268 - val_accuracy: 0.9091
Epoch 90/100
42/42 [==============================] - 1s 24ms/step - loss: 0.2489 - accuracy: 0.9189 - val_loss: 0.3405 - val_accuracy: 0.9000
Epoch 91/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2824 - accuracy: 0.9121 - val_loss: 0.3430 - val_accuracy: 0.8909
Epoch 92/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2693 - accuracy: 0.9144 - val_loss: 0.3768 - val_accuracy: 0.9000
Epoch 93/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2801 - accuracy: 0.9174 - val_loss: 0.3466 - val_accuracy: 0.8818
Epoch 94/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2286 - accuracy: 0.9226 - val_loss: 0.3502 - val_accuracy: 0.9000
Epoch 95/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2850 - accuracy: 0.9128 - val_loss: 0.3237 - val_accuracy: 0.8909
Epoch 96/100
42/42 [==============================] - 1s 26ms/step - loss: 0.2417 - accuracy: 0.9234 - val_loss: 0.3311 - val_accuracy: 0.8909
Epoch 97/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2638 - accuracy: 0.9241 - val_loss: 0.3125 - val_accuracy: 0.9000
Epoch 98/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2617 - accuracy: 0.9113 - val_loss: 0.3342 - val_accuracy: 0.8818
Epoch 99/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2669 - accuracy: 0.9211 - val_loss: 0.3437 - val_accuracy: 0.8818
Epoch 100/100
42/42 [==============================] - 1s 25ms/step - loss: 0.2409 - accuracy: 0.9286 - val_loss: 0.3455 - val_accuracy: 0.8909

Check the change in training and validation accuracy over epochs.

Divergence of these two generally suggests overfitting. This can be addressed by image augmentation, dropout, and getting more data.

In [6]:
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = 10, 6

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val.'], loc='upper left')
plt.show()

Check the testing accuracy

If it is similar to the validation accuracy, the model may be generalizing (which is good).

In [7]:
score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])
Test loss: 0.6227933420737585
Test accuracy: 0.87666667

Check the per class FAR/POD

You can see that the high accuracy is mostly due to the model doing well on Cellular, QLCS, and Tropical cases. The model actually has fantastic POD (Sensitivity/Recall) overall, but higher FAR (An aspect of precision) for Cellular and Tropical cases. At least for the test set, the model detects almost all QLCSs and has very few "QLCS False Alarms". This could be my bias towards identifying QLCSs, as I have been looking at images of QLCSs in my free time for..... longer than is healthy at this point.

76 out of 78 cellular cases, 144 out of 152 QLCS cases, and 31 out of 33 Tropical cases are properly identified!

11 Noise, 1 Missing, 3 Other, and 3 QLCSs cases are identified as Cellular (76 out of 94 (81%) Cellular Predictions were correct)

1 Noise, 2 Tropical, and 1 Cellular cases are identified as QLCS (144 out of 148 (97%) QLCS Predictions were correct)

1 Noise, 7 Other, 4 QLCS, and 1 Cellular cases are identified as Tropical (31 out of 44 (70%) Tropical Predictions were correct)

The other classes have very low counts, but these results suggest the classifications of Other, Missing, and Noise are not reliable. The work continues!

In [9]:
from sklearn.metrics import classification_report, confusion_matrix

y_pred = model.predict(x_test)
y_pred = np.argmax(y_pred, axis=1)
y_test_ = np.argmax(y_test, axis=1)

print('Confusion Matrix')
print(confusion_matrix(y_test_, y_pred))
print('Classification Report')
target_names = ['Cellular', 'QLCS', 'Tropical', 'Other', 'Missing', 'Noise']
print(classification_report(y_test_, y_pred, target_names=target_names))
Confusion Matrix
[[ 76   1   1   0   0   0]
 [  3 144   4   0   0   1]
 [  0   2  31   0   0   0]
 [  3   0   7   1   0   0]
 [  1   0   0   0   0   1]
 [ 11   1   1   0   0  11]]
Classification Report
              precision    recall  f1-score   support

    Cellular       0.81      0.97      0.88        78
        QLCS       0.97      0.95      0.96       152
    Tropical       0.70      0.94      0.81        33
       Other       1.00      0.09      0.17        11
     Missing       0.00      0.00      0.00         2
       Noise       0.85      0.46      0.59        24

    accuracy                           0.88       300
   macro avg       0.72      0.57      0.57       300
weighted avg       0.89      0.88      0.86       300

Apply the model to the testing images and compare the predictions to the actual labels

In [10]:
from svrimg.utils.get_tables import get_svrgis_table, get_pred_tables

actual = get_pred_tables(data_dir="../data/csvs/", example=True, remove_first_row=True)

svrgis = get_svrgis_table(data_dir="../data/csvs/")

actual = actual.join(svrgis)

actual.head()
Out[10]:
Class Code Class Name om tz st stf stn mag inj fat ... f3 f4 fc init_date fmt_date date_utc yr mo dy hr
UNID
199604200208z000000206 0 Cellular 206 3 IL 17 38 0 0 0 ... 0 0 0 1996-04-19-20:08:00 4/19/1996 20:08 4/20/1996 2:08 1996 4 20 2
199604192244z000000197 0 Cellular 197 3 IL 17 8 0 0 0 ... 0 0 0 1996-04-19-16:44:00 4/19/1996 16:44 4/19/1996 22:44 1996 4 19 22
199605280130z000000300 0 Cellular 300 3 IL 17 49 0 0 0 ... 0 0 0 1996-05-27-19:30:00 5/27/1996 19:30 5/28/1996 1:30 1996 5 28 1
199605280140z000000298 0 Cellular 298 3 IL 17 50 0 0 0 ... 0 0 0 1996-05-27-19:40:00 5/27/1996 19:40 5/28/1996 1:40 1996 5 28 1
199604192307z000000207 0 Cellular 207 3 IL 17 11 2 1 0 ... 0 0 0 1996-04-19-17:07:00 4/19/1996 17:07 4/19/1996 23:07 1996 4 19 23

5 rows × 33 columns

In [14]:
from svrimg.utils.get_images import get_img_list

plt.rcParams['figure.figsize'] = 25, 25
plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['axes.labelsize'] = 20

#Testing data are 2014 and on.  It is "cheating" to look at earlier data.
sample = actual[actual.yr>=2014].sample(9)

#Load the images and transform them to be "CNN-friendly"
imgs = get_img_list(sample.index.values, "../data/tor/")
imgs = expand_dims(imgs, 3)
imgs = imgs / 80 #normalize

#Identify the column with the highest probability
pred = np.argmax(model.predict(imgs), axis=1)
truth = sample['Class Code'].values

lookup = {0:'Cellular', 1:'QLCS', 2:'Tropical', 3:'Other', 4:'Noise', 5:'Missing'}

for i, (img, p) in enumerate(zip(imgs, pred)):
    
    ax = plt.subplot(3, 3, i+1)
    
    ax = draw_box_plot(ax, img.squeeze()*80, cbar_shrink=0.8)
    
    ax.set_title("Prediction: {}\nActual: {}".format(lookup[p], lookup[truth[i]]), fontsize=25)

Save model

In [15]:
model.save("../data/models/morph_model_v01.h5")