Transfer Learning in TensorFlow : Scaling Up

This Notebook is an account of my working for the Udemy course :TensorFlow Developer Certificate in 2022: Zero to Mastery.

Concepts covered in this Notebook:

  • Dowloading and preparing 10% of all Food101 Classes (7500+ training images)
  • Training a transfer learning feature extraction model.
  • Fine-tuning feature extraction model to beat the original Food101 with only 10% of data.
  • Evaluating Food Vison Mini's predictions
    • Find the most wrong predictions(on test dataset)
  • Making predictions with Food Vision mini on our own custom images

In this Notebook, we are scaling up our models for even larger dataset and with more Classes of data -- For all classes of the FOOD Vision Dataset(i.e. 101 classes of Food)

Creating helper functions

We can download the helper functions file that has the entire collection of custom built functions we have created in the previous notebooks.

!wget https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py
--2022-02-20 18:37:18--  https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10246 (10K) [text/plain]
Saving to: ‘helper_functions.py’

helper_functions.py 100%[===================>]  10.01K  --.-KB/s    in 0s      

2022-02-20 18:37:18 (49.1 MB/s) - ‘helper_functions.py’ saved [10246/10246]

from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, compare_historys, walk_through_dir

101 Food Classes : working with less data

Our goal is to beat the original Food101 paper with 10% of the training data, so let's download it.

The data we're downloading comes from the original Food101 Dataset but has been preprocessed using image_modification.ipynb Notebook.

!wget https://storage.googleapis.com/ztm_tf_course/food_vision/101_food_classes_10_percent.zip
unzip_data("101_food_classes_10_percent.zip")

train_dir = "101_food_classes_10_percent/train/"
test_dir = "101_food_classes_10_percent/test/"
--2022-02-20 14:26:13--  https://storage.googleapis.com/ztm_tf_course/food_vision/101_food_classes_10_percent.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.193.128, 173.194.194.128, 173.194.195.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.193.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1625420029 (1.5G) [application/zip]
Saving to: ‘101_food_classes_10_percent.zip’

101_food_classes_10 100%[===================>]   1.51G   216MB/s    in 7.3s    

2022-02-20 14:26:20 (213 MB/s) - ‘101_food_classes_10_percent.zip’ saved [1625420029/1625420029]

# how many image classes are there
walk_through_dir("101_food_classes_10_percent")

There are 2 directories and 0 images in '101_food_classes_10_percent'.
There are 101 directories and 0 images in '101_food_classes_10_percent/train'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/prime_rib'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/strawberry_shortcake'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/takoyaki'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pizza'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/filet_mignon'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/paella'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/mussels'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/eggs_benedict'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/ravioli'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/apple_pie'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/peking_duck'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/baklava'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pork_chop'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/spaghetti_bolognese'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/huevos_rancheros'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pad_thai'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/breakfast_burrito'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/cannoli'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/fried_rice'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/chicken_quesadilla'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/french_toast'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/garlic_bread'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/chocolate_cake'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/fried_calamari'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/macarons'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/ceviche'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/crab_cakes'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/club_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/dumplings'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/spring_rolls'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/bread_pudding'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/sashimi'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/risotto'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pho'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/foie_gras'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/frozen_yogurt'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/grilled_salmon'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/cheese_plate'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/poutine'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/bruschetta'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/cheesecake'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/caesar_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/beet_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/lasagna'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/carrot_cake'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/tuna_tartare'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/beef_carpaccio'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/falafel'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/greek_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/guacamole'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/hummus'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/hamburger'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/beef_tartare'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/beignets'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/waffles'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/gnocchi'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/baby_back_ribs'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/grilled_cheese_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/steak'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/edamame'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/spaghetti_carbonara'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/sushi'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/macaroni_and_cheese'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/churros'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/cup_cakes'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/chicken_wings'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/clam_chowder'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/scallops'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/hot_dog'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/french_onion_soup'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/ramen'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pancakes'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/french_fries'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/panna_cotta'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/tiramisu'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/hot_and_sour_soup'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pulled_pork_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/escargots'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/gyoza'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/seaweed_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/donuts'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/miso_soup'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/chocolate_mousse'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/shrimp_and_grits'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/red_velvet_cake'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/ice_cream'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/bibimbap'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/lobster_roll_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/nachos'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/samosa'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/fish_and_chips'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/caprese_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/croque_madame'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/tacos'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/chicken_curry'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/omelette'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/deviled_eggs'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/oysters'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/creme_brulee'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/lobster_bisque'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/onion_rings'.
There are 101 directories and 0 images in '101_food_classes_10_percent/test'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/prime_rib'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/strawberry_shortcake'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/takoyaki'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pizza'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/filet_mignon'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/paella'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/mussels'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/eggs_benedict'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/ravioli'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/apple_pie'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/peking_duck'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/baklava'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pork_chop'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/spaghetti_bolognese'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/huevos_rancheros'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pad_thai'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/breakfast_burrito'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/cannoli'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/fried_rice'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/chicken_quesadilla'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/french_toast'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/garlic_bread'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/chocolate_cake'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/fried_calamari'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/macarons'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/ceviche'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/crab_cakes'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/club_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/dumplings'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/spring_rolls'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/bread_pudding'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/sashimi'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/risotto'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pho'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/foie_gras'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/frozen_yogurt'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/grilled_salmon'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/cheese_plate'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/poutine'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/bruschetta'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/cheesecake'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/caesar_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/beet_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/lasagna'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/carrot_cake'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/tuna_tartare'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/beef_carpaccio'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/falafel'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/greek_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/guacamole'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/hummus'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/hamburger'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/beef_tartare'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/beignets'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/waffles'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/gnocchi'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/baby_back_ribs'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/grilled_cheese_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/steak'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/edamame'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/spaghetti_carbonara'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/sushi'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/macaroni_and_cheese'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/churros'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/cup_cakes'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/chicken_wings'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/clam_chowder'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/scallops'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/hot_dog'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/french_onion_soup'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/ramen'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pancakes'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/french_fries'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/panna_cotta'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/tiramisu'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/hot_and_sour_soup'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pulled_pork_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/escargots'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/gyoza'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/seaweed_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/donuts'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/miso_soup'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/chocolate_mousse'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/shrimp_and_grits'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/red_velvet_cake'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/ice_cream'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/bibimbap'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/lobster_roll_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/nachos'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/samosa'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/fish_and_chips'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/caprese_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/croque_madame'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/tacos'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/chicken_curry'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/omelette'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/deviled_eggs'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/oysters'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/creme_brulee'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/lobster_bisque'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/onion_rings'.
import tensorflow as tf
IMG_SIZE = (224,224)
train_data_all_10_percent = tf.keras.preprocessing.image_dataset_from_directory(train_dir,
                                                                                 label_mode = "categorical",
                                                                                 image_size = IMG_SIZE)

test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
                                                                 label_mode = "categorical",
                                                                 image_size  = IMG_SIZE,
                                                                 shuffle = False) # don't shuffle test data for prediction analysis
Found 7575 files belonging to 101 classes.
Found 25250 files belonging to 101 classes.

Train a large Model with transfer learning on 10% of 101 food classes

Here are the steps we're going to take :

  • Create a ModelCheckpoint callback
  • Create a data augmentation layer to build data augmentation right into the layer
  • Build a headless(no top layers) Functional EfficientNetB0 backboned-model
  • Compile our model
  • Feature extract for 5 full passes(5 epochs on the traindataset and validate on 15% of the test data, to save epoch time)
checkpoint_path = "101_classes_10_percent_data_model_checkpoint"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                         save_weights_only = True,
                                                         monitor = "val_accuracy",
                                                         save_best_only = True)
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.models import Sequential

# Setup data augmentation
data_augmentation = Sequential([
  preprocessing.RandomFlip("horizontal"),
  preprocessing.RandomRotation(0.2),
  preprocessing.RandomHeight(0.2),
  preprocessing.RandomWidth(0.2),
  preprocessing.RandomZoom(0.2)       
  # preprocessing.Rescaling(1/255.) # rescale inputs of images to between 0 and 1 required for ResNet50 like models                         
], name = "data_augmentation")
base_model = tf.keras.applications.EfficientNetB0(include_top=False)
base_model.trainable = False

# Setup model architecture with trainable top layers
inputs = layers.Input(shape = (224,224,3), name = "input_layer")
x = data_augmentation(inputs), # augment images (only happens during training phase)
x = base_model(x, training = False) # Put the base model in inference mode so weights don't get updated
x = layers.GlobalAveragePooling2D(name = "global_avg_pool_layer")(x)
outputs = layers.Dense(len(train_data_all_10_percent.class_names), activation = "softmax", name = "output_layer")(x)

model = tf.keras.Model(inputs,outputs)
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_layer (InputLayer)    [(None, 224, 224, 3)]     0         
                                                                 
 data_augmentation (Sequenti  (None, 224, 224)         0         
 al)                                                             
                                                                 
 efficientnetb0 (Functional)  (None, None, None, 1280)  4049571  
                                                                 
 global_avg_pool_layer (Glob  (None, 1280)             0         
 alAveragePooling2D)                                             
                                                                 
 output_layer (Dense)        (None, 101)               129381    
                                                                 
=================================================================
Total params: 4,178,952
Trainable params: 129,381
Non-trainable params: 4,049,571
_________________________________________________________________
model.compile(loss = "categorical_crossentropy",
              optimizer = tf.keras.optimizers.Adam(),
              metrics = ["accuracy"])

# Fit 
history_all_classes_10_percent = model.fit(train_data_all_10_percent,
                                           epochs =5,
                                           validation_data =test_data,
                                           validation_steps = int(0.25*len(test_data)),
                                           callbacks = [checkpoint_callback])
Epoch 1/5
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
237/237 [==============================] - ETA: 0s - loss: 3.4688 - accuracy: 0.2467WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
237/237 [==============================] - 147s 541ms/step - loss: 3.4688 - accuracy: 0.2467 - val_loss: 2.6962 - val_accuracy: 0.3731
Epoch 2/5
237/237 [==============================] - 106s 447ms/step - loss: 2.3543 - accuracy: 0.4524 - val_loss: 2.2650 - val_accuracy: 0.4410
Epoch 3/5
237/237 [==============================] - 97s 409ms/step - loss: 1.9770 - accuracy: 0.5350 - val_loss: 2.0941 - val_accuracy: 0.4676
Epoch 4/5
237/237 [==============================] - 93s 391ms/step - loss: 1.7549 - accuracy: 0.5781 - val_loss: 2.0071 - val_accuracy: 0.4789
Epoch 5/5
237/237 [==============================] - 84s 354ms/step - loss: 1.6097 - accuracy: 0.6057 - val_loss: 1.9830 - val_accuracy: 0.4765
feature_extraction_results= model.evaluate(test_data)
feature_extraction_results
790/790 [==============================] - 106s 133ms/step - loss: 1.6050 - accuracy: 0.5757
[1.6050441265106201, 0.5757227540016174]
import matplotlib.pyplot as plt
plt.style.use('dark_background')
plot_loss_curves(history_all_classes_10_percent)

Ideally, the two curves should be very close to each other, but if they are not close to each other it means, our model maybe overfitting(Performing too well on the training data and not generalizing to unseen data)

Fine-tuning

base_model.trainable = True

# Refreeze every layer except the last 5
for layer in base_model.layers[:-5]:
  layer.trainable = False
model.compile(loss = "categorical_crossentropy",
              optimizer = tf.keras.optimizers.Adam(lr = 0.0001),
              metrics = ["accuracy"])
/usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  super(Adam, self).__init__(name, **kwargs)
# Check the trainable layers
for layer in model.layers:
  print(layer.name, layer.trainable)

input_layer True
data_augmentation True
efficientnetb0 True
global_avg_pool_layer True
output_layer True
# Check which layers are trainable in our base model
for layer_number , layer in enumerate(model.layers[2].layers):
  print(layer_number, layer.name, layer.trainable)

0 input_4 False
1 rescaling_3 False
2 normalization_3 False
3 stem_conv_pad False
4 stem_conv False
5 stem_bn False
6 stem_activation False
7 block1a_dwconv False
8 block1a_bn False
9 block1a_activation False
10 block1a_se_squeeze False
11 block1a_se_reshape False
12 block1a_se_reduce False
13 block1a_se_expand False
14 block1a_se_excite False
15 block1a_project_conv False
16 block1a_project_bn False
17 block2a_expand_conv False
18 block2a_expand_bn False
19 block2a_expand_activation False
20 block2a_dwconv_pad False
21 block2a_dwconv False
22 block2a_bn False
23 block2a_activation False
24 block2a_se_squeeze False
25 block2a_se_reshape False
26 block2a_se_reduce False
27 block2a_se_expand False
28 block2a_se_excite False
29 block2a_project_conv False
30 block2a_project_bn False
31 block2b_expand_conv False
32 block2b_expand_bn False
33 block2b_expand_activation False
34 block2b_dwconv False
35 block2b_bn False
36 block2b_activation False
37 block2b_se_squeeze False
38 block2b_se_reshape False
39 block2b_se_reduce False
40 block2b_se_expand False
41 block2b_se_excite False
42 block2b_project_conv False
43 block2b_project_bn False
44 block2b_drop False
45 block2b_add False
46 block3a_expand_conv False
47 block3a_expand_bn False
48 block3a_expand_activation False
49 block3a_dwconv_pad False
50 block3a_dwconv False
51 block3a_bn False
52 block3a_activation False
53 block3a_se_squeeze False
54 block3a_se_reshape False
55 block3a_se_reduce False
56 block3a_se_expand False
57 block3a_se_excite False
58 block3a_project_conv False
59 block3a_project_bn False
60 block3b_expand_conv False
61 block3b_expand_bn False
62 block3b_expand_activation False
63 block3b_dwconv False
64 block3b_bn False
65 block3b_activation False
66 block3b_se_squeeze False
67 block3b_se_reshape False
68 block3b_se_reduce False
69 block3b_se_expand False
70 block3b_se_excite False
71 block3b_project_conv False
72 block3b_project_bn False
73 block3b_drop False
74 block3b_add False
75 block4a_expand_conv False
76 block4a_expand_bn False
77 block4a_expand_activation False
78 block4a_dwconv_pad False
79 block4a_dwconv False
80 block4a_bn False
81 block4a_activation False
82 block4a_se_squeeze False
83 block4a_se_reshape False
84 block4a_se_reduce False
85 block4a_se_expand False
86 block4a_se_excite False
87 block4a_project_conv False
88 block4a_project_bn False
89 block4b_expand_conv False
90 block4b_expand_bn False
91 block4b_expand_activation False
92 block4b_dwconv False
93 block4b_bn False
94 block4b_activation False
95 block4b_se_squeeze False
96 block4b_se_reshape False
97 block4b_se_reduce False
98 block4b_se_expand False
99 block4b_se_excite False
100 block4b_project_conv False
101 block4b_project_bn False
102 block4b_drop False
103 block4b_add False
104 block4c_expand_conv False
105 block4c_expand_bn False
106 block4c_expand_activation False
107 block4c_dwconv False
108 block4c_bn False
109 block4c_activation False
110 block4c_se_squeeze False
111 block4c_se_reshape False
112 block4c_se_reduce False
113 block4c_se_expand False
114 block4c_se_excite False
115 block4c_project_conv False
116 block4c_project_bn False
117 block4c_drop False
118 block4c_add False
119 block5a_expand_conv False
120 block5a_expand_bn False
121 block5a_expand_activation False
122 block5a_dwconv False
123 block5a_bn False
124 block5a_activation False
125 block5a_se_squeeze False
126 block5a_se_reshape False
127 block5a_se_reduce False
128 block5a_se_expand False
129 block5a_se_excite False
130 block5a_project_conv False
131 block5a_project_bn False
132 block5b_expand_conv False
133 block5b_expand_bn False
134 block5b_expand_activation False
135 block5b_dwconv False
136 block5b_bn False
137 block5b_activation False
138 block5b_se_squeeze False
139 block5b_se_reshape False
140 block5b_se_reduce False
141 block5b_se_expand False
142 block5b_se_excite False
143 block5b_project_conv False
144 block5b_project_bn False
145 block5b_drop False
146 block5b_add False
147 block5c_expand_conv False
148 block5c_expand_bn False
149 block5c_expand_activation False
150 block5c_dwconv False
151 block5c_bn False
152 block5c_activation False
153 block5c_se_squeeze False
154 block5c_se_reshape False
155 block5c_se_reduce False
156 block5c_se_expand False
157 block5c_se_excite False
158 block5c_project_conv False
159 block5c_project_bn False
160 block5c_drop False
161 block5c_add False
162 block6a_expand_conv False
163 block6a_expand_bn False
164 block6a_expand_activation False
165 block6a_dwconv_pad False
166 block6a_dwconv False
167 block6a_bn False
168 block6a_activation False
169 block6a_se_squeeze False
170 block6a_se_reshape False
171 block6a_se_reduce False
172 block6a_se_expand False
173 block6a_se_excite False
174 block6a_project_conv False
175 block6a_project_bn False
176 block6b_expand_conv False
177 block6b_expand_bn False
178 block6b_expand_activation False
179 block6b_dwconv False
180 block6b_bn False
181 block6b_activation False
182 block6b_se_squeeze False
183 block6b_se_reshape False
184 block6b_se_reduce False
185 block6b_se_expand False
186 block6b_se_excite False
187 block6b_project_conv False
188 block6b_project_bn False
189 block6b_drop False
190 block6b_add False
191 block6c_expand_conv False
192 block6c_expand_bn False
193 block6c_expand_activation False
194 block6c_dwconv False
195 block6c_bn False
196 block6c_activation False
197 block6c_se_squeeze False
198 block6c_se_reshape False
199 block6c_se_reduce False
200 block6c_se_expand False
201 block6c_se_excite False
202 block6c_project_conv False
203 block6c_project_bn False
204 block6c_drop False
205 block6c_add False
206 block6d_expand_conv False
207 block6d_expand_bn False
208 block6d_expand_activation False
209 block6d_dwconv False
210 block6d_bn False
211 block6d_activation False
212 block6d_se_squeeze False
213 block6d_se_reshape False
214 block6d_se_reduce False
215 block6d_se_expand False
216 block6d_se_excite False
217 block6d_project_conv False
218 block6d_project_bn False
219 block6d_drop False
220 block6d_add False
221 block7a_expand_conv False
222 block7a_expand_bn False
223 block7a_expand_activation False
224 block7a_dwconv False
225 block7a_bn False
226 block7a_activation False
227 block7a_se_squeeze False
228 block7a_se_reshape False
229 block7a_se_reduce False
230 block7a_se_expand False
231 block7a_se_excite False
232 block7a_project_conv True
233 block7a_project_bn True
234 top_conv True
235 top_bn True
236 top_activation True
fine_tune_epochs = 10 # model has already done 5 epochs(feature extraction)
# This is the total no of epochs 5 for feature extraction and 5 for fine-tuning

# Fine-tune model
history_all_classes_10_percentp_fine_tune = model.fit(train_data_all_10_percent,
                                                      epochs = fine_tune_epochs,
                                                      validation_data = test_data,
                                                      validation_steps =int(0.25*len(test_data)),
                                                      initial_epoch = history_all_classes_10_percent.epoch[-1])
Epoch 5/10
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
237/237 [==============================] - ETA: 0s - loss: 1.3739 - accuracy: 0.6396WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
237/237 [==============================] - 100s 382ms/step - loss: 1.3739 - accuracy: 0.6396 - val_loss: 1.9134 - val_accuracy: 0.4916
Epoch 6/10
237/237 [==============================] - 86s 361ms/step - loss: 1.2250 - accuracy: 0.6817 - val_loss: 1.9096 - val_accuracy: 0.4914
Epoch 7/10
237/237 [==============================] - 80s 334ms/step - loss: 1.1593 - accuracy: 0.6944 - val_loss: 1.8845 - val_accuracy: 0.5040
Epoch 8/10
237/237 [==============================] - 81s 340ms/step - loss: 1.0929 - accuracy: 0.7067 - val_loss: 1.9180 - val_accuracy: 0.4975
Epoch 9/10
237/237 [==============================] - 76s 320ms/step - loss: 1.0192 - accuracy: 0.7291 - val_loss: 1.8902 - val_accuracy: 0.5030
Epoch 10/10
237/237 [==============================] - 77s 323ms/step - loss: 0.9608 - accuracy: 0.7430 - val_loss: 1.9129 - val_accuracy: 0.4992
all_classes_10_percent_fine_tune_results = model.evaluate(test_data)
all_classes_10_percent_fine_tune_results
790/790 [==============================] - 106s 134ms/step - loss: 1.6050 - accuracy: 0.5757
[1.6050441265106201, 0.5757227540016174]
compare_historys(original_history = history_all_classes_10_percent,
                 new_history = history_all_classes_10_percentp_fine_tune,
                 initial_epochs = 5)

Saving and loading our Model

To use our model in an external application we need to save it and export it somewhere

model.save("/content/drive/MyDrive/tensorflowcourseudemy/101_food_classes_10_percent_fine_tuned_model")
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
INFO:tensorflow:Assets written to: /content/drive/MyDrive/tensorflowcourseudemy/101_food_classes_10_percent_fine_tuned_model/assets
# Load and evaluate the saved model
loaded_model = tf.keras.models.load_model("/content/drive/MyDrive/tensorflowcourseudemy/101_food_classes_10_percent_fine_tuned_model")

WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
loaded_model_results = loaded_model.evaluate(test_data)
loaded_model_results
WARNING:tensorflow:Model was constructed with shape (None, 224, 224) for input KerasTensor(type_spec=TensorSpec(shape=(None, 224, 224), dtype=tf.float32, name='random_flip_2_input'), name='random_flip_2_input', description="created by layer 'random_flip_2_input'"), but it was called on an input with incompatible shape (None, 224, 224, 3).
790/790 [==============================] - 108s 134ms/step - loss: 1.6050 - accuracy: 0.5757
[1.6050441265106201, 0.5757227540016174]
all_classes_10_percent_fine_tune_results
[1.6050441265106201, 0.5757227540016174]

Our loaded model gives the exact same results as the model we performed on this Notebook. That means we get the same predicts if we use the saved model in any application.

Evaluating the performance of our model across all different classes

Let's make some predictions, visualize them and then later find out which predictions were the "most" wrong.

import tensorflow as tf
# Download a pre-trained model
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/06_101_food_class_10_percent_saved_big_dog_model.zip
--2022-02-20 16:21:59--  https://storage.googleapis.com/ztm_tf_course/food_vision/06_101_food_class_10_percent_saved_big_dog_model.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 209.85.145.128, 209.85.146.128, 142.250.125.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|209.85.145.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 46760742 (45M) [application/zip]
Saving to: ‘06_101_food_class_10_percent_saved_big_dog_model.zip’

06_101_food_class_1 100%[===================>]  44.59M  55.6MB/s    in 0.8s    

2022-02-20 16:22:00 (55.6 MB/s) - ‘06_101_food_class_10_percent_saved_big_dog_model.zip’ saved [46760742/46760742]

unzip_data("/content/06_101_food_class_10_percent_saved_big_dog_model.zip")
# Load in saved model
model = tf.keras.models.load_model("/content/06_101_food_class_10_percent_saved_big_dog_model")

WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:absl:Importing a function (__inference_block6c_expand_activation_layer_call_and_return_conditional_losses_419470) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_446460) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_activation_layer_call_and_return_conditional_losses_450449) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_415747) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_activation_layer_call_and_return_conditional_losses_416083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_activation_layer_call_and_return_conditional_losses_450775) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_activation_layer_call_and_return_conditional_losses_451847) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_expand_activation_layer_call_and_return_conditional_losses_417915) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_se_reduce_layer_call_and_return_conditional_losses_451887) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_expand_activation_layer_call_and_return_conditional_losses_452467) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_functional_17_layer_call_and_return_conditional_losses_438312) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_expand_activation_layer_call_and_return_conditional_losses_417583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_activation_layer_call_and_return_conditional_losses_418582) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_se_reduce_layer_call_and_return_conditional_losses_454031) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_activation_layer_call_and_return_conditional_losses_455436) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_activation_layer_call_and_return_conditional_losses_415524) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_activation_layer_call_and_return_conditional_losses_451474) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_expand_activation_layer_call_and_return_conditional_losses_451768) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_441729) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_se_reduce_layer_call_and_return_conditional_losses_454357) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_activation_layer_call_and_return_conditional_losses_416695) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_expand_activation_layer_call_and_return_conditional_losses_454238) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_functional_17_layer_call_and_return_conditional_losses_436681) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_activation_layer_call_and_return_conditional_losses_415804) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_activation_layer_call_and_return_conditional_losses_452919) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_se_reduce_layer_call_and_return_conditional_losses_453658) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_448082) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_activation_layer_call_and_return_conditional_losses_418915) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_expand_activation_layer_call_and_return_conditional_losses_453539) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_se_reduce_layer_call_and_return_conditional_losses_452586) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_se_reduce_layer_call_and_return_conditional_losses_450163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_se_reduce_layer_call_and_return_conditional_losses_418018) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_expand_activation_layer_call_and_return_conditional_losses_455357) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_activation_layer_call_and_return_conditional_losses_417639) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_se_reduce_layer_call_and_return_conditional_losses_451188) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_activation_layer_call_and_return_conditional_losses_420190) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_stem_activation_layer_call_and_return_conditional_losses_415468) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_se_reduce_layer_call_and_return_conditional_losses_455476) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_se_reduce_layer_call_and_return_conditional_losses_417354) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_se_reduce_layer_call_and_return_conditional_losses_452213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_activation_layer_call_and_return_conditional_losses_452173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_se_reduce_layer_call_and_return_conditional_losses_415571) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_se_reduce_layer_call_and_return_conditional_losses_451514) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_activation_layer_call_and_return_conditional_losses_417971) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_se_reduce_layer_call_and_return_conditional_losses_454730) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_se_reduce_layer_call_and_return_conditional_losses_416742) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_se_reduce_layer_call_and_return_conditional_losses_450489) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_activation_layer_call_and_return_conditional_losses_451148) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_expand_activation_layer_call_and_return_conditional_losses_418194) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_se_reduce_layer_call_and_return_conditional_losses_416463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_429711) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_443351) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_expand_activation_layer_call_and_return_conditional_losses_418526) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_activation_layer_call_and_return_conditional_losses_453245) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_activation_layer_call_and_return_conditional_losses_416416) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_428089) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_expand_activation_layer_call_and_return_conditional_losses_416027) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_expand_activation_layer_call_and_return_conditional_losses_453912) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_activation_layer_call_and_return_conditional_losses_452546) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_se_reduce_layer_call_and_return_conditional_losses_420237) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_se_reduce_layer_call_and_return_conditional_losses_418629) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_expand_activation_layer_call_and_return_conditional_losses_416359) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_expand_activation_layer_call_and_return_conditional_losses_451395) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_activation_layer_call_and_return_conditional_losses_454690) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_se_reduce_layer_call_and_return_conditional_losses_419905) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_activation_layer_call_and_return_conditional_losses_419526) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_se_reduce_layer_call_and_return_conditional_losses_418297) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_expand_activation_layer_call_and_return_conditional_losses_452094) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference__wrapped_model_408990) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_activation_layer_call_and_return_conditional_losses_453618) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_expand_activation_layer_call_and_return_conditional_losses_454984) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_expand_activation_layer_call_and_return_conditional_losses_450696) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_expand_activation_layer_call_and_return_conditional_losses_418858) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_stem_activation_layer_call_and_return_conditional_losses_450044) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_activation_layer_call_and_return_conditional_losses_418250) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_activation_layer_call_and_return_conditional_losses_453991) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_se_reduce_layer_call_and_return_conditional_losses_453285) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_expand_activation_layer_call_and_return_conditional_losses_416971) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_top_activation_layer_call_and_return_conditional_losses_455683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_se_reduce_layer_call_and_return_conditional_losses_415851) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_expand_activation_layer_call_and_return_conditional_losses_453166) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_top_activation_layer_call_and_return_conditional_losses_420413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_activation_layer_call_and_return_conditional_losses_450123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_se_reduce_layer_call_and_return_conditional_losses_417075) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_expand_activation_layer_call_and_return_conditional_losses_452840) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_activation_layer_call_and_return_conditional_losses_417307) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_activation_layer_call_and_return_conditional_losses_455063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_expand_activation_layer_call_and_return_conditional_losses_419802) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_activation_layer_call_and_return_conditional_losses_419858) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_se_reduce_layer_call_and_return_conditional_losses_452959) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_expand_activation_layer_call_and_return_conditional_losses_451069) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_450370) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_expand_activation_layer_call_and_return_conditional_losses_419138) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_activation_layer_call_and_return_conditional_losses_419194) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_se_reduce_layer_call_and_return_conditional_losses_419573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_expand_activation_layer_call_and_return_conditional_losses_420134) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_activation_layer_call_and_return_conditional_losses_417028) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_expand_activation_layer_call_and_return_conditional_losses_454611) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_expand_activation_layer_call_and_return_conditional_losses_416639) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_se_reduce_layer_call_and_return_conditional_losses_417686) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).layer-1.layer-0._random_generator._generator._state_var
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).layer-1.layer-0._random_generator._generator._state_var
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).layer-1.layer-1._random_generator._generator._state_var
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).layer-1.layer-1._random_generator._generator._state_var
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).layer-1.layer-2._random_generator._generator._state_var
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).layer-1.layer-2._random_generator._generator._state_var
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).layer-1.layer-3._random_generator._generator._state_var
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).layer-1.layer-3._random_generator._generator._state_var
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).layer-1.layer-4._random_generator._generator._state_var
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).layer-1.layer-4._random_generator._generator._state_var
WARNING:absl:Importing a function (__inference_block4b_expand_activation_layer_call_and_return_conditional_losses_417251) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_se_reduce_layer_call_and_return_conditional_losses_455103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_se_reduce_layer_call_and_return_conditional_losses_450815) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_se_reduce_layer_call_and_return_conditional_losses_416130) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_activation_layer_call_and_return_conditional_losses_454317) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_se_reduce_layer_call_and_return_conditional_losses_418962) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_se_reduce_layer_call_and_return_conditional_losses_419241) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
results_downloaded_model = model.evaluate(test_data)
results_downloaded_model
790/790 [==============================] - 108s 134ms/step - loss: 1.8027 - accuracy: 0.6078
[1.8027207851409912, 0.6077623963356018]

Making predictions with our trained model

preds_probs = model.predict(test_data, verbose = 1)
790/790 [==============================] - 100s 125ms/step
preds_probs
array([[5.9541941e-02, 3.5742332e-06, 4.1376889e-02, ..., 1.4138899e-09,
        8.3530460e-05, 3.0897565e-03],
       [9.6401680e-01, 1.3753089e-09, 8.4779976e-04, ..., 5.4286684e-05,
        7.8363253e-12, 9.8467334e-10],
       [9.5925868e-01, 3.2534019e-05, 1.4867033e-03, ..., 7.1891884e-07,
        5.4398350e-07, 4.0276311e-05],
       ...,
       [1.5138583e-05, 4.0972975e-04, 8.0249712e-10, ..., 2.1742959e-05,
        1.0797195e-05, 5.3789973e-01],
       [5.9317499e-03, 4.9236189e-03, 9.9823205e-03, ..., 1.1989424e-04,
        1.6889933e-05, 4.5217723e-02],
       [3.1363364e-02, 7.5052544e-03, 4.2974975e-04, ..., 5.0347066e-04,
        5.2056580e-06, 6.9062799e-01]], dtype=float32)
len(preds_probs)
25250
preds_probs.shape
(25250, 101)

The predictions were done on 25250 images and the shape of our predictions is (25250,101)

preds_probs[:10]
array([[5.9541941e-02, 3.5742332e-06, 4.1376889e-02, ..., 1.4138899e-09,
        8.3530460e-05, 3.0897565e-03],
       [9.6401680e-01, 1.3753089e-09, 8.4779976e-04, ..., 5.4286684e-05,
        7.8363253e-12, 9.8467334e-10],
       [9.5925868e-01, 3.2534019e-05, 1.4867033e-03, ..., 7.1891884e-07,
        5.4398350e-07, 4.0276311e-05],
       ...,
       [4.7313324e-01, 1.2931301e-07, 1.4805583e-03, ..., 5.9749611e-04,
        6.6969820e-05, 2.3469329e-05],
       [4.4571780e-02, 4.7265351e-07, 1.2258515e-01, ..., 6.3498578e-06,
        7.5319103e-06, 3.6778715e-03],
       [7.2438985e-01, 1.9249777e-09, 5.2310857e-05, ..., 1.2291447e-03,
        1.5793171e-09, 9.6395503e-05]], dtype=float32)
preds_probs[0], len(preds_probs[0]), sum(preds_probs[0])
(array([5.9541941e-02, 3.5742332e-06, 4.1376889e-02, 1.0660903e-09,
        8.1613996e-09, 8.6639682e-09, 8.0926134e-07, 8.5652442e-07,
        1.9858850e-05, 8.0977554e-07, 3.1727692e-09, 9.8673388e-07,
        2.8532100e-04, 7.8049661e-10, 7.4230990e-04, 3.8915794e-05,
        6.4740016e-06, 2.4977169e-06, 3.7891397e-05, 2.0678806e-07,
        1.5538471e-05, 8.1506892e-07, 2.6230925e-06, 2.0010653e-07,
        8.3827712e-07, 5.4215743e-06, 3.7391112e-06, 1.3150788e-08,
        2.7761345e-03, 2.8051816e-05, 6.8561651e-10, 2.5574524e-05,
        1.6688934e-04, 7.6409645e-10, 4.0452869e-04, 1.3150487e-08,
        1.7957433e-06, 1.4448400e-06, 2.3062853e-02, 8.2465459e-07,
        8.5366531e-07, 1.7138503e-06, 7.0526130e-06, 1.8402382e-08,
        2.8553984e-07, 7.9482870e-06, 2.0682012e-06, 1.8525193e-07,
        3.3619781e-08, 3.1522335e-04, 1.0410886e-05, 8.5448306e-07,
        8.4741890e-01, 1.0555387e-05, 4.4094719e-07, 3.7404192e-05,
        3.5306137e-05, 3.2489079e-05, 6.7313988e-05, 1.2852399e-08,
        2.6220215e-10, 1.0318094e-05, 8.5742751e-05, 1.0569768e-06,
        2.1293156e-06, 3.7636986e-05, 7.5972878e-08, 2.5340833e-04,
        9.2906589e-07, 1.2598188e-04, 6.2621680e-06, 1.2458612e-08,
        4.0519622e-05, 6.8728390e-08, 1.2546213e-06, 5.2887103e-08,
        7.5424801e-08, 7.5397300e-05, 7.7540310e-05, 6.4025420e-07,
        9.9033900e-07, 2.2225931e-05, 1.5013910e-05, 1.4038655e-07,
        1.2232513e-05, 1.9044673e-02, 5.0000424e-05, 4.6225618e-06,
        1.5388186e-07, 3.3824463e-07, 3.9227444e-09, 1.6563394e-07,
        8.1322025e-05, 4.8964989e-06, 2.4068495e-07, 2.3124319e-05,
        3.1040644e-04, 3.1380074e-05, 1.4138899e-09, 8.3530460e-05,
        3.0897565e-03], dtype=float32), 101, 1.0000000616546507)

All of the prediction probabilites(array of with N number of variables, where N is the number of classes) for any prediction ideally should sum up to 1. But the value we got here is 1.0000000616546507 which is close to one, it's not because there is something wrong with the model, it is because of the way computers store numbers in memory. To understand better look into the precision of various datatypes.

print(f"Number of prediction probabilites for sample 0: {len(preds_probs)}")
print(f"What prediction probability sample 0 looks like: \n {preds_probs[0]}")
print(f"The class with the  highest predicted probability by the model for sample 0: {preds_probs[0].argmax()}")
Number of prediction probabilites for sample 0: 25250
What prediction probability sample 0 looks like: 
 [5.9541941e-02 3.5742332e-06 4.1376889e-02 1.0660903e-09 8.1613996e-09
 8.6639682e-09 8.0926134e-07 8.5652442e-07 1.9858850e-05 8.0977554e-07
 3.1727692e-09 9.8673388e-07 2.8532100e-04 7.8049661e-10 7.4230990e-04
 3.8915794e-05 6.4740016e-06 2.4977169e-06 3.7891397e-05 2.0678806e-07
 1.5538471e-05 8.1506892e-07 2.6230925e-06 2.0010653e-07 8.3827712e-07
 5.4215743e-06 3.7391112e-06 1.3150788e-08 2.7761345e-03 2.8051816e-05
 6.8561651e-10 2.5574524e-05 1.6688934e-04 7.6409645e-10 4.0452869e-04
 1.3150487e-08 1.7957433e-06 1.4448400e-06 2.3062853e-02 8.2465459e-07
 8.5366531e-07 1.7138503e-06 7.0526130e-06 1.8402382e-08 2.8553984e-07
 7.9482870e-06 2.0682012e-06 1.8525193e-07 3.3619781e-08 3.1522335e-04
 1.0410886e-05 8.5448306e-07 8.4741890e-01 1.0555387e-05 4.4094719e-07
 3.7404192e-05 3.5306137e-05 3.2489079e-05 6.7313988e-05 1.2852399e-08
 2.6220215e-10 1.0318094e-05 8.5742751e-05 1.0569768e-06 2.1293156e-06
 3.7636986e-05 7.5972878e-08 2.5340833e-04 9.2906589e-07 1.2598188e-04
 6.2621680e-06 1.2458612e-08 4.0519622e-05 6.8728390e-08 1.2546213e-06
 5.2887103e-08 7.5424801e-08 7.5397300e-05 7.7540310e-05 6.4025420e-07
 9.9033900e-07 2.2225931e-05 1.5013910e-05 1.4038655e-07 1.2232513e-05
 1.9044673e-02 5.0000424e-05 4.6225618e-06 1.5388186e-07 3.3824463e-07
 3.9227444e-09 1.6563394e-07 8.1322025e-05 4.8964989e-06 2.4068495e-07
 2.3124319e-05 3.1040644e-04 3.1380074e-05 1.4138899e-09 8.3530460e-05
 3.0897565e-03]
The class with the  highest predicted probability by the model for sample 0: 52
test_data.class_names[52]
'gyoza'
pred_classes = preds_probs.argmax(axis =1)

pred_classes[:10]
array([52,  0,  0, 80, 79, 61, 29,  0, 85,  0])
len(pred_classes)
25250

Now we've got a predictions array of all our model's predictions, to evaluate them, we need to compare them to the ground truth labels.

y_labels = []
for images, labels in test_data.unbatch():
  y_labels.append(labels.numpy().argmax()) # currently test labels look like : [0,0,0,1,0...] we want the index value
y_labels[:10] # look at the first 10
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
len(y_labels)
25250

Evluating our model's predictions

one way to check that our model's predictions array is in the same order as our test labels array is to find the accuracy score.

results_downloaded_model
[1.8027207851409912, 0.6077623963356018]
from sklearn.metrics import accuracy_score
sklearn_accuracy = accuracy_score(y_true = y_labels,
                                  y_pred = pred_classes)
sklearn_accuracy
0.6077623762376237
import numpy as np
np.isclose(results_downloaded_model[1], sklearn_accuracy)
True

Making a Confusion Matrix

from helper_functions import make_confusion_matrix
class_names = test_data.class_names
class_names[:10]
['apple_pie',
 'baby_back_ribs',
 'baklava',
 'beef_carpaccio',
 'beef_tartare',
 'beet_salad',
 'beignets',
 'bibimbap',
 'bread_pudding',
 'breakfast_burrito']
# plot_confusion_matrix function - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html
import itertools
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix

# Our function needs a different name to sklearn's plot_confusion_matrix
def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False): 
  """Makes a labelled confusion matrix comparing predictions and ground truth labels.

  If classes is passed, confusion matrix will be labelled, if not, integer class values
  will be used.

  Args:
    y_true: Array of truth labels (must be same shape as y_pred).
    y_pred: Array of predicted labels (must be same shape as y_true).
    classes: Array of class labels (e.g. string form). If `None`, integer labels are used.
    figsize: Size of output figure (default=(10, 10)).
    text_size: Size of output figure text (default=15).
    norm: normalize values or not (default=False).
    savefig: save confusion matrix to file (default=False).
  
  Returns:
    A labelled confusion matrix plot comparing y_true and y_pred.

  Example usage:
    make_confusion_matrix(y_true=test_labels, # ground truth test labels
                          y_pred=y_preds, # predicted labels
                          classes=class_names, # array of class label names
                          figsize=(15, 15),
                          text_size=10)
  """  
  # Create the confustion matrix
  cm = confusion_matrix(y_true, y_pred)
  cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # normalize it
  n_classes = cm.shape[0] # find the number of classes we're dealing with

  # Plot the figure and make it pretty
  fig, ax = plt.subplots(figsize=figsize)
  cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better
  fig.colorbar(cax)

  # Are there a list of classes?
  if classes:
    labels = classes
  else:
    labels = np.arange(cm.shape[0])
  
  # Label the axes
  ax.set(title="Confusion Matrix",
         xlabel="Predicted label",
         ylabel="True label",
         xticks=np.arange(n_classes), # create enough axis slots for each class
         yticks=np.arange(n_classes), 
         xticklabels=labels, # axes will labeled with class names (if they exist) or ints
         yticklabels=labels)
  
  # Make x-axis labels appear on bottom
  ax.xaxis.set_label_position("bottom")
  ax.xaxis.tick_bottom()

  ### Added: Rotate xticks for readability & increase font size (required due to such a large confusion matrix)
  plt.xticks(rotation=70, fontsize=text_size)
  plt.yticks(fontsize=text_size)

  # Set the threshold for different colors
  threshold = (cm.max() + cm.min()) / 2.

  # Plot the text on each cell
  for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    if norm:
      plt.text(j, i, f"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)",
              horizontalalignment="center",
              color="white" if cm[i, j] > threshold else "black",
              size=text_size)
    else:
      plt.text(j, i, f"{cm[i, j]}",
              horizontalalignment="center",
              color="white" if cm[i, j] > threshold else "black",
              size=text_size)

  # Save the figure to the current working directory
  if savefig:
    fig.savefig("confusion_matrix.png")
make_confusion_matrix(y_true = y_labels,
                      y_pred = pred_classes,
                      classes = class_names,
                      figsize= (100,100),
                      text_size = 20,
                      savefig = True)

Making a Classification Report

Sckit_learn has a helpful function for acquiring many different classification metrics per class (e.g. precision, recall and F1 score) called classification_report.

# Make a classification report
from sklearn.metrics import classification_report
print(classification_report(y_true = y_labels,
                            y_pred = pred_classes))

              precision    recall  f1-score   support

           0       0.29      0.20      0.24       250
           1       0.51      0.69      0.59       250
           2       0.56      0.65      0.60       250
           3       0.74      0.53      0.62       250
           4       0.73      0.43      0.54       250
           5       0.34      0.54      0.42       250
           6       0.67      0.79      0.72       250
           7       0.82      0.76      0.79       250
           8       0.40      0.37      0.39       250
           9       0.62      0.44      0.51       250
          10       0.62      0.42      0.50       250
          11       0.84      0.49      0.62       250
          12       0.52      0.74      0.61       250
          13       0.56      0.60      0.58       250
          14       0.56      0.59      0.57       250
          15       0.44      0.32      0.37       250
          16       0.45      0.75      0.57       250
          17       0.37      0.51      0.43       250
          18       0.43      0.60      0.50       250
          19       0.68      0.60      0.64       250
          20       0.68      0.75      0.71       250
          21       0.35      0.64      0.45       250
          22       0.30      0.37      0.33       250
          23       0.66      0.77      0.71       250
          24       0.83      0.72      0.77       250
          25       0.76      0.71      0.73       250
          26       0.51      0.42      0.46       250
          27       0.78      0.72      0.75       250
          28       0.70      0.69      0.69       250
          29       0.70      0.68      0.69       250
          30       0.92      0.63      0.75       250
          31       0.78      0.70      0.74       250
          32       0.75      0.83      0.79       250
          33       0.89      0.98      0.94       250
          34       0.68      0.78      0.72       250
          35       0.78      0.66      0.72       250
          36       0.53      0.56      0.55       250
          37       0.30      0.55      0.39       250
          38       0.78      0.63      0.69       250
          39       0.27      0.33      0.30       250
          40       0.72      0.81      0.76       250
          41       0.81      0.62      0.70       250
          42       0.50      0.58      0.54       250
          43       0.75      0.60      0.67       250
          44       0.74      0.45      0.56       250
          45       0.77      0.85      0.81       250
          46       0.81      0.46      0.58       250
          47       0.44      0.49      0.46       250
          48       0.45      0.81      0.58       250
          49       0.50      0.44      0.47       250
          50       0.54      0.39      0.46       250
          51       0.71      0.86      0.78       250
          52       0.51      0.77      0.61       250
          53       0.67      0.68      0.68       250
          54       0.88      0.75      0.81       250
          55       0.86      0.69      0.76       250
          56       0.56      0.24      0.34       250
          57       0.62      0.45      0.52       250
          58       0.68      0.58      0.62       250
          59       0.70      0.37      0.49       250
          60       0.83      0.59      0.69       250
          61       0.54      0.81      0.65       250
          62       0.72      0.49      0.58       250
          63       0.94      0.86      0.90       250
          64       0.78      0.85      0.81       250
          65       0.82      0.82      0.82       250
          66       0.69      0.32      0.44       250
          67       0.41      0.58      0.48       250
          68       0.90      0.78      0.83       250
          69       0.84      0.82      0.83       250
          70       0.62      0.83      0.71       250
          71       0.81      0.46      0.59       250
          72       0.64      0.65      0.65       250
          73       0.51      0.44      0.47       250
          74       0.72      0.61      0.66       250
          75       0.84      0.90      0.87       250
          76       0.78      0.78      0.78       250
          77       0.36      0.27      0.31       250
          78       0.79      0.74      0.76       250
          79       0.44      0.81      0.57       250
          80       0.57      0.60      0.59       250
          81       0.65      0.70      0.68       250
          82       0.38      0.31      0.34       250
          83       0.58      0.80      0.67       250
          84       0.61      0.38      0.47       250
          85       0.44      0.74      0.55       250
          86       0.71      0.86      0.78       250
          87       0.41      0.39      0.40       250
          88       0.83      0.80      0.81       250
          89       0.71      0.31      0.43       250
          90       0.92      0.69      0.79       250
          91       0.83      0.87      0.85       250
          92       0.68      0.65      0.67       250
          93       0.31      0.38      0.34       250
          94       0.61      0.54      0.57       250
          95       0.74      0.61      0.67       250
          96       0.56      0.29      0.38       250
          97       0.45      0.74      0.56       250
          98       0.47      0.33      0.39       250
          99       0.52      0.27      0.35       250
         100       0.59      0.70      0.64       250

    accuracy                           0.61     25250
   macro avg       0.63      0.61      0.61     25250
weighted avg       0.63      0.61      0.61     25250

The numbers above give a great class-by-class evaluation of our model's predictions but with so many classes, they're quite hard to understand.

# Get a dictionary of the classification report
classification_report_dict = classification_report(y_labels, pred_classes, output_dict = True)
classification_report_dict

{'0': {'f1-score': 0.24056603773584903,
  'precision': 0.29310344827586204,
  'recall': 0.204,
  'support': 250},
 '1': {'f1-score': 0.5864406779661017,
  'precision': 0.5088235294117647,
  'recall': 0.692,
  'support': 250},
 '10': {'f1-score': 0.5047619047619047,
  'precision': 0.6235294117647059,
  'recall': 0.424,
  'support': 250},
 '100': {'f1-score': 0.641025641025641,
  'precision': 0.5912162162162162,
  'recall': 0.7,
  'support': 250},
 '11': {'f1-score': 0.6161616161616161,
  'precision': 0.8356164383561644,
  'recall': 0.488,
  'support': 250},
 '12': {'f1-score': 0.6105610561056106,
  'precision': 0.5196629213483146,
  'recall': 0.74,
  'support': 250},
 '13': {'f1-score': 0.5775193798449612,
  'precision': 0.5601503759398496,
  'recall': 0.596,
  'support': 250},
 '14': {'f1-score': 0.574757281553398,
  'precision': 0.5584905660377358,
  'recall': 0.592,
  'support': 250},
 '15': {'f1-score': 0.36744186046511623,
  'precision': 0.4388888888888889,
  'recall': 0.316,
  'support': 250},
 '16': {'f1-score': 0.5654135338345864,
  'precision': 0.4530120481927711,
  'recall': 0.752,
  'support': 250},
 '17': {'f1-score': 0.42546063651591287,
  'precision': 0.3659942363112392,
  'recall': 0.508,
  'support': 250},
 '18': {'f1-score': 0.5008403361344538,
  'precision': 0.4318840579710145,
  'recall': 0.596,
  'support': 250},
 '19': {'f1-score': 0.6411889596602972,
  'precision': 0.6832579185520362,
  'recall': 0.604,
  'support': 250},
 '2': {'f1-score': 0.6022304832713754,
  'precision': 0.5625,
  'recall': 0.648,
  'support': 250},
 '20': {'f1-score': 0.7123809523809523,
  'precision': 0.68,
  'recall': 0.748,
  'support': 250},
 '21': {'f1-score': 0.45261669024045265,
  'precision': 0.350109409190372,
  'recall': 0.64,
  'support': 250},
 '22': {'f1-score': 0.3291592128801431,
  'precision': 0.2977346278317152,
  'recall': 0.368,
  'support': 250},
 '23': {'f1-score': 0.7134935304990757,
  'precision': 0.6632302405498282,
  'recall': 0.772,
  'support': 250},
 '24': {'f1-score': 0.7708779443254817,
  'precision': 0.8294930875576036,
  'recall': 0.72,
  'support': 250},
 '25': {'f1-score': 0.734020618556701,
  'precision': 0.7574468085106383,
  'recall': 0.712,
  'support': 250},
 '26': {'f1-score': 0.4625550660792952,
  'precision': 0.5147058823529411,
  'recall': 0.42,
  'support': 250},
 '27': {'f1-score': 0.7494824016563146,
  'precision': 0.776824034334764,
  'recall': 0.724,
  'support': 250},
 '28': {'f1-score': 0.6935483870967742,
  'precision': 0.6991869918699187,
  'recall': 0.688,
  'support': 250},
 '29': {'f1-score': 0.6910569105691057,
  'precision': 0.7024793388429752,
  'recall': 0.68,
  'support': 250},
 '3': {'f1-score': 0.616822429906542,
  'precision': 0.7415730337078652,
  'recall': 0.528,
  'support': 250},
 '30': {'f1-score': 0.7476190476190476,
  'precision': 0.9235294117647059,
  'recall': 0.628,
  'support': 250},
 '31': {'f1-score': 0.7357293868921776,
  'precision': 0.7802690582959642,
  'recall': 0.696,
  'support': 250},
 '32': {'f1-score': 0.7855787476280836,
  'precision': 0.7472924187725631,
  'recall': 0.828,
  'support': 250},
 '33': {'f1-score': 0.9371428571428572,
  'precision': 0.8945454545454545,
  'recall': 0.984,
  'support': 250},
 '34': {'f1-score': 0.7238805970149255,
  'precision': 0.6783216783216783,
  'recall': 0.776,
  'support': 250},
 '35': {'f1-score': 0.715835140997831,
  'precision': 0.7819905213270142,
  'recall': 0.66,
  'support': 250},
 '36': {'f1-score': 0.5475728155339805,
  'precision': 0.5320754716981132,
  'recall': 0.564,
  'support': 250},
 '37': {'f1-score': 0.3870056497175141,
  'precision': 0.29912663755458513,
  'recall': 0.548,
  'support': 250},
 '38': {'f1-score': 0.6946902654867257,
  'precision': 0.7772277227722773,
  'recall': 0.628,
  'support': 250},
 '39': {'f1-score': 0.29749103942652333,
  'precision': 0.2694805194805195,
  'recall': 0.332,
  'support': 250},
 '4': {'f1-score': 0.544080604534005,
  'precision': 0.7346938775510204,
  'recall': 0.432,
  'support': 250},
 '40': {'f1-score': 0.7622641509433963,
  'precision': 0.7214285714285714,
  'recall': 0.808,
  'support': 250},
 '41': {'f1-score': 0.7029478458049886,
  'precision': 0.8115183246073299,
  'recall': 0.62,
  'support': 250},
 '42': {'f1-score': 0.537037037037037,
  'precision': 0.5,
  'recall': 0.58,
  'support': 250},
 '43': {'f1-score': 0.6651884700665188,
  'precision': 0.746268656716418,
  'recall': 0.6,
  'support': 250},
 '44': {'f1-score': 0.5586034912718205,
  'precision': 0.7417218543046358,
  'recall': 0.448,
  'support': 250},
 '45': {'f1-score': 0.8114285714285714,
  'precision': 0.7745454545454545,
  'recall': 0.852,
  'support': 250},
 '46': {'f1-score': 0.5831202046035805,
  'precision': 0.8085106382978723,
  'recall': 0.456,
  'support': 250},
 '47': {'f1-score': 0.4641509433962264,
  'precision': 0.4392857142857143,
  'recall': 0.492,
  'support': 250},
 '48': {'f1-score': 0.577524893314367,
  'precision': 0.4481236203090508,
  'recall': 0.812,
  'support': 250},
 '49': {'f1-score': 0.47234042553191485,
  'precision': 0.5045454545454545,
  'recall': 0.444,
  'support': 250},
 '5': {'f1-score': 0.41860465116279066,
  'precision': 0.34177215189873417,
  'recall': 0.54,
  'support': 250},
 '50': {'f1-score': 0.45581395348837206,
  'precision': 0.5444444444444444,
  'recall': 0.392,
  'support': 250},
 '51': {'f1-score': 0.7783783783783783,
  'precision': 0.7081967213114754,
  'recall': 0.864,
  'support': 250},
 '52': {'f1-score': 0.6124401913875598,
  'precision': 0.5092838196286472,
  'recall': 0.768,
  'support': 250},
 '53': {'f1-score': 0.6759443339960238,
  'precision': 0.6719367588932806,
  'recall': 0.68,
  'support': 250},
 '54': {'f1-score': 0.8103448275862069,
  'precision': 0.8785046728971962,
  'recall': 0.752,
  'support': 250},
 '55': {'f1-score': 0.7644444444444444,
  'precision': 0.86,
  'recall': 0.688,
  'support': 250},
 '56': {'f1-score': 0.3398328690807799,
  'precision': 0.5596330275229358,
  'recall': 0.244,
  'support': 250},
 '57': {'f1-score': 0.5209302325581396,
  'precision': 0.6222222222222222,
  'recall': 0.448,
  'support': 250},
 '58': {'f1-score': 0.6233766233766233,
  'precision': 0.6792452830188679,
  'recall': 0.576,
  'support': 250},
 '59': {'f1-score': 0.486910994764398,
  'precision': 0.7045454545454546,
  'recall': 0.372,
  'support': 250},
 '6': {'f1-score': 0.7229357798165138,
  'precision': 0.6677966101694915,
  'recall': 0.788,
  'support': 250},
 '60': {'f1-score': 0.6885245901639344,
  'precision': 0.8305084745762712,
  'recall': 0.588,
  'support': 250},
 '61': {'f1-score': 0.6495176848874598,
  'precision': 0.543010752688172,
  'recall': 0.808,
  'support': 250},
 '62': {'f1-score': 0.5823389021479712,
  'precision': 0.7218934911242604,
  'recall': 0.488,
  'support': 250},
 '63': {'f1-score': 0.895397489539749,
  'precision': 0.9385964912280702,
  'recall': 0.856,
  'support': 250},
 '64': {'f1-score': 0.8129770992366412,
  'precision': 0.7773722627737226,
  'recall': 0.852,
  'support': 250},
 '65': {'f1-score': 0.82, 'precision': 0.82, 'recall': 0.82, 'support': 250},
 '66': {'f1-score': 0.44141689373297005,
  'precision': 0.6923076923076923,
  'recall': 0.324,
  'support': 250},
 '67': {'f1-score': 0.47840531561461797,
  'precision': 0.4090909090909091,
  'recall': 0.576,
  'support': 250},
 '68': {'f1-score': 0.832618025751073,
  'precision': 0.8981481481481481,
  'recall': 0.776,
  'support': 250},
 '69': {'f1-score': 0.8340080971659919,
  'precision': 0.8442622950819673,
  'recall': 0.824,
  'support': 250},
 '7': {'f1-score': 0.7908902691511386,
  'precision': 0.8197424892703863,
  'recall': 0.764,
  'support': 250},
 '70': {'f1-score': 0.7101200686106347,
  'precision': 0.6216216216216216,
  'recall': 0.828,
  'support': 250},
 '71': {'f1-score': 0.5903307888040712,
  'precision': 0.8111888111888111,
  'recall': 0.464,
  'support': 250},
 '72': {'f1-score': 0.6468253968253969,
  'precision': 0.6417322834645669,
  'recall': 0.652,
  'support': 250},
 '73': {'f1-score': 0.4743589743589744,
  'precision': 0.5091743119266054,
  'recall': 0.444,
  'support': 250},
 '74': {'f1-score': 0.658008658008658,
  'precision': 0.7169811320754716,
  'recall': 0.608,
  'support': 250},
 '75': {'f1-score': 0.8665377176015473,
  'precision': 0.8389513108614233,
  'recall': 0.896,
  'support': 250},
 '76': {'f1-score': 0.7808764940239045,
  'precision': 0.7777777777777778,
  'recall': 0.784,
  'support': 250},
 '77': {'f1-score': 0.30875576036866365,
  'precision': 0.3641304347826087,
  'recall': 0.268,
  'support': 250},
 '78': {'f1-score': 0.7603305785123966,
  'precision': 0.7863247863247863,
  'recall': 0.736,
  'support': 250},
 '79': {'f1-score': 0.571830985915493,
  'precision': 0.44130434782608696,
  'recall': 0.812,
  'support': 250},
 '8': {'f1-score': 0.3866943866943867,
  'precision': 0.4025974025974026,
  'recall': 0.372,
  'support': 250},
 '80': {'f1-score': 0.5870841487279843,
  'precision': 0.5747126436781609,
  'recall': 0.6,
  'support': 250},
 '81': {'f1-score': 0.6756756756756757,
  'precision': 0.6529850746268657,
  'recall': 0.7,
  'support': 250},
 '82': {'f1-score': 0.34285714285714286,
  'precision': 0.3804878048780488,
  'recall': 0.312,
  'support': 250},
 '83': {'f1-score': 0.6711409395973154,
  'precision': 0.5780346820809249,
  'recall': 0.8,
  'support': 250},
 '84': {'f1-score': 0.4653465346534653,
  'precision': 0.6103896103896104,
  'recall': 0.376,
  'support': 250},
 '85': {'f1-score': 0.5525525525525525,
  'precision': 0.4423076923076923,
  'recall': 0.736,
  'support': 250},
 '86': {'f1-score': 0.7783783783783783,
  'precision': 0.7081967213114754,
  'recall': 0.864,
  'support': 250},
 '87': {'f1-score': 0.3975409836065574,
  'precision': 0.40756302521008403,
  'recall': 0.388,
  'support': 250},
 '88': {'f1-score': 0.8130081300813008,
  'precision': 0.8264462809917356,
  'recall': 0.8,
  'support': 250},
 '89': {'f1-score': 0.4301675977653631,
  'precision': 0.7129629629629629,
  'recall': 0.308,
  'support': 250},
 '9': {'f1-score': 0.5117370892018779,
  'precision': 0.6193181818181818,
  'recall': 0.436,
  'support': 250},
 '90': {'f1-score': 0.7881548974943051,
  'precision': 0.9153439153439153,
  'recall': 0.692,
  'support': 250},
 '91': {'f1-score': 0.84765625,
  'precision': 0.8282442748091603,
  'recall': 0.868,
  'support': 250},
 '92': {'f1-score': 0.6652977412731006,
  'precision': 0.6835443037974683,
  'recall': 0.648,
  'support': 250},
 '93': {'f1-score': 0.34234234234234234,
  'precision': 0.3114754098360656,
  'recall': 0.38,
  'support': 250},
 '94': {'f1-score': 0.5714285714285714,
  'precision': 0.6118721461187214,
  'recall': 0.536,
  'support': 250},
 '95': {'f1-score': 0.6710526315789473,
  'precision': 0.7427184466019418,
  'recall': 0.612,
  'support': 250},
 '96': {'f1-score': 0.3809523809523809,
  'precision': 0.5625,
  'recall': 0.288,
  'support': 250},
 '97': {'f1-score': 0.5644916540212443,
  'precision': 0.4547677261613692,
  'recall': 0.744,
  'support': 250},
 '98': {'f1-score': 0.3858823529411765,
  'precision': 0.4685714285714286,
  'recall': 0.328,
  'support': 250},
 '99': {'f1-score': 0.35356200527704484,
  'precision': 0.5193798449612403,
  'recall': 0.268,
  'support': 250},
 'accuracy': 0.6077623762376237,
 'macro avg': {'f1-score': 0.6061252197245781,
  'precision': 0.6328666845830312,
  'recall': 0.6077623762376237,
  'support': 25250},
 'weighted avg': {'f1-score': 0.606125219724578,
  'precision': 0.6328666845830311,
  'recall': 0.6077623762376237,
  'support': 25250}}

Let's plot all our classes F1-Score

classification_report_dict["99"]["f1-score"]
0.35356200527704484
# Create empty dictionary
class_f1_scores = {}
# Loop through classification report dictionary items
for k, v in classification_report_dict.items():
  if k == "accuracy": # Stop once we get the accuracy key
     break
  else:
    #Add class names and f1-scores to new dictionary
    class_f1_scores[class_names[int(k)]] = v["f1-score"]
class_f1_scores

{'apple_pie': 0.24056603773584903,
 'baby_back_ribs': 0.5864406779661017,
 'baklava': 0.6022304832713754,
 'beef_carpaccio': 0.616822429906542,
 'beef_tartare': 0.544080604534005,
 'beet_salad': 0.41860465116279066,
 'beignets': 0.7229357798165138,
 'bibimbap': 0.7908902691511386,
 'bread_pudding': 0.3866943866943867,
 'breakfast_burrito': 0.5117370892018779,
 'bruschetta': 0.5047619047619047,
 'caesar_salad': 0.6161616161616161,
 'cannoli': 0.6105610561056106,
 'caprese_salad': 0.5775193798449612,
 'carrot_cake': 0.574757281553398,
 'ceviche': 0.36744186046511623,
 'cheese_plate': 0.5654135338345864,
 'cheesecake': 0.42546063651591287,
 'chicken_curry': 0.5008403361344538,
 'chicken_quesadilla': 0.6411889596602972,
 'chicken_wings': 0.7123809523809523,
 'chocolate_cake': 0.45261669024045265,
 'chocolate_mousse': 0.3291592128801431,
 'churros': 0.7134935304990757,
 'clam_chowder': 0.7708779443254817,
 'club_sandwich': 0.734020618556701,
 'crab_cakes': 0.4625550660792952,
 'creme_brulee': 0.7494824016563146,
 'croque_madame': 0.6935483870967742,
 'cup_cakes': 0.6910569105691057,
 'deviled_eggs': 0.7476190476190476,
 'donuts': 0.7357293868921776,
 'dumplings': 0.7855787476280836,
 'edamame': 0.9371428571428572,
 'eggs_benedict': 0.7238805970149255,
 'escargots': 0.715835140997831,
 'falafel': 0.5475728155339805,
 'filet_mignon': 0.3870056497175141,
 'fish_and_chips': 0.6946902654867257,
 'foie_gras': 0.29749103942652333,
 'french_fries': 0.7622641509433963,
 'french_onion_soup': 0.7029478458049886,
 'french_toast': 0.537037037037037,
 'fried_calamari': 0.6651884700665188,
 'fried_rice': 0.5586034912718205,
 'frozen_yogurt': 0.8114285714285714,
 'garlic_bread': 0.5831202046035805,
 'gnocchi': 0.4641509433962264,
 'greek_salad': 0.577524893314367,
 'grilled_cheese_sandwich': 0.47234042553191485,
 'grilled_salmon': 0.45581395348837206,
 'guacamole': 0.7783783783783783,
 'gyoza': 0.6124401913875598,
 'hamburger': 0.6759443339960238,
 'hot_and_sour_soup': 0.8103448275862069,
 'hot_dog': 0.7644444444444444,
 'huevos_rancheros': 0.3398328690807799,
 'hummus': 0.5209302325581396,
 'ice_cream': 0.6233766233766233,
 'lasagna': 0.486910994764398,
 'lobster_bisque': 0.6885245901639344,
 'lobster_roll_sandwich': 0.6495176848874598,
 'macaroni_and_cheese': 0.5823389021479712,
 'macarons': 0.895397489539749,
 'miso_soup': 0.8129770992366412,
 'mussels': 0.82,
 'nachos': 0.44141689373297005,
 'omelette': 0.47840531561461797,
 'onion_rings': 0.832618025751073,
 'oysters': 0.8340080971659919,
 'pad_thai': 0.7101200686106347,
 'paella': 0.5903307888040712,
 'pancakes': 0.6468253968253969,
 'panna_cotta': 0.4743589743589744,
 'peking_duck': 0.658008658008658,
 'pho': 0.8665377176015473,
 'pizza': 0.7808764940239045,
 'pork_chop': 0.30875576036866365,
 'poutine': 0.7603305785123966,
 'prime_rib': 0.571830985915493,
 'pulled_pork_sandwich': 0.5870841487279843,
 'ramen': 0.6756756756756757,
 'ravioli': 0.34285714285714286,
 'red_velvet_cake': 0.6711409395973154,
 'risotto': 0.4653465346534653,
 'samosa': 0.5525525525525525,
 'sashimi': 0.7783783783783783,
 'scallops': 0.3975409836065574,
 'seaweed_salad': 0.8130081300813008,
 'shrimp_and_grits': 0.4301675977653631,
 'spaghetti_bolognese': 0.7881548974943051,
 'spaghetti_carbonara': 0.84765625,
 'spring_rolls': 0.6652977412731006,
 'steak': 0.34234234234234234,
 'strawberry_shortcake': 0.5714285714285714,
 'sushi': 0.6710526315789473,
 'tacos': 0.3809523809523809,
 'takoyaki': 0.5644916540212443,
 'tiramisu': 0.3858823529411765,
 'tuna_tartare': 0.35356200527704484,
 'waffles': 0.641025641025641}
# Turn f1-scores into dataframe for visualization
import pandas as pd
f1_scores = pd.DataFrame({"class_names": list(class_f1_scores.keys()),
                          "f1-score": list(class_f1_scores.values())}).sort_values("f1-score", ascending = False)
f1_scores
class_names f1-score
33 edamame 0.937143
63 macarons 0.895397
75 pho 0.866538
91 spaghetti_carbonara 0.847656
69 oysters 0.834008
... ... ...
56 huevos_rancheros 0.339833
22 chocolate_mousse 0.329159
77 pork_chop 0.308756
39 foie_gras 0.297491
0 apple_pie 0.240566

101 rows × 2 columns

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(12, 25))
scores = ax.barh(range(len(f1_scores)), f1_scores["f1-score"].values)
ax.set_yticks(range(len(f1_scores)))
ax.set_yticklabels(list(f1_scores["class_names"]))
ax.set_xlabel("f1-score")
ax.set_title("F1-Scores for 10 Different Classes")
ax.invert_yaxis(); # reverse the order

def autolabel(rects): # Modified version of: https://matplotlib.org/examples/api/barchart_demo.html
  """
  Attach a text label above each bar displaying its height (it's value).
  """
  for rect in rects:
    width = rect.get_width()
    ax.text(1.03*width, rect.get_y() + rect.get_height()/1.5,
            f"{width:.2f}",
            ha='center', va='bottom')

autolabel(scores)

Visualizing predictions on custom images

How does our model go on food images not even in our test dataset

To visualize our model's predictions on our images we'll need a function to load and preprocess images, specifically it will need to:

  • Read in a target image file path suing tf.io.read_file()
  • Turn the image into a Tensor using tf.io.decode_image()
  • Resize the image tensor to be the same size as the images our model has trained on using tf.image.resize()
  • Scale the image to get all of the pixel values between 0 & 1 (if necessary)
def load_and_prep_image(filename, img_shape = 224, scale=True):
  """
  
  Reads in an image from filename, turns it into a tensor and reshapes intoo
  specified shape (img_shape, color_channels = 3)

  Args: 
    filename (str) : path to target image
    image_shape (int) : height/width dimension of target image size
    scale(bool) : scale pixel values from 0-255 to 0-1 or not
  """
  # Read in the image
  img = tf.io.read_file(filename)

  # Decode the image into tensor
  img = tf.io.decode_image(img, channels = 3)

  # Resize the image
  img = tf.image.resize(img,[img_shape, img_shape])

  # Scale(yes/no)
  if scale:
    # rescale the image (get all values between 0 and 1)
    return img/255.
  else: 
    return img # Don't need to rescale images for EfficientNetB0

Now, we got function to load and prepare target images, let's now write some code to visualize images, their target lanel and our model's predictions.

Specifically, we'll write some code to:

  1. Load a few random images from the test dataset
  2. Make predictions on the loaded images
  3. Plot the original image(s) along with the model's predictions, prediction probabilty and ground truth labels
import os
import random

plt.figure(figsize=(17, 10))
for i in range(3):
  # Choose a random image from a random class 
  class_name = random.choice(class_names)
  filename = random.choice(os.listdir(test_dir + "/" + class_name))
  filepath = test_dir + class_name + "/" + filename

  # Load the image and make predictions
  img = load_and_prep_image(filepath, scale=False) # don't scale images for EfficientNet predictions
  pred_prob = model.predict(tf.expand_dims(img, axis=0)) # model accepts tensors of shape [None, 224, 224, 3]
  pred_class = class_names[pred_prob.argmax()] # find the predicted class 

  # Plot the image(s)
  plt.subplot(1, 3, i+1)
  plt.imshow(img/255.)
  if class_name == pred_class: # Change the color of text based on whether prediction is right or wrong
    title_color = "g"
  else:
    title_color = "r"
  plt.title(f"actual: {class_name}, pred: {pred_class}, prob: {pred_prob.max():.2f}", c=title_color)
  plt.axis(False);

Finding the most wrong predictions

  • A good way to inspect your model's performance is to view the wrong predictions with the highest prediction probability (or highest loss)
  • Can reveal insights such as:
    • Data issues (wrong labels)
    • Confusing classes(get better/more diverse data)

To find out where our model is most wrong, do the following:

  1. Get all of the image file paths in the test datasets using list_files() method
  2. Create a pandas DataFrame of the image filepaths, ground truth labels, predicted classes(from our model), max prediction probabilities, prediction_classnames and ground truth labels.
  3. Use our DataFrame based on wrong predictions (where the ground truth label doesn't match the prediction)
  4. Sort the DataFrame based on wrong predictions (have the highest prediction probability at the top)
  5. Visualize the images with the highest prediction probabilites but have the wrong prediction.
filepaths = []
for filepath in test_data.list_files("101_food_classes_10_percent/test/*/*.jpg", 
                                     shuffle=False):
  filepaths.append(filepath.numpy())
filepaths[:10]
[b'101_food_classes_10_percent/test/apple_pie/1011328.jpg',
 b'101_food_classes_10_percent/test/apple_pie/101251.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1034399.jpg',
 b'101_food_classes_10_percent/test/apple_pie/103801.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1038694.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1047447.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1068632.jpg',
 b'101_food_classes_10_percent/test/apple_pie/110043.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1106961.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1113017.jpg']
import pandas as pd
pred_df = pd.DataFrame({"img_path": filepaths,
                       "y_true": y_labels,
                       "y_pred": pred_classes,
                        "pred_conf": preds_probs.max(axis=1),
                        "y_true_classname": [class_names[i] for i in y_labels],
                        "y_pred_classname": [class_names[i] for i in pred_classes]}) # get the maximum prediction prob value

pred_df
img_path y_true y_pred pred_conf y_true_classname y_pred_classname
0 b'101_food_classes_10_percent/test/apple_pie/1... 0 52 0.847419 apple_pie gyoza
1 b'101_food_classes_10_percent/test/apple_pie/1... 0 0 0.964017 apple_pie apple_pie
2 b'101_food_classes_10_percent/test/apple_pie/1... 0 0 0.959259 apple_pie apple_pie
3 b'101_food_classes_10_percent/test/apple_pie/1... 0 80 0.658607 apple_pie pulled_pork_sandwich
4 b'101_food_classes_10_percent/test/apple_pie/1... 0 79 0.367902 apple_pie prime_rib
... ... ... ... ... ... ...
25245 b'101_food_classes_10_percent/test/waffles/942... 100 100 0.972823 waffles waffles
25246 b'101_food_classes_10_percent/test/waffles/954... 100 16 0.878027 waffles cheese_plate
25247 b'101_food_classes_10_percent/test/waffles/961... 100 100 0.537900 waffles waffles
25248 b'101_food_classes_10_percent/test/waffles/970... 100 94 0.501951 waffles strawberry_shortcake
25249 b'101_food_classes_10_percent/test/waffles/971... 100 100 0.690628 waffles waffles

25250 rows × 6 columns

pred_df["pred_correct"] = pred_df["y_true"] == pred_df["y_pred"]
pred_df.head()
img_path y_true y_pred pred_conf y_true_classname y_pred_classname pred_correct
0 b'101_food_classes_10_percent/test/apple_pie/1... 0 52 0.847419 apple_pie gyoza False
1 b'101_food_classes_10_percent/test/apple_pie/1... 0 0 0.964017 apple_pie apple_pie True
2 b'101_food_classes_10_percent/test/apple_pie/1... 0 0 0.959259 apple_pie apple_pie True
3 b'101_food_classes_10_percent/test/apple_pie/1... 0 80 0.658607 apple_pie pulled_pork_sandwich False
4 b'101_food_classes_10_percent/test/apple_pie/1... 0 79 0.367902 apple_pie prime_rib False
top_100_wrong = pred_df[pred_df["pred_correct"] == False].sort_values("pred_conf", ascending = False)[:100]
top_100_wrong
img_path y_true y_pred pred_conf y_true_classname y_pred_classname pred_correct
21810 b'101_food_classes_10_percent/test/scallops/17... 87 29 0.999997 scallops cup_cakes False
231 b'101_food_classes_10_percent/test/apple_pie/8... 0 100 0.999995 apple_pie waffles False
15359 b'101_food_classes_10_percent/test/lobster_rol... 61 53 0.999988 lobster_roll_sandwich hamburger False
23539 b'101_food_classes_10_percent/test/strawberry_... 94 83 0.999987 strawberry_shortcake red_velvet_cake False
21400 b'101_food_classes_10_percent/test/samosa/3140... 85 92 0.999981 samosa spring_rolls False
... ... ... ... ... ... ... ...
8763 b'101_food_classes_10_percent/test/escargots/1... 35 41 0.997169 escargots french_onion_soup False
2663 b'101_food_classes_10_percent/test/bruschetta/... 10 61 0.997055 bruschetta lobster_roll_sandwich False
7924 b'101_food_classes_10_percent/test/donuts/3454... 31 29 0.997020 donuts cup_cakes False
18586 b'101_food_classes_10_percent/test/peking_duck... 74 39 0.996885 peking_duck foie_gras False
3519 b'101_food_classes_10_percent/test/carrot_cake... 14 21 0.996842 carrot_cake chocolate_cake False

100 rows × 7 columns

images_to_view = 9
start_index = 0 # change the index to view more of the wrong predictions
plt.figure(figsize=(15,10))
for i, row in enumerate(top_100_wrong[start_index:start_index+9].itertuples()):
  plt.subplot(3,3,i+1)
  img = load_and_prep_image(row[1], scale = False)
  _, _, _, _, pred_prob, y_true_classname, y_pred_classname, _ = row # only interested in a few parameters of each row
  plt.imshow(img/255.)
  plt.title(f"actual: {y_true_classname}, pred: {y_pred_classname},\n prob: {pred_prob}")
  plt.axis(False)

These are the images where our model is predicting the image with high probability even when the prediction is not correct according to ground truth labels

Test our model on custom images

from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, compare_historys
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/custom_food_images.zip

unzip_data("custom_food_images.zip")
--2022-02-20 18:37:33--  https://storage.googleapis.com/ztm_tf_course/food_vision/custom_food_images.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.216.128, 173.194.217.128, 173.194.218.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.216.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13192985 (13M) [application/zip]
Saving to: ‘custom_food_images.zip.2’

custom_food_images. 100%[===================>]  12.58M  --.-KB/s    in 0.07s   

2022-02-20 18:37:33 (174 MB/s) - ‘custom_food_images.zip.2’ saved [13192985/13192985]

import os
custom_food_images = ["custom_food_images/" + img_path for img_path in os.listdir("custom_food_images")]
custom_food_images
['custom_food_images/chicken_wings.jpeg',
 'custom_food_images/hamburger.jpeg',
 'custom_food_images/steak.jpeg',
 'custom_food_images/ramen.jpeg',
 'custom_food_images/pizza-dad.jpeg',
 'custom_food_images/sushi.jpeg']
for img in custom_food_images:
  img = load_and_prep_image(img, scale=False) # load in target image and turn it into tensor
  pred_prob = model.predict(tf.expand_dims(img, axis=0)) # make prediction on image with shape [None, 224, 224, 3]
  pred_class = class_names[pred_prob.argmax()] # find the predicted class label
  # Plot the image with appropriate annotations
  plt.figure()
  plt.imshow(img/255.) # imshow() requires float inputs to be normalized
  plt.title(f"pred: {pred_class}, prob: {pred_prob.max():.2f}")
  plt.axis(False)
Back to top of page