-
Notifications
You must be signed in to change notification settings - Fork 117
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init * refactor * minor improvements * Remove softmax activation * small * rename * redo the tf change --------- Co-authored-by: chenmoneygithub <chenmoney@chenmoney-gpu-scratch.us-west1-a.c.keras-team-gcp.internal>
- Loading branch information
1 parent
5973be2
commit df95e7c
Showing
5 changed files
with
187 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import time | ||
|
||
import keras_core | ||
|
||
|
||
class BenchmarkMetricsCallback(keras_core.callbacks.Callback): | ||
def __init__(self, start_batch=1, stop_batch=None): | ||
self.start_batch = start_batch | ||
self.stop_batch = stop_batch | ||
|
||
# Store the throughput of each epoch. | ||
self.state = {"throughput": []} | ||
|
||
def on_train_batch_begin(self, batch, logs=None): | ||
if batch == self.start_batch: | ||
self.state["epoch_begin_time"] = time.time() | ||
|
||
def on_train_batch_end(self, batch, logs=None): | ||
if batch == self.stop_batch: | ||
epoch_end_time = time.time() | ||
throughput = (self.stop_batch - self.start_batch + 1) / ( | ||
epoch_end_time - self.state["epoch_begin_time"] | ||
) | ||
self.state["throughput"].append(throughput) |
148 changes: 148 additions & 0 deletions
148
benchmarks/model_benchmark/efficient_net_image_classification_benchmark.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
"""Image classification benchmark with EfficientNetV2B0. | ||
To run the benchmark, make sure you are in model_benchmark/ directory, and run | ||
the command below: | ||
python3 -m model_benchmark.resnet_image_classification_benchmark \ | ||
--epochs=2 \ | ||
--batch_size=32 | ||
""" | ||
|
||
import time | ||
|
||
import keras_core as keras | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
from absl import app | ||
from absl import flags | ||
from absl import logging | ||
|
||
|
||
from model_benchmark.benchmark_utils import BenchmarkMetricsCallback | ||
from keras_core.applications import EfficientNetV2B0 | ||
|
||
flags.DEFINE_integer("epochs", 1, "The number of epochs.") | ||
flags.DEFINE_integer("batch_size", 4, "Batch Size.") | ||
flags.DEFINE_string( | ||
"mixed_precision_policy", | ||
"mixed_float16", | ||
"The global precision policy to use, e.g., 'mixed_float16' or 'float32'.", | ||
) | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
BATCH_SIZE = 32 | ||
IMAGE_SIZE = (224, 224) | ||
CHANNELS = 3 | ||
|
||
|
||
def load_data(): | ||
# Load cats vs dogs dataset, and split into train and validation sets. | ||
train_dataset, val_dataset = tfds.load( | ||
"cats_vs_dogs", split=["train[:90%]", "train[90%:]"], as_supervised=True | ||
) | ||
|
||
resizing = keras.layers.Resizing( | ||
IMAGE_SIZE[0], IMAGE_SIZE[1], crop_to_aspect_ratio=True | ||
) | ||
|
||
def preprocess_inputs(image, label): | ||
image = tf.cast(image, "float32") | ||
return resizing(image), label | ||
|
||
train_dataset = ( | ||
train_dataset.map( | ||
preprocess_inputs, num_parallel_calls=tf.data.AUTOTUNE | ||
) | ||
.batch(FLAGS.batch_size) | ||
.cache() | ||
.prefetch(tf.data.AUTOTUNE) | ||
) | ||
val_dataset = ( | ||
val_dataset.map(preprocess_inputs, num_parallel_calls=tf.data.AUTOTUNE) | ||
.batch(FLAGS.batch_size) | ||
.cache() | ||
.prefetch(tf.data.AUTOTUNE) | ||
) | ||
return train_dataset, val_dataset | ||
|
||
|
||
def load_model(): | ||
# Load the EfficientNetV2B0 model and add a classification head. | ||
model = EfficientNetV2B0(include_top=False, weights="imagenet") | ||
classifier = keras.models.Sequential( | ||
[ | ||
keras.Input([IMAGE_SIZE[0], IMAGE_SIZE[1], CHANNELS]), | ||
model, | ||
keras.layers.GlobalAveragePooling2D(), | ||
keras.layers.Dense(2), | ||
] | ||
) | ||
return classifier | ||
|
||
|
||
def main(_): | ||
keras.mixed_precision.set_dtype_policy(FLAGS.mixed_precision_policy) | ||
|
||
logging.info( | ||
"Benchmarking configs...\n" | ||
"=========================\n" | ||
f"MODEL: EfficientNetV2B0\n" | ||
f"TASK: image classification/dogs-vs-cats \n" | ||
f"BATCH_SIZE: {FLAGS.batch_size}\n" | ||
f"EPOCHS: {FLAGS.epochs}\n" | ||
"=========================\n" | ||
) | ||
|
||
# Load datasets. | ||
train_ds, validation_ds = load_data() | ||
|
||
# Load the model. | ||
classifier = load_model() | ||
|
||
lr = keras.optimizers.schedules.PolynomialDecay( | ||
5e-4, | ||
decay_steps=train_ds.cardinality() * FLAGS.epochs, | ||
end_learning_rate=0.0, | ||
) | ||
optimizer = keras.optimizers.Adam(lr) | ||
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) | ||
|
||
benchmark_metrics_callback = BenchmarkMetricsCallback( | ||
start_batch=1, | ||
stop_batch=train_ds.cardinality().numpy()-1, | ||
) | ||
|
||
classifier.compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
metrics=["sparse_categorical_accuracy"], | ||
) | ||
# Start training. | ||
logging.info("Starting Training...") | ||
|
||
st = time.time() | ||
|
||
history = classifier.fit( | ||
train_ds, | ||
validation_data=validation_ds, | ||
epochs=FLAGS.epochs, | ||
callbacks=[benchmark_metrics_callback], | ||
) | ||
|
||
wall_time = time.time() - st | ||
validation_accuracy = history.history["val_sparse_categorical_accuracy"][-1] | ||
|
||
examples_per_second = np.mean( | ||
np.array(benchmark_metrics_callback.state["throughput"]) | ||
) * FLAGS.batch_size | ||
|
||
logging.info("Training Finished!") | ||
logging.info(f"Wall Time: {wall_time:.4f} seconds.") | ||
logging.info(f"Validation Accuracy: {validation_accuracy:.4f}") | ||
logging.info(f"examples_per_second: {examples_per_second:.4f}") | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters