Skip to content

Commit

Permalink
Fixes a few warnings in the example scripts (#505)
Browse files Browse the repository at this point in the history
* fixes Nx.random_uniform deprecation warning

* fixes 'warning: Nx.to_batched_list/2 is deprecated. Use to_batched/3 instead'

* Update examples/vision/cifar10.exs

Co-authored-by: Sean Moriarity <[email protected]>

* Update examples/vision/cifar10.exs

Co-authored-by: Sean Moriarity <[email protected]>

---------

Co-authored-by: Sean Moriarity <[email protected]>
  • Loading branch information
grzuy and seanmor5 committed Jun 29, 2023
1 parent 9a6efae commit 68d34cb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
13 changes: 8 additions & 5 deletions examples/basics/multi_output_example.exs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ defmodule Power do
# model input and y is the target. Because we have multiple targets, we represent
# y as a tuple. In the future, Axon will support any Nx container as an output
data =
Stream.repeatedly(fn ->
# Batch size of 32
x = Nx.random_uniform({32, 1}, -10, 10, type: {:f, 32})
{x, {Nx.pow(x, 2), Nx.pow(x, 3)}}
end)
Stream.unfold(
Nx.Random.key(:erlang.system_time()),
fn key ->
# Batch size of 32
{x, next_key} = Nx.Random.uniform(key, -10, 10, shape: {32, 1}, type: {:f, 32})
{{x, {Nx.pow(x, 2), Nx.pow(x, 3)}}, next_key}
end
)

# Create the training loop, notice we specify 2 MSE objectives, 1 for the first
# output and 1 for the second output. This will create a loss function which is
Expand Down
4 changes: 2 additions & 2 deletions examples/vision/cifar10.exs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ defmodule Cifar do
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 32, 32, 3})
|> Nx.divide(255.0)
|> Nx.to_batched_list(32)
|> Nx.to_batched(32)
|> Enum.split(1500)
end

Expand All @@ -22,7 +22,7 @@ defmodule Cifar do
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched_list(32)
|> Nx.to_batched(32)
|> Enum.split(1500)
end

Expand Down

0 comments on commit 68d34cb

Please sign in to comment.