Skip to content

Commit

Permalink
Nits.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jun 20, 2023
1 parent 37dfbde commit 62b2ae6
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
13 changes: 8 additions & 5 deletions keras_core/backend/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Torch backend APIs.
Torch has a different logic of device management compared to TF and JAX. In
short variables/tensors are not by default created on GPU, and GPU cannot
directly communicate with CPU. Therefore, we are doing the following to automate
device management for Torch backend, if GPU is available:
# Note on device placement
Torch has a different device placement style compared to TF and JAX.
In short, variables/tensors are not created on GPU by default,
and the GPU cannot directly communicate with the CPU.
To bring Torch behavior in line with TF and JAX automated device placement,
we are doing the following to automate device placement if a GPU is available:
- Variables are created on GPU.
- Input data will be placed on GPU at the first `keras_core.layers.Layer` call.
- Tensor creation happens on GPU, e.g., `zeros()` will create a tensor on GPU.
- `convert_to_numpy` will bring the tensor to CPU and convert to numpy array.
- `convert_to_numpy` will bring the tensor to CPU before converting it to NumPy.
"""

from keras_core.backend.torch import core
Expand Down
5 changes: 5 additions & 0 deletions keras_core/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
slice
slice_update
while_loop
stop_gradient
shape
cast
convert_to_tensor
convert_to_numpy
"""

from keras_core import backend
Expand Down
2 changes: 1 addition & 1 deletion keras_core/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def compile(
else:
self._compile_metrics = None
if jit_compile == "auto":
if model_supports_jit(self):
if not run_eagerly and model_supports_jit(self):
jit_compile = True
else:
jit_compile = False
Expand Down

0 comments on commit 62b2ae6

Please sign in to comment.