Skip to content

Commit

Permalink
Add EfficientNet benchmark (#394)
Browse files Browse the repository at this point in the history
* init

* refactor

* minor improvements

* Remove softmax activation

* small

* rename

* redo the tf change

---------

Co-authored-by: chenmoneygithub <chenmoney@chenmoney-gpu-scratch.us-west1-a.c.keras-team-gcp.internal>
  • Loading branch information
chenmoneygithub and chenmoneygithub authored Jun 23, 2023
1 parent 5973be2 commit df95e7c
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 0 deletions.
Empty file.
24 changes: 24 additions & 0 deletions benchmarks/model_benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import time

import keras_core


class BenchmarkMetricsCallback(keras_core.callbacks.Callback):
def __init__(self, start_batch=1, stop_batch=None):
self.start_batch = start_batch
self.stop_batch = stop_batch

# Store the throughput of each epoch.
self.state = {"throughput": []}

def on_train_batch_begin(self, batch, logs=None):
if batch == self.start_batch:
self.state["epoch_begin_time"] = time.time()

def on_train_batch_end(self, batch, logs=None):
if batch == self.stop_batch:
epoch_end_time = time.time()
throughput = (self.stop_batch - self.start_batch + 1) / (
epoch_end_time - self.state["epoch_begin_time"]
)
self.state["throughput"].append(throughput)
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Image classification benchmark with EfficientNetV2B0.
To run the benchmark, make sure you are in model_benchmark/ directory, and run
the command below:
python3 -m model_benchmark.resnet_image_classification_benchmark \
--epochs=2 \
--batch_size=32
"""

import time

import keras_core as keras
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from absl import app
from absl import flags
from absl import logging


from model_benchmark.benchmark_utils import BenchmarkMetricsCallback
from keras_core.applications import EfficientNetV2B0

flags.DEFINE_integer("epochs", 1, "The number of epochs.")
flags.DEFINE_integer("batch_size", 4, "Batch Size.")
flags.DEFINE_string(
"mixed_precision_policy",
"mixed_float16",
"The global precision policy to use, e.g., 'mixed_float16' or 'float32'.",
)

FLAGS = flags.FLAGS

BATCH_SIZE = 32
IMAGE_SIZE = (224, 224)
CHANNELS = 3


def load_data():
# Load cats vs dogs dataset, and split into train and validation sets.
train_dataset, val_dataset = tfds.load(
"cats_vs_dogs", split=["train[:90%]", "train[90%:]"], as_supervised=True
)

resizing = keras.layers.Resizing(
IMAGE_SIZE[0], IMAGE_SIZE[1], crop_to_aspect_ratio=True
)

def preprocess_inputs(image, label):
image = tf.cast(image, "float32")
return resizing(image), label

train_dataset = (
train_dataset.map(
preprocess_inputs, num_parallel_calls=tf.data.AUTOTUNE
)
.batch(FLAGS.batch_size)
.cache()
.prefetch(tf.data.AUTOTUNE)
)
val_dataset = (
val_dataset.map(preprocess_inputs, num_parallel_calls=tf.data.AUTOTUNE)
.batch(FLAGS.batch_size)
.cache()
.prefetch(tf.data.AUTOTUNE)
)
return train_dataset, val_dataset


def load_model():
# Load the EfficientNetV2B0 model and add a classification head.
model = EfficientNetV2B0(include_top=False, weights="imagenet")
classifier = keras.models.Sequential(
[
keras.Input([IMAGE_SIZE[0], IMAGE_SIZE[1], CHANNELS]),
model,
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(2),
]
)
return classifier


def main(_):
keras.mixed_precision.set_dtype_policy(FLAGS.mixed_precision_policy)

logging.info(
"Benchmarking configs...\n"
"=========================\n"
f"MODEL: EfficientNetV2B0\n"
f"TASK: image classification/dogs-vs-cats \n"
f"BATCH_SIZE: {FLAGS.batch_size}\n"
f"EPOCHS: {FLAGS.epochs}\n"
"=========================\n"
)

# Load datasets.
train_ds, validation_ds = load_data()

# Load the model.
classifier = load_model()

lr = keras.optimizers.schedules.PolynomialDecay(
5e-4,
decay_steps=train_ds.cardinality() * FLAGS.epochs,
end_learning_rate=0.0,
)
optimizer = keras.optimizers.Adam(lr)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

benchmark_metrics_callback = BenchmarkMetricsCallback(
start_batch=1,
stop_batch=train_ds.cardinality().numpy()-1,
)

classifier.compile(
optimizer=optimizer,
loss=loss,
metrics=["sparse_categorical_accuracy"],
)
# Start training.
logging.info("Starting Training...")

st = time.time()

history = classifier.fit(
train_ds,
validation_data=validation_ds,
epochs=FLAGS.epochs,
callbacks=[benchmark_metrics_callback],
)

wall_time = time.time() - st
validation_accuracy = history.history["val_sparse_categorical_accuracy"][-1]

examples_per_second = np.mean(
np.array(benchmark_metrics_callback.state["throughput"])
) * FLAGS.batch_size

logging.info("Training Finished!")
logging.info(f"Wall Time: {wall_time:.4f} seconds.")
logging.info(f"Validation Accuracy: {validation_accuracy:.4f}")
logging.info(f"examples_per_second: {examples_per_second:.4f}")


if __name__ == "__main__":
app.run(main)
1 change: 1 addition & 0 deletions keras_core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from keras_core import activations
from keras_core import applications
from keras_core import backend
from keras_core import constraints
from keras_core import datasets
Expand Down
14 changes: 14 additions & 0 deletions keras_core/backend/tensorflow/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,22 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
)


def _get_concrete_noise_shape(inputs, noise_shape):
if noise_shape is None:
return tf.shape(inputs)

concrete_inputs_shape = tf.shape(inputs)
concrete_noise_shape = []
for i, value in enumerate(noise_shape):
concrete_noise_shape.append(
concrete_inputs_shape[i] if value is None else value
)
return concrete_noise_shape


def dropout(inputs, rate, noise_shape=None, seed=None):
seed = tf_draw_seed(seed)
noise_shape = _get_concrete_noise_shape(inputs, noise_shape)
return tf.nn.experimental.stateless_dropout(
inputs,
rate=rate,
Expand Down

0 comments on commit df95e7c

Please sign in to comment.