TensorFlow Hub with Keras

TensorFlow Hub is a way to share pretrained model components. See the TensorFlow Module Hub for a searchable listing of pre-trained models. This tutorial demonstrates:

  1. How to use TensorFlow Hub with Keras.
  2. How to do image classification using TensorFlow Hub.
  3. How to do simple transfer learning.



An ImageNet classifier

Download the classifier

Use layer_hub to load a mobilenet and transform it into a Keras layer. Any TensorFlow 2 compatible image classifier URL from tfhub.dev will work here.

classifier_url <- "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2" 
mobilenet_layer <- layer_hub(handle = classifier_url)

We can then create our Keras model:

input <- layer_input(shape = c(224, 224, 3))
output <- input %>% 

model <- keras_model(input, output)

Run it on a single image

Download a single image to try the model on.

img <- image_read('https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg') %>%
  image_resize(geometry = "224x224x3!") %>% 
  image_data() %>% 
  as.numeric() %>% 
  abind::abind(along = 0) # expand to batch dimension
result <- predict(model, img)
mobilenet_decode_predictions(result[,-1, drop = FALSE])

Simple transfer learning

Using TF Hub it is simple to retrain the top layer of the model to recognize the classes in our dataset.


For this example you will use the TensorFlow flowers dataset:

data_root <- pins::pin("https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", "flower_photos")
data_root <- fs::path_dir(fs::path_dir(data_root[100])) # go down 2 levels

The simplest way to load this data into our model is using image_data_generator

All of TensorFlow Hub’s image modules expect float inputs in the [0, 1] range. Use the image_data_generator’s rescale parameter to achieve this.

image_generator <- image_data_generator(rescale = 1/255, validation_split = 0.2)
training_data <- flow_images_from_directory(
  directory = data_root, 
  generator = image_generator,
  target_size = c(224, 224), 
  subset = "training"

validation_data <- flow_images_from_directory(
  directory = data_root, 
  generator = image_generator,
  target_size = c(224, 224), 
  subset = "validation"

The resulting object is an iterator that returns image_batch, label_batch pairs.

Download the headless model

TensorFlow Hub also distributes models without the top classification layer. These can be used to easily do transfer learning.

Any Tensorflow 2 compatible image feature vector URL from tfhub.dev will work here.

feature_extractor_url <- "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/2"
feature_extractor_layer <- layer_hub(handle = feature_extractor_url)

Attach a classification head

Now we can create our classification model by attaching a classification head into the feature extractor layer. We define the following model:

input <- layer_input(shape = c(224, 224, 3))
output <- input %>% 
  feature_extractor_layer() %>% 
  layer_dense(units = training_data$num_classes, activation = "softmax")

model <- keras_model(input, output)

Train the model

We can now train our model in the same way we would train any other Keras model. We first use compile to configure the training process:

model %>% 
    loss = "categorical_crossentropy",
    optimizer = "adam",
    metrics = "acc"

We can then use the fit function to fit our model.

model %>% 
    steps_per_epoch = training_data$n/training_data$batch_size,
    validation_data = validation_data

You can then export your model with:

save_model_tf(model, "model")

You can also reload the model_from_saved_model function. Note that you need to pass the custom_object with the definition of the KerasLayer since it/s not a default Keras layer.

reloaded_model <- load_model_tf("model")

We can verify that the predictions of both the trained model and the reloaded model are equal:

steps <- as.integer(validation_data$n/validation_data$batch_size)
  predict_generator(model, validation_data, steps = steps),
  predict_generator(reloaded_model, validation_data, steps = steps),

The saved model can also be loaded for inference later or be converted to TFLite or TFjs.