From 68d34cb126d17ee60f45aea236bcc293c81b4418 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 29 Jun 2023 14:04:06 -0300 Subject: [PATCH] Fixes a few warnings in the example scripts (#505) * fixes Nx.random_uniform deprecation warning * fixes 'warning: Nx.to_batched_list/2 is deprecated. Use to_batched/3 instead' * Update examples/vision/cifar10.exs Co-authored-by: Sean Moriarity * Update examples/vision/cifar10.exs Co-authored-by: Sean Moriarity --------- Co-authored-by: Sean Moriarity --- examples/basics/multi_output_example.exs | 13 ++++++++----- examples/vision/cifar10.exs | 4 ++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/basics/multi_output_example.exs b/examples/basics/multi_output_example.exs index 74a13a2a..45ae9d8e 100644 --- a/examples/basics/multi_output_example.exs +++ b/examples/basics/multi_output_example.exs @@ -31,11 +31,14 @@ defmodule Power do # model input and y is the target. Because we have multiple targets, we represent # y as a tuple. In the future, Axon will support any Nx container as an output data = - Stream.repeatedly(fn -> - # Batch size of 32 - x = Nx.random_uniform({32, 1}, -10, 10, type: {:f, 32}) - {x, {Nx.pow(x, 2), Nx.pow(x, 3)}} - end) + Stream.unfold( + Nx.Random.key(:erlang.system_time()), + fn key -> + # Batch size of 32 + {x, next_key} = Nx.Random.uniform(key, -10, 10, shape: {32, 1}, type: {:f, 32}) + {{x, {Nx.pow(x, 2), Nx.pow(x, 3)}}, next_key} + end + ) # Create the training loop, notice we specify 2 MSE objectives, 1 for the first # output and 1 for the second output. This will create a loss function which is diff --git a/examples/vision/cifar10.exs b/examples/vision/cifar10.exs index edbdbd96..37a290b7 100644 --- a/examples/vision/cifar10.exs +++ b/examples/vision/cifar10.exs @@ -13,7 +13,7 @@ defmodule Cifar do |> Nx.from_binary(type) |> Nx.reshape({elem(shape, 0), 32, 32, 3}) |> Nx.divide(255.0) - |> Nx.to_batched_list(32) + |> Nx.to_batched(32) |> Enum.split(1500) end @@ -22,7 +22,7 @@ defmodule Cifar do |> Nx.from_binary(type) |> Nx.new_axis(-1) |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) - |> Nx.to_batched_list(32) + |> Nx.to_batched(32) |> Enum.split(1500) end