Skip to content

Commit

Permalink
Merge pull request #1363 from Kaggle/upgrade-keras-3
Browse files Browse the repository at this point in the history
Upgrade keras 3
  • Loading branch information
calderjo authored Feb 8, 2024
2 parents 71487b6 + 8bcef0f commit 9c637d1
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 24 deletions.
25 changes: 19 additions & 6 deletions Dockerfile.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ ENV KMP_SETTINGS=false
ENV PIP_ROOT_USER_ACTION=ignore

ADD clean-layer.sh /tmp/clean-layer.sh
ADD patches/keras_patch.sh /tmp/keras_patch.sh
ADD patches/nbconvert-extensions.tpl /opt/kaggle/nbconvert-extensions.tpl
ADD patches/template_conf.json /opt/kaggle/conf.json

Expand Down Expand Up @@ -202,18 +203,30 @@ RUN apt-get install -y default-jre && \

RUN pip install -f http://h2o-release.s3.amazonaws.com/h2o/latest_stable_Py.html h2o && /tmp/clean-layer.sh

# b/318672158 Use simply tensorflow-probability once > 0.23.0 is released.
RUN pip install \
"tensorflow==${TENSORFLOW_VERSION}" \
"tensorflow-io==${TENSORFLOW_IO_VERSION}" \
tensorflow_decision_forests \
git+https://github.com/tensorflow/probability.git@fbc5ebe9b1d343113fb917010096cfd88b32eecf \
tensorflow_text \
tensorflowjs \
tensorflow_hub && \
"tensorflow_hub>=0.16.0" \
tf-keras && \
/tmp/clean-layer.sh

# TODO(b/318672158): Upgrade to Keras 3 once compatible with other TF libries.
# See blockers here: https://b.corp.google.com/issues/319722433#comment8
RUN pip install keras keras-cv keras-nlp && \
# b/318672158 Use simply tensorflow_decision_forests on next release, expected with tf 2.16
RUN pip install tensorflow_decision_forests --no-deps && \
/tmp/clean-layer.sh

RUN chmod +x /tmp/keras_patch.sh && \
/tmp/keras_patch.sh

ADD patches/keras_internal.py /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_internal.py
ADD patches/keras_internal_test.py /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_internal_test.py

# Remove "--no-deps" flag and "namex" package once Keras 3.* is included in our base image.
# We ignore dependencies since tf2.15 and Keras 3.* should work despite pip saying it won't.
# Currently, keras tries to install a nightly version of tf 2.16: https://github.com/keras-team/keras/blob/fe2f54aa5bc42fb23a96449cf90434ab9bb6a2cd/requirements.txt#L2
RUN pip install --no-deps "keras>3" keras-cv keras-nlp namex && \
/tmp/clean-layer.sh

RUN pip install pysal
Expand Down
24 changes: 24 additions & 0 deletions patches/keras_internal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Access to Keras function with a different internal and external path."""

from tf_keras.src.engine import data_adapter as _data_adapter
from tf_keras.src.models import Functional
from tf_keras.layers import DenseFeatures
from tf_keras.src.utils.dataset_creator import DatasetCreator


unpack_x_y_sample_weight = _data_adapter.unpack_x_y_sample_weight
get_data_handler = _data_adapter.get_data_handler
23 changes: 23 additions & 0 deletions patches/keras_internal_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow_decision_forests.keras import keras_internal


# Does nothing. Ensures keras_internal can be loaded.

if __name__ == "__main__":
tf.test.main()

41 changes: 41 additions & 0 deletions patches/keras_patch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash

# The following "sed" are to patch the current version of tf-df with
# a fix for keras 3. In essence, replaces the use of package name "tf.keras" with
# "tf_keras"

sed -i "/import tensorflow_decision_forests as tfdf/a import tf_keras" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/__init__.py && \
sed -i -e "/import tensorflow as tf/a import tf_keras" \
-e "/from yggdrasil_decision_forests.utils.distribute.implementations.grpc/a from tensorflow_decision_forests.keras import keras_internal" \
-e '/try:/{:a;N;/backend = tf.keras.backend/!ba;d}'\
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/core.py && \
sed -i -e "s/from typing import Optional, List, Dict, Any, Union, NamedTuple/from typing import Any, Dict, List, NamedTuple, Optional, Union/g" \
-e "/import tensorflow as tf/a from tensorflow_decision_forests.keras import keras_internal" \
-e "/import tensorflow as tf/a import tf_keras" \
-e '/layers = tf.keras.layers/{:a;N;/backend = tf.keras.backend/!ba;d}' \
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/core_inference.py && \
find /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests -type f -exec sed -i \
-e "s/get_data_handler/keras_internal.get_data_handler/g" \
-e 's/"models.Functional"/keras_internal.Functional/g' \
-e "s/tf.keras.utils.unpack_x_y_sample_weight/keras_internal.unpack_x_y_sample_weight/g" \
-e "s/tf.keras.utils.experimental/keras_internal/g" \
{} \; && \
sed -i -e "/import tensorflow as tf/a import tf_keras" \
-e "/from tensorflow_decision_forests.keras import core/a from tensorflow_decision_forests.keras import keras_internal" \
-e '/layers = tf.keras.layers/{:a;N;/callbacks = tf.keras.callbacks/!ba;d}' \
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_test.py && \
find /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras -type f -exec sed -i \
-e "s/ layers.Input/ tf_keras.layers.Input/g" \
-e "s/layers.minimum/tf_keras.layers.minimum/g" \
-e "s/layers.Concatenate/tf_keras.layers.Concatenate/g" \
-e "s/layers.Dense/tf_keras.layers.Dense/g" \
-e "s/layers.experimental.preprocessing./tf_keras.layers./g" \
-e "s/layers.DenseFeatures/keras_internal.layers.DenseFeatures/g" \
-e "s/models.Model/tf_keras.models.Model/g" {} \; && \
sed -i "s/ models.load_model/ tf_keras.models.load_model/g" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_test.py && \
sed -i "/import tensorflow as tf/a import tf_keras" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/test_runner.py && \
sed -i "/import tensorflow as tf/a import tf_keras" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/wrappers.py && \
sed -i -e "/import tensorflow as tf/a import tf_keras" \
-e "s/optimizer=optimizers.Adam()/optimizer=tf_keras.optimizers.Adam()/g" \
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/wrappers_pre_generated.py && \
find /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests -type f -exec sed -i "s/tf.keras./tf_keras./g" {} \;
18 changes: 0 additions & 18 deletions tests/test_tensorflowjs.py

This file was deleted.

0 comments on commit 9c637d1

Please sign in to comment.