Transfer Learning in TensorFlow : Fine Tuning

This Notebook is an account of my working for the Udemy course :TensorFlow Developer Certificate in 2022: Zero to Mastery.

Concepts covered in this Notebook:

  • Introduce fine-tuning transfer learning with TensorFlow.
  • Introduce the Keras Functional API to build models
  • Using a small dataset to experiment faster(e.g. 10% of training samples)
  • Data augmentation (making your training set more diverse without adding samples)
  • Running a series of experiments on our Food Vision data
  • Introduce the ModelCheckpoint callback to save intermediate training results.

Create helper functions

In previous notebooks, we have created some helper functions for reusing them while evaluating and visualizing the results of our models.

So, it's a good idea to put functions you'll want to use again in a script you can download and import into your notebooks.

Below we download a file that contains all the functions we have created to help us during the model training in previous notebooks.

!wget https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py
--2022-02-20 05:18:46--  https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10246 (10K) [text/plain]
Saving to: ‘helper_functions.py’

helper_functions.py 100%[===================>]  10.01K  --.-KB/s    in 0s      

2022-02-20 05:18:46 (90.7 MB/s) - ‘helper_functions.py’ saved [10246/10246]

from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, walk_through_dir

Note: If you are running this notebook in google colab, when it times out colab will delete helper_functions.py, so you'll have to redownload it if you want access to your helper functions.

Get the Data

Let's see how we cna use the pretrained models within tf.keras.applications and apply them to our own problem (eg. recognizing images of foods)

!wget https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_10_percent.zip

unzip_data("10_food_classes_10_percent.zip")
--2022-02-20 05:18:50--  https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_10_percent.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.120.128, 74.125.70.128, 74.125.69.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.120.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 168546183 (161M) [application/zip]
Saving to: ‘10_food_classes_10_percent.zip’

10_food_classes_10_ 100%[===================>] 160.74M   210MB/s    in 0.8s    

2022-02-20 05:18:50 (210 MB/s) - ‘10_food_classes_10_percent.zip’ saved [168546183/168546183]

walk_through_dir("10_food_classes_10_percent")
There are 2 directories and 0 images in '10_food_classes_10_percent'.
There are 10 directories and 0 images in '10_food_classes_10_percent/train'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/pizza'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/fried_rice'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/grilled_salmon'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/hamburger'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/steak'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/sushi'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/chicken_wings'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/ramen'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/ice_cream'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/chicken_curry'.
There are 10 directories and 0 images in '10_food_classes_10_percent/test'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/pizza'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/fried_rice'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/grilled_salmon'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/hamburger'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/steak'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/sushi'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/chicken_wings'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/ramen'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/ice_cream'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/chicken_curry'.
train_dir = "10_food_classes_10_percent/train"
test_dir = "10_food_classes_10_percent/test"
import tensorflow as tf
IMG_SIZE = (224,224)
BATCH_SIZE = 32
train_data_10_percent = tf.keras.preprocessing.image_dataset_from_directory(directory= train_dir,
                                                                            batch_size= BATCH_SIZE,
                                                                            image_size = IMG_SIZE,
                                                                            label_mode = "categorical")


test_data = tf.keras.preprocessing.image_dataset_from_directory(directory= test_dir,
                                                                batch_size = BATCH_SIZE,
                                                                image_size = IMG_SIZE,
                                                                label_mode = "categorical")
Found 750 files belonging to 10 classes.
Found 2500 files belonging to 10 classes.
train_data_10_percent
<BatchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))>
train_data_10_percent.class_names
['chicken_curry',
 'chicken_wings',
 'fried_rice',
 'grilled_salmon',
 'hamburger',
 'ice_cream',
 'pizza',
 'ramen',
 'steak',
 'sushi']
# see an example of a batch data
for images, labels in train_data_10_percent.take(1):
  print(images, labels)

tf.Tensor(
[[[[2.49285721e+02 2.37285721e+02 2.15285721e+02]
   [2.49642853e+02 2.37642853e+02 2.15642853e+02]
   [2.50719391e+02 2.38505112e+02 2.17147964e+02]
   ...
   [1.02734596e+02 5.24540863e+01 2.82450047e+01]
   [8.28316269e+01 4.35459671e+01 2.50919056e+01]
   [7.25765381e+01 3.89490776e+01 2.45051785e+01]]

  [[2.51331635e+02 2.40331635e+02 2.20331635e+02]
   [2.50928574e+02 2.39928574e+02 2.19928574e+02]
   [2.49801025e+02 2.38801025e+02 2.20688782e+02]
   ...
   [1.19739761e+02 5.73265800e+01 1.01123981e+01]
   [9.37603836e+01 4.14696159e+01 2.19920874e+00]
   [1.05264969e+02 5.65507469e+01 2.07395535e+01]]

  [[2.53000000e+02 2.41642853e+02 2.25857147e+02]
   [2.52056122e+02 2.40698975e+02 2.24913269e+02]
   [2.51000000e+02 2.39642853e+02 2.23857147e+02]
   ...
   [1.62510101e+02 8.69336243e+01 1.48163013e+01]
   [1.48438828e+02 7.55816879e+01 1.06531439e+01]
   [1.38101807e+02 6.69589539e+01 6.32120275e+00]]

  ...

  [[7.52704544e+01 1.87092133e+01 4.41836119e+00]
   [6.75867310e+01 1.34285698e+01 1.38775051e+00]
   [6.25663338e+01 1.29489927e+01 2.63772011e+00]
   ...
   [2.46550980e+02 2.32785736e+02 2.00831650e+02]
   [2.42341843e+02 2.28127579e+02 1.99127579e+02]
   [2.44923477e+02 2.30494949e+02 2.02137741e+02]]

  [[8.22144012e+01 2.20051937e+01 8.40826416e+00]
   [7.04388046e+01 1.41581659e+01 1.02039850e+00]
   [7.18009491e+01 1.86274719e+01 6.84175634e+00]
   ...
   [2.44714294e+02 2.29714294e+02 1.98428650e+02]
   [2.43061234e+02 2.29061234e+02 2.02066345e+02]
   [2.45596985e+02 2.31596985e+02 2.04668427e+02]]

  [[8.27750092e+01 2.07750072e+01 7.15269184e+00]
   [7.02039642e+01 1.21325331e+01 1.83662802e-01]
   [6.33011475e+01 9.57664299e+00 3.06152552e-01]
   ...
   [2.43709213e+02 2.27494949e+02 2.01775589e+02]
   [2.45714233e+02 2.29714233e+02 2.05520432e+02]
   [2.41714355e+02 2.26714355e+02 2.04428711e+02]]]


 [[[8.49846954e+01 7.61122513e+01 1.20969391e+02]
   [8.11377563e+01 6.66122437e+01 1.29566330e+02]
   [9.43112259e+01 7.86581650e+01 1.51076523e+02]
   ...
   [5.30764809e+01 4.94948196e+01 8.37856064e+01]
   [4.62142220e+01 3.90968246e+01 6.73110962e+01]
   [3.42702446e+01 2.16987476e+01 4.63416405e+01]]

  [[9.47040863e+01 8.37040863e+01 1.26346939e+02]
   [8.83673477e+01 7.36479645e+01 1.28005096e+02]
   [9.88724518e+01 8.07449036e+01 1.37811218e+02]
   ...
   [4.72702789e+01 5.57601547e+01 8.97294388e+01]
   [3.77091446e+01 4.37090912e+01 6.94335403e+01]
   [2.97856102e+01 2.77141457e+01 4.93110542e+01]]

  [[1.02204079e+02 8.90714264e+01 1.32500000e+02]
   [9.37959213e+01 8.39846954e+01 1.31510208e+02]
   [1.43352051e+02 1.32132660e+02 1.78704086e+02]
   ...
   [5.03825264e+01 6.33825264e+01 9.60253220e+01]
   [3.89591103e+01 4.64284668e+01 7.19743423e+01]
   [2.87141819e+01 2.90100098e+01 5.00100098e+01]]

  ...

  [[1.46147980e+02 1.51362274e+02 1.55647919e+02]
   [1.52000015e+02 1.57214310e+02 1.59928543e+02]
   [1.55025513e+02 1.62811234e+02 1.62836685e+02]
   ...
   [9.81480789e+01 4.58470612e+01 1.52449875e+01]
   [9.55765076e+01 4.28622742e+01 1.48315954e+01]
   [9.27905884e+01 3.90763588e+01 1.56579075e+01]]

  [[1.48928589e+02 1.54000000e+02 1.50142822e+02]
   [1.50994873e+02 1.56994873e+02 1.53137695e+02]
   [1.56173477e+02 1.63484680e+02 1.55785660e+02]
   ...
   [9.74947815e+01 5.72245941e+01 1.85102692e+01]
   [8.15609360e+01 4.09233055e+01 5.18866348e+00]
   [7.57805328e+01 3.60664253e+01 3.78078032e+00]]

  [[1.51056107e+02 1.50770401e+02 1.57341812e+02]
   [1.54668365e+02 1.55025513e+02 1.60882629e+02]
   [1.60076508e+02 1.62637741e+02 1.64989761e+02]
   ...
   [8.08877945e+01 4.93215256e+01 2.98368626e+01]
   [8.68162308e+01 5.76733513e+01 5.10049820e+01]
   [7.26017303e+01 4.30455780e+01 4.75455475e+01]]]


 [[[6.42857194e-01 2.64285707e+00 1.64285719e+00]
   [1.00000000e+00 3.00000000e+00 2.00000000e+00]
   [1.00000000e+00 3.00000000e+00 2.00000000e+00]
   ...
   [1.68571205e+01 1.96428566e+01 8.00006580e+00]
   [5.78567505e+00 1.07856750e+01 3.78567505e+00]
   [2.71432066e+00 8.94390011e+00 4.25516176e+00]]

  [[0.00000000e+00 2.00000000e+00 1.00000000e+00]
   [1.00000000e+00 3.00000000e+00 2.00000000e+00]
   [7.85714149e-01 2.78571415e+00 1.78571415e+00]
   ...
   [1.51734447e+01 1.79591808e+01 6.33169317e+00]
   [5.79079103e+00 1.07907915e+01 3.79079103e+00]
   [4.31126165e+00 1.13112621e+01 4.31126165e+00]]

  [[0.00000000e+00 2.00000000e+00 2.14285612e-01]
   [9.43877757e-01 2.94387770e+00 4.28571224e-01]
   [1.68367237e-01 2.16836715e+00 3.82652849e-01]
   ...
   [1.41937780e+01 1.69795132e+01 5.55098629e+00]
   [4.14284420e+00 9.14284420e+00 2.14284396e+00]
   [2.64800143e+00 9.64800167e+00 2.64800143e+00]]

  ...

  [[1.12581635e+02 8.12245178e+01 5.58163023e+00]
   [1.07801003e+02 7.64438858e+01 2.65814614e+00]
   [1.05714256e+02 7.33571396e+01 2.57141972e+00]
   ...
   [0.00000000e+00 3.21426392e+00 0.00000000e+00]
   [1.00000000e+00 3.00000000e+00 2.00000000e+00]
   [3.71435547e+00 2.71435547e+00 7.71435547e+00]]

  [[1.12357140e+02 8.33571396e+01 3.49996519e+00]
   [1.05423447e+02 7.64234467e+01 2.80595243e-01]
   [1.02428551e+02 7.24285507e+01 1.27519876e-01]
   ...
   [0.00000000e+00 3.21426392e+00 0.00000000e+00]
   [1.00000000e+00 3.00000000e+00 2.00000000e+00]
   [3.71435547e+00 2.71435547e+00 7.71435547e+00]]

  [[1.08714294e+02 8.07142944e+01 0.00000000e+00]
   [1.03000008e+02 7.40000076e+01 0.00000000e+00]
   [1.01989799e+02 7.19897995e+01 0.00000000e+00]
   ...
   [0.00000000e+00 3.21426392e+00 0.00000000e+00]
   [1.00000000e+00 3.00000000e+00 2.00000000e+00]
   [3.71435547e+00 2.71435547e+00 7.71435547e+00]]]


 ...


 [[[1.02911827e+02 7.52979889e+01 6.88616037e+00]
   [9.25312500e+01 6.86674118e+01 4.71651840e+00]
   [7.62087021e+01 6.07410698e+01 3.38839245e+00]
   ...
   [3.58181877e+01 3.62467155e+01 1.22467136e+01]
   [3.08817253e+01 2.62734623e+01 4.14288330e+00]
   [3.29284668e+01 2.79284668e+01 5.92846680e+00]]

  [[1.06181923e+02 8.02968750e+01 1.42611609e+01]
   [9.68281250e+01 7.60546875e+01 1.37142868e+01]
   [7.93515625e+01 6.38716469e+01 9.06026745e+00]
   ...
   [3.33047485e+01 3.37332764e+01 9.73327446e+00]
   [2.84118767e+01 2.62879715e+01 3.32927275e+00]
   [3.20847168e+01 2.70847168e+01 5.08471680e+00]]

  [[1.02643974e+02 8.01953125e+01 1.80424099e+01]
   [9.70066986e+01 7.76383972e+01 1.86551342e+01]
   [7.95089264e+01 6.45357132e+01 1.18359365e+01]
   ...
   [2.74207649e+01 2.78492928e+01 3.84929323e+00]
   [2.51049500e+01 2.31049500e+01 7.57842541e-01]
   [2.85679512e+01 2.35679512e+01 2.07023239e+00]]

  ...

  [[3.96082611e+01 3.87176361e+01 5.60825920e+00]
   [4.47142868e+01 4.38236618e+01 1.01629477e+01]
   [4.66618309e+01 4.53191948e+01 1.11729908e+01]
   ...
   [2.18277878e+02 2.17277878e+02 1.73277878e+02]
   [2.15878326e+02 2.14878326e+02 1.70878326e+02]
   [2.11285645e+02 2.10285645e+02 1.64285645e+02]]

  [[3.85725441e+01 3.49944191e+01 5.56361628e+00]
   [4.35993309e+01 4.05580368e+01 8.85714340e+00]
   [4.63058052e+01 4.38839302e+01 7.14955378e+00]
   ...
   [2.18428528e+02 2.17428528e+02 1.73428528e+02]
   [2.15887253e+02 2.14887253e+02 1.70887253e+02]
   [2.11285645e+02 2.10285645e+02 1.64285645e+02]]

  [[3.83069191e+01 3.43069191e+01 7.73995495e+00]
   [4.03303604e+01 3.72589302e+01 6.40178585e+00]
   [4.23895111e+01 4.03895111e+01 2.20535755e+00]
   ...
   [2.17385025e+02 2.16385025e+02 1.72385025e+02]
   [2.14997742e+02 2.13997742e+02 1.69997742e+02]
   [2.10426270e+02 2.09426270e+02 1.63426270e+02]]]


 [[[1.51994904e+02 1.03994904e+02 1.89948978e+01]
   [1.46545914e+02 1.00545921e+02 1.25459175e+01]
   [1.49387756e+02 1.03387756e+02 1.53877535e+01]
   ...
   [9.32192535e+01 9.10815125e+01 1.53774986e+01]
   [8.15866089e+01 8.33927460e+01 7.09683895e+00]
   [8.09031296e+01 8.39031296e+01 3.36228728e+00]]

  [[1.58816330e+02 1.12816322e+02 2.48163261e+01]
   [1.54499985e+02 1.08500000e+02 2.04999981e+01]
   [1.56403046e+02 1.10617348e+02 2.19744854e+01]
   ...
   [1.09525360e+02 1.06183571e+02 2.68978958e+01]
   [9.51120453e+01 9.35253448e+01 1.62549019e+01]
   [8.02756348e+01 8.30359039e+01 7.09245980e-01]]

  [[1.57801025e+02 1.11801018e+02 2.38010197e+01]
   [1.58214279e+02 1.12428574e+02 2.37857132e+01]
   [1.60255096e+02 1.15346939e+02 2.43010178e+01]
   ...
   [1.33377197e+02 1.28162933e+02 4.65201378e+01]
   [1.09887840e+02 1.07245010e+02 2.56582451e+01]
   [9.57136917e+01 9.65556107e+01 9.87191296e+00]]

  ...

  [[1.22418358e+02 9.69898300e+01 1.52040968e+01]
   [1.09907974e+02 8.49079742e+01 2.90797639e+00]
   [9.82447433e+01 7.32447433e+01 0.00000000e+00]
   ...
   [2.00117538e+02 1.66454239e+02 3.30716248e+01]
   [2.06484726e+02 1.73683685e+02 4.03980370e+01]
   [2.07433884e+02 1.74433884e+02 4.30053635e+01]]

  [[9.43879929e+01 6.93879929e+01 3.21487457e-01]
   [1.10525581e+02 8.55255814e+01 2.18883681e+00]
   [1.08959282e+02 8.46021347e+01 1.97966623e+00]
   ...
   [1.99678497e+02 1.64035706e+02 3.62499733e+01]
   [2.06418442e+02 1.71418442e+02 4.36940002e+01]
   [2.12540665e+02 1.78540665e+02 5.25406609e+01]]

  [[1.28990067e+02 1.03990059e+02 1.99900608e+01]
   [1.21633408e+02 9.66334076e+01 1.26334066e+01]
   [1.24398750e+02 1.00041611e+02 1.56130352e+01]
   ...
   [2.05040634e+02 1.69397842e+02 4.36121063e+01]
   [2.08188904e+02 1.73188904e+02 4.71889000e+01]
   [2.14341690e+02 1.80341690e+02 5.63416862e+01]]]


 [[[4.09438782e+01 4.09438782e+01 7.76581650e+01]
   [3.85000000e+01 3.74744873e+01 6.96938782e+01]
   [3.62908173e+01 3.56479568e+01 6.34336739e+01]
   ...
   [7.64949493e+01 2.91377411e+01 2.53520050e+01]
   [7.20255127e+01 2.40255146e+01 2.00255146e+01]
   [7.41990280e+01 2.61990261e+01 2.21990261e+01]]

  [[4.10459213e+01 4.16428604e+01 7.39030609e+01]
   [3.80051003e+01 3.70051003e+01 6.81530609e+01]
   [3.73163261e+01 3.68877525e+01 6.20612221e+01]
   ...
   [7.31428528e+01 2.57856483e+01 2.16019936e+01]
   [7.18571167e+01 2.38571167e+01 1.98571167e+01]
   [6.81428528e+01 2.01428566e+01 1.61428566e+01]]

  [[4.25000000e+01 4.22193871e+01 7.07857132e+01]
   [3.81020393e+01 3.76581612e+01 6.33877563e+01]
   [3.55918350e+01 3.36377563e+01 5.51173477e+01]
   ...
   [7.39081726e+01 2.63571644e+01 2.09540634e+01]
   [7.63417969e+01 2.73417950e+01 2.23417950e+01]
   [7.53675003e+01 2.63675022e+01 2.13675022e+01]]

  ...

  [[2.34913239e+02 2.30479507e+02 2.10418289e+02]
   [2.36357010e+02 2.30229477e+02 2.08969208e+02]
   [2.35724625e+02 2.24892929e+02 2.05321487e+02]
   ...
   [6.78313370e+01 3.74028130e+01 2.59742851e+01]
   [7.17549973e+01 4.08979454e+01 2.96836796e+01]
   [6.80202942e+01 3.55917664e+01 2.43775024e+01]]

  [[2.28367554e+02 2.17081787e+02 1.95816498e+02]
   [2.40433746e+02 2.28010208e+02 2.06153091e+02]
   [2.35801147e+02 2.17530670e+02 1.97887848e+02]
   ...
   [5.86427917e+01 3.14846630e+01 2.01581287e+01]
   [6.47856750e+01 3.77856750e+01 2.87856750e+01]
   [6.65967789e+01 3.95967751e+01 3.05967751e+01]]

  [[2.11456894e+02 1.92385361e+02 1.70742538e+02]
   [1.98146255e+02 1.75197174e+02 1.54528839e+02]
   [1.94248489e+02 1.65952484e+02 1.47523941e+02]
   ...
   [6.87808914e+01 4.08574219e+01 2.81993027e+01]
   [5.90919418e+01 3.10919399e+01 2.00919399e+01]
   [6.14847527e+01 3.24847527e+01 2.44847546e+01]]]], shape=(32, 224, 224, 3), dtype=float32) tf.Tensor(
[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(32, 10), dtype=float32)

Modelling experiments we are going to run:


Experiment Data Preprocessing Model
Model 0(baseline) 10 classes of Food101 data(random 10% training data only) None Feature Extractor: EfficientNetB0 (pre-trained on ImageNet, all layers frozen) with no top
Model 1 10 classes of Food101 data(random 1% training dta only) Random Flip, Rotation, Zoom, Height, Width datat augmentation Same as Model 0
Model 2 Same as Model 0 Same as Model 1 Same as Model 0
Model 3 Same as Model 0 Same as Model 1 Fine tuning: Model 2 (EfficientNetB0 pre-trained on ImageNet) with top layer trained on custom data, top 10 layers unfrozen
Model 4 10 classes of Food101 data(100% training data Same as Model 1 Same as Model 3

Keras Functional API:

# Creating a model with the Functional API
inputs = tf.keras.layers.Input(shape = (28,28))
x = tf.keras.layers.Flatten()(inputs)
x = tf.keras.layers.Dense(64, activation = "relu")(x)
x = tf.keras.layers.Dense(64, activation = "relu")(x)
outputs = tf.keras.layers.Dense(10, activation = "softmax")(x)
functional_model = tf.keras.Model(inputs,outputs,name = "functional model")

functional_model.compile(
    loss = tf.keras.losses.SparseCategoricalCrossentropy(),
    optimizer = tf.keras.optimizers.Adam(),
    metrics =["accuracy"]

)

functional_model.fit(X_train, y_train,
                     batch_size = 32,
                     epochs =5)

Model 0: Building a transfer learning model using the Keras Functional API

The Sequential API is straight-forward, it runs our layers in sequential order.

But the functional API give us more flexibility with our models.

base_model = tf.keras.applications.EfficientNetB0(include_top = False)

# 2. Freeze the base model(so the underlying pre-trained patterns aren't updated)
base_model.trainable = False

# 3. Create inputs into our model
inputs = tf.keras.layers.Input(shape=(224,224,3), name= "input_layer")

# 4. If using a model likeResNet50V2 you will need to normalize the inputs(you don't have to for EfficientNet(s))
# x = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)(inputs)

# 5. Pass the inputs to the base_model
x = base_model(inputs)
print(f"Shape after passing inputs through base model: {x.shape}")

# 6. Average pool the output of the base model(aggregate all the most important info., reduce no. of compuatations)
x = tf.keras.layers.GlobalAveragePooling2D(name = "global_average_pooling_layer")(x)
print(f"Shape after GlobalAveragePooling2D: {x.shape}")

# 7. Create the output activation layer
outputs = tf.keras.layers.Dense(10,activation = "softmax", name = "output_layer")(x)

# 8. Combine the inputs with the ouputs into a model
model_0 = tf.keras.Model(inputs,outputs)

# 9. Compile the model
model_0.compile(loss = "categorical_crossentropy",
                optimizer = tf.keras.optimizers.Adam(),
                metrics = ["accuracy"])

# 10. Fit the model
history_10_percent = model_0.fit(train_data_10_percent,
                                 epochs = 5,
                                 steps_per_epoch = len(train_data_10_percent),
                                 validation_data = test_data,
                                 validation_steps = int(0.25*len(test_data)),
                                 callbacks = [create_tensorboard_callback(dir_name = "transfer_learning",
                                                                          experiment_name ="10_percent_feature_extraction")])
Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5
16711680/16705208 [==============================] - 0s 0us/step
16719872/16705208 [==============================] - 0s 0us/step
Shape after passing inputs through base model: (None, 7, 7, 1280)
Shape after GlobalAveragePooling2D: (None, 1280)
Saving TensorBoard log files to: transfer_learning/10_percent_feature_extraction/20220220-051900
Epoch 1/5
24/24 [==============================] - 28s 455ms/step - loss: 1.8831 - accuracy: 0.4107 - val_loss: 1.3535 - val_accuracy: 0.7023
Epoch 2/5
24/24 [==============================] - 9s 361ms/step - loss: 1.1441 - accuracy: 0.7613 - val_loss: 0.9226 - val_accuracy: 0.8092
Epoch 3/5
24/24 [==============================] - 9s 357ms/step - loss: 0.8421 - accuracy: 0.8080 - val_loss: 0.7516 - val_accuracy: 0.8240
Epoch 4/5
24/24 [==============================] - 7s 274ms/step - loss: 0.6958 - accuracy: 0.8400 - val_loss: 0.6333 - val_accuracy: 0.8454
Epoch 5/5
24/24 [==============================] - 9s 357ms/step - loss: 0.5975 - accuracy: 0.8627 - val_loss: 0.6452 - val_accuracy: 0.8355
model_0.evaluate(test_data)
79/79 [==============================] - 12s 139ms/step - loss: 0.6102 - accuracy: 0.8412
[0.6101745367050171, 0.8411999940872192]
for layer_number, layer in enumerate(base_model.layers):
  print(layer_number, layer.name)

0 input_1
1 rescaling
2 normalization
3 stem_conv_pad
4 stem_conv
5 stem_bn
6 stem_activation
7 block1a_dwconv
8 block1a_bn
9 block1a_activation
10 block1a_se_squeeze
11 block1a_se_reshape
12 block1a_se_reduce
13 block1a_se_expand
14 block1a_se_excite
15 block1a_project_conv
16 block1a_project_bn
17 block2a_expand_conv
18 block2a_expand_bn
19 block2a_expand_activation
20 block2a_dwconv_pad
21 block2a_dwconv
22 block2a_bn
23 block2a_activation
24 block2a_se_squeeze
25 block2a_se_reshape
26 block2a_se_reduce
27 block2a_se_expand
28 block2a_se_excite
29 block2a_project_conv
30 block2a_project_bn
31 block2b_expand_conv
32 block2b_expand_bn
33 block2b_expand_activation
34 block2b_dwconv
35 block2b_bn
36 block2b_activation
37 block2b_se_squeeze
38 block2b_se_reshape
39 block2b_se_reduce
40 block2b_se_expand
41 block2b_se_excite
42 block2b_project_conv
43 block2b_project_bn
44 block2b_drop
45 block2b_add
46 block3a_expand_conv
47 block3a_expand_bn
48 block3a_expand_activation
49 block3a_dwconv_pad
50 block3a_dwconv
51 block3a_bn
52 block3a_activation
53 block3a_se_squeeze
54 block3a_se_reshape
55 block3a_se_reduce
56 block3a_se_expand
57 block3a_se_excite
58 block3a_project_conv
59 block3a_project_bn
60 block3b_expand_conv
61 block3b_expand_bn
62 block3b_expand_activation
63 block3b_dwconv
64 block3b_bn
65 block3b_activation
66 block3b_se_squeeze
67 block3b_se_reshape
68 block3b_se_reduce
69 block3b_se_expand
70 block3b_se_excite
71 block3b_project_conv
72 block3b_project_bn
73 block3b_drop
74 block3b_add
75 block4a_expand_conv
76 block4a_expand_bn
77 block4a_expand_activation
78 block4a_dwconv_pad
79 block4a_dwconv
80 block4a_bn
81 block4a_activation
82 block4a_se_squeeze
83 block4a_se_reshape
84 block4a_se_reduce
85 block4a_se_expand
86 block4a_se_excite
87 block4a_project_conv
88 block4a_project_bn
89 block4b_expand_conv
90 block4b_expand_bn
91 block4b_expand_activation
92 block4b_dwconv
93 block4b_bn
94 block4b_activation
95 block4b_se_squeeze
96 block4b_se_reshape
97 block4b_se_reduce
98 block4b_se_expand
99 block4b_se_excite
100 block4b_project_conv
101 block4b_project_bn
102 block4b_drop
103 block4b_add
104 block4c_expand_conv
105 block4c_expand_bn
106 block4c_expand_activation
107 block4c_dwconv
108 block4c_bn
109 block4c_activation
110 block4c_se_squeeze
111 block4c_se_reshape
112 block4c_se_reduce
113 block4c_se_expand
114 block4c_se_excite
115 block4c_project_conv
116 block4c_project_bn
117 block4c_drop
118 block4c_add
119 block5a_expand_conv
120 block5a_expand_bn
121 block5a_expand_activation
122 block5a_dwconv
123 block5a_bn
124 block5a_activation
125 block5a_se_squeeze
126 block5a_se_reshape
127 block5a_se_reduce
128 block5a_se_expand
129 block5a_se_excite
130 block5a_project_conv
131 block5a_project_bn
132 block5b_expand_conv
133 block5b_expand_bn
134 block5b_expand_activation
135 block5b_dwconv
136 block5b_bn
137 block5b_activation
138 block5b_se_squeeze
139 block5b_se_reshape
140 block5b_se_reduce
141 block5b_se_expand
142 block5b_se_excite
143 block5b_project_conv
144 block5b_project_bn
145 block5b_drop
146 block5b_add
147 block5c_expand_conv
148 block5c_expand_bn
149 block5c_expand_activation
150 block5c_dwconv
151 block5c_bn
152 block5c_activation
153 block5c_se_squeeze
154 block5c_se_reshape
155 block5c_se_reduce
156 block5c_se_expand
157 block5c_se_excite
158 block5c_project_conv
159 block5c_project_bn
160 block5c_drop
161 block5c_add
162 block6a_expand_conv
163 block6a_expand_bn
164 block6a_expand_activation
165 block6a_dwconv_pad
166 block6a_dwconv
167 block6a_bn
168 block6a_activation
169 block6a_se_squeeze
170 block6a_se_reshape
171 block6a_se_reduce
172 block6a_se_expand
173 block6a_se_excite
174 block6a_project_conv
175 block6a_project_bn
176 block6b_expand_conv
177 block6b_expand_bn
178 block6b_expand_activation
179 block6b_dwconv
180 block6b_bn
181 block6b_activation
182 block6b_se_squeeze
183 block6b_se_reshape
184 block6b_se_reduce
185 block6b_se_expand
186 block6b_se_excite
187 block6b_project_conv
188 block6b_project_bn
189 block6b_drop
190 block6b_add
191 block6c_expand_conv
192 block6c_expand_bn
193 block6c_expand_activation
194 block6c_dwconv
195 block6c_bn
196 block6c_activation
197 block6c_se_squeeze
198 block6c_se_reshape
199 block6c_se_reduce
200 block6c_se_expand
201 block6c_se_excite
202 block6c_project_conv
203 block6c_project_bn
204 block6c_drop
205 block6c_add
206 block6d_expand_conv
207 block6d_expand_bn
208 block6d_expand_activation
209 block6d_dwconv
210 block6d_bn
211 block6d_activation
212 block6d_se_squeeze
213 block6d_se_reshape
214 block6d_se_reduce
215 block6d_se_expand
216 block6d_se_excite
217 block6d_project_conv
218 block6d_project_bn
219 block6d_drop
220 block6d_add
221 block7a_expand_conv
222 block7a_expand_bn
223 block7a_expand_activation
224 block7a_dwconv
225 block7a_bn
226 block7a_activation
227 block7a_se_squeeze
228 block7a_se_reshape
229 block7a_se_reduce
230 block7a_se_expand
231 block7a_se_excite
232 block7a_project_conv
233 block7a_project_bn
234 top_conv
235 top_bn
236 top_activation

236 layers in EfficientNetB0. The EfficientNetB0 architecture already has the first layers with normalization so we don't need to do the rescaling.

# Let's check the summary of the base model i.e. EfficientNetB0 Model
base_model.summary()

Model: "efficientnetb0"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 rescaling (Rescaling)          (None, None, None,   0           ['input_1[0][0]']                
                                3)                                                                
                                                                                                  
 normalization (Normalization)  (None, None, None,   7           ['rescaling[0][0]']              
                                3)                                                                
                                                                                                  
 stem_conv_pad (ZeroPadding2D)  (None, None, None,   0           ['normalization[0][0]']          
                                3)                                                                
                                                                                                  
 stem_conv (Conv2D)             (None, None, None,   864         ['stem_conv_pad[0][0]']          
                                32)                                                               
                                                                                                  
 stem_bn (BatchNormalization)   (None, None, None,   128         ['stem_conv[0][0]']              
                                32)                                                               
                                                                                                  
 stem_activation (Activation)   (None, None, None,   0           ['stem_bn[0][0]']                
                                32)                                                               
                                                                                                  
 block1a_dwconv (DepthwiseConv2  (None, None, None,   288        ['stem_activation[0][0]']        
 D)                             32)                                                               
                                                                                                  
 block1a_bn (BatchNormalization  (None, None, None,   128        ['block1a_dwconv[0][0]']         
 )                              32)                                                               
                                                                                                  
 block1a_activation (Activation  (None, None, None,   0          ['block1a_bn[0][0]']             
 )                              32)                                                               
                                                                                                  
 block1a_se_squeeze (GlobalAver  (None, 32)          0           ['block1a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block1a_se_reshape (Reshape)   (None, 1, 1, 32)     0           ['block1a_se_squeeze[0][0]']     
                                                                                                  
 block1a_se_reduce (Conv2D)     (None, 1, 1, 8)      264         ['block1a_se_reshape[0][0]']     
                                                                                                  
 block1a_se_expand (Conv2D)     (None, 1, 1, 32)     288         ['block1a_se_reduce[0][0]']      
                                                                                                  
 block1a_se_excite (Multiply)   (None, None, None,   0           ['block1a_activation[0][0]',     
                                32)                               'block1a_se_expand[0][0]']      
                                                                                                  
 block1a_project_conv (Conv2D)  (None, None, None,   512         ['block1a_se_excite[0][0]']      
                                16)                                                               
                                                                                                  
 block1a_project_bn (BatchNorma  (None, None, None,   64         ['block1a_project_conv[0][0]']   
 lization)                      16)                                                               
                                                                                                  
 block2a_expand_conv (Conv2D)   (None, None, None,   1536        ['block1a_project_bn[0][0]']     
                                96)                                                               
                                                                                                  
 block2a_expand_bn (BatchNormal  (None, None, None,   384        ['block2a_expand_conv[0][0]']    
 ization)                       96)                                                               
                                                                                                  
 block2a_expand_activation (Act  (None, None, None,   0          ['block2a_expand_bn[0][0]']      
 ivation)                       96)                                                               
                                                                                                  
 block2a_dwconv_pad (ZeroPaddin  (None, None, None,   0          ['block2a_expand_activation[0][0]
 g2D)                           96)                              ']                               
                                                                                                  
 block2a_dwconv (DepthwiseConv2  (None, None, None,   864        ['block2a_dwconv_pad[0][0]']     
 D)                             96)                                                               
                                                                                                  
 block2a_bn (BatchNormalization  (None, None, None,   384        ['block2a_dwconv[0][0]']         
 )                              96)                                                               
                                                                                                  
 block2a_activation (Activation  (None, None, None,   0          ['block2a_bn[0][0]']             
 )                              96)                                                               
                                                                                                  
 block2a_se_squeeze (GlobalAver  (None, 96)          0           ['block2a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block2a_se_reshape (Reshape)   (None, 1, 1, 96)     0           ['block2a_se_squeeze[0][0]']     
                                                                                                  
 block2a_se_reduce (Conv2D)     (None, 1, 1, 4)      388         ['block2a_se_reshape[0][0]']     
                                                                                                  
 block2a_se_expand (Conv2D)     (None, 1, 1, 96)     480         ['block2a_se_reduce[0][0]']      
                                                                                                  
 block2a_se_excite (Multiply)   (None, None, None,   0           ['block2a_activation[0][0]',     
                                96)                               'block2a_se_expand[0][0]']      
                                                                                                  
 block2a_project_conv (Conv2D)  (None, None, None,   2304        ['block2a_se_excite[0][0]']      
                                24)                                                               
                                                                                                  
 block2a_project_bn (BatchNorma  (None, None, None,   96         ['block2a_project_conv[0][0]']   
 lization)                      24)                                                               
                                                                                                  
 block2b_expand_conv (Conv2D)   (None, None, None,   3456        ['block2a_project_bn[0][0]']     
                                144)                                                              
                                                                                                  
 block2b_expand_bn (BatchNormal  (None, None, None,   576        ['block2b_expand_conv[0][0]']    
 ization)                       144)                                                              
                                                                                                  
 block2b_expand_activation (Act  (None, None, None,   0          ['block2b_expand_bn[0][0]']      
 ivation)                       144)                                                              
                                                                                                  
 block2b_dwconv (DepthwiseConv2  (None, None, None,   1296       ['block2b_expand_activation[0][0]
 D)                             144)                             ']                               
                                                                                                  
 block2b_bn (BatchNormalization  (None, None, None,   576        ['block2b_dwconv[0][0]']         
 )                              144)                                                              
                                                                                                  
 block2b_activation (Activation  (None, None, None,   0          ['block2b_bn[0][0]']             
 )                              144)                                                              
                                                                                                  
 block2b_se_squeeze (GlobalAver  (None, 144)         0           ['block2b_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block2b_se_reshape (Reshape)   (None, 1, 1, 144)    0           ['block2b_se_squeeze[0][0]']     
                                                                                                  
 block2b_se_reduce (Conv2D)     (None, 1, 1, 6)      870         ['block2b_se_reshape[0][0]']     
                                                                                                  
 block2b_se_expand (Conv2D)     (None, 1, 1, 144)    1008        ['block2b_se_reduce[0][0]']      
                                                                                                  
 block2b_se_excite (Multiply)   (None, None, None,   0           ['block2b_activation[0][0]',     
                                144)                              'block2b_se_expand[0][0]']      
                                                                                                  
 block2b_project_conv (Conv2D)  (None, None, None,   3456        ['block2b_se_excite[0][0]']      
                                24)                                                               
                                                                                                  
 block2b_project_bn (BatchNorma  (None, None, None,   96         ['block2b_project_conv[0][0]']   
 lization)                      24)                                                               
                                                                                                  
 block2b_drop (Dropout)         (None, None, None,   0           ['block2b_project_bn[0][0]']     
                                24)                                                               
                                                                                                  
 block2b_add (Add)              (None, None, None,   0           ['block2b_drop[0][0]',           
                                24)                               'block2a_project_bn[0][0]']     
                                                                                                  
 block3a_expand_conv (Conv2D)   (None, None, None,   3456        ['block2b_add[0][0]']            
                                144)                                                              
                                                                                                  
 block3a_expand_bn (BatchNormal  (None, None, None,   576        ['block3a_expand_conv[0][0]']    
 ization)                       144)                                                              
                                                                                                  
 block3a_expand_activation (Act  (None, None, None,   0          ['block3a_expand_bn[0][0]']      
 ivation)                       144)                                                              
                                                                                                  
 block3a_dwconv_pad (ZeroPaddin  (None, None, None,   0          ['block3a_expand_activation[0][0]
 g2D)                           144)                             ']                               
                                                                                                  
 block3a_dwconv (DepthwiseConv2  (None, None, None,   3600       ['block3a_dwconv_pad[0][0]']     
 D)                             144)                                                              
                                                                                                  
 block3a_bn (BatchNormalization  (None, None, None,   576        ['block3a_dwconv[0][0]']         
 )                              144)                                                              
                                                                                                  
 block3a_activation (Activation  (None, None, None,   0          ['block3a_bn[0][0]']             
 )                              144)                                                              
                                                                                                  
 block3a_se_squeeze (GlobalAver  (None, 144)         0           ['block3a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block3a_se_reshape (Reshape)   (None, 1, 1, 144)    0           ['block3a_se_squeeze[0][0]']     
                                                                                                  
 block3a_se_reduce (Conv2D)     (None, 1, 1, 6)      870         ['block3a_se_reshape[0][0]']     
                                                                                                  
 block3a_se_expand (Conv2D)     (None, 1, 1, 144)    1008        ['block3a_se_reduce[0][0]']      
                                                                                                  
 block3a_se_excite (Multiply)   (None, None, None,   0           ['block3a_activation[0][0]',     
                                144)                              'block3a_se_expand[0][0]']      
                                                                                                  
 block3a_project_conv (Conv2D)  (None, None, None,   5760        ['block3a_se_excite[0][0]']      
                                40)                                                               
                                                                                                  
 block3a_project_bn (BatchNorma  (None, None, None,   160        ['block3a_project_conv[0][0]']   
 lization)                      40)                                                               
                                                                                                  
 block3b_expand_conv (Conv2D)   (None, None, None,   9600        ['block3a_project_bn[0][0]']     
                                240)                                                              
                                                                                                  
 block3b_expand_bn (BatchNormal  (None, None, None,   960        ['block3b_expand_conv[0][0]']    
 ization)                       240)                                                              
                                                                                                  
 block3b_expand_activation (Act  (None, None, None,   0          ['block3b_expand_bn[0][0]']      
 ivation)                       240)                                                              
                                                                                                  
 block3b_dwconv (DepthwiseConv2  (None, None, None,   6000       ['block3b_expand_activation[0][0]
 D)                             240)                             ']                               
                                                                                                  
 block3b_bn (BatchNormalization  (None, None, None,   960        ['block3b_dwconv[0][0]']         
 )                              240)                                                              
                                                                                                  
 block3b_activation (Activation  (None, None, None,   0          ['block3b_bn[0][0]']             
 )                              240)                                                              
                                                                                                  
 block3b_se_squeeze (GlobalAver  (None, 240)         0           ['block3b_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block3b_se_reshape (Reshape)   (None, 1, 1, 240)    0           ['block3b_se_squeeze[0][0]']     
                                                                                                  
 block3b_se_reduce (Conv2D)     (None, 1, 1, 10)     2410        ['block3b_se_reshape[0][0]']     
                                                                                                  
 block3b_se_expand (Conv2D)     (None, 1, 1, 240)    2640        ['block3b_se_reduce[0][0]']      
                                                                                                  
 block3b_se_excite (Multiply)   (None, None, None,   0           ['block3b_activation[0][0]',     
                                240)                              'block3b_se_expand[0][0]']      
                                                                                                  
 block3b_project_conv (Conv2D)  (None, None, None,   9600        ['block3b_se_excite[0][0]']      
                                40)                                                               
                                                                                                  
 block3b_project_bn (BatchNorma  (None, None, None,   160        ['block3b_project_conv[0][0]']   
 lization)                      40)                                                               
                                                                                                  
 block3b_drop (Dropout)         (None, None, None,   0           ['block3b_project_bn[0][0]']     
                                40)                                                               
                                                                                                  
 block3b_add (Add)              (None, None, None,   0           ['block3b_drop[0][0]',           
                                40)                               'block3a_project_bn[0][0]']     
                                                                                                  
 block4a_expand_conv (Conv2D)   (None, None, None,   9600        ['block3b_add[0][0]']            
                                240)                                                              
                                                                                                  
 block4a_expand_bn (BatchNormal  (None, None, None,   960        ['block4a_expand_conv[0][0]']    
 ization)                       240)                                                              
                                                                                                  
 block4a_expand_activation (Act  (None, None, None,   0          ['block4a_expand_bn[0][0]']      
 ivation)                       240)                                                              
                                                                                                  
 block4a_dwconv_pad (ZeroPaddin  (None, None, None,   0          ['block4a_expand_activation[0][0]
 g2D)                           240)                             ']                               
                                                                                                  
 block4a_dwconv (DepthwiseConv2  (None, None, None,   2160       ['block4a_dwconv_pad[0][0]']     
 D)                             240)                                                              
                                                                                                  
 block4a_bn (BatchNormalization  (None, None, None,   960        ['block4a_dwconv[0][0]']         
 )                              240)                                                              
                                                                                                  
 block4a_activation (Activation  (None, None, None,   0          ['block4a_bn[0][0]']             
 )                              240)                                                              
                                                                                                  
 block4a_se_squeeze (GlobalAver  (None, 240)         0           ['block4a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block4a_se_reshape (Reshape)   (None, 1, 1, 240)    0           ['block4a_se_squeeze[0][0]']     
                                                                                                  
 block4a_se_reduce (Conv2D)     (None, 1, 1, 10)     2410        ['block4a_se_reshape[0][0]']     
                                                                                                  
 block4a_se_expand (Conv2D)     (None, 1, 1, 240)    2640        ['block4a_se_reduce[0][0]']      
                                                                                                  
 block4a_se_excite (Multiply)   (None, None, None,   0           ['block4a_activation[0][0]',     
                                240)                              'block4a_se_expand[0][0]']      
                                                                                                  
 block4a_project_conv (Conv2D)  (None, None, None,   19200       ['block4a_se_excite[0][0]']      
                                80)                                                               
                                                                                                  
 block4a_project_bn (BatchNorma  (None, None, None,   320        ['block4a_project_conv[0][0]']   
 lization)                      80)                                                               
                                                                                                  
 block4b_expand_conv (Conv2D)   (None, None, None,   38400       ['block4a_project_bn[0][0]']     
                                480)                                                              
                                                                                                  
 block4b_expand_bn (BatchNormal  (None, None, None,   1920       ['block4b_expand_conv[0][0]']    
 ization)                       480)                                                              
                                                                                                  
 block4b_expand_activation (Act  (None, None, None,   0          ['block4b_expand_bn[0][0]']      
 ivation)                       480)                                                              
                                                                                                  
 block4b_dwconv (DepthwiseConv2  (None, None, None,   4320       ['block4b_expand_activation[0][0]
 D)                             480)                             ']                               
                                                                                                  
 block4b_bn (BatchNormalization  (None, None, None,   1920       ['block4b_dwconv[0][0]']         
 )                              480)                                                              
                                                                                                  
 block4b_activation (Activation  (None, None, None,   0          ['block4b_bn[0][0]']             
 )                              480)                                                              
                                                                                                  
 block4b_se_squeeze (GlobalAver  (None, 480)         0           ['block4b_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block4b_se_reshape (Reshape)   (None, 1, 1, 480)    0           ['block4b_se_squeeze[0][0]']     
                                                                                                  
 block4b_se_reduce (Conv2D)     (None, 1, 1, 20)     9620        ['block4b_se_reshape[0][0]']     
                                                                                                  
 block4b_se_expand (Conv2D)     (None, 1, 1, 480)    10080       ['block4b_se_reduce[0][0]']      
                                                                                                  
 block4b_se_excite (Multiply)   (None, None, None,   0           ['block4b_activation[0][0]',     
                                480)                              'block4b_se_expand[0][0]']      
                                                                                                  
 block4b_project_conv (Conv2D)  (None, None, None,   38400       ['block4b_se_excite[0][0]']      
                                80)                                                               
                                                                                                  
 block4b_project_bn (BatchNorma  (None, None, None,   320        ['block4b_project_conv[0][0]']   
 lization)                      80)                                                               
                                                                                                  
 block4b_drop (Dropout)         (None, None, None,   0           ['block4b_project_bn[0][0]']     
                                80)                                                               
                                                                                                  
 block4b_add (Add)              (None, None, None,   0           ['block4b_drop[0][0]',           
                                80)                               'block4a_project_bn[0][0]']     
                                                                                                  
 block4c_expand_conv (Conv2D)   (None, None, None,   38400       ['block4b_add[0][0]']            
                                480)                                                              
                                                                                                  
 block4c_expand_bn (BatchNormal  (None, None, None,   1920       ['block4c_expand_conv[0][0]']    
 ization)                       480)                                                              
                                                                                                  
 block4c_expand_activation (Act  (None, None, None,   0          ['block4c_expand_bn[0][0]']      
 ivation)                       480)                                                              
                                                                                                  
 block4c_dwconv (DepthwiseConv2  (None, None, None,   4320       ['block4c_expand_activation[0][0]
 D)                             480)                             ']                               
                                                                                                  
 block4c_bn (BatchNormalization  (None, None, None,   1920       ['block4c_dwconv[0][0]']         
 )                              480)                                                              
                                                                                                  
 block4c_activation (Activation  (None, None, None,   0          ['block4c_bn[0][0]']             
 )                              480)                                                              
                                                                                                  
 block4c_se_squeeze (GlobalAver  (None, 480)         0           ['block4c_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block4c_se_reshape (Reshape)   (None, 1, 1, 480)    0           ['block4c_se_squeeze[0][0]']     
                                                                                                  
 block4c_se_reduce (Conv2D)     (None, 1, 1, 20)     9620        ['block4c_se_reshape[0][0]']     
                                                                                                  
 block4c_se_expand (Conv2D)     (None, 1, 1, 480)    10080       ['block4c_se_reduce[0][0]']      
                                                                                                  
 block4c_se_excite (Multiply)   (None, None, None,   0           ['block4c_activation[0][0]',     
                                480)                              'block4c_se_expand[0][0]']      
                                                                                                  
 block4c_project_conv (Conv2D)  (None, None, None,   38400       ['block4c_se_excite[0][0]']      
                                80)                                                               
                                                                                                  
 block4c_project_bn (BatchNorma  (None, None, None,   320        ['block4c_project_conv[0][0]']   
 lization)                      80)                                                               
                                                                                                  
 block4c_drop (Dropout)         (None, None, None,   0           ['block4c_project_bn[0][0]']     
                                80)                                                               
                                                                                                  
 block4c_add (Add)              (None, None, None,   0           ['block4c_drop[0][0]',           
                                80)                               'block4b_add[0][0]']            
                                                                                                  
 block5a_expand_conv (Conv2D)   (None, None, None,   38400       ['block4c_add[0][0]']            
                                480)                                                              
                                                                                                  
 block5a_expand_bn (BatchNormal  (None, None, None,   1920       ['block5a_expand_conv[0][0]']    
 ization)                       480)                                                              
                                                                                                  
 block5a_expand_activation (Act  (None, None, None,   0          ['block5a_expand_bn[0][0]']      
 ivation)                       480)                                                              
                                                                                                  
 block5a_dwconv (DepthwiseConv2  (None, None, None,   12000      ['block5a_expand_activation[0][0]
 D)                             480)                             ']                               
                                                                                                  
 block5a_bn (BatchNormalization  (None, None, None,   1920       ['block5a_dwconv[0][0]']         
 )                              480)                                                              
                                                                                                  
 block5a_activation (Activation  (None, None, None,   0          ['block5a_bn[0][0]']             
 )                              480)                                                              
                                                                                                  
 block5a_se_squeeze (GlobalAver  (None, 480)         0           ['block5a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5a_se_reshape (Reshape)   (None, 1, 1, 480)    0           ['block5a_se_squeeze[0][0]']     
                                                                                                  
 block5a_se_reduce (Conv2D)     (None, 1, 1, 20)     9620        ['block5a_se_reshape[0][0]']     
                                                                                                  
 block5a_se_expand (Conv2D)     (None, 1, 1, 480)    10080       ['block5a_se_reduce[0][0]']      
                                                                                                  
 block5a_se_excite (Multiply)   (None, None, None,   0           ['block5a_activation[0][0]',     
                                480)                              'block5a_se_expand[0][0]']      
                                                                                                  
 block5a_project_conv (Conv2D)  (None, None, None,   53760       ['block5a_se_excite[0][0]']      
                                112)                                                              
                                                                                                  
 block5a_project_bn (BatchNorma  (None, None, None,   448        ['block5a_project_conv[0][0]']   
 lization)                      112)                                                              
                                                                                                  
 block5b_expand_conv (Conv2D)   (None, None, None,   75264       ['block5a_project_bn[0][0]']     
                                672)                                                              
                                                                                                  
 block5b_expand_bn (BatchNormal  (None, None, None,   2688       ['block5b_expand_conv[0][0]']    
 ization)                       672)                                                              
                                                                                                  
 block5b_expand_activation (Act  (None, None, None,   0          ['block5b_expand_bn[0][0]']      
 ivation)                       672)                                                              
                                                                                                  
 block5b_dwconv (DepthwiseConv2  (None, None, None,   16800      ['block5b_expand_activation[0][0]
 D)                             672)                             ']                               
                                                                                                  
 block5b_bn (BatchNormalization  (None, None, None,   2688       ['block5b_dwconv[0][0]']         
 )                              672)                                                              
                                                                                                  
 block5b_activation (Activation  (None, None, None,   0          ['block5b_bn[0][0]']             
 )                              672)                                                              
                                                                                                  
 block5b_se_squeeze (GlobalAver  (None, 672)         0           ['block5b_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5b_se_reshape (Reshape)   (None, 1, 1, 672)    0           ['block5b_se_squeeze[0][0]']     
                                                                                                  
 block5b_se_reduce (Conv2D)     (None, 1, 1, 28)     18844       ['block5b_se_reshape[0][0]']     
                                                                                                  
 block5b_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block5b_se_reduce[0][0]']      
                                                                                                  
 block5b_se_excite (Multiply)   (None, None, None,   0           ['block5b_activation[0][0]',     
                                672)                              'block5b_se_expand[0][0]']      
                                                                                                  
 block5b_project_conv (Conv2D)  (None, None, None,   75264       ['block5b_se_excite[0][0]']      
                                112)                                                              
                                                                                                  
 block5b_project_bn (BatchNorma  (None, None, None,   448        ['block5b_project_conv[0][0]']   
 lization)                      112)                                                              
                                                                                                  
 block5b_drop (Dropout)         (None, None, None,   0           ['block5b_project_bn[0][0]']     
                                112)                                                              
                                                                                                  
 block5b_add (Add)              (None, None, None,   0           ['block5b_drop[0][0]',           
                                112)                              'block5a_project_bn[0][0]']     
                                                                                                  
 block5c_expand_conv (Conv2D)   (None, None, None,   75264       ['block5b_add[0][0]']            
                                672)                                                              
                                                                                                  
 block5c_expand_bn (BatchNormal  (None, None, None,   2688       ['block5c_expand_conv[0][0]']    
 ization)                       672)                                                              
                                                                                                  
 block5c_expand_activation (Act  (None, None, None,   0          ['block5c_expand_bn[0][0]']      
 ivation)                       672)                                                              
                                                                                                  
 block5c_dwconv (DepthwiseConv2  (None, None, None,   16800      ['block5c_expand_activation[0][0]
 D)                             672)                             ']                               
                                                                                                  
 block5c_bn (BatchNormalization  (None, None, None,   2688       ['block5c_dwconv[0][0]']         
 )                              672)                                                              
                                                                                                  
 block5c_activation (Activation  (None, None, None,   0          ['block5c_bn[0][0]']             
 )                              672)                                                              
                                                                                                  
 block5c_se_squeeze (GlobalAver  (None, 672)         0           ['block5c_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5c_se_reshape (Reshape)   (None, 1, 1, 672)    0           ['block5c_se_squeeze[0][0]']     
                                                                                                  
 block5c_se_reduce (Conv2D)     (None, 1, 1, 28)     18844       ['block5c_se_reshape[0][0]']     
                                                                                                  
 block5c_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block5c_se_reduce[0][0]']      
                                                                                                  
 block5c_se_excite (Multiply)   (None, None, None,   0           ['block5c_activation[0][0]',     
                                672)                              'block5c_se_expand[0][0]']      
                                                                                                  
 block5c_project_conv (Conv2D)  (None, None, None,   75264       ['block5c_se_excite[0][0]']      
                                112)                                                              
                                                                                                  
 block5c_project_bn (BatchNorma  (None, None, None,   448        ['block5c_project_conv[0][0]']   
 lization)                      112)                                                              
                                                                                                  
 block5c_drop (Dropout)         (None, None, None,   0           ['block5c_project_bn[0][0]']     
                                112)                                                              
                                                                                                  
 block5c_add (Add)              (None, None, None,   0           ['block5c_drop[0][0]',           
                                112)                              'block5b_add[0][0]']            
                                                                                                  
 block6a_expand_conv (Conv2D)   (None, None, None,   75264       ['block5c_add[0][0]']            
                                672)                                                              
                                                                                                  
 block6a_expand_bn (BatchNormal  (None, None, None,   2688       ['block6a_expand_conv[0][0]']    
 ization)                       672)                                                              
                                                                                                  
 block6a_expand_activation (Act  (None, None, None,   0          ['block6a_expand_bn[0][0]']      
 ivation)                       672)                                                              
                                                                                                  
 block6a_dwconv_pad (ZeroPaddin  (None, None, None,   0          ['block6a_expand_activation[0][0]
 g2D)                           672)                             ']                               
                                                                                                  
 block6a_dwconv (DepthwiseConv2  (None, None, None,   16800      ['block6a_dwconv_pad[0][0]']     
 D)                             672)                                                              
                                                                                                  
 block6a_bn (BatchNormalization  (None, None, None,   2688       ['block6a_dwconv[0][0]']         
 )                              672)                                                              
                                                                                                  
 block6a_activation (Activation  (None, None, None,   0          ['block6a_bn[0][0]']             
 )                              672)                                                              
                                                                                                  
 block6a_se_squeeze (GlobalAver  (None, 672)         0           ['block6a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block6a_se_reshape (Reshape)   (None, 1, 1, 672)    0           ['block6a_se_squeeze[0][0]']     
                                                                                                  
 block6a_se_reduce (Conv2D)     (None, 1, 1, 28)     18844       ['block6a_se_reshape[0][0]']     
                                                                                                  
 block6a_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block6a_se_reduce[0][0]']      
                                                                                                  
 block6a_se_excite (Multiply)   (None, None, None,   0           ['block6a_activation[0][0]',     
                                672)                              'block6a_se_expand[0][0]']      
                                                                                                  
 block6a_project_conv (Conv2D)  (None, None, None,   129024      ['block6a_se_excite[0][0]']      
                                192)                                                              
                                                                                                  
 block6a_project_bn (BatchNorma  (None, None, None,   768        ['block6a_project_conv[0][0]']   
 lization)                      192)                                                              
                                                                                                  
 block6b_expand_conv (Conv2D)   (None, None, None,   221184      ['block6a_project_bn[0][0]']     
                                1152)                                                             
                                                                                                  
 block6b_expand_bn (BatchNormal  (None, None, None,   4608       ['block6b_expand_conv[0][0]']    
 ization)                       1152)                                                             
                                                                                                  
 block6b_expand_activation (Act  (None, None, None,   0          ['block6b_expand_bn[0][0]']      
 ivation)                       1152)                                                             
                                                                                                  
 block6b_dwconv (DepthwiseConv2  (None, None, None,   28800      ['block6b_expand_activation[0][0]
 D)                             1152)                            ']                               
                                                                                                  
 block6b_bn (BatchNormalization  (None, None, None,   4608       ['block6b_dwconv[0][0]']         
 )                              1152)                                                             
                                                                                                  
 block6b_activation (Activation  (None, None, None,   0          ['block6b_bn[0][0]']             
 )                              1152)                                                             
                                                                                                  
 block6b_se_squeeze (GlobalAver  (None, 1152)        0           ['block6b_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block6b_se_reshape (Reshape)   (None, 1, 1, 1152)   0           ['block6b_se_squeeze[0][0]']     
                                                                                                  
 block6b_se_reduce (Conv2D)     (None, 1, 1, 48)     55344       ['block6b_se_reshape[0][0]']     
                                                                                                  
 block6b_se_expand (Conv2D)     (None, 1, 1, 1152)   56448       ['block6b_se_reduce[0][0]']      
                                                                                                  
 block6b_se_excite (Multiply)   (None, None, None,   0           ['block6b_activation[0][0]',     
                                1152)                             'block6b_se_expand[0][0]']      
                                                                                                  
 block6b_project_conv (Conv2D)  (None, None, None,   221184      ['block6b_se_excite[0][0]']      
                                192)                                                              
                                                                                                  
 block6b_project_bn (BatchNorma  (None, None, None,   768        ['block6b_project_conv[0][0]']   
 lization)                      192)                                                              
                                                                                                  
 block6b_drop (Dropout)         (None, None, None,   0           ['block6b_project_bn[0][0]']     
                                192)                                                              
                                                                                                  
 block6b_add (Add)              (None, None, None,   0           ['block6b_drop[0][0]',           
                                192)                              'block6a_project_bn[0][0]']     
                                                                                                  
 block6c_expand_conv (Conv2D)   (None, None, None,   221184      ['block6b_add[0][0]']            
                                1152)                                                             
                                                                                                  
 block6c_expand_bn (BatchNormal  (None, None, None,   4608       ['block6c_expand_conv[0][0]']    
 ization)                       1152)                                                             
                                                                                                  
 block6c_expand_activation (Act  (None, None, None,   0          ['block6c_expand_bn[0][0]']      
 ivation)                       1152)                                                             
                                                                                                  
 block6c_dwconv (DepthwiseConv2  (None, None, None,   28800      ['block6c_expand_activation[0][0]
 D)                             1152)                            ']                               
                                                                                                  
 block6c_bn (BatchNormalization  (None, None, None,   4608       ['block6c_dwconv[0][0]']         
 )                              1152)                                                             
                                                                                                  
 block6c_activation (Activation  (None, None, None,   0          ['block6c_bn[0][0]']             
 )                              1152)                                                             
                                                                                                  
 block6c_se_squeeze (GlobalAver  (None, 1152)        0           ['block6c_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block6c_se_reshape (Reshape)   (None, 1, 1, 1152)   0           ['block6c_se_squeeze[0][0]']     
                                                                                                  
 block6c_se_reduce (Conv2D)     (None, 1, 1, 48)     55344       ['block6c_se_reshape[0][0]']     
                                                                                                  
 block6c_se_expand (Conv2D)     (None, 1, 1, 1152)   56448       ['block6c_se_reduce[0][0]']      
                                                                                                  
 block6c_se_excite (Multiply)   (None, None, None,   0           ['block6c_activation[0][0]',     
                                1152)                             'block6c_se_expand[0][0]']      
                                                                                                  
 block6c_project_conv (Conv2D)  (None, None, None,   221184      ['block6c_se_excite[0][0]']      
                                192)                                                              
                                                                                                  
 block6c_project_bn (BatchNorma  (None, None, None,   768        ['block6c_project_conv[0][0]']   
 lization)                      192)                                                              
                                                                                                  
 block6c_drop (Dropout)         (None, None, None,   0           ['block6c_project_bn[0][0]']     
                                192)                                                              
                                                                                                  
 block6c_add (Add)              (None, None, None,   0           ['block6c_drop[0][0]',           
                                192)                              'block6b_add[0][0]']            
                                                                                                  
 block6d_expand_conv (Conv2D)   (None, None, None,   221184      ['block6c_add[0][0]']            
                                1152)                                                             
                                                                                                  
 block6d_expand_bn (BatchNormal  (None, None, None,   4608       ['block6d_expand_conv[0][0]']    
 ization)                       1152)                                                             
                                                                                                  
 block6d_expand_activation (Act  (None, None, None,   0          ['block6d_expand_bn[0][0]']      
 ivation)                       1152)                                                             
                                                                                                  
 block6d_dwconv (DepthwiseConv2  (None, None, None,   28800      ['block6d_expand_activation[0][0]
 D)                             1152)                            ']                               
                                                                                                  
 block6d_bn (BatchNormalization  (None, None, None,   4608       ['block6d_dwconv[0][0]']         
 )                              1152)                                                             
                                                                                                  
 block6d_activation (Activation  (None, None, None,   0          ['block6d_bn[0][0]']             
 )                              1152)                                                             
                                                                                                  
 block6d_se_squeeze (GlobalAver  (None, 1152)        0           ['block6d_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block6d_se_reshape (Reshape)   (None, 1, 1, 1152)   0           ['block6d_se_squeeze[0][0]']     
                                                                                                  
 block6d_se_reduce (Conv2D)     (None, 1, 1, 48)     55344       ['block6d_se_reshape[0][0]']     
                                                                                                  
 block6d_se_expand (Conv2D)     (None, 1, 1, 1152)   56448       ['block6d_se_reduce[0][0]']      
                                                                                                  
 block6d_se_excite (Multiply)   (None, None, None,   0           ['block6d_activation[0][0]',     
                                1152)                             'block6d_se_expand[0][0]']      
                                                                                                  
 block6d_project_conv (Conv2D)  (None, None, None,   221184      ['block6d_se_excite[0][0]']      
                                192)                                                              
                                                                                                  
 block6d_project_bn (BatchNorma  (None, None, None,   768        ['block6d_project_conv[0][0]']   
 lization)                      192)                                                              
                                                                                                  
 block6d_drop (Dropout)         (None, None, None,   0           ['block6d_project_bn[0][0]']     
                                192)                                                              
                                                                                                  
 block6d_add (Add)              (None, None, None,   0           ['block6d_drop[0][0]',           
                                192)                              'block6c_add[0][0]']            
                                                                                                  
 block7a_expand_conv (Conv2D)   (None, None, None,   221184      ['block6d_add[0][0]']            
                                1152)                                                             
                                                                                                  
 block7a_expand_bn (BatchNormal  (None, None, None,   4608       ['block7a_expand_conv[0][0]']    
 ization)                       1152)                                                             
                                                                                                  
 block7a_expand_activation (Act  (None, None, None,   0          ['block7a_expand_bn[0][0]']      
 ivation)                       1152)                                                             
                                                                                                  
 block7a_dwconv (DepthwiseConv2  (None, None, None,   10368      ['block7a_expand_activation[0][0]
 D)                             1152)                            ']                               
                                                                                                  
 block7a_bn (BatchNormalization  (None, None, None,   4608       ['block7a_dwconv[0][0]']         
 )                              1152)                                                             
                                                                                                  
 block7a_activation (Activation  (None, None, None,   0          ['block7a_bn[0][0]']             
 )                              1152)                                                             
                                                                                                  
 block7a_se_squeeze (GlobalAver  (None, 1152)        0           ['block7a_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block7a_se_reshape (Reshape)   (None, 1, 1, 1152)   0           ['block7a_se_squeeze[0][0]']     
                                                                                                  
 block7a_se_reduce (Conv2D)     (None, 1, 1, 48)     55344       ['block7a_se_reshape[0][0]']     
                                                                                                  
 block7a_se_expand (Conv2D)     (None, 1, 1, 1152)   56448       ['block7a_se_reduce[0][0]']      
                                                                                                  
 block7a_se_excite (Multiply)   (None, None, None,   0           ['block7a_activation[0][0]',     
                                1152)                             'block7a_se_expand[0][0]']      
                                                                                                  
 block7a_project_conv (Conv2D)  (None, None, None,   368640      ['block7a_se_excite[0][0]']      
                                320)                                                              
                                                                                                  
 block7a_project_bn (BatchNorma  (None, None, None,   1280       ['block7a_project_conv[0][0]']   
 lization)                      320)                                                              
                                                                                                  
 top_conv (Conv2D)              (None, None, None,   409600      ['block7a_project_bn[0][0]']     
                                1280)                                                             
                                                                                                  
 top_bn (BatchNormalization)    (None, None, None,   5120        ['top_conv[0][0]']               
                                1280)                                                             
                                                                                                  
 top_activation (Activation)    (None, None, None,   0           ['top_bn[0][0]']                 
                                1280)                                                             
                                                                                                  
==================================================================================================
Total params: 4,049,571
Trainable params: 0
Non-trainable params: 4,049,571
__________________________________________________________________________________________________
model_0.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_layer (InputLayer)    [(None, 224, 224, 3)]     0         
                                                                 
 efficientnetb0 (Functional)  (None, None, None, 1280)  4049571  
                                                                 
 global_average_pooling_laye  (None, 1280)             0         
 r (GlobalAveragePooling2D)                                      
                                                                 
 output_layer (Dense)        (None, 10)                12810     
                                                                 
=================================================================
Total params: 4,062,381
Trainable params: 12,810
Non-trainable params: 4,049,571
_________________________________________________________________
import matplotlib.pyplot as plt
plt.style.use('dark_background')
plot_loss_curves(history_10_percent)

Getting a Feature vector from a trained model

Let's demonstrate the Global Average Pooling 2D layer...

We have a tensor after our model goes through base_model of shape (None, 7, 7, 1280).

But then when it passes through GlobalAveragePooling2D, it turns into (None, 1280).

Let's use a similar shaped tensor of (1,4,4,3) and then pass it to GlobalAveragePooling2D.

input_shape = (1,4,4,3)

# Create a random tensor
tf.random.set_seed(42)
input_tensor = tf.random.normal(input_shape)
print(f"Random input tensor: \n {input_tensor} \n")

# Pass the random tensor to the GlobalAveragePooling 2D layer
global_average_pooled_tensor = tf.keras.layers.GlobalAveragePooling2D()(input_tensor)
print(f"2D global average pooled random tensor: \n {global_average_pooled_tensor}\n")

# Check the shape of the different tensors
print(f"Shape of input tensor: {input_tensor.shape}\n")
print(f"Shape of Global Average Pooled 2D tensor : {global_average_pooled_tensor.shape}\n")
Random input tensor: 
 [[[[ 0.3274685  -0.8426258   0.3194337 ]
   [-1.4075519  -2.3880599  -1.0392479 ]
   [-0.5573232   0.539707    1.6994323 ]
   [ 0.28893656 -1.5066116  -0.2645474 ]]

  [[-0.59722406 -1.9171132  -0.62044144]
   [ 0.8504023  -0.40604794 -3.0258412 ]
   [ 0.9058464   0.29855987 -0.22561555]
   [-0.7616443  -1.8917141  -0.93847126]]

  [[ 0.77852213 -0.47338897  0.97772694]
   [ 0.24694404  0.20573747 -0.5256233 ]
   [ 0.32410017  0.02545409 -0.10638497]
   [-0.6369475   1.1603122   0.2507359 ]]

  [[-0.41728503  0.4012578  -1.4145443 ]
   [-0.5931857  -1.6617213   0.33567193]
   [ 0.10815629  0.23479682 -0.56668764]
   [-0.35819843  0.88698614  0.52744764]]]] 

2D global average pooled random tensor: 
 [[-0.09368646 -0.45840448 -0.2885598 ]]

Shape of input tensor: (1, 4, 4, 3)

Shape of Global Average Pooled 2D tensor : (1, 3)

tf.reduce_mean(input_tensor, axis = [1,2])
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[-0.09368646, -0.45840448, -0.2885598 ]], dtype=float32)>
# Define the input shape
input_shape = (1,4,4,3)

# Create a random tensor
tf.random.set_seed(42)
input_tensor = tf.random.normal(input_shape)
print(f"Random input tensor: \n {input_tensor} \n")

# Pass the random tensor to the GlobalAveragePooling 2D layer
global_max_pooled_tensor = tf.keras.layers.GlobalMaxPooling2D()(input_tensor)
print(f"2D global max pooled random tensor: \n {global_max_pooled_tensor}\n")

# Check the shape of the different tensors
print(f"Shape of input tensor: {input_tensor.shape}\n")
print(f"Shape of Global Max Pooled 2D tensor : {global_max_pooled_tensor.shape}\n")
Random input tensor: 
 [[[[ 0.3274685  -0.8426258   0.3194337 ]
   [-1.4075519  -2.3880599  -1.0392479 ]
   [-0.5573232   0.539707    1.6994323 ]
   [ 0.28893656 -1.5066116  -0.2645474 ]]

  [[-0.59722406 -1.9171132  -0.62044144]
   [ 0.8504023  -0.40604794 -3.0258412 ]
   [ 0.9058464   0.29855987 -0.22561555]
   [-0.7616443  -1.8917141  -0.93847126]]

  [[ 0.77852213 -0.47338897  0.97772694]
   [ 0.24694404  0.20573747 -0.5256233 ]
   [ 0.32410017  0.02545409 -0.10638497]
   [-0.6369475   1.1603122   0.2507359 ]]

  [[-0.41728503  0.4012578  -1.4145443 ]
   [-0.5931857  -1.6617213   0.33567193]
   [ 0.10815629  0.23479682 -0.56668764]
   [-0.35819843  0.88698614  0.52744764]]]] 

2D global max pooled random tensor: 
 [[0.9058464 1.1603122 1.6994323]]

Shape of input tensor: (1, 4, 4, 3)

Shape of Global Max Pooled 2D tensor : (1, 3)

Note: One of the reasons feature extraction transfer learning is named how it is becuase what often happens is pre-trained model outputs a feature vector, a long tensor of number which represents the learned representation of the model on a particular sample, in our case, this is the output of the tf.keras.layers.GlobalAveragePooling2D() layer) which can then be used to extract patterns out of our own specific problem.

Feature Vector:

  • A feature vector is a learned representation of the input data (a compressed form of the input data based on how the model sees it)

Running transfer learning experiments

We have seen the incredible results transfer learning can get with only 10% of training data, but how does it go with 1% of training data. We will set up a couple of experiments to find out.

  1. model_1 - use feature extraction transfer learning with 1% of data with data augmentation.

  2. model_2 - use feature extraction transfer learning with 10% of the training data with data augmentation.

  3. model_3 - use fine-tuning transfer learning on 10% of the training data with data augmentation.

  4. model_4 - use fine-tuning transfer learning on 100% of the training data with data augmentation.

Note: throughout all experiments the same test dataset will be used to evaluate our model. This ensures consistency across evaluation metrics.

Getting and preprocessing data for model_1

!wget https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_1_percent.zip

unzip_data("10_food_classes_1_percent.zip")
--2022-02-20 05:20:41--  https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_1_percent.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.120.128, 74.125.70.128, 74.125.69.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.120.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 133612354 (127M) [application/zip]
Saving to: ‘10_food_classes_1_percent.zip’

10_food_classes_1_p 100%[===================>] 127.42M   160MB/s    in 0.8s    

2022-02-20 05:20:42 (160 MB/s) - ‘10_food_classes_1_percent.zip’ saved [133612354/133612354]

train_dir_1_percent = "10_food_classes_1_percent/train"
test_dir = "10_food_classes_1_percent/test"
walk_through_dir("10_food_classes_1_percent")
There are 2 directories and 0 images in '10_food_classes_1_percent'.
There are 10 directories and 0 images in '10_food_classes_1_percent/train'.
There are 0 directories and 7 images in '10_food_classes_1_percent/train/pizza'.
There are 0 directories and 7 images in '10_food_classes_1_percent/train/fried_rice'.
There are 0 directories and 7 images in '10_food_classes_1_percent/train/grilled_salmon'.
There are 0 directories and 7 images in '10_food_classes_1_percent/train/hamburger'.
There are 0 directories and 7 images in '10_food_classes_1_percent/train/steak'.
There are 0 directories and 7 images in '10_food_classes_1_percent/train/sushi'.
There are 0 directories and 7 images in '10_food_classes_1_percent/train/chicken_wings'.
There are 0 directories and 7 images in '10_food_classes_1_percent/train/ramen'.
There are 0 directories and 7 images in '10_food_classes_1_percent/train/ice_cream'.
There are 0 directories and 7 images in '10_food_classes_1_percent/train/chicken_curry'.
There are 10 directories and 0 images in '10_food_classes_1_percent/test'.
There are 0 directories and 250 images in '10_food_classes_1_percent/test/pizza'.
There are 0 directories and 250 images in '10_food_classes_1_percent/test/fried_rice'.
There are 0 directories and 250 images in '10_food_classes_1_percent/test/grilled_salmon'.
There are 0 directories and 250 images in '10_food_classes_1_percent/test/hamburger'.
There are 0 directories and 250 images in '10_food_classes_1_percent/test/steak'.
There are 0 directories and 250 images in '10_food_classes_1_percent/test/sushi'.
There are 0 directories and 250 images in '10_food_classes_1_percent/test/chicken_wings'.
There are 0 directories and 250 images in '10_food_classes_1_percent/test/ramen'.
There are 0 directories and 250 images in '10_food_classes_1_percent/test/ice_cream'.
There are 0 directories and 250 images in '10_food_classes_1_percent/test/chicken_curry'.
IMG_SIZE = (224,224)
BATCH_SIZE = 32
train_data_1_percent = tf.keras.preprocessing.image_dataset_from_directory(train_dir_1_percent,
                                                                           label_mode = "categorical",
                                                                           image_size = IMG_SIZE,
                                                                           batch_size = BATCH_SIZE) # default is 32

test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
                                                                label_mode = "categorical",
                                                                image_size = IMG_SIZE,
                                                                batch_size=BATCH_SIZE)
Found 70 files belonging to 10 classes.
Found 2500 files belonging to 10 classes.

Let's look at the Food Vision dataset we are using:

Dataset Name Source Classes Traning data Testing data
pizza_steak Food101 pizza,steak(2) 750 images of pizza and steak(same as original Food101 dataset) 250 images of pizza and steak (same as original Food101 dataset
10_food_classes_1_percent Same as above Chicken curry, Chicken wings, fried rice, grilled salmon, hamburger, ice cream, pizza, ramen, steak, sushi(10) 7 randomly selected images of each (1% of original training data) 250 images of each class(same as original Food101 dataset
10_food_classes_10_percent Same as above Same as above 75 randomly selected images of each class(10% of original training data) Same as above
10_food_classes_100_percent Same as above Same as above 750 images of each class (100% of original training data) Same as above
101_food_classes_10_percenet Same as above All classes from Food101(101) 75 images of each class (10% of original Food101 dataset) 250 images of each class (same as original Food101 dataset

Adding data augmentation into the model

To add data augmentation right into our models, we can use the layers inside:

  • tf.keras.layers.experimental.preprocessing()

We can see the benefits of doing this withing the TensorFlow data augmentation documentation: https://www.tensorflow.org/tutorials/images/data_augmentation

Main Benefits:

  • Preprocessing of image (augmenting them) happens on the GPU (much faster) rather than CPU.
  • Image data augmentation only happens during training, so we can still export our whole model and use it elsewhere.

Example of using data augmentation as the first layer within a model (EfficientNetB0).

The data augmentation transformations we're going to use are:

  • RandomFlip - flips image on horizontal or vertical axis.
  • RandomRotation - randomly rotates image by a specified amount.
  • RandomZoom - randomly zooms into an image by specified amount.
  • RandomHeight - randomly shifts image height by a specified amount.
  • RandomWidth - randomly shifts image width by a specified amount.
  • Rescaling - normalizes the image pixel values to be between 0 and 1, this is worth mentioning because it is required for some image models but since we're using the tf.keras.applications implementation of EfficientNetB0, it's not required.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing

# Create data augmentation stage with horizontal flipping, rotation, zoom, rotations etc.
data_augmentation = keras.Sequential([
  preprocessing.RandomFlip("horizontal"),
  preprocessing.RandomRotation(0.2),
  preprocessing.RandomZoom(0.2),
  preprocessing.RandomHeight(0.2),
  preprocessing.RandomWidth(0.2),
  # preprocessing.Rescale(1./255) # Keep for model like ResNet50V2 but Efficient has in-built rescaling                                     
], name = "data_augmentation")

Visualize our data augmentation layer

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import random
target_class = random.choice(train_data_1_percent.class_names) # choose a random class
target_dir = "10_food_classes_1_percent/train/" + target_class # create the target directory
random_image = random.choice(os.listdir(target_dir)) # choose a random image from target directory
random_image_path = target_dir + "/" + random_image # create the choosen random image path
img = mpimg.imread(random_image_path) # read in the chosen target image
plt.imshow(img) # plot the target image
plt.title(f"Original random image from class: {target_class}")
plt.axis(False); # turn off the axes

# Augment the image
augmented_img = data_augmentation(tf.expand_dims(img, axis=0)) # data augmentation model requires shape (None, height, width, 3)
plt.figure()
plt.imshow(tf.squeeze(augmented_img)/255.) # requires normalization after augmentation
plt.title(f"Augmented random image from class: {target_class}")
plt.axis(False);

Model 1 : Feature extraction Transfer learning on 1% of data with data augmentation

input_shape = (224,224,3)
base_model = tf.keras.applications.EfficientNetB0(include_top = False)
base_model.trainable = False
# Create input layers
inputs = layers.Input(shape = input_shape)

# Add in data augmentation Sequential model as a layer
x = data_augmentation(inputs)

# Give base model the inputs (after augmentation) and don't train it
x = base_model(x, training = False)

# Pool output features of the base model
x = layers.GlobalAveragePooling2D()(x)

# Put a dense layer on as the output
outputs = layers.Dense(10, activation = "softmax", name= "output_layer")(x)

# Make a model using the inputs and outputs
model_1 = keras.Model(inputs,outputs)

# Compile the Model
model_1.compile(loss = "categorical_crossentropy",
                optimizer = tf.keras.optimizers.Adam(),
                metrics = ["accuracy"])

# Fit the Model
history_1_percent = model_1.fit(train_data_1_percent,
                                epochs = 5,
                                steps_per_epoch = len(train_data_1_percent),
                                validation_data = test_data,
                                validation_steps = int(0.25*len(test_data)),
                                callbacks = [create_tensorboard_callback(dir_name="transfer_learning",
                                                                         experiment_name = "1_percent_data_aug")])
Saving TensorBoard log files to: transfer_learning/1_percent_data_aug/20220220-052221
Epoch 1/5
3/3 [==============================] - 15s 3s/step - loss: 2.3906 - accuracy: 0.1000 - val_loss: 2.2103 - val_accuracy: 0.1760
Epoch 2/5
3/3 [==============================] - 5s 2s/step - loss: 2.1522 - accuracy: 0.2000 - val_loss: 2.1057 - val_accuracy: 0.2599
Epoch 3/5
3/3 [==============================] - 5s 2s/step - loss: 1.9507 - accuracy: 0.3571 - val_loss: 2.0142 - val_accuracy: 0.3125
Epoch 4/5
3/3 [==============================] - 6s 3s/step - loss: 1.7978 - accuracy: 0.5143 - val_loss: 1.9147 - val_accuracy: 0.3882
Epoch 5/5
3/3 [==============================] - 5s 2s/step - loss: 1.6239 - accuracy: 0.6286 - val_loss: 1.8459 - val_accuracy: 0.4293
model_1.summary()
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 data_augmentation (Sequenti  (None, None, None, 3)    0         
 al)                                                             
                                                                 
 efficientnetb0 (Functional)  (None, None, None, 1280)  4049571  
                                                                 
 global_average_pooling2d_1   (None, 1280)             0         
 (GlobalAveragePooling2D)                                        
                                                                 
 output_layer (Dense)        (None, 10)                12810     
                                                                 
=================================================================
Total params: 4,062,381
Trainable params: 12,810
Non-trainable params: 4,049,571
_________________________________________________________________
results_1_percent_data_aug = model_1.evaluate(test_data)
results_1_percent_data_aug
79/79 [==============================] - 11s 132ms/step - loss: 1.8305 - accuracy: 0.4460
[1.8304558992385864, 0.44600000977516174]
plot_loss_curves(history_1_percent)

Model 2: Feature extraction transfer learning model with 10% of data and data augmentation

#!wget https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_10_percent.zip

##unzip_data(10_food_classes_10_percent)

train_dir_10_percent = "10_food_classes_10_percent/train"
test_dir = "10_food_classes_10_percent/test"
IMG_SIZE
(224, 224)
walk_through_dir("10_food_classes_10_percent")
There are 2 directories and 0 images in '10_food_classes_10_percent'.
There are 10 directories and 0 images in '10_food_classes_10_percent/train'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/pizza'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/fried_rice'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/grilled_salmon'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/hamburger'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/steak'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/sushi'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/chicken_wings'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/ramen'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/ice_cream'.
There are 0 directories and 75 images in '10_food_classes_10_percent/train/chicken_curry'.
There are 10 directories and 0 images in '10_food_classes_10_percent/test'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/pizza'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/fried_rice'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/grilled_salmon'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/hamburger'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/steak'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/sushi'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/chicken_wings'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/ramen'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/ice_cream'.
There are 0 directories and 250 images in '10_food_classes_10_percent/test/chicken_curry'.
import tensorflow as tf
IMG_SIZE = (224,224)
train_data_10_percent = tf.keras.preprocessing.image_dataset_from_directory(train_dir_10_percent,
                                                                            label_mode = "categorical",
                                                                            image_size =  (224,224))

test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
                                                                label_mode = "categorical",
                                                                image_size = IMG_SIZE)
Found 750 files belonging to 10 classes.
Found 2500 files belonging to 10 classes.
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.models import Sequential

# Build data augmentation layer
data_augmentation = Sequential([
  preprocessing.RandomFlip('horizontal'),
  preprocessing.RandomHeight(0.2),
  preprocessing.RandomWidth(0.2),
  preprocessing.RandomZoom(0.2),
  preprocessing.RandomRotation(0.2),
  # preprocessing.Rescaling(1./255) # keep for ResNet50V2, remove for EfficientNet                 
], name="data_augmentation")

# Setup the input shape to our model
input_shape = (224, 224, 3)

# Create a frozen base model
base_model = tf.keras.applications.EfficientNetB0(include_top=False)
base_model.trainable = False

# Create input and output layers
inputs = layers.Input(shape=input_shape, name="input_layer") # create input layer
x = data_augmentation(inputs) # augment our training images
x = base_model(x, training=False) # pass augmented images to base model but keep it in inference mode, so batchnorm layers don't get updated: https://keras.io/guides/transfer_learning/#build-a-model 
x = layers.GlobalAveragePooling2D(name="global_average_pooling_layer")(x)
outputs = layers.Dense(10, activation="softmax", name="output_layer")(x)
model_2 = tf.keras.Model(inputs, outputs)

# Compile
model_2.compile(loss="categorical_crossentropy",
              optimizer=tf.keras.optimizers.Adam(lr=0.001), # use Adam optimizer with base learning rate
              metrics=["accuracy"])
/usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  super(Adam, self).__init__(name, **kwargs)
model_2.summary()
Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_layer (InputLayer)    [(None, 224, 224, 3)]     0         
                                                                 
 data_augmentation (Sequenti  (None, 224, 224, 3)      0         
 al)                                                             
                                                                 
 efficientnetb0 (Functional)  (None, None, None, 1280)  4049571  
                                                                 
 global_average_pooling_laye  (None, 1280)             0         
 r (GlobalAveragePooling2D)                                      
                                                                 
 output_layer (Dense)        (None, 10)                12810     
                                                                 
=================================================================
Total params: 4,062,381
Trainable params: 12,810
Non-trainable params: 4,049,571
_________________________________________________________________

Creating a ModelCheckpoint Callback

  • Callbacks are a took which can add helpful functionality to your models during training, evaluation or inference.
  • Some popular callbacks include:
Callback name Use Case Code
TensorBoard Log the performance of multiple models and then view and compare these models in a visual way on Tensor Board. Helpfule to compare teh results of different models on your data tf.keras.callbacks.TensorBoard()
Model Checkpointing Save your model as it trains so you can stop training if needed and come back to continue off where you left. Helpful if training takes a long time and can't be done in one sitting tf.keras.callbacks.ModelCheckpoint()
Early Stopping Leave your model training for arbitray amount of time and have it stop training automatically when it ceases to improve. Helpful when you've got a large dataset and don't know long training will take tf.keras.callbacks.EarlyStopping()
checkpoint_path = "ten_percent_model_checkpoints_weights/checkpoint.ckpt"

# Create a modelcheckpoint callback that saves the model's weights only
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_path,
                                                         weights_only = True,
                                                         save_best_only = False,
                                                         save_freq= "epoch", #save every epoch
                                                         verbose = 1)

Fit Model 2 passing in the ModelCheckpoint callback

initial_epochs = 5
history_10_percent_data_aug = model_2.fit(train_data_10_percent,
                                          epochs = initial_epochs,
                                          validation_data= test_data,
                                          validation_steps = int(0.25* len(test_data)),
                                          callbacks = [create_tensorboard_callback(dir_name = "transfer_learning",
                                                                                    experiment_name = "10_percent_data_aug"),
                                                       checkpoint_callback])
Saving TensorBoard log files to: transfer_learning/10_percent_data_aug/20220220-062505
Epoch 1/5
24/24 [==============================] - ETA: 0s - loss: 2.0047 - accuracy: 0.3213
Epoch 1: saving model to ten_percent_model_checkpoints_weights/checkpoint.ckpt
INFO:tensorflow:Assets written to: ten_percent_model_checkpoints_weights/checkpoint.ckpt/assets
24/24 [==============================] - 81s 3s/step - loss: 2.0047 - accuracy: 0.3213 - val_loss: 1.4956 - val_accuracy: 0.6250
Epoch 2/5
24/24 [==============================] - ETA: 0s - loss: 1.3760 - accuracy: 0.6747
Epoch 2: saving model to ten_percent_model_checkpoints_weights/checkpoint.ckpt
INFO:tensorflow:Assets written to: ten_percent_model_checkpoints_weights/checkpoint.ckpt/assets
24/24 [==============================] - 93s 4s/step - loss: 1.3760 - accuracy: 0.6747 - val_loss: 1.0669 - val_accuracy: 0.7632
Epoch 3/5
24/24 [==============================] - ETA: 0s - loss: 1.0676 - accuracy: 0.7360
Epoch 3: saving model to ten_percent_model_checkpoints_weights/checkpoint.ckpt
INFO:tensorflow:Assets written to: ten_percent_model_checkpoints_weights/checkpoint.ckpt/assets
24/24 [==============================] - 69s 3s/step - loss: 1.0676 - accuracy: 0.7360 - val_loss: 0.8876 - val_accuracy: 0.7796
Epoch 4/5
24/24 [==============================] - ETA: 0s - loss: 0.9203 - accuracy: 0.7640
Epoch 4: saving model to ten_percent_model_checkpoints_weights/checkpoint.ckpt
INFO:tensorflow:Assets written to: ten_percent_model_checkpoints_weights/checkpoint.ckpt/assets
24/24 [==============================] - 66s 3s/step - loss: 0.9203 - accuracy: 0.7640 - val_loss: 0.7766 - val_accuracy: 0.8125
Epoch 5/5
24/24 [==============================] - ETA: 0s - loss: 0.8018 - accuracy: 0.7987
Epoch 5: saving model to ten_percent_model_checkpoints_weights/checkpoint.ckpt
INFO:tensorflow:Assets written to: ten_percent_model_checkpoints_weights/checkpoint.ckpt/assets
24/24 [==============================] - 68s 3s/step - loss: 0.8018 - accuracy: 0.7987 - val_loss: 0.6841 - val_accuracy: 0.8174
model_0.evaluate(test_data)
79/79 [==============================] - 12s 133ms/step - loss: 0.6102 - accuracy: 0.8412
[0.6101743578910828, 0.8411999940872192]
results_10_percent_data_aug = model_2.evaluate(test_data)
results_10_percent_data_aug
79/79 [==============================] - 11s 132ms/step - loss: 0.6865 - accuracy: 0.8160
[0.6864515542984009, 0.8159999847412109]
plot_loss_curves(history_10_percent_data_aug)

Loading in checkpointed weights

Loading in checkpointed weights returns a model to a specific checkpoint

model_2.load_weights(checkpoint_path)
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fcd9fd4e290>
loaded_weights_model_results = model_2.evaluate(test_data)
79/79 [==============================] - 12s 140ms/step - loss: 0.6865 - accuracy: 0.8160
results_10_percent_data_aug == loaded_weights_model_results
True

Model 3: Fine-tuning an existing model on 10% of the data

Note: Fine-tuning usually works best after training a feature extraction model with large amounts of custom data.

model_2.layers
[<keras.engine.input_layer.InputLayer at 0x7fcda877ea10>,
 <keras.engine.sequential.Sequential at 0x7fcda8755810>,
 <keras.engine.functional.Functional at 0x7fcda82e9650>,
 <keras.layers.pooling.GlobalAveragePooling2D at 0x7fcda877fcd0>,
 <keras.layers.core.dense.Dense at 0x7fcda8231690>]
for layer in model_2.layers:
  print(layer, layer.trainable)
# Checkout the trainable layers in our model (EfficientNetB0)
for i, layer in enumerate(model_2.layers[2].layers):
  print(i, layer.name , layer.trainable)

0 input_7 False
1 rescaling_4 False
2 normalization_4 False
3 stem_conv_pad False
4 stem_conv False
5 stem_bn False
6 stem_activation False
7 block1a_dwconv False
8 block1a_bn False
9 block1a_activation False
10 block1a_se_squeeze False
11 block1a_se_reshape False
12 block1a_se_reduce False
13 block1a_se_expand False
14 block1a_se_excite False
15 block1a_project_conv False
16 block1a_project_bn False
17 block2a_expand_conv False
18 block2a_expand_bn False
19 block2a_expand_activation False
20 block2a_dwconv_pad False
21 block2a_dwconv False
22 block2a_bn False
23 block2a_activation False
24 block2a_se_squeeze False
25 block2a_se_reshape False
26 block2a_se_reduce False
27 block2a_se_expand False
28 block2a_se_excite False
29 block2a_project_conv False
30 block2a_project_bn False
31 block2b_expand_conv False
32 block2b_expand_bn False
33 block2b_expand_activation False
34 block2b_dwconv False
35 block2b_bn False
36 block2b_activation False
37 block2b_se_squeeze False
38 block2b_se_reshape False
39 block2b_se_reduce False
40 block2b_se_expand False
41 block2b_se_excite False
42 block2b_project_conv False
43 block2b_project_bn False
44 block2b_drop False
45 block2b_add False
46 block3a_expand_conv False
47 block3a_expand_bn False
48 block3a_expand_activation False
49 block3a_dwconv_pad False
50 block3a_dwconv False
51 block3a_bn False
52 block3a_activation False
53 block3a_se_squeeze False
54 block3a_se_reshape False
55 block3a_se_reduce False
56 block3a_se_expand False
57 block3a_se_excite False
58 block3a_project_conv False
59 block3a_project_bn False
60 block3b_expand_conv False
61 block3b_expand_bn False
62 block3b_expand_activation False
63 block3b_dwconv False
64 block3b_bn False
65 block3b_activation False
66 block3b_se_squeeze False
67 block3b_se_reshape False
68 block3b_se_reduce False
69 block3b_se_expand False
70 block3b_se_excite False
71 block3b_project_conv False
72 block3b_project_bn False
73 block3b_drop False
74 block3b_add False
75 block4a_expand_conv False
76 block4a_expand_bn False
77 block4a_expand_activation False
78 block4a_dwconv_pad False
79 block4a_dwconv False
80 block4a_bn False
81 block4a_activation False
82 block4a_se_squeeze False
83 block4a_se_reshape False
84 block4a_se_reduce False
85 block4a_se_expand False
86 block4a_se_excite False
87 block4a_project_conv False
88 block4a_project_bn False
89 block4b_expand_conv False
90 block4b_expand_bn False
91 block4b_expand_activation False
92 block4b_dwconv False
93 block4b_bn False
94 block4b_activation False
95 block4b_se_squeeze False
96 block4b_se_reshape False
97 block4b_se_reduce False
98 block4b_se_expand False
99 block4b_se_excite False
100 block4b_project_conv False
101 block4b_project_bn False
102 block4b_drop False
103 block4b_add False
104 block4c_expand_conv False
105 block4c_expand_bn False
106 block4c_expand_activation False
107 block4c_dwconv False
108 block4c_bn False
109 block4c_activation False
110 block4c_se_squeeze False
111 block4c_se_reshape False
112 block4c_se_reduce False
113 block4c_se_expand False
114 block4c_se_excite False
115 block4c_project_conv False
116 block4c_project_bn False
117 block4c_drop False
118 block4c_add False
119 block5a_expand_conv False
120 block5a_expand_bn False
121 block5a_expand_activation False
122 block5a_dwconv False
123 block5a_bn False
124 block5a_activation False
125 block5a_se_squeeze False
126 block5a_se_reshape False
127 block5a_se_reduce False
128 block5a_se_expand False
129 block5a_se_excite False
130 block5a_project_conv False
131 block5a_project_bn False
132 block5b_expand_conv False
133 block5b_expand_bn False
134 block5b_expand_activation False
135 block5b_dwconv False
136 block5b_bn False
137 block5b_activation False
138 block5b_se_squeeze False
139 block5b_se_reshape False
140 block5b_se_reduce False
141 block5b_se_expand False
142 block5b_se_excite False
143 block5b_project_conv False
144 block5b_project_bn False
145 block5b_drop False
146 block5b_add False
147 block5c_expand_conv False
148 block5c_expand_bn False
149 block5c_expand_activation False
150 block5c_dwconv False
151 block5c_bn False
152 block5c_activation False
153 block5c_se_squeeze False
154 block5c_se_reshape False
155 block5c_se_reduce False
156 block5c_se_expand False
157 block5c_se_excite False
158 block5c_project_conv False
159 block5c_project_bn False
160 block5c_drop False
161 block5c_add False
162 block6a_expand_conv False
163 block6a_expand_bn False
164 block6a_expand_activation False
165 block6a_dwconv_pad False
166 block6a_dwconv False
167 block6a_bn False
168 block6a_activation False
169 block6a_se_squeeze False
170 block6a_se_reshape False
171 block6a_se_reduce False
172 block6a_se_expand False
173 block6a_se_excite False
174 block6a_project_conv False
175 block6a_project_bn False
176 block6b_expand_conv False
177 block6b_expand_bn False
178 block6b_expand_activation False
179 block6b_dwconv False
180 block6b_bn False
181 block6b_activation False
182 block6b_se_squeeze False
183 block6b_se_reshape False
184 block6b_se_reduce False
185 block6b_se_expand False
186 block6b_se_excite False
187 block6b_project_conv False
188 block6b_project_bn False
189 block6b_drop False
190 block6b_add False
191 block6c_expand_conv False
192 block6c_expand_bn False
193 block6c_expand_activation False
194 block6c_dwconv False
195 block6c_bn False
196 block6c_activation False
197 block6c_se_squeeze False
198 block6c_se_reshape False
199 block6c_se_reduce False
200 block6c_se_expand False
201 block6c_se_excite False
202 block6c_project_conv False
203 block6c_project_bn False
204 block6c_drop False
205 block6c_add False
206 block6d_expand_conv False
207 block6d_expand_bn False
208 block6d_expand_activation False
209 block6d_dwconv False
210 block6d_bn False
211 block6d_activation False
212 block6d_se_squeeze False
213 block6d_se_reshape False
214 block6d_se_reduce False
215 block6d_se_expand False
216 block6d_se_excite False
217 block6d_project_conv False
218 block6d_project_bn False
219 block6d_drop False
220 block6d_add False
221 block7a_expand_conv False
222 block7a_expand_bn False
223 block7a_expand_activation False
224 block7a_dwconv False
225 block7a_bn False
226 block7a_activation False
227 block7a_se_squeeze False
228 block7a_se_reshape False
229 block7a_se_reduce False
230 block7a_se_expand False
231 block7a_se_excite False
232 block7a_project_conv False
233 block7a_project_bn False
234 top_conv False
235 top_bn False
236 top_activation False
print(len(model_2.layers[2].trainable_variables))
0

Currently there are zero trainable variables in our base model

base_model.trainable = True

# Freeze all layers except for the last 10
for layer in base_model.layers[:-10]:
  layer.trainable= False

# Recompile the model (we have to recompile our model every time we make change)
model_2.compile(loss = "categorical_crossentropy",
                optimizer = tf.keras.optimizers.Adam(lr = 0.0001), # when fine-tuning you typically want to lower the lr by 10x
                metrics = ["accuracy"])
/usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  super(Adam, self).__init__(name, **kwargs)

Note: When using fine-tuning it's best practics to lower your learning rate. This is a hyperparamter you can tune. But a good rule of thumb is at least 10x (though different sources will claim other values)
Resource: Universal Language Model Fine-tuning for Text Classification paper by Jeremy Howard and Sebastian Ruder.

# Check which layers are tunable(trainable)
for layer_number, layer in enumerate(model_2.layers[2].layers):
  print(layer_number, layer.name, layer.trainable)

0 input_7 False
1 rescaling_4 False
2 normalization_4 False
3 stem_conv_pad False
4 stem_conv False
5 stem_bn False
6 stem_activation False
7 block1a_dwconv False
8 block1a_bn False
9 block1a_activation False
10 block1a_se_squeeze False
11 block1a_se_reshape False
12 block1a_se_reduce False
13 block1a_se_expand False
14 block1a_se_excite False
15 block1a_project_conv False
16 block1a_project_bn False
17 block2a_expand_conv False
18 block2a_expand_bn False
19 block2a_expand_activation False
20 block2a_dwconv_pad False
21 block2a_dwconv False
22 block2a_bn False
23 block2a_activation False
24 block2a_se_squeeze False
25 block2a_se_reshape False
26 block2a_se_reduce False
27 block2a_se_expand False
28 block2a_se_excite False
29 block2a_project_conv False
30 block2a_project_bn False
31 block2b_expand_conv False
32 block2b_expand_bn False
33 block2b_expand_activation False
34 block2b_dwconv False
35 block2b_bn False
36 block2b_activation False
37 block2b_se_squeeze False
38 block2b_se_reshape False
39 block2b_se_reduce False
40 block2b_se_expand False
41 block2b_se_excite False
42 block2b_project_conv False
43 block2b_project_bn False
44 block2b_drop False
45 block2b_add False
46 block3a_expand_conv False
47 block3a_expand_bn False
48 block3a_expand_activation False
49 block3a_dwconv_pad False
50 block3a_dwconv False
51 block3a_bn False
52 block3a_activation False
53 block3a_se_squeeze False
54 block3a_se_reshape False
55 block3a_se_reduce False
56 block3a_se_expand False
57 block3a_se_excite False
58 block3a_project_conv False
59 block3a_project_bn False
60 block3b_expand_conv False
61 block3b_expand_bn False
62 block3b_expand_activation False
63 block3b_dwconv False
64 block3b_bn False
65 block3b_activation False
66 block3b_se_squeeze False
67 block3b_se_reshape False
68 block3b_se_reduce False
69 block3b_se_expand False
70 block3b_se_excite False
71 block3b_project_conv False
72 block3b_project_bn False
73 block3b_drop False
74 block3b_add False
75 block4a_expand_conv False
76 block4a_expand_bn False
77 block4a_expand_activation False
78 block4a_dwconv_pad False
79 block4a_dwconv False
80 block4a_bn False
81 block4a_activation False
82 block4a_se_squeeze False
83 block4a_se_reshape False
84 block4a_se_reduce False
85 block4a_se_expand False
86 block4a_se_excite False
87 block4a_project_conv False
88 block4a_project_bn False
89 block4b_expand_conv False
90 block4b_expand_bn False
91 block4b_expand_activation False
92 block4b_dwconv False
93 block4b_bn False
94 block4b_activation False
95 block4b_se_squeeze False
96 block4b_se_reshape False
97 block4b_se_reduce False
98 block4b_se_expand False
99 block4b_se_excite False
100 block4b_project_conv False
101 block4b_project_bn False
102 block4b_drop False
103 block4b_add False
104 block4c_expand_conv False
105 block4c_expand_bn False
106 block4c_expand_activation False
107 block4c_dwconv False
108 block4c_bn False
109 block4c_activation False
110 block4c_se_squeeze False
111 block4c_se_reshape False
112 block4c_se_reduce False
113 block4c_se_expand False
114 block4c_se_excite False
115 block4c_project_conv False
116 block4c_project_bn False
117 block4c_drop False
118 block4c_add False
119 block5a_expand_conv False
120 block5a_expand_bn False
121 block5a_expand_activation False
122 block5a_dwconv False
123 block5a_bn False
124 block5a_activation False
125 block5a_se_squeeze False
126 block5a_se_reshape False
127 block5a_se_reduce False
128 block5a_se_expand False
129 block5a_se_excite False
130 block5a_project_conv False
131 block5a_project_bn False
132 block5b_expand_conv False
133 block5b_expand_bn False
134 block5b_expand_activation False
135 block5b_dwconv False
136 block5b_bn False
137 block5b_activation False
138 block5b_se_squeeze False
139 block5b_se_reshape False
140 block5b_se_reduce False
141 block5b_se_expand False
142 block5b_se_excite False
143 block5b_project_conv False
144 block5b_project_bn False
145 block5b_drop False
146 block5b_add False
147 block5c_expand_conv False
148 block5c_expand_bn False
149 block5c_expand_activation False
150 block5c_dwconv False
151 block5c_bn False
152 block5c_activation False
153 block5c_se_squeeze False
154 block5c_se_reshape False
155 block5c_se_reduce False
156 block5c_se_expand False
157 block5c_se_excite False
158 block5c_project_conv False
159 block5c_project_bn False
160 block5c_drop False
161 block5c_add False
162 block6a_expand_conv False
163 block6a_expand_bn False
164 block6a_expand_activation False
165 block6a_dwconv_pad False
166 block6a_dwconv False
167 block6a_bn False
168 block6a_activation False
169 block6a_se_squeeze False
170 block6a_se_reshape False
171 block6a_se_reduce False
172 block6a_se_expand False
173 block6a_se_excite False
174 block6a_project_conv False
175 block6a_project_bn False
176 block6b_expand_conv False
177 block6b_expand_bn False
178 block6b_expand_activation False
179 block6b_dwconv False
180 block6b_bn False
181 block6b_activation False
182 block6b_se_squeeze False
183 block6b_se_reshape False
184 block6b_se_reduce False
185 block6b_se_expand False
186 block6b_se_excite False
187 block6b_project_conv False
188 block6b_project_bn False
189 block6b_drop False
190 block6b_add False
191 block6c_expand_conv False
192 block6c_expand_bn False
193 block6c_expand_activation False
194 block6c_dwconv False
195 block6c_bn False
196 block6c_activation False
197 block6c_se_squeeze False
198 block6c_se_reshape False
199 block6c_se_reduce False
200 block6c_se_expand False
201 block6c_se_excite False
202 block6c_project_conv False
203 block6c_project_bn False
204 block6c_drop False
205 block6c_add False
206 block6d_expand_conv False
207 block6d_expand_bn False
208 block6d_expand_activation False
209 block6d_dwconv False
210 block6d_bn False
211 block6d_activation False
212 block6d_se_squeeze False
213 block6d_se_reshape False
214 block6d_se_reduce False
215 block6d_se_expand False
216 block6d_se_excite False
217 block6d_project_conv False
218 block6d_project_bn False
219 block6d_drop False
220 block6d_add False
221 block7a_expand_conv False
222 block7a_expand_bn False
223 block7a_expand_activation False
224 block7a_dwconv False
225 block7a_bn False
226 block7a_activation False
227 block7a_se_squeeze True
228 block7a_se_reshape True
229 block7a_se_reduce True
230 block7a_se_expand True
231 block7a_se_excite True
232 block7a_project_conv True
233 block7a_project_bn True
234 top_conv True
235 top_bn True
236 top_activation True
print(len(model_2.trainable_variables))
12
fine_tune_epochs = initial_epochs +5

# Refit the mdoel(same as model_2 except with more trainable layers)
history_fine_10_percent_data_aug = model_2.fit(train_data_10_percent,
                                               epochs = fine_tune_epochs,
                                               validation_data = test_data,
                                               validation_steps = int(0.25* len(test_data)),
                                               initial_epoch = history_10_percent_data_aug.epoch[-1], #start training from prev last epoch 
                                               callbacks = [create_tensorboard_callback(dir_name = "transfer_learning",
                                                                                        experiment_name = "10_percent_fine_tune_last_10")])
Saving TensorBoard log files to: transfer_learning/10_percent_fine_tune_last_10/20220220-071428
Epoch 5/10
24/24 [==============================] - 28s 734ms/step - loss: 0.6990 - accuracy: 0.8000 - val_loss: 0.5796 - val_accuracy: 0.8224
Epoch 6/10
24/24 [==============================] - 14s 556ms/step - loss: 0.5585 - accuracy: 0.8360 - val_loss: 0.5457 - val_accuracy: 0.8273
Epoch 7/10
24/24 [==============================] - 13s 531ms/step - loss: 0.5267 - accuracy: 0.8360 - val_loss: 0.5517 - val_accuracy: 0.8240
Epoch 8/10
24/24 [==============================] - 14s 586ms/step - loss: 0.4685 - accuracy: 0.8520 - val_loss: 0.5281 - val_accuracy: 0.8273
Epoch 9/10
24/24 [==============================] - 12s 493ms/step - loss: 0.3966 - accuracy: 0.8827 - val_loss: 0.5169 - val_accuracy: 0.8273
Epoch 10/10
24/24 [==============================] - 13s 514ms/step - loss: 0.3901 - accuracy: 0.8827 - val_loss: 0.5140 - val_accuracy: 0.8355
results_fine_tune_10_percent = model_2.evaluate(test_data)
79/79 [==============================] - 11s 136ms/step - loss: 0.4815 - accuracy: 0.8392
plot_loss_curves(history_fine_10_percent_data_aug)

The plot_loss_curves function works great with models which have only been fit once, however, we want something to compare one series of running fit() with another(e.g. before and after fine-tuning)

def compare_history(original_history, new_history, initial_epochs = 5):
  """
  Compares two TensorFlow History objects
  """

  # Get original history measurements 
  acc = original_history.history["accuracy"]
  loss = original_history.history["loss"]

  val_acc = original_history.history["val_accuracy"]
  val_loss = original_history.history["val_loss"]

  # Combine original history
  total_acc = acc + new_history.history["accuracy"]
  total_loss = loss + new_history.history["loss"]

  total_val_acc = val_acc + new_history.history["val_accuracy"]
  total_val_loss = val_loss + new_history.history["val_loss"]

  # Make plot for Accuracy
  plt.figure(figsize = (8,8))
  plt.subplot(2,1,1)
  plt.plot(total_acc, label = "Training Accuracy")
  plt.plot(total_val_acc, label = "Val Accuracy")
  plt.plot([initial_epochs-1, initial_epochs-1], plt.ylim(), label = "Start Fine tuning")
  plt.legend(loc = "lower right")
  plt.title("Trainable and Validation Accuracy")

  # Make plot for Loss
  plt.figure(figsize = (8,8))
  plt.subplot(2,1,2)
  plt.plot(total_loss, label = "Training Loss")
  plt.plot(total_val_loss, label = "Val Loss")
  plt.plot([initial_epochs-1, initial_epochs-1], plt.ylim(), label = "Start Fine tuning")
  plt.legend(loc = "upper right")
  plt.title("Trainable and Validation Accuracy")
compare_history(history_10_percent_data_aug,
                history_fine_10_percent_data_aug,
                initial_epochs = 5)

Model 4: Fine-tuning and existing model on the full dataset

!wget https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_all_data.zip
unzip_data("10_food_classes_all_data.zip")
--2022-02-20 07:57:52--  https://storage.googleapis.com/ztm_tf_course/food_vision/10_food_classes_all_data.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.136.128, 142.250.148.128, 108.177.112.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.136.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 519183241 (495M) [application/zip]
Saving to: ‘10_food_classes_all_data.zip’

10_food_classes_all 100%[===================>] 495.13M   189MB/s    in 2.6s    

2022-02-20 07:57:55 (189 MB/s) - ‘10_food_classes_all_data.zip’ saved [519183241/519183241]

train_dir_all_data = "10_food_classes_all_data/train"
test_dir = "10_food_classes_all_data/test"
walk_through_dir("10_food_classes_all_data")
There are 2 directories and 0 images in '10_food_classes_all_data'.
There are 10 directories and 0 images in '10_food_classes_all_data/train'.
There are 0 directories and 750 images in '10_food_classes_all_data/train/pizza'.
There are 0 directories and 750 images in '10_food_classes_all_data/train/fried_rice'.
There are 0 directories and 750 images in '10_food_classes_all_data/train/grilled_salmon'.
There are 0 directories and 750 images in '10_food_classes_all_data/train/hamburger'.
There are 0 directories and 750 images in '10_food_classes_all_data/train/steak'.
There are 0 directories and 750 images in '10_food_classes_all_data/train/sushi'.
There are 0 directories and 750 images in '10_food_classes_all_data/train/chicken_wings'.
There are 0 directories and 750 images in '10_food_classes_all_data/train/ramen'.
There are 0 directories and 750 images in '10_food_classes_all_data/train/ice_cream'.
There are 0 directories and 750 images in '10_food_classes_all_data/train/chicken_curry'.
There are 10 directories and 0 images in '10_food_classes_all_data/test'.
There are 0 directories and 250 images in '10_food_classes_all_data/test/pizza'.
There are 0 directories and 250 images in '10_food_classes_all_data/test/fried_rice'.
There are 0 directories and 250 images in '10_food_classes_all_data/test/grilled_salmon'.
There are 0 directories and 250 images in '10_food_classes_all_data/test/hamburger'.
There are 0 directories and 250 images in '10_food_classes_all_data/test/steak'.
There are 0 directories and 250 images in '10_food_classes_all_data/test/sushi'.
There are 0 directories and 250 images in '10_food_classes_all_data/test/chicken_wings'.
There are 0 directories and 250 images in '10_food_classes_all_data/test/ramen'.
There are 0 directories and 250 images in '10_food_classes_all_data/test/ice_cream'.
There are 0 directories and 250 images in '10_food_classes_all_data/test/chicken_curry'.
import tensorflow as tf
IMG_SIZE = (224,224)
train_data_10_classes_full = tf.keras.preprocessing.image_dataset_from_directory(train_dir_all_data,
                                                                                 label_mode = "categorical",
                                                                                 image_size = IMG_SIZE)

test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
                                                                label_mode = "categorical",
                                                                image_size = IMG_SIZE)
Found 7500 files belonging to 10 classes.
Found 2500 files belonging to 10 classes.

The test dataset we have loaded in is the same as what we've been using for previous experiments( all experiments have used the same test dataset).

Let's verify this...

model_2.evaluate(test_data)
79/79 [==============================] - 12s 141ms/step - loss: 0.4815 - accuracy: 0.8392
[0.48150259256362915, 0.8392000198364258]
results_fine_tune_10_percent
[0.48150259256362915, 0.8392000198364258]

To train a fine-tuning model (model_4) we need to revert model_2 back to its feature extraction weights

# the same stage the 10 percent data model was fine-tuned from
model_2.load_weights(checkpoint_path)
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fcdab2c5950>
model_2.evaluate(test_data)
79/79 [==============================] - 12s 140ms/step - loss: 0.6865 - accuracy: 0.8160
[0.6864517331123352, 0.8159999847412109]
results_10_percent_data_aug
[0.6864515542984009, 0.8159999847412109]

What we have done till now:

  1. Trained a feature extraction transfer learning model for 5 epochs on 10% of the data with data augmentation (model_2) and we saved the model's weights using ModelCheckpoint callback.

  2. Fine-tuned the same model on the same 10% of the data for further 5 epochs with the top 10 layers of the base model unfrozen (model_3).

  3. Saved the results and training logs each time

  4. Reloaded the model from step 1 to do the same steps as step 2 except this time we are going to use all the data (model_4).

for layer_number, layer in enumerate(model_2.layers):
  print(layer_number, layer.name, layer.trainable)
0 input_layer True
1 data_augmentation True
2 efficientnetb0 True
3 global_average_pooling_layer True
4 output_layer True
# Let's drill into our base_model (EfficientNetB0) check what layers are trainable
for layer_number, layer in enumerate(model_2.layers[2].layers):
  print(layer_number, layer.name, layer.trainable)

0 input_7 False
1 rescaling_4 False
2 normalization_4 False
3 stem_conv_pad False
4 stem_conv False
5 stem_bn False
6 stem_activation False
7 block1a_dwconv False
8 block1a_bn False
9 block1a_activation False
10 block1a_se_squeeze False
11 block1a_se_reshape False
12 block1a_se_reduce False
13 block1a_se_expand False
14 block1a_se_excite False
15 block1a_project_conv False
16 block1a_project_bn False
17 block2a_expand_conv False
18 block2a_expand_bn False
19 block2a_expand_activation False
20 block2a_dwconv_pad False
21 block2a_dwconv False
22 block2a_bn False
23 block2a_activation False
24 block2a_se_squeeze False
25 block2a_se_reshape False
26 block2a_se_reduce False
27 block2a_se_expand False
28 block2a_se_excite False
29 block2a_project_conv False
30 block2a_project_bn False
31 block2b_expand_conv False
32 block2b_expand_bn False
33 block2b_expand_activation False
34 block2b_dwconv False
35 block2b_bn False
36 block2b_activation False
37 block2b_se_squeeze False
38 block2b_se_reshape False
39 block2b_se_reduce False
40 block2b_se_expand False
41 block2b_se_excite False
42 block2b_project_conv False
43 block2b_project_bn False
44 block2b_drop False
45 block2b_add False
46 block3a_expand_conv False
47 block3a_expand_bn False
48 block3a_expand_activation False
49 block3a_dwconv_pad False
50 block3a_dwconv False
51 block3a_bn False
52 block3a_activation False
53 block3a_se_squeeze False
54 block3a_se_reshape False
55 block3a_se_reduce False
56 block3a_se_expand False
57 block3a_se_excite False
58 block3a_project_conv False
59 block3a_project_bn False
60 block3b_expand_conv False
61 block3b_expand_bn False
62 block3b_expand_activation False
63 block3b_dwconv False
64 block3b_bn False
65 block3b_activation False
66 block3b_se_squeeze False
67 block3b_se_reshape False
68 block3b_se_reduce False
69 block3b_se_expand False
70 block3b_se_excite False
71 block3b_project_conv False
72 block3b_project_bn False
73 block3b_drop False
74 block3b_add False
75 block4a_expand_conv False
76 block4a_expand_bn False
77 block4a_expand_activation False
78 block4a_dwconv_pad False
79 block4a_dwconv False
80 block4a_bn False
81 block4a_activation False
82 block4a_se_squeeze False
83 block4a_se_reshape False
84 block4a_se_reduce False
85 block4a_se_expand False
86 block4a_se_excite False
87 block4a_project_conv False
88 block4a_project_bn False
89 block4b_expand_conv False
90 block4b_expand_bn False
91 block4b_expand_activation False
92 block4b_dwconv False
93 block4b_bn False
94 block4b_activation False
95 block4b_se_squeeze False
96 block4b_se_reshape False
97 block4b_se_reduce False
98 block4b_se_expand False
99 block4b_se_excite False
100 block4b_project_conv False
101 block4b_project_bn False
102 block4b_drop False
103 block4b_add False
104 block4c_expand_conv False
105 block4c_expand_bn False
106 block4c_expand_activation False
107 block4c_dwconv False
108 block4c_bn False
109 block4c_activation False
110 block4c_se_squeeze False
111 block4c_se_reshape False
112 block4c_se_reduce False
113 block4c_se_expand False
114 block4c_se_excite False
115 block4c_project_conv False
116 block4c_project_bn False
117 block4c_drop False
118 block4c_add False
119 block5a_expand_conv False
120 block5a_expand_bn False
121 block5a_expand_activation False
122 block5a_dwconv False
123 block5a_bn False
124 block5a_activation False
125 block5a_se_squeeze False
126 block5a_se_reshape False
127 block5a_se_reduce False
128 block5a_se_expand False
129 block5a_se_excite False
130 block5a_project_conv False
131 block5a_project_bn False
132 block5b_expand_conv False
133 block5b_expand_bn False
134 block5b_expand_activation False
135 block5b_dwconv False
136 block5b_bn False
137 block5b_activation False
138 block5b_se_squeeze False
139 block5b_se_reshape False
140 block5b_se_reduce False
141 block5b_se_expand False
142 block5b_se_excite False
143 block5b_project_conv False
144 block5b_project_bn False
145 block5b_drop False
146 block5b_add False
147 block5c_expand_conv False
148 block5c_expand_bn False
149 block5c_expand_activation False
150 block5c_dwconv False
151 block5c_bn False
152 block5c_activation False
153 block5c_se_squeeze False
154 block5c_se_reshape False
155 block5c_se_reduce False
156 block5c_se_expand False
157 block5c_se_excite False
158 block5c_project_conv False
159 block5c_project_bn False
160 block5c_drop False
161 block5c_add False
162 block6a_expand_conv False
163 block6a_expand_bn False
164 block6a_expand_activation False
165 block6a_dwconv_pad False
166 block6a_dwconv False
167 block6a_bn False
168 block6a_activation False
169 block6a_se_squeeze False
170 block6a_se_reshape False
171 block6a_se_reduce False
172 block6a_se_expand False
173 block6a_se_excite False
174 block6a_project_conv False
175 block6a_project_bn False
176 block6b_expand_conv False
177 block6b_expand_bn False
178 block6b_expand_activation False
179 block6b_dwconv False
180 block6b_bn False
181 block6b_activation False
182 block6b_se_squeeze False
183 block6b_se_reshape False
184 block6b_se_reduce False
185 block6b_se_expand False
186 block6b_se_excite False
187 block6b_project_conv False
188 block6b_project_bn False
189 block6b_drop False
190 block6b_add False
191 block6c_expand_conv False
192 block6c_expand_bn False
193 block6c_expand_activation False
194 block6c_dwconv False
195 block6c_bn False
196 block6c_activation False
197 block6c_se_squeeze False
198 block6c_se_reshape False
199 block6c_se_reduce False
200 block6c_se_expand False
201 block6c_se_excite False
202 block6c_project_conv False
203 block6c_project_bn False
204 block6c_drop False
205 block6c_add False
206 block6d_expand_conv False
207 block6d_expand_bn False
208 block6d_expand_activation False
209 block6d_dwconv False
210 block6d_bn False
211 block6d_activation False
212 block6d_se_squeeze False
213 block6d_se_reshape False
214 block6d_se_reduce False
215 block6d_se_expand False
216 block6d_se_excite False
217 block6d_project_conv False
218 block6d_project_bn False
219 block6d_drop False
220 block6d_add False
221 block7a_expand_conv False
222 block7a_expand_bn False
223 block7a_expand_activation False
224 block7a_dwconv False
225 block7a_bn False
226 block7a_activation False
227 block7a_se_squeeze True
228 block7a_se_reshape True
229 block7a_se_reduce True
230 block7a_se_expand True
231 block7a_se_excite True
232 block7a_project_conv True
233 block7a_project_bn True
234 top_conv True
235 top_bn True
236 top_activation True
model_2.compile(loss = "categorical_crossentropy",
                optimizer = tf.keras.optimizers.Adam(lr = 0.0001),
                metrics = ["accuracy"])
/usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  super(Adam, self).__init__(name, **kwargs)
fine_tune_epochs = initial_epochs + 5

history_fine_10_classes_full = model_2.fit(train_data_10_classes_full,
                                           epochs = fine_tune_epochs,
                                           validation_data = test_data,
                                           validation_steps = int(0.25*len(test_data)),
                                           initial_epoch = history_10_percent_data_aug.epoch[-1],
                                           callbacks = [create_tensorboard_callback(dir_name = "transfer_learning",
                                                                                    experiment_name = "full_10_classes_fine_tune_last_10")])
Saving TensorBoard log files to: transfer_learning/full_10_classes_fine_tune_last_10/20220220-085835
Epoch 5/10
235/235 [==============================] - 97s 366ms/step - loss: 0.7321 - accuracy: 0.7609 - val_loss: 0.4013 - val_accuracy: 0.8635
Epoch 6/10
235/235 [==============================] - 79s 335ms/step - loss: 0.6010 - accuracy: 0.8061 - val_loss: 0.3560 - val_accuracy: 0.8882
Epoch 7/10
235/235 [==============================] - 75s 314ms/step - loss: 0.5271 - accuracy: 0.8305 - val_loss: 0.3115 - val_accuracy: 0.9013
Epoch 8/10
235/235 [==============================] - 70s 293ms/step - loss: 0.4879 - accuracy: 0.8433 - val_loss: 0.3053 - val_accuracy: 0.9013
Epoch 9/10
235/235 [==============================] - 64s 270ms/step - loss: 0.4517 - accuracy: 0.8539 - val_loss: 0.2934 - val_accuracy: 0.9161
Epoch 10/10
235/235 [==============================] - 63s 265ms/step - loss: 0.4240 - accuracy: 0.8677 - val_loss: 0.2869 - val_accuracy: 0.9194
results_fine_tune_full_data = model_2.evaluate(test_data)
results_fine_tune_full_data
79/79 [==============================] - 11s 133ms/step - loss: 0.3047 - accuracy: 0.9052
[0.3046596944332123, 0.9052000045776367]
compare_history(original_history= history_10_percent_data_aug,
                new_history = history_fine_10_classes_full,
                initial_epochs = 5)

Viewing our experiment data on TensorBoard

Note: Anything you upload to TensorBoard.dev is going to be public. So, if you have private data, do not upload.

# Upload TensorBoard dev records
#!tensorboard dev upload --logdir ./transfer_learning \
#  --name "Transfer Learning Experiments with 10 FOOD101 classes" \
#  --description " A series of different transfer learning experiments with varying amount of data" \
#  --one_shot # Exits the uploader once its finished uploading

#Run the above line by taking of the comments for the last four lines to upload the experiments to tensorboard

My TensorBoard Experiments are avaliable at : https://tensorboard.dev/experiment/0ZGWuA1vTv2NAdhoRmATDw/#scalars

# !tensorboard dev list
# TO delete a particular experiment
#!tensorboard dev delete --experiment_id {type the id here}
Back to top of page