From f14362eb0f0c0eeccaa51d7556ff9f9602bc43bd Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Fri, 25 Oct 2024 12:44:46 +0000 Subject: [PATCH 1/9] #0: Use small vector instead of std::vector for shapes --- cmake/dependencies.cmake | 1 + .../unit_tests/gtests/test_async_runtime.cpp | 2 +- tt_metal/CMakeLists.txt | 1 + ttnn/cpp/pybind11/pytensor.cpp | 12 +- .../moreh_clip_grad_norm_step1.cpp | 6 +- .../moreh_clip_grad_norm_step2.cpp | 6 +- .../moreh_clip_grad_norm_step3.cpp | 6 +- ttnn/cpp/ttnn/graph/graph_trace_utils.cpp | 2 +- .../conv2d_op_sharded_program_factory.cpp | 2 +- ...onv2d_op_width_sharded_program_factory.cpp | 2 +- .../core/to_layout/to_layout_op.cpp | 16 +- .../data_movement/concat/concat.cpp | 4 +- .../data_movement/fold/fold_pybind.cpp | 2 +- .../ttnn/operations/data_movement/pad/pad.cpp | 25 +- .../ttnn/operations/data_movement/pad/pad.hpp | 2 +- .../data_movement/pad/pad_pybind.hpp | 2 +- .../data_movement/permute/permute.cpp | 30 +-- .../data_movement/permute/permute.hpp | 6 +- .../data_movement/permute/permute_pybind.hpp | 2 +- .../data_movement/repeat/repeat.cpp | 12 +- .../repeat_interleave/repeat_interleave.cpp | 2 +- .../reshape_on_device/reshape.cpp | 8 +- .../reshape_on_device/reshape.hpp | 6 +- .../reshape_on_device/reshape_pybind.cpp | 2 +- .../data_movement/reshape_view/reshape.cpp | 12 +- .../data_movement/reshape_view/reshape.hpp | 2 +- .../reshape_view/reshape_pybind.cpp | 2 +- .../data_movement/slice/device/slice_op.cpp | 2 +- .../operations/data_movement/slice/slice.cpp | 246 +++--------------- .../operations/data_movement/slice/slice.hpp | 49 ++-- .../data_movement/slice/slice_pybind.hpp | 13 +- .../operations/data_movement/split/split.cpp | 10 +- .../data_movement/squeeze/squeeze.cpp | 4 +- .../data_movement/transpose/transpose.cpp | 10 +- .../data_movement/unsqueeze/unsqueeze.cpp | 2 +- .../device/untilize_with_unpadding_op.cpp | 2 +- .../binary/device/binary_device_operation.cpp | 2 +- .../unary/device/unary_composite_op.cpp | 16 +- .../eltwise/unary_backward/unary_backward.cpp | 36 +-- .../experimental/auto_format/auto_format.cpp | 4 +- .../ccl/all_reduce/device/all_reduce_op.cpp | 2 +- .../experimental/reduction/argmax/argmax.cpp | 4 +- .../fast_reduce_nc_device_operation.cpp | 6 +- .../fast_reduce_nc_device_operation.hpp | 2 +- .../fast_reduce_nc/fast_reduce_nc.cpp | 4 +- .../fast_reduce_nc/fast_reduce_nc.hpp | 4 +- .../fast_reduce_nc/fast_reduce_nc_pybind.cpp | 4 +- ...p_kv_cache_load_slice_device_operation.cpp | 2 +- .../full/device/full_device_operation.cpp | 8 +- .../full/device/full_device_operation.hpp | 6 +- ttnn/cpp/ttnn/operations/full/full.cpp | 2 +- ttnn/cpp/ttnn/operations/full/full.hpp | 2 +- .../device/moreh_arange_device_operation.cpp | 4 +- .../device/moreh_getitem_device_operation.cpp | 10 +- .../device/moreh_getitem_device_operation.hpp | 8 +- .../moreh/moreh_getitem/moreh_getitem.cpp | 2 +- .../moreh/moreh_getitem/moreh_getitem.hpp | 2 +- .../moreh_group_norm_device_operation.cpp | 2 +- .../moreh/moreh_helper_functions.cpp | 24 +- .../moreh/moreh_helper_functions.hpp | 8 +- .../moreh_layer_norm_device_operation.cpp | 4 +- .../moreh_linear_backward.cpp | 16 +- .../device/moreh_matmul_device_operation.cpp | 16 +- .../device/moreh_matmul_device_operation.hpp | 4 +- .../device/moreh_matmul_program_factory.cpp | 40 +-- .../device/moreh_mean_device_operation.cpp | 4 +- .../moreh/moreh_mean/moreh_mean.cpp | 4 +- .../moreh/moreh_mean/moreh_mean.hpp | 2 +- .../moreh_mean_backward_device_operation.cpp | 6 +- .../moreh_mean_backward_device_operation.hpp | 4 +- .../moreh_mean_backward_program_factory.cpp | 10 +- .../moreh_mean_backward.cpp | 6 +- .../moreh_mean_backward.hpp | 2 +- .../moreh_nll_loss_step2_device_operation.cpp | 4 +- .../device/moreh_norm_device_operation.cpp | 16 +- .../moreh/moreh_norm/moreh_norm.cpp | 8 +- .../moreh/moreh_norm/moreh_norm.hpp | 2 +- .../moreh_norm_backward_device_operation.cpp | 4 +- .../moreh_norm_backward_device_operation.hpp | 8 +- .../moreh_norm_backward_program_factory.cpp | 10 +- .../moreh_norm_backward.cpp | 2 +- .../moreh_norm_backward.hpp | 2 +- .../device/moreh_sum_device_operation.cpp | 4 +- .../operations/moreh/moreh_sum/moreh_sum.cpp | 4 +- .../operations/moreh/moreh_sum/moreh_sum.hpp | 2 +- .../moreh_sum_backward_device_operation.cpp | 4 +- .../moreh_sum_backward_device_operation.hpp | 4 +- .../moreh_sum_backward_program_factory.cpp | 10 +- .../moreh_sum_backward/moreh_sum_backward.cpp | 4 +- .../moreh_sum_backward/moreh_sum_backward.hpp | 2 +- .../maxpool/device/max_pool2d_device_op.cpp | 2 +- .../pool/upsample/device/upsample_op.cpp | 2 +- .../reduction/generic/generic_reductions.cpp | 18 +- .../reduction/generic/generic_reductions.hpp | 2 +- .../reduction/prod/device/prod_nc_op.cpp | 4 +- .../reduction/prod/device/prod_nc_op.hpp | 2 +- .../ttnn/operations/reduction/prod/prod.cpp | 22 +- .../ttnn/operations/reduction/prod/prod.hpp | 2 +- .../operations/reduction/prod/prod_pybind.hpp | 4 +- .../sliding_window/sliding_window.cpp | 10 +- ttnn/cpp/ttnn/run_operation.cpp | 2 +- ttnn/cpp/ttnn/tensor/CMakeLists.txt | 1 + ttnn/cpp/ttnn/tensor/tensor.cpp | 6 +- ttnn/cpp/ttnn/tensor/tensor_impl.cpp | 12 +- ttnn/cpp/ttnn/tensor/tensor_impl.hpp | 14 +- ttnn/cpp/ttnn/tensor/tensor_ops.cpp | 12 +- ttnn/cpp/ttnn/tensor/tensor_utils.cpp | 12 +- ttnn/cpp/ttnn/tensor/tensor_utils.hpp | 8 +- ttnn/cpp/ttnn/tensor/types.cpp | 30 +-- ttnn/cpp/ttnn/tensor/types.hpp | 61 +++-- ttnn/cpp/ttnn/tensor/vector_base.cpp | 83 ++++++ ttnn/cpp/ttnn/tensor/vector_base.hpp | 99 +++++++ 112 files changed, 669 insertions(+), 624 deletions(-) create mode 100644 ttnn/cpp/ttnn/tensor/vector_base.cpp create mode 100644 ttnn/cpp/ttnn/tensor/vector_base.hpp diff --git a/cmake/dependencies.cmake b/cmake/dependencies.cmake index 55b1db35a7c..33dfcbae00b 100644 --- a/cmake/dependencies.cmake +++ b/cmake/dependencies.cmake @@ -8,6 +8,7 @@ include(${PROJECT_SOURCE_DIR}/cmake/fetch_boost.cmake) fetch_boost_library(core) fetch_boost_library(smart_ptr) +fetch_boost_library(container) add_library(span INTERFACE) target_link_libraries(span INTERFACE Boost::core) diff --git a/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp b/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp index 4d50192f234..8537a6778b8 100644 --- a/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp +++ b/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp @@ -43,7 +43,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) { Tensor np_tensor = ttnn::numpy::full(input_shape.value, static_cast(1), DataType::BFLOAT16) .to(Layout::TILE) .to(device); - std::vector reduce_dims = {3}; + ttnn::SmallVector reduce_dims = {3}; Tensor np_out = ttnn::moreh_sum(np_tensor, reduce_dims, false, std::nullopt, std::nullopt, std::nullopt); Tensor np_out_host = np_out.cpu(); const bfloat16* golden_output = std::get>(std::get(np_out_host.get_storage()).buffer).begin(); diff --git a/tt_metal/CMakeLists.txt b/tt_metal/CMakeLists.txt index 005383e4a77..e5dfd69335d 100644 --- a/tt_metal/CMakeLists.txt +++ b/tt_metal/CMakeLists.txt @@ -31,6 +31,7 @@ target_link_libraries( magic_enum fmt span + Boost::container ) target_precompile_headers( diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 9a8f28d42bf..6fa13b17842 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -64,7 +64,7 @@ void log_external_operation( #endif template -Tensor create_owned_tensor(T* data_ptr, size_t num_elements, std::vector& shape, DataType data_type, Layout layout, const std::optional& optional_tile = std::nullopt) +Tensor create_owned_tensor(T* data_ptr, size_t num_elements, std::span shape, DataType data_type, Layout layout, const std::optional& optional_tile = std::nullopt) { auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); @@ -80,7 +80,7 @@ Tensor convert_torch_tensor_to_tt_tensor( } auto torch_dtype = torch_tensor.attr("dtype"); - auto shape = py::cast>(torch_tensor.attr("shape")); + auto shape = py::cast>(torch_tensor.attr("shape")); auto contiguous_torch_tensor = torch_tensor.attr("contiguous")(); @@ -251,7 +251,7 @@ Tensor convert_numpy_tensor_to_tt_tensor( } auto np_dtype = np_tensor.attr("dtype"); - auto shape = py::cast>(np_tensor.attr("shape")); + auto shape = py::cast>(np_tensor.attr("shape")); auto contiguous_np_tensor = np.attr("ascontiguousarray")(np_tensor); @@ -1325,7 +1325,7 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( "unpad_from_tile", - [](const Tensor &self, const std::vector &output_tensor_shape) { + [](const Tensor &self, const ttnn::SmallVector &output_tensor_shape) { return self.unpad_from_tile(ttnn::SimpleShape(output_tensor_shape)); }, R"doc( @@ -1593,7 +1593,7 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( "reshape", - [](Tensor &self, int N, int C, int H, int W) { return self.reshape(infer_dims_for_reshape(self, {N, C, H, W})); }, + [](Tensor &self, int N, int C, int H, int W) { return self.reshape(infer_dims_for_reshape(self, ttnn::SmallVector{N, C, H, W})); }, R"doc( Reshapes TT tensor @@ -1613,7 +1613,7 @@ void pytensor_module(py::module &m_tensor) { )doc") .def( "reshape", - [](Tensor &self, const std::vector &shape) -> Tensor { return self.reshape(infer_dims_for_reshape(self, shape)); }, + [](Tensor &self, const ttnn::SmallVector &shape) -> Tensor { return self.reshape(infer_dims_for_reshape(self, shape)); }, R"doc( Reshapes TT tensor diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp index bee109f031b..29b9f51b546 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp @@ -144,7 +144,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl( const auto [origin_h, origin_w] = origin_hw_vec.at(i); // reader - const std::vector reader_runtime_args{ + const std::array reader_runtime_args{ input_addr, static_cast(is_dram(input)), num_tiles, @@ -154,12 +154,12 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step1_impl( SetRuntimeArgs(program, reader_kernels_id, core, reader_runtime_args); // writer - const std::vector writer_runtime_args{ + const std::array writer_runtime_args{ output_addr, static_cast(is_dram(tmp_pow_sum)), tile_offset}; SetRuntimeArgs(program, writer_kernels_id, core, writer_runtime_args); // compute - const std::vector compute_runtime_args{ + const std::array compute_runtime_args{ num_tiles, p, static_cast(p_is_negative), diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp index 221ed7c15c7..b65a1a7ea58 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp @@ -101,16 +101,16 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step2_impl( const auto output_addr = total_norm.buffer()->address(); // reader - const std::vector reader_runtime_args{ + const std::array reader_runtime_args{ input_addr, static_cast(is_dram(tmp_pow_sum)), num_tiles, *reinterpret_cast(&decimal)}; SetRuntimeArgs(program, reader_kernels_id, single_core, reader_runtime_args); // writer - const std::vector writer_runtime_args{output_addr, static_cast(is_dram(total_norm))}; + const std::array writer_runtime_args{output_addr, static_cast(is_dram(total_norm))}; SetRuntimeArgs(program, writer_kernels_id, single_core, writer_runtime_args); // compute - const std::vector compute_runtime_args{num_tiles, p, static_cast(p_is_negative)}; + const std::array compute_runtime_args{num_tiles, p, static_cast(p_is_negative)}; SetRuntimeArgs(program, compute_kernels_id, single_core, compute_runtime_args); //////////////////////////////////////////////////////////////////////////// diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp index a8016a54f5d..f62b9eace7a 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp @@ -107,7 +107,7 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl( const auto num_tiles = input.volume() / tt::constants::TILE_HW; // reader - const std::vector reader_runtime_args{ + const std::array reader_runtime_args{ input_addr, static_cast(is_dram(input)), clip_coef_clamped_addr, @@ -116,11 +116,11 @@ operation::ProgramWithCallbacks moreh_clip_grad_norm_step3_impl( SetRuntimeArgs(program, reader_kernels_id, core, reader_runtime_args); // writer - const std::vector writer_runtime_args{input_addr, static_cast(is_dram(input)), num_tiles}; + const std::array writer_runtime_args{input_addr, static_cast(is_dram(input)), num_tiles}; SetRuntimeArgs(program, writer_kernels_id, core, writer_runtime_args); // compute - const std::vector compute_runtime_args{num_tiles}; + const std::array compute_runtime_args{num_tiles}; SetRuntimeArgs(program, compute_kernels_id, core, compute_runtime_args); } diff --git a/ttnn/cpp/ttnn/graph/graph_trace_utils.cpp b/ttnn/cpp/ttnn/graph/graph_trace_utils.cpp index cb93e40e177..8615732fdd9 100644 --- a/ttnn/cpp/ttnn/graph/graph_trace_utils.cpp +++ b/ttnn/cpp/ttnn/graph/graph_trace_utils.cpp @@ -25,7 +25,7 @@ ttnn::Shape parse_shape(std::string_view shape_string) { std::string_view shape_values = shape_string.substr(start, end - start); // Vector to hold the parsed shape values - std::vector shape; + SmallVector shape; const char* str = shape_values.data(); const char* end_str = str + shape_values.size(); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp index 8bd6bd51a0d..49bc3b5052b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp @@ -491,7 +491,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( block_config.act_block_h_ntiles % block_config.out_subblock_h_ntiles == 0, "Out_block_h must be divisible by out_subblock_h!"); } - ttnn::Shape ashape_with_channels_padded(std::vector({ashape[0], ashape[1], ashape[2], input_channels_padded})); + ttnn::Shape ashape_with_channels_padded(ttnn::SmallVector({ashape[0], ashape[1], ashape[2], input_channels_padded})); uint32_t conv_act_size_h = ashape_with_channels_padded[1]; uint32_t conv_act_size_w = ashape_with_channels_padded[2]; uint32_t conv_act_size_c = ashape_with_channels_padded[3]; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp index f105f06f256..9cc49a47f22 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp @@ -183,7 +183,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_width_sharded_v2_impl( "Out_block_h must be divisible by out_subblock_h!"); } - ttnn::Shape ashape_with_channels_padded(std::vector{ashape[0], ashape[1], ashape[2], input_channels_padded}); + ttnn::Shape ashape_with_channels_padded({ashape[0], ashape[1], ashape[2], input_channels_padded}); uint32_t conv_act_size_h = ashape_with_channels_padded[1]; uint32_t conv_act_size_w = ashape_with_channels_padded[2]; diff --git a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp index 6b23191b0c3..23e44509c06 100644 --- a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp @@ -94,14 +94,14 @@ Tensor to_layout_impl( auto tensor = tensor_arg; - std::vector output_shape; + SmallVector output_shape; if (layout == ttnn::TILE_LAYOUT and intended_shape.rank() < 2) { output_shape.push_back(1); tensor = ttnn::reshape( tensor, ttnn::Shape( - std::vector{1, intended_shape[0]}, - std::vector{1, tensor_arg.get_shape().with_tile_padding()[0]})); + SmallVector{1, intended_shape[0]}, + SmallVector{1, tensor_arg.get_shape().with_tile_padding()[0]})); } for (auto index = 0; index < intended_shape.rank(); ++index) { output_shape.push_back(intended_shape[index]); @@ -144,7 +144,7 @@ Tensor to_layout_impl( output_memory_config = tt::tt_metal::MemoryConfig{memory_config.memory_layout, memory_config.buffer_type}; } - std::vector output_tensor_end; + SmallVector output_tensor_end; for (auto index = 0; index < tensor.get_shape().rank(); ++index) { output_tensor_end.push_back(tensor.get_shape()[index] - 1); } @@ -154,7 +154,7 @@ Tensor to_layout_impl( return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape}); } else if (layout == ttnn::TILE_LAYOUT) { - std::vector padded_output_shape; + SmallVector padded_output_shape; for (int index = 0; index < tensor.get_shape().rank(); ++index) { if (index >= tensor.get_shape().rank() - 2) { @@ -166,7 +166,7 @@ Tensor to_layout_impl( if (tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { // ttnn::tilize_with_val_padding doesn't support height sharded tensors // workaround by applying padding and then tilizing - std::vector> padding = { + SmallVector> padding = { {0, 0}, {0, 0}, {0, padded_output_shape[2] - output_shape[2]}, @@ -192,8 +192,8 @@ Tensor to_layout_impl( tensor = tensor.unpad_from_tile(tensor.get_logical_shape()); return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape}); } else if (layout == ttnn::TILE_LAYOUT) { - std::vector padded_output_shape; - std::vector padded_input_start; + SmallVector padded_output_shape; + SmallVector padded_input_start; for (int index = 0; index < tensor.get_shape().rank(); ++index) { if (index >= tensor.get_shape().rank() - 2) { padded_output_shape.push_back(ttnn::pad_to_multiple_of_tile_size(tensor.get_shape()[index])); diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp index 322e467360c..3ffc8a9f0df 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp @@ -86,8 +86,8 @@ namespace data_movement { while (output_tensor.get_shape().rank() > rank) { const auto shape = output_tensor.get_shape(); const auto full_shape = output_tensor.get_shape().with_tile_padding(); - std::vector shape_vec{}; - std::vector full_shape_vec{}; + SmallVector shape_vec{}; + SmallVector full_shape_vec{}; // int i = 0; // while(i < 3 and shape[i] == 1) i++; for (int i = 1; i < shape.rank(); i++) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/fold/fold_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/fold/fold_pybind.cpp index 012d0531e43..806c1ec28c8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fold/fold_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fold/fold_pybind.cpp @@ -27,7 +27,7 @@ void bind_fold_operation(py::module& module) { )doc", ttnn::pybind_overload_t{ [](const decltype(ttnn::fold)& op, const ttnn::Tensor& input, uint32_t stride_h, uint32_t stride_w, - bool use_transpose_as_fold, std::optional> output_shape, uint32_t pad_c, uint32_t pad_h, uint32_t pad_w, std::optional grid_size, std::optional override_memory_config, + bool use_transpose_as_fold, std::optional> output_shape, uint32_t pad_c, uint32_t pad_h, uint32_t pad_w, std::optional grid_size, std::optional override_memory_config, const uint8_t& queue_id) -> ttnn::Tensor { return op(queue_id, input, stride_h, stride_w, use_transpose_as_fold, output_shape, pad_c, pad_h, pad_w, grid_size, override_memory_config); diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp index 72cce5c4f32..a1bae5e6f95 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp @@ -53,7 +53,7 @@ template static ttnn::Tensor pad_impl( uint8_t queue_id, const ttnn::Tensor& input_tensor, - std::vector> padding, + ttnn::SmallVector> padding, const float value, const bool use_multicore, const std::optional& memory_config_arg) { @@ -116,31 +116,32 @@ static ttnn::Tensor pad_impl( ttnn::Tensor ExecutePad::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - const std::vector>& padding, + std::span> padding, const float value, const bool use_multicore, const std::optional& memory_config_arg) { const int original_rank = input_tensor.get_shape().rank(); + ttnn::SmallVector> padding_vec(padding.begin(), padding.end()); ttnn::Tensor output_tensor; if (input_tensor.storage_type() != StorageType::DEVICE) { switch (original_rank) { - case 1: output_tensor = pad_impl(queue_id, input_tensor, padding, value, use_multicore, memory_config_arg); break; - case 2: output_tensor = pad_impl(queue_id, input_tensor, padding, value, use_multicore, memory_config_arg); break; - case 3: output_tensor = pad_impl(queue_id, input_tensor, padding, value, use_multicore, memory_config_arg); break; - case 4: output_tensor = pad_impl(queue_id, input_tensor, padding, value, use_multicore, memory_config_arg); break; - case 5: output_tensor = pad_impl(queue_id, input_tensor, padding, value, use_multicore, memory_config_arg); break; - case 6: output_tensor = pad_impl(queue_id, input_tensor, padding, value, use_multicore, memory_config_arg); break; - case 7: output_tensor = pad_impl(queue_id, input_tensor, padding, value, use_multicore, memory_config_arg); break; - case 8: output_tensor = pad_impl(queue_id, input_tensor, padding, value, use_multicore, memory_config_arg); break; + case 1: output_tensor = pad_impl(queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); break; + case 2: output_tensor = pad_impl(queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); break; + case 3: output_tensor = pad_impl(queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); break; + case 4: output_tensor = pad_impl(queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); break; + case 5: output_tensor = pad_impl(queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); break; + case 6: output_tensor = pad_impl(queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); break; + case 7: output_tensor = pad_impl(queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); break; + case 8: output_tensor = pad_impl(queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); break; default: TT_THROW("Unsupported tensor rank of {}. Needs to be between 1 and 8 inclusively.", original_rank); } } else { - output_tensor = pad_impl(queue_id, input_tensor, padding, value, use_multicore, memory_config_arg); + output_tensor = pad_impl(queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); } // output_tensor is currently 4D. We have to squeeze back to the original rank - auto to_vec = [](const auto& arr) {return std::vector(arr.begin(), arr.end());}; + auto to_vec = [](const auto& arr) {return ttnn::SmallVector(arr.begin(), arr.end());}; auto shape = to_vec(output_tensor.get_shape().value); auto padded_shape = to_vec(output_tensor.get_shape().with_tile_padding().value); if (auto rank_diff = shape.size() - original_rank; rank_diff) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp index 2117be64f60..5c006212cbf 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp @@ -32,7 +32,7 @@ struct ExecutePad { // Any rank tensor supported static ttnn::Tensor invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, - const std::vector>& padding, + std::span> padding, const float value, const bool use_multicore, const std::optional& memory_config_arg); diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad_pybind.hpp index 3f2799cd908..3effaaeac67 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad_pybind.hpp @@ -51,7 +51,7 @@ void bind_pad(py::module& module) { ttnn::pybind_overload_t{ [] (const OperationType& self, const ttnn::Tensor& input_tensor, - std::vector> padding, + ttnn::SmallVector> padding, const float value, const bool use_multicore, const std::optional& memory_config, diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp index 29a07c7e71d..9efdcf72c38 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp @@ -33,7 +33,7 @@ inline bool has_tile_padding(const Tensor& t) { return false; } -ttnn::Tensor permute_impl(const ttnn::Tensor &a, const std::vector& dims, const MemoryConfig& output_mem_config) { +ttnn::Tensor permute_impl(const ttnn::Tensor &a, const SmallVector& dims, const MemoryConfig& output_mem_config) { using ttnn::operations::experimental::auto_format::AutoFormat; Device * device; @@ -54,8 +54,8 @@ ttnn::Tensor permute_impl(const ttnn::Tensor &a, const std::vector& di auto input_shape = a.get_logical_shape(); // create_output_tensor shape is useless when we potentially have new padding to deal with - std::vector output_shape = {input_shape[N], input_shape[C], input_shape[H], input_shape[W]}; - std::vector padded_output_shape = output_shape; + SmallVector output_shape = {input_shape[N], input_shape[C], input_shape[H], input_shape[W]}; + SmallVector padded_output_shape = output_shape; uint32_t input_rank = a.get_logical_shape().rank(); if (a.layout() == Layout::TILE) { @@ -128,14 +128,14 @@ ttnn::Tensor permute_impl(const ttnn::Tensor &a, const std::vector& di return output; } -ttnn::Tensor permute_launch(const ttnn::Tensor &a, const std::vector& dims, const MemoryConfig& output_mem_config) { +ttnn::Tensor permute_launch(const ttnn::Tensor &a, std::span dims, const MemoryConfig& output_mem_config) { std::vector output_tensors = {ttnn::Tensor(operation::get_workers_for_op_output({a}))}; operation::launch_with_autoformat( [dims, output_mem_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { auto& a = input_tensors.at(0); - std::vector normalized_dims(dims.size()); + SmallVector normalized_dims(dims.size()); std::transform(dims.begin(), dims.end(), normalized_dims.begin(), [a](std::int64_t idx) {return a.get_legacy_shape().get_normalized_index(idx);}); - std::vector seq_dims(dims.size()); + SmallVector seq_dims(dims.size()); std::iota(seq_dims.begin(), seq_dims.end(), 0); if (normalized_dims == seq_dims) { return {ttnn::operations::experimental::auto_format::AutoFormat::move_tensor_to_mem_config(a, output_mem_config)}; @@ -147,7 +147,7 @@ ttnn::Tensor permute_launch(const ttnn::Tensor &a, const std::vector& dims, + std::span dims, const std::optional& memory_config) { auto output_tensor = permute_launch(input_tensor, dims, memory_config.value_or(input_tensor.memory_config())); @@ -159,7 +159,7 @@ Tensor composite_invoke( ttnn::Tensor ExecutePermute::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - const std::vector& dims, + std::span dims, const std::optional& memory_config, bool composite) { @@ -175,15 +175,15 @@ ttnn::Tensor ExecutePermute::invoke( input_rank == dims.size(), "The number of dimensions in the tensor input does not match the length of the desired ordering"); - auto adjust_order = [](const std::vector& dims) { - std::vector new_order; + auto adjust_order = [](std::span dims) { + ttnn::SmallVector new_order; TT_FATAL(dims.size() <= 4, "Error"); int additional_ranks = 4 - dims.size(); for (int i = 0; i < additional_ranks; i++) { new_order.push_back(i); } for (int i = 0; i < dims.size(); i++) { - new_order.push_back(dims.at(i) + additional_ranks); + new_order.push_back(dims[i] + additional_ranks); } return new_order; }; @@ -197,8 +197,8 @@ ttnn::Tensor ExecutePermute::invoke( if (input_rank < 4) { const auto shape = output_tensor.get_shape(); const auto full_shape = output_tensor.get_shape().with_tile_padding(); - std::vector shape_vec{}; - std::vector full_shape_vec{}; + SmallVector shape_vec{}; + SmallVector full_shape_vec{}; int i = 0; while (i < 3 and shape[i] == 1) i++; for (; i < shape.rank(); i++) { @@ -218,12 +218,12 @@ ttnn::Tensor ExecutePermute::invoke( ttnn::Tensor ExecutePermute::invoke( const ttnn::Tensor& input_tensor, - const std::vector& dims, + std::span dims, const std::optional& memory_config) { return invoke(DefaultQueueId, input_tensor, dims, memory_config); } -ttnn::Tensor ExecutePermute::invoke(const ttnn::Tensor& input_tensor, const std::vector& dims) { +ttnn::Tensor ExecutePermute::invoke(const ttnn::Tensor& input_tensor, std::span dims) { return invoke(input_tensor, dims, std::nullopt); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp index 6d7c8190c85..a1b37e4994c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp @@ -13,16 +13,16 @@ struct ExecutePermute { static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - const std::vector& dims, + std::span dims, const std::optional& memory_config, bool composite = true); static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const std::vector& dims, + std::span dims, const std::optional& memory_config); - static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const std::vector& dims); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, std::span dims); }; } // namespace operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.hpp index 4ef8c3ae7e6..4f2b82a14ea 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.hpp @@ -46,7 +46,7 @@ void bind_permute(py::module& module) { ttnn::pybind_overload_t{ [] (const OperationType& self, const ttnn::Tensor& input_tensor, - const std::vector &dims, + const ttnn::SmallVector &dims, const std::optional& memory_config, uint8_t queue_id) { return self(queue_id, input_tensor, dims, memory_config, false); diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp index 8c4863e50c0..436f2fb3482 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp @@ -61,9 +61,9 @@ ttnn::Tensor RepeatOperation::invoke( TT_FATAL(output_tensors.size() == 1, "ttnn.repeat: expected 1 output tensor, but got {}", output_tensors.size()); if (input_tensor.get_layout() != Layout::ROW_MAJOR && logical_input_shape != padded_input_shape) { - auto zero_indices = std::vector(input_rank, 0); - auto end_indices = repeated_logical_shape.as_vector(); - auto step = std::vector(input_rank, 1); + auto zero_indices = ttnn::SmallVector(input_rank, 0); + auto end_indices = repeated_logical_shape.view(); + auto step = ttnn::SmallVector(input_rank, 1); if (repeated_logical_shape.volume() % tt::constants::TILE_HW != 0) { // volume of the repeated tensor doesn't fit neatly into tiles. @@ -86,7 +86,7 @@ ttnn::Tensor RepeatOperation::invoke( auto padded_width = tt::round_up(sliced_padded_shape[-1], tt::constants::TILE_WIDTH); TT_ASSERT(input_rank >= 2, "ttnn.repeat: rank of tiled input tensor must be >= 2"); uint32_t num_non_hw_dims = input_rank - 2u; - auto padding_vec = std::vector>(num_non_hw_dims, {0, 0}); + auto padding_vec = ttnn::SmallVector>(num_non_hw_dims, {0, 0}); padding_vec.reserve(input_rank); padding_vec.emplace_back(0, padded_height - sliced_padded_shape[-2]); padding_vec.emplace_back(0, padded_width - sliced_padded_shape[-1]); @@ -95,8 +95,8 @@ ttnn::Tensor RepeatOperation::invoke( auto padded_output = ttnn::pad(queue_id, sliced_output, padding_vec, 0.0f, pad_use_multicore, std::nullopt); auto tiled_output = ttnn::tilize(padded_output, input_tensor.memory_config()); - auto padded_to_tiled_shape = ttnn::Shape(sliced_logical_shape.as_vector(), - tiled_output.get_padded_shape().as_vector()); + auto padded_to_tiled_shape = ttnn::Shape(sliced_logical_shape.view(), + tiled_output.get_padded_shape().view()); tiled_output.set_shape(padded_to_tiled_shape); return tiled_output; } else { diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.cpp index 515e8d64df5..332879c9358 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.cpp @@ -37,7 +37,7 @@ ttnn::Tensor ExecuteRepeatInterleave::invoke(const ttnn::Tensor& input_a, uint32 } rm_input = ttnn::to_layout(rm_input, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device*)nullptr); - std::vector final_shape; + SmallVector final_shape; final_shape.reserve(input_rank); for (uint32_t i = 0; i < rm_input.get_shape().rank(); i++) { final_shape.push_back(rm_input.get_shape()[i]); diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp index 93532abcfb9..6b479264c73 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp @@ -94,15 +94,15 @@ ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, const tt return invoke(DefaultQueueId, input_tensor, shape, std::nullopt); } -ttnn::Tensor ReshapeOperation::invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, const std::vector & shape_vector, const std::optional& memory_config_arg) { - return invoke(queue_id, input_tensor, ttnn::Shape(infer_dims_for_reshape(input_tensor, shape_vector).as_vector()), memory_config_arg); +ttnn::Tensor ReshapeOperation::invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, std::span shape_vector, const std::optional& memory_config_arg) { + return invoke(queue_id, input_tensor, ttnn::Shape(infer_dims_for_reshape(input_tensor, shape_vector).view()), memory_config_arg); } -ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, const std::vector& shape_vector, const std::optional& memory_config_arg) { +ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, std::span shape_vector, const std::optional& memory_config_arg) { return invoke(DefaultQueueId, input_tensor, shape_vector, memory_config_arg); } -ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, const std::vector& shape_vector) { +ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, std::span shape_vector) { return invoke(input_tensor, shape_vector, std::nullopt); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp index e6044087170..cdd34ac34df 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp @@ -24,9 +24,9 @@ struct ReshapeOperation { static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape); - static ttnn::Tensor invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, const std::vector& shape_vector, const std::optional& memory_config_arg); - static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const std::vector& shape_vector, const std::optional& memory_config_arg); - static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const std::vector& shape_vector); + static ttnn::Tensor invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, std::span shape_vector, const std::optional& memory_config_arg); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, std::span shape_vector, const std::optional& memory_config_arg); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, std::span shape_vector); }; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape_pybind.cpp index b3e9f335567..9eaecc3c166 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape_pybind.cpp @@ -30,7 +30,7 @@ void bind_reshape(pybind11::module& module, const data_movement_operation_t& ope int X, const std::optional& memory_config, uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor, std::vector{W, Z, Y, X}, memory_config); + return self(queue_id, input_tensor, ttnn::SmallVector{W, Z, Y, X}, memory_config); }, py::arg("input_tensor"), py::arg("W"), diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index 7f459d8046e..ebbe216841d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -33,14 +33,12 @@ ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) auto rm_tensor = ttnn::to_layout(host_tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device *)nullptr); if(tensor_shape.has_tile_padding()) { ttnn::Tensor slice_input; - std::vector begins; - std::vector ends; TT_FATAL(tensor_shape.rank() <= 4, "Only up to 4D tensors"); auto host_tensor_4d = unsqueeze_to_4D(rm_tensor); auto tensor_shape_4d = host_tensor_4d.shape(); - begins = std::vector({0, 0, 0, 0}); - ends = std::vector({tensor_shape_4d[0], tensor_shape_4d[1], tensor_shape_4d[2], tensor_shape_4d[3]}); - auto step = std::vector({1, 1, 1, 1}); + ttnn::SmallVector begins({0, 0, 0, 0}); + ttnn::SmallVector ends({tensor_shape_4d[0], tensor_shape_4d[1], tensor_shape_4d[2], tensor_shape_4d[3]}); + ttnn::SmallVector step({1, 1, 1, 1}); host_tensor_4d = ttnn::slice(host_tensor_4d, begins, ends, step, std::nullopt); host_tensor = squeeze_from_4D(host_tensor_4d, tensor_shape.rank()); } @@ -128,12 +126,12 @@ ttnn::Tensor ReshapeViewOperation::invoke(const ttnn::Tensor& tensor, const ttnn } ttnn::Tensor ReshapeViewOperation::invoke(const ttnn::Tensor& tensor, const ttnn::SimpleShape& shape) { - return invoke(tensor, ttnn::Shape(shape.as_vector())); + return invoke(tensor, ttnn::Shape(shape.view())); } ttnn::Tensor ReshapeViewOperation::invoke( const ttnn::Tensor& tensor, - const std::vector & shape_vector + std::span shape_vector ) { return invoke(tensor, tt::tt_metal::infer_dims_for_reshape(tensor, shape_vector)); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp index c6c1941d2a2..38c49b5ff32 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp @@ -13,7 +13,7 @@ namespace operations::data_movement { struct ReshapeViewOperation { static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape); static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& logical_shape); - static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const std::vector & shape_vector); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, std::span shape_vector); }; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp index 851c452c3fc..b57eddcc3f0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp @@ -34,7 +34,7 @@ void bind_reshape_view(pybind11::module& module, const data_movement_operation_t ttnn::pybind_overload_t{ [](const data_movement_operation_t& self, const ttnn::Tensor& input_tensor, - const std::vector& shape + std::span shape ) -> ttnn::Tensor { return self(input_tensor, shape); }, diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp index b5f6a2db555..386a76fef93 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.cpp @@ -118,7 +118,7 @@ void SliceDeviceOperation::validate_with_output_tensors( } std::vector SliceDeviceOperation::compute_output_shapes(const std::vector &input_tensors) const { - std::vector out_shape; + SmallVector out_shape; auto rank = input_tensors[0].get_legacy_shape().rank(); out_shape.reserve(rank); diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp index e65f1bba9ce..9a77bba6f2d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp @@ -15,13 +15,14 @@ namespace ttnn::operations::data_movement { +namespace { template -ttnn::Tensor SliceOperation::invoke( +ttnn::Tensor slice_operation_invoke_impl( uint8_t queue_id, const ttnn::Tensor& input_tensor, - const std::vector &begins, - const std::vector &ends, - const std::vector &step, + std::span begins, + std::span ends, + std::span step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor) { @@ -98,8 +99,7 @@ ttnn::Tensor SliceOperation::invoke( padded_ends[adjusted_rank - 2] = std::max(tt::round_up(padded_ends[adjusted_rank - 2], tt::constants::TILE_HEIGHT), tt::constants::TILE_HEIGHT); padded_ends[adjusted_rank - 1] = std::max(tt::round_up(padded_ends[adjusted_rank - 1], tt::constants::TILE_WIDTH), tt::constants::TILE_WIDTH); } - - std::vector actual_shape, final_padded_shape; + SmallVector actual_shape, padded_shape; actual_shape.reserve(input_rank); final_padded_shape.reserve(input_rank); bool empty = false; @@ -177,210 +177,48 @@ ttnn::Tensor SliceOperation::invoke( return rm_only ? ttnn::to_layout(res, input_tensor.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr) : res; } } +} -template ttnn::Tensor SliceOperation::invoke( - const ttnn::Tensor& input_tensor, - const std::vector &begins, - const std::vector &ends, - const std::vector &step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor) { - return SliceOperation::invoke(ttnn::DefaultQueueId, input_tensor, begins, ends, step, memory_config_arg); - } - -// Specialization for uint32_t and N=4 -template<> -ttnn::Tensor SliceOperation::invoke( - uint8_t queue_id, - const ttnn::Tensor& input_tensor, - const std::array &begins, - const std::array &ends, - const std::array &step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor) { - - const auto& padded_input_shape = input_tensor.get_padded_shape(); - TT_FATAL(padded_input_shape.rank() == 4, "Input tensor must have rank 4"); - - bool no_step = step[0] == 1 && step[1] == 1 && step[2] == 1 && step[3] == 1; - bool starts_zero = begins[0]==0 && begins[1]==0 && begins[2]==0 && begins[3]==0; - bool ends_max = ends[0]==padded_input_shape[0] && ends[1]==padded_input_shape[1] && ends[2]==padded_input_shape[2] && ends[3]==padded_input_shape[3]; - - if (no_step && starts_zero && ends_max) { - if (input_tensor.storage_type() == StorageType::DEVICE) { - auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); - return ttnn::to_memory_config(input_tensor, memory_config, std::nullopt); - } - return input_tensor; - } - bool rm_only = !no_step && input_tensor.get_layout() == Layout::TILE; - ttnn::Tensor input = input_tensor; - if (rm_only) { - input = ttnn::to_layout(input_tensor, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); - } - - const bool tiled = input.get_layout() == Layout::TILE; - bool on_device = input.storage_type() == StorageType::DEVICE; - - std::array actual_shape; - std::array padded_shape; - const std::array padded_ends = tiled ? std::array({ends[0], ends[1], std::max(tt::round_up(ends[2], tt::constants::TILE_HEIGHT), tt::constants::TILE_HEIGHT), std::max(tt::round_up(ends[3], tt::constants::TILE_WIDTH), tt::constants::TILE_WIDTH)}) : ends; - bool empty = false; - for (int i = 0; i < 4; ++i) { - TT_FATAL(ends[i] >= begins[i], "End {} must be greater than or equal to start {}", ends[i], begins[i]); - uint32_t offset = step[i] - begins[i] - 1; - uint32_t dim_size = (ends[i] + offset) / step[i]; - empty |= dim_size == 0; - actual_shape[i] = dim_size; - padded_shape[i]= std::max((padded_ends[i] + offset) / step[i], 1u); - } - - ttnn::Shape output_shape(actual_shape, padded_shape); - - if (empty) { - TT_FATAL(on_device, "Host tensor slice cannot return a scalar or empty tensor"); - auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); - return ttnn::empty(output_shape, input.dtype(), input_tensor.layout(), - input.device(), memory_config); - } - - // Early exit if slice is a no-op - if (ttnn::Shape(padded_shape) == padded_input_shape && no_step) { - if (input.storage_type() == StorageType::DEVICE) { - auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); - auto res = ttnn::to_memory_config(input, memory_config, std::nullopt); - return ttnn::reshape(res, output_shape); - } - return ttnn::reshape(input, output_shape); // change to view - } - - if (on_device) { - auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); - - // Check for in-place unpad optimization - if (input.is_sharded() && input.memory_config() == memory_config && padded_input_shape.rank() > 1) { - TT_FATAL(no_step, "Sharded tensor slice implementation does not support striding"); - bool in_place_unpad = true; - for (int i = 0; i < 2; ++i) { - in_place_unpad &= begins[i] == 0 && ends[i] == 1 && padded_input_shape[i] == 1; - } - in_place_unpad &= begins[2] == 0 && - tt::div_up(ends[2], input.shard_spec().value().shape[0]) == - tt::div_up(padded_input_shape[2], input.shard_spec().value().shape[0]); - in_place_unpad &= begins[3] == 0 && ends[3] == padded_input_shape[3]; - if (in_place_unpad) { - return ttnn::reshape(input, output_shape); - } - } - - input = operation::run( - SliceDeviceOperation{ - begins, - padded_ends, - step, - memory_config}, - {input}, {}, {optional_output_tensor}, queue_id)[0]; - input = ttnn::reshape(input, output_shape); - return rm_only ? ttnn::to_layout(input, input.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr) : input; - } - - TT_FATAL(no_step, "Host tensor slice does not support strides"); - - if (input.get_padded_shape() == actual_shape) { - return input; - } else { - auto input_4d_rm = ttnn::to_layout(input, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); - auto output_4d = input_4d_rm.unpad(ttnn::SimpleShape(begins), ttnn::SimpleShape(ends)); - auto output_4d_rm = ttnn::to_layout(output_4d, input.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr); - return ttnn::reshape(output_4d_rm, output_shape); - } + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + std::span begins, + std::span ends, + std::span step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor) { + return slice_operation_invoke_impl(queue_id, input_tensor, begins, ends, step, memory_config_arg, optional_output_tensor); } -template ttnn::Tensor SliceOperation::invoke( - uint8_t queue_id, - const ttnn::Tensor& input_tensor, - const std::array &output_tensor_start, - const std::array &output_tensor_end, - const std::array &step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor) { - std::vector start(output_tensor_start.begin(), output_tensor_start.end()); - std::vector end(output_tensor_end.begin(), output_tensor_end.end()); - std::vector step_vec(step.begin(), step.end()); - return SliceOperation::invoke(queue_id, input_tensor, start, end, step_vec, memory_config_arg); - } + const ttnn::Tensor& input_tensor, + std::span begins, + std::span ends, + std::span step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor) { + return slice_operation_invoke_impl(ttnn::DefaultQueueId, input_tensor, begins, ends, step, memory_config_arg, optional_output_tensor); +} -template ttnn::Tensor SliceOperation::invoke( - const ttnn::Tensor& input_tensor, - const std::array &output_tensor_start, - const std::array &output_tensor_end, - const std::array &step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor) { - return SliceOperation::invoke(ttnn::DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, memory_config_arg); - } - -template ttnn::Tensor SliceOperation::invoke( - uint8_t queue_id, - const ttnn::Tensor& input_tensor, - const std::vector &begins, - const std::vector &ends, - const std::vector &step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor); - -template ttnn::Tensor SliceOperation::invoke( - const ttnn::Tensor& input_tensor, - const std::vector &begins, - const std::vector &ends, - const std::vector &step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor); - - -template ttnn::Tensor SliceOperation::invoke( - uint8_t queue_id, - const ttnn::Tensor& input_tensor, - const std::vector &begins, - const std::vector &ends, - const std::vector &step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor); - -template ttnn::Tensor SliceOperation::invoke( - const ttnn::Tensor& input_tensor, - const std::vector &begins, - const std::vector &ends, - const std::vector &step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor); - -template ttnn::Tensor SliceOperation::invoke( - const ttnn::Tensor& input_tensor, - const std::array &output_tensor_start, - const std::array &output_tensor_end, - const std::array &step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor); - -template ttnn::Tensor SliceOperation::invoke( - uint8_t queue_id, - const ttnn::Tensor& input_tensor, - const std::array &output_tensor_start, - const std::array &output_tensor_end, - const std::array &step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor); - -template ttnn::Tensor SliceOperation::invoke( - const ttnn::Tensor& input_tensor, - const std::array &output_tensor_start, - const std::array &output_tensor_end, - const std::array &step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor); + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + std::span begins, + std::span ends, + std::span step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor) { + return slice_operation_invoke_impl(queue_id, input_tensor, begins, ends, step, memory_config_arg, optional_output_tensor); +} +ttnn::Tensor SliceOperation::invoke( + const ttnn::Tensor& input_tensor, + std::span begins, + std::span ends, + std::span step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor) { + return slice_operation_invoke_impl(ttnn::DefaultQueueId, input_tensor, begins, ends, step, memory_config_arg, optional_output_tensor); +} +slice.cpp } // namespace operations diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp index f814df01cdc..0206e0dd4ea 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp @@ -11,45 +11,62 @@ namespace operations { namespace data_movement { struct SliceOperation { - template static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - const std::vector &begins, - const std::vector &ends, - const std::vector &step, + std::span begins, + std::span ends, + std::span step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); - template static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const std::vector &output_tensor_start, - const std::vector &output_tensor_end, - const std::vector &step, + std::span begins, + std::span ends, + std::span step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); - template static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - const std::array &output_tensor_start, - const std::array &output_tensor_end, - const std::array &step, + std::span begins, + std::span ends, + std::span step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); - template static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const std::array &output_tensor_start, - const std::array &output_tensor_end, - const std::array &step, + std::span begins, + std::span ends, + std::span step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); + template + static ttnn::Tensor invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + const std::array& begins, + const std::array& ends, + const std::array& step, + const std::optional& memory_config_arg = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt) { + return invoke(queue_id, input_tensor, std::span(begins), std::span(ends), std::span(step), memory_config_arg, optional_output_tensor); + } + template + static ttnn::Tensor invoke( + const ttnn::Tensor& input_tensor, + const std::array& begins, + const std::array& ends, + const std::array& step, + const std::optional& memory_config_arg = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt) { + return invoke(ttnn::DefaultQueueId, input_tensor, std::span(begins), std::span(ends), std::span(step), memory_config_arg, optional_output_tensor); + } }; } // namespace data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp index 1aeae87eefd..963b235f9a8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp @@ -52,14 +52,17 @@ void bind_slice(py::module& module) { ttnn::pybind_overload_t{ [] (const OperationType& self, const ttnn::Tensor& input_tensor, - const std::vector &slice_start, - const std::vector &slice_end, - const std::optional> &step, + std::span slice_start, + std::span slice_end, + std::optional> step, const std::optional& memory_config, const std::optional& optional_output_tensor, uint8_t queue_id) { - const auto step_value = step.value_or(std::vector(slice_end.size(), 1)); - return self(queue_id, input_tensor, slice_start, slice_end, step_value, memory_config, optional_output_tensor); + if (step.has_value()) { + return self(queue_id, input_tensor, slice_start, slice_end, step.value(), memory_config, optional_output_tensor); + } else { + return self(queue_id, input_tensor, slice_start, slice_end, ttnn::SmallVector(slice_end.size(), 1), memory_config, optional_output_tensor); + } }, py::arg("input_tensor"), py::arg("slice_start"), diff --git a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp index eb2bbdbed53..e5439fd02b0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp @@ -46,10 +46,10 @@ namespace detail { auto start = i*chunk_len; auto end = start + chunk_len; - std::vector start_shape(preproc_shape.size(), 0); + ttnn::SmallVector start_shape(preproc_shape.size(), 0); start_shape[dim] = start; - std::vector end_shape(preproc_shape.size()); + ttnn::SmallVector end_shape(preproc_shape.size()); for (int j = 0; j < end_shape.size(); j++) { if (j == dim) { end_shape[j] = end; @@ -61,7 +61,7 @@ namespace detail { Tensor output_chunk = ttnn::slice(preprocessed, start_shape, end_shape, - std::vector(end_shape.size(), 1), + ttnn::SmallVector(end_shape.size(), 1), mem_config); if (input_rank < 4) { output_chunk = ttnn::squeeze_from_4D(output_chunk, input_rank); @@ -102,13 +102,13 @@ namespace detail { } const int W = 1, Z = shape[0] * shape[1], Y = shape[2], X = shape[3]; - const Tensor &reshaped_tensor = ttnn::reshape_on_device(input_tensor, std::vector{1, -1, Y, X}, mem_config); + const Tensor &reshaped_tensor = ttnn::reshape_on_device(input_tensor, ttnn::SmallVector{1, -1, Y, X}, mem_config); auto part_reshaped = impl_split_last_dim_two_chunks_tiled(reshaped_tensor, mem_config); std::vector results; results.reserve(part_reshaped.size()); - for (auto &part : part_reshaped) results.emplace_back(ttnn::reshape_on_device(part, std::vector{-1, (int32_t)shape[1], Y, X / 2}, mem_config)); + for (auto &part : part_reshaped) results.emplace_back(ttnn::reshape_on_device(part, ttnn::SmallVector{-1, (int32_t)shape[1], Y, X / 2}, mem_config)); return results; } diff --git a/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp b/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp index d0674b0bb8e..e622da4918d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp @@ -19,8 +19,8 @@ ttnn::Tensor SqueezeOperation::invoke(const ttnn::Tensor& input_tensor, const in normal_dim += input_tensor_rank; } - std::vector original_logical_shape_vector(input_tensor_rank - 1); - std::vector padded_shape_vector(input_tensor_rank - 1); + SmallVector original_logical_shape_vector(input_tensor_rank - 1); + SmallVector padded_shape_vector(input_tensor_rank - 1); uint32_t vector_id = 0; for(int i=0; i< input_tensor_rank; i++) { if(i != normal_dim or original_logical_shape[i] != 1) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp index 6baf416be2e..41aacf7948b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp @@ -69,11 +69,11 @@ inline Tensor transpose_(const Tensor &a, TransposeOpDim transpose_dim, const Me break; // bubble dim around to make it possible as these implementations don't have a kernel case TransposeOpDim::NH: - return ttnn::permute((const ttnn::Tensor)a, std::vector({2, 1, 0, 3}), output_mem_config); + return ttnn::permute((const ttnn::Tensor)a, ttnn::SmallVector({2, 1, 0, 3}), output_mem_config); case TransposeOpDim::NW: - return ttnn::permute((const ttnn::Tensor)a, std::vector({3, 1, 2, 0}), output_mem_config); + return ttnn::permute((const ttnn::Tensor)a, ttnn::SmallVector({3, 1, 2, 0}), output_mem_config); case TransposeOpDim::CW: - return ttnn::permute((const ttnn::Tensor)a, std::vector({0, 3, 2, 1}), output_mem_config); + return ttnn::permute((const ttnn::Tensor)a, ttnn::SmallVector({0, 3, 2, 1}), output_mem_config); case TransposeOpDim::CN: tiled_only = true; // CN only has a tiled implementation at the moment break; @@ -173,12 +173,12 @@ ttnn::Tensor ExecuteTranspose::invoke( auto input_shape = input_typecasted.get_logical_shape(); // create_output_tensor shape is useless when we potentially have new padding to deal with - std::vector output_shape; + SmallVector output_shape; output_shape.reserve(input_shape.rank()); for (int i = 0; i < input_shape.rank(); ++i) { output_shape.push_back(input_shape[i]); } - std::vector padded_output_shape = output_shape; + SmallVector padded_output_shape = output_shape; std::swap(output_shape[normalized_dim1], output_shape[normalized_dim2]); std::swap(padded_output_shape[normalized_dim1], padded_output_shape[normalized_dim2]); diff --git a/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp b/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp index d1fbf47978f..d2546e18e35 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp @@ -11,7 +11,7 @@ namespace ttnn::operations::data_movement { ttnn::Tensor UnsqueezeOperation::invoke(const ttnn::Tensor& input_tensor, const int dim) { const auto tensor_shape = input_tensor.get_shape(); const auto rank = tensor_shape.rank(); - std::vector output_shape_vector; + SmallVector output_shape_vector; TT_FATAL(input_tensor.get_layout() == Layout::ROW_MAJOR or (!tensor_shape.has_tile_padding()), "Currently supporing ROW-MAJOR tensors or TILE tensors with no padding"); diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_op.cpp index 0959a4998b8..9a80603e5f9 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_op.cpp @@ -73,7 +73,7 @@ void UntilizeWithUnpadding::validate(const std::vector& input_tensors) c std::vector UntilizeWithUnpadding::compute_output_shapes( const std::vector& input_tensors) const { - std::vector out_shape; + SmallVector out_shape; auto rank = input_tensors[0].get_legacy_shape().rank(); out_shape.reserve(rank); for (uint32_t i = 0; i < rank; i++) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp index e42f7e72d70..668426a53b8 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp @@ -168,7 +168,7 @@ BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_outpu // - The lambda is reused for both logical shapes and padded shapes, ensuring consistency. // ------------------------------------------------------------------------- auto compute_broadcasted_output = [rank_a, rank_b, larger_rank](const auto& shape_a, const auto& shape_b) { - std::vector output_shape(larger_rank, 1); + SmallVector output_shape(larger_rank, 1); for (int i = -1; i >= -larger_rank; --i) { auto dim_a = (i >= -rank_a) ? shape_a[i] : 1; auto dim_b = (i >= -rank_b) ? shape_b[i] : 1; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index dee368d713c..20de721016c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -398,7 +398,7 @@ Tensor _variance_impl( const Tensor& mean_y, Tensor& y_minus_mean_y, const std::optional& output_mem_config) { - std::vector dims = { 2, 3 }; + ttnn::SmallVector dims = { 2, 3 }; constexpr float correction = 0.0f; auto shape_wh = y.get_legacy_shape(); float scale = 1.0f / ((float)(shape_wh[3] * shape_wh[2]) - correction); @@ -412,7 +412,7 @@ Tensor _variance_impl(const Tensor& y, const Tensor& mean_y, const std::optional Tensor _variance(const Tensor& y, const std::optional& output_mem_config) { auto output_memory_config = output_mem_config.value_or(y.memory_config()); - std::vector dims = { 2, 3 }; + ttnn::SmallVector dims = { 2, 3 }; Tensor mean_y = ttnn::mean(y, dims, true); return _variance_impl(y, mean_y, output_memory_config); } @@ -435,7 +435,7 @@ Tensor _std_overload(const Tensor& y, const std::optional& output // Function normalize // use transformation y = (y - mean(y))/std(y) by broadcast Tensor _normalize(const Tensor& y, const std::optional& output_mem_config) { - std::vector dims = { 2, 3 }; + ttnn::SmallVector dims = { 2, 3 }; Tensor mean_y = ttnn::mean(y, dims, true); Tensor y_minus_mean_y = ttnn::bcast(0, y, mean_y, ttnn::BcastOpMath::SUB, ttnn::BcastOpDim::HW); Tensor std_y = _std(y, mean_y, y_minus_mean_y, output_mem_config); @@ -551,13 +551,13 @@ std::vector split_tensor_for_glu(const Tensor& input_a, int32_t dim, con std::vector t_split; tt::tt_metal::LegacyShape inshape(input_a.get_legacy_shape()); TT_FATAL(((inshape[dim] / 2) % tt::constants::TILE_WIDTH == 0), "Split tensor dimension should be in full tile"); - std::vector s_a = {0, 0, 0, 0}; - std::vector e_a = {input_a.get_legacy_shape()[0], inshape[1], inshape[2], inshape[3] / 2}; + ttnn::SmallVector s_a = {0, 0, 0, 0}; + ttnn::SmallVector e_a = {input_a.get_legacy_shape()[0], inshape[1], inshape[2], inshape[3] / 2}; - std::vector s_b = {0, 0, 0, inshape[3] / 2}; - std::vector e_b = {inshape[0], inshape[1], inshape[2], inshape[3]}; + ttnn::SmallVector s_b = {0, 0, 0, inshape[3] / 2}; + ttnn::SmallVector e_b = {inshape[0], inshape[1], inshape[2], inshape[3]}; - auto step = std::vector({1,1,1,1}); + auto step = ttnn::SmallVector({1,1,1,1}); Tensor t_a = ttnn::slice(DefaultQueueId, input_a, s_a, e_a, step, output_mem_config); Tensor t_b = ttnn::slice(DefaultQueueId, input_a, s_b, e_b, step, output_mem_config); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp index 6328d0a4b6a..eae80c8fe21 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp @@ -1399,7 +1399,7 @@ std::vector ExecuteUnaryBackwardRepeat::invoke( grad_tensor.emplace_back(zero_tensor); return grad_tensor; } else if (shape[0] > 1) { - std::vector dim = {0}; + ttnn::SmallVector dim = {0}; TT_FATAL(shape[1] == 1 && shape[2] == 1 && shape[3] == 1, "repeat[1], [2], [3] should be 1"); std::array intended_shape_array = {1, shape_wh[1], shape_wh[2], shape_wh[3]}; const ttnn::Shape required = ttnn::Shape(intended_shape_array); @@ -1413,7 +1413,7 @@ std::vector ExecuteUnaryBackwardRepeat::invoke( grad_tensor.emplace_back(result); return grad_tensor; } else if (shape[1] > 1) { - std::vector dim = {1}; + ttnn::SmallVector dim = {1}; TT_FATAL(shape[0] == 1 && shape[2] == 1 && shape[3] == 1, "repeat[0], [2], [3] should be 1"); std::array intended_shape_array = {shape_wh[0], 1, shape_wh[2], shape_wh[3]}; const ttnn::Shape required = ttnn::Shape(intended_shape_array); @@ -1464,13 +1464,13 @@ std::vector ExecuteUnaryBackwardProd::invoke( } // all_dimensions = False Tensor updated_grad = prod_result; - auto step = std::vector({1, 1, 1, 1}); + auto step = ttnn::SmallVector({1, 1, 1, 1}); if (prod_result.get_logical_shape() != grad.get_padded_shape()) { if (dim == 3 || dim == -1) { - std::vector after_permute_dims = {0, 3, 1, 2}; + ttnn::SmallVector after_permute_dims = {0, 3, 1, 2}; Tensor required = ttnn::permute(grad, after_permute_dims, output_memory_config); - std::vector start_index = {0, 0, 0, 0}; - std::vector end_index = { + ttnn::SmallVector start_index = {0, 0, 0, 0}; + ttnn::SmallVector end_index = { grad.get_legacy_shape()[0], 1, grad.get_legacy_shape()[1], grad.get_legacy_shape()[2]}; Tensor new_slice_tensor = ttnn::slice(DefaultQueueId, required, start_index, end_index, step, std::nullopt); after_permute_dims = {0, 2, 3, 1}; @@ -1481,10 +1481,10 @@ std::vector ExecuteUnaryBackwardProd::invoke( updated_grad = pad_updated_grad.to(input.device()); } } else if (dim == 2 || dim == -2) { - std::vector after_permute_dims = {0, 2, 1, 3}; + ttnn::SmallVector after_permute_dims = {0, 2, 1, 3}; Tensor required = ttnn::permute(grad, after_permute_dims, output_memory_config); - std::vector start_index = {0, 0, 0, 0}; - std::vector end_index = { + ttnn::SmallVector start_index = {0, 0, 0, 0}; + ttnn::SmallVector end_index = { grad.get_legacy_shape()[0], 1, grad.get_legacy_shape()[1], grad.get_legacy_shape()[3]}; Tensor new_slice_tensor = ttnn::slice(DefaultQueueId, required, start_index, end_index, step, std::nullopt); updated_grad = ttnn::permute(new_slice_tensor, after_permute_dims, output_memory_config); @@ -1509,13 +1509,13 @@ std::vector ExecuteUnaryBackwardProd::invoke( } else if (dim == 1 || dim == -3) { Tensor tensor_1_temp = reciprocal_input; if (reciprocal_input.get_legacy_shape()[1] % 32 != 0) { - std::vector> padding = {{0, 0}, + ttnn::SmallVector> padding = {{0, 0}, {0, 32 - (reciprocal_input.get_legacy_shape()[1] % 32)}, {0, 0}, {0, 0}}; tensor_1_temp = ttnn::pad(0, reciprocal_input, padding, 0, true, std::nullopt); } - std::vector after_permute_dims = {0, 2, 3, 1}; + ttnn::SmallVector after_permute_dims = {0, 2, 3, 1}; Tensor tensor_1 = ttnn::permute(tensor_1_temp, after_permute_dims, output_memory_config); Tensor tensor_2 = ttnn::permute(temp, after_permute_dims, output_memory_config); @@ -1530,13 +1530,13 @@ std::vector ExecuteUnaryBackwardProd::invoke( output_memory_config); Tensor grad_result = result; if (reciprocal_input.get_legacy_shape()[1] % 32 != 0) { - std::vector start_index = {0, 0, 0, 0}; - std::vector end_index = { + ttnn::SmallVector start_index = {0, 0, 0, 0}; + ttnn::SmallVector end_index = { input.get_legacy_shape()[0], input.get_legacy_shape()[1], input.get_legacy_shape()[2], input.get_legacy_shape()[3]}; - auto step = std::vector({1,1,1,1}); + auto step = ttnn::SmallVector({1,1,1,1}); grad_result = ttnn::slice(DefaultQueueId, result, start_index, end_index, step, std::nullopt); } grad_tensor.emplace_back(grad_result); @@ -1545,13 +1545,13 @@ std::vector ExecuteUnaryBackwardProd::invoke( // dim 0 Tensor tensor_1_temp = reciprocal_input; if (reciprocal_input.get_legacy_shape()[0] % 32 != 0) { - std::vector> padding = {{0, (32 - (reciprocal_input.get_legacy_shape()[0] % 32))}, + ttnn::SmallVector> padding = {{0, (32 - (reciprocal_input.get_legacy_shape()[0] % 32))}, {0, 0}, {0, 0}, {0, 0}}; tensor_1_temp = ttnn::pad(0, reciprocal_input, padding, 0, false, std::nullopt); } - std::vector after_permute_dims = {3, 1, 2, 0}; + ttnn::SmallVector after_permute_dims = {3, 1, 2, 0}; Tensor tensor_1 = ttnn::permute(tensor_1_temp, after_permute_dims, output_memory_config); Tensor tensor_2 = ttnn::permute(temp, after_permute_dims, output_memory_config); @@ -1565,8 +1565,8 @@ std::vector ExecuteUnaryBackwardProd::invoke( output_memory_config); Tensor grad_result = result; if (reciprocal_input.get_legacy_shape()[0] % 32 != 0) { - std::vector start_index = {0, 0, 0, 0}; - std::vector end_index = { + ttnn::SmallVector start_index = {0, 0, 0, 0}; + ttnn::SmallVector end_index = { input.get_legacy_shape()[0], input.get_legacy_shape()[1], input.get_legacy_shape()[2], diff --git a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp index ff909303e39..3f8ae4f6f68 100644 --- a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp @@ -187,7 +187,7 @@ Tensor AutoFormat::format_output_tensor( } else if (formatted_output.get_layout() == Layout::TILE && AutoFormat::legal_rm_shape(shape)) { formatted_output = ttnn::untilize_with_unpadding( formatted_output, - std::vector({shape[0] - 1, shape[1] - 1, shape[2] - 1, shape[3] - 1}), + SmallVector({shape[0] - 1, shape[1] - 1, shape[2] - 1, shape[3] - 1}), mem_config); return formatted_output; } @@ -196,7 +196,7 @@ Tensor AutoFormat::format_output_tensor( AutoFormat::legal_rm_shape(shape)) { formatted_output = ttnn::untilize_with_unpadding( formatted_output, - std::vector({shape[0] - 1, shape[1] - 1, shape[2] - 1, shape[3] - 1}), + SmallVector({shape[0] - 1, shape[1] - 1, shape[2] - 1, shape[3] - 1}), mem_config); return formatted_output; } else if ( diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp index ca82d3ba307..ee9955c8c69 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp @@ -92,7 +92,7 @@ Tensor all_reduce( merged_dim_size *= shape[i]; } - std::vector new_shape{1, merged_dim_size, shape[rank - 2], shape[rank - 1]}; + ttnn::SmallVector new_shape{1, merged_dim_size, shape[rank - 2], shape[rank - 1]}; auto reshaped_tensor = ttnn::reshape(input_tensor, new_shape); diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp b/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp index ca8e8123e18..3380cff9e69 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp @@ -67,7 +67,7 @@ Tensor ArgmaxOperation::invoke(const Tensor& input_t, int64_t _dim, bool all, co result = ttnn::min(result, (int)dim, true, output_memory_config); Tensor res_index = ttnn::zeros_like(result); result = ttnn::where(ttnn::eq(result, size), res_index, result, output_memory_config); - std::vector permute_dims = {3, 0, 1, 2}; + ttnn::SmallVector permute_dims = {3, 0, 1, 2}; if (is_width) { res_index = ttnn::add(res_index, result, std::nullopt, output_memory_config); } else { @@ -105,7 +105,7 @@ Tensor ArgmaxOperation::invoke(const Tensor& input_t, int64_t _dim, bool all, co Tensor res_index = ttnn::zeros_like(result); result = ttnn::where(ttnn::eq(result, full_like(result, size)), res_index, result, output_memory_config); if (is_channel) { - std::vector permute_dims = {1, 0, 2, 3}; + ttnn::SmallVector permute_dims = {1, 0, 2, 3}; Tensor transpose_res = ttnn::permute(result, permute_dims, output_memory_config); return {transpose_res}; } else { diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.cpp b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.cpp index fa18cd9ab86..d794d3adda1 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.cpp @@ -73,7 +73,7 @@ std::vector FastReduceNCDeviceOperation::compute_outp // last 2-dim output_shape[this->dim] = 1; - return {tt::tt_metal::LegacyShape(output_shape.as_vector(), padding)}; + return {tt::tt_metal::LegacyShape(output_shape.view(), padding)}; } std::vector FastReduceNCDeviceOperation::create_output_tensors( @@ -97,12 +97,12 @@ operation::ProgramWithCallbacks FastReduceNCDeviceOperation::create_program( Tensor fast_reduce_nc( uint8_t queue_id, const ttnn::Tensor& input, - const std::vector& dims, + std::span dims, const std::optional output, const MemoryConfig& output_mem_config, std::optional compute_kernel_config) { - std::vector sorted_dims = dims; + ttnn::SmallVector sorted_dims(dims.begin(), dims.end()); std::sort(sorted_dims.begin(), sorted_dims.end()); auto temp_input = input; diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.hpp b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.hpp index ec18e782fb7..b19e85abd86 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.hpp @@ -29,7 +29,7 @@ struct FastReduceNCDeviceOperation { Tensor fast_reduce_nc( uint8_t queue_id, const ttnn::Tensor &input, - const std::vector &dims, + std::span dims, const std::optional output = std::nullopt, const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::optional compute_kernel_config = std::nullopt); diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.cpp b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.cpp index 6ad219e20fd..8b7560cd494 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.cpp @@ -14,7 +14,7 @@ namespace operations::experimental::reduction{ ttnn::Tensor FastReduceNCOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input, - const std::vector& dims, + std::span dims, const std::optional output, const ttnn::MemoryConfig memory_config, std::optional compute_kernel_config) { @@ -23,7 +23,7 @@ ttnn::Tensor FastReduceNCOperation::invoke( ttnn::Tensor FastReduceNCOperation::invoke( const ttnn::Tensor& input, - const std::vector& dims, + std::span dims, const std::optional output, const ttnn::MemoryConfig memory_config, std::optional compute_kernel_config) { diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.hpp b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.hpp index 174bc68acc3..0a731a40838 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.hpp @@ -16,14 +16,14 @@ struct FastReduceNCOperation { static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input, - const std::vector& dims, + std::span dims, const std::optional output, const ttnn::MemoryConfig memory_config, std::optional compute_kernel_config); static ttnn::Tensor invoke( const ttnn::Tensor& input, - const std::vector& dims, + std::span dims, const std::optional output, const ttnn::MemoryConfig memory_config, std::optional compute_kernel_config); diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc_pybind.cpp index bf2738b5835..6174d6bae0c 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc_pybind.cpp @@ -22,7 +22,7 @@ void bind_fast_reduce_nc(pybind11::module& module) { ttnn::pybind_overload_t{ [] (const OperationType& self, const ttnn::Tensor& input, - const std::vector& dims, + std::span dims, const std::optional output, const ttnn::MemoryConfig memory_config, std::optional compute_kernel_config, @@ -31,7 +31,7 @@ void bind_fast_reduce_nc(pybind11::module& module) { }, pybind11::arg("input").noconvert(), pybind11::kw_only(), - pybind11::arg("dims").noconvert() = std::vector(), + pybind11::arg("dims").noconvert() = std::span(), pybind11::arg("output").noconvert() = std::nullopt, pybind11::arg("memory_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, pybind11::arg("compute_kernel_config").noconvert() = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/nlp_kv_cache_load_slice/device/nlp_kv_cache_load_slice_device_operation.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/nlp_kv_cache_load_slice/device/nlp_kv_cache_load_slice_device_operation.cpp index 4e4230198fe..4785ddb46bb 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/nlp_kv_cache_load_slice/device/nlp_kv_cache_load_slice_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/nlp_kv_cache_load_slice/device/nlp_kv_cache_load_slice_device_operation.cpp @@ -42,7 +42,7 @@ void NlpKVCacheLoadSliceDeviceOperation::validate(const std::vector &inp "Can only unpad tilized tensor with full tiles"); } std::vector NlpKVCacheLoadSliceDeviceOperation::compute_output_shapes(const std::vector &input_tensors) const { - std::vector out_shape; + SmallVector out_shape; auto rank = input_tensors[0].get_legacy_shape().rank(); out_shape.reserve(rank); for (uint32_t i = 0; i < rank; i++) { diff --git a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp index c0f8b270c91..7aaffe18c6b 100644 --- a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.cpp @@ -67,16 +67,16 @@ FullOperation::tensor_return_value_t FullOperation::create_output_tensors( } std::tuple FullOperation::invoke( - const std::vector shape, - const std::variant fill_value, + ttnn::SmallVector shape, + std::variant fill_value, const Tensor& any, const std::optional& dtype, const std::optional& layout, const std::optional& memory_config) { return { operation_attributes_t{ - shape, - fill_value, + std::move(shape), + std::move(fill_value), dtype.value_or(any.get_dtype()), layout.value_or(any.get_layout()), memory_config.value_or(any.memory_config()), diff --git a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp index f38fede912d..1a12849b814 100644 --- a/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/full/device/full_device_operation.hpp @@ -11,7 +11,7 @@ namespace ttnn::operations::full { struct FullOperation { struct operation_attributes_t { - const std::vector shape; + const ttnn::SmallVector shape; const std::variant fill_value; const DataType dtype; const Layout layout; @@ -56,8 +56,8 @@ struct FullOperation { static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static std::tuple invoke( - const std::vector shape, - const std::variant fill_value, + ttnn::SmallVector shape, + std::variant fill_value, const Tensor& any, const std::optional& dtype, const std::optional& layout, diff --git a/ttnn/cpp/ttnn/operations/full/full.cpp b/ttnn/cpp/ttnn/operations/full/full.cpp index 9ba435c480d..e634bad14cd 100644 --- a/ttnn/cpp/ttnn/operations/full/full.cpp +++ b/ttnn/cpp/ttnn/operations/full/full.cpp @@ -10,7 +10,7 @@ namespace ttnn::operations::full { Tensor Full::invoke( - const std::vector& shape, + const ttnn::SmallVector shape, const std::variant fill_value, const ttnn::Tensor& any, const std::optional& dtype, diff --git a/ttnn/cpp/ttnn/operations/full/full.hpp b/ttnn/cpp/ttnn/operations/full/full.hpp index f3173d5680e..8e8a57f79f5 100644 --- a/ttnn/cpp/ttnn/operations/full/full.hpp +++ b/ttnn/cpp/ttnn/operations/full/full.hpp @@ -9,7 +9,7 @@ namespace ttnn::operations::full { struct Full { static ttnn::Tensor invoke( - const std::vector& shape, + const ttnn::SmallVector shape, const std::variant fill_value, const ttnn::Tensor& any, const std::optional& dtype, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.cpp index 75d7806af93..c2cfec9fdf4 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.cpp @@ -62,10 +62,10 @@ MorehArangeOperation::shape_return_value_t MorehArangeOperation::compute_output_ if (operation_attributes.untilize_out) return ttnn::Shape(tt::tt_metal::LegacyShape({num_elems})); - std::vector output_size = { + SmallVector output_size = { tt::constants::TILE_HEIGHT, tt::round_up(num_elems, tt::constants::TILE_WIDTH)}; - auto dimensions_pads = std::vector(); + auto dimensions_pads = SmallVector(); dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = 31}); dimensions_pads.push_back( Padding::PadDimension{.front = 0, .back = tt::round_up(num_elems, tt::constants::TILE_WIDTH) - num_elems}); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp index f7e92b15568..8ed30508837 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp @@ -55,7 +55,7 @@ void MorehGetItemOperation::validate_inputs( TT_FATAL( dim_start + i == dim, "The value of index_dims={} must be consecutive integers.", - operation_attributes.index_dims); + std::vector(operation_attributes.index_dims.begin(), operation_attributes.index_dims.end())); i++; } if (!output_tensor.has_value()) { @@ -104,8 +104,8 @@ MorehGetItemOperation::shape_return_value_t MorehGetItemOperation::compute_outpu // index_dims = 1,2 // output: (10, 1, 100, 40) auto dim_offset = 5 - input_shape.rank(); - auto dimensions_pads = std::vector(); - std::vector output_size_vec; + auto dimensions_pads = SmallVector(); + SmallVector output_size_vec; for (int dim = 0; dim < output_shape.size(); dim++) { dimensions_pads.push_back(output_shape.value.padding()[dim]); output_size_vec.push_back(output_shape.value[dim]); @@ -148,7 +148,7 @@ MorehGetItemOperation::shape_return_value_t MorehGetItemOperation::compute_outpu // index_tensor: [(100), (100)] // index_dims = 1,2 // output: (10, 100, 40) - std::vector output_size_vec; + SmallVector output_size_vec; auto input_shape = input_tensor.get_shape(); uint32_t input_rank = input_shape.rank(); @@ -193,7 +193,7 @@ std::tuple& index_tensors, - const std::vector index_dims, + const ttnn::SmallVector index_dims, const std::optional& output, const std::optional memory_config) { operation_attributes_t operation_attributes = {index_dims, memory_config.value_or(input.memory_config())}; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.hpp index ca774e95954..1cbae242656 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.hpp @@ -14,7 +14,7 @@ namespace ttnn::operations::moreh::moreh_getitem { struct MorehGetItemOperation { struct operation_attributes_t { - const std::vector index_dims; + const ttnn::SmallVector index_dims; // const CoreRange core_range; const MemoryConfig memory_config; }; @@ -34,7 +34,7 @@ struct MorehGetItemOperation { KernelHandle unary_writer_kernel_id; std::size_t num_cores; uint32_t core_h; - std::vector index_dims; + ttnn::SmallVector index_dims; uint32_t input_dim_offset; }; @@ -58,7 +58,7 @@ struct MorehGetItemOperation { KernelHandle unary_writer_kernel_id; std::size_t num_cores; uint32_t core_h; - std::vector index_dims; + ttnn::SmallVector index_dims; uint32_t input_dim_offset; }; @@ -87,7 +87,7 @@ struct MorehGetItemOperation { static std::tuple invoke( const Tensor& input, const std::vector& index_tensors, - const std::vector index_dims, + const ttnn::SmallVector index_dims, const std::optional& output, // const CoreRange core_range, const std::optional memory_config); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.cpp index 35d3c8dba61..8eb34d94b38 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.cpp @@ -8,7 +8,7 @@ namespace ttnn::operations::moreh::moreh_getitem { Tensor MorehGetItem::invoke( const Tensor& input, const std::vector& index_tensors, - const std::vector index_dims, + const ttnn::SmallVector index_dims, const std::optional& output, // const CoreRange core_range, const std::optional memory_config) { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.hpp index 06d16afe73b..558b7d4ecb3 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.hpp @@ -12,7 +12,7 @@ struct MorehGetItem { static Tensor invoke( const Tensor& input, const std::vector& index_tensors, - const std::vector index_dims, + const ttnn::SmallVector index_dims, const std::optional& output, // const CoreRange core_range, const std::optional memory_config); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp index e7c907a2916..011a8150dd5 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp @@ -87,7 +87,7 @@ MorehGroupNormOperation::shape_return_value_t MorehGroupNormOperation::compute_o const auto output_shape = tensor_args.input.get_logical_shape(); const auto N = output_shape[0]; const auto num_groups = operation_attributes.num_groups; - std::vector mean_rstd_origin_shape{ + SmallVector mean_rstd_origin_shape{ 1, 1, N, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.cpp index b331eca682b..380190bde2d 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.cpp @@ -331,7 +331,7 @@ uint32_t compute_outer(tt::tt_metal::LegacyShape shape, uint32_t dim) { return num_outer; } -void expand_to_max_dim(std::vector &dim, const ttnn::SimpleShape &shape) { +void expand_to_max_dim(ttnn::SmallVector &dim, const ttnn::SimpleShape &shape) { const auto rank = shape.rank(); for (auto i = 0; i < rank; ++i) { auto idx = rank - 1 - i; @@ -381,10 +381,10 @@ void validate_output_with_keepdim(const Tensor &input, const Tensor &output, con output_rank); } - std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector output_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector input_dim_wo_padding(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector output_dim_wo_padding(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector output_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector input_dim_wo_padding(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector output_dim_wo_padding(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); expand_to_max_dim(input_dim, input_shape); expand_to_max_dim(output_dim, output_shape); expand_to_max_dim(input_dim_wo_padding, input_shape_wo_padding); @@ -395,8 +395,8 @@ void validate_output_with_keepdim(const Tensor &input, const Tensor &output, con TT_FATAL(input_dim_wo_padding[i] == output_dim_wo_padding[i], "Error"); } } else { - std::vector expected_output_shape; - std::vector expected_output_shape_wo_padding; + ttnn::SmallVector expected_output_shape; + ttnn::SmallVector expected_output_shape_wo_padding; for (int i = 0; i < output_shape.rank(); ++i) { if (i == dim && !is_tile_dim) { expected_output_shape.push_back(1); @@ -418,21 +418,21 @@ void validate_output_with_keepdim(const Tensor &input, const Tensor &output, con } } -void initialize_dims_with_range(std::vector &dims, uint32_t input_rank) { +void initialize_dims_with_range(ttnn::SmallVector &dims, uint32_t input_rank) { dims.resize(input_rank); std::iota(dims.begin(), dims.end(), 0); } -std::vector get_dim( - const std::optional>> &dim, uint32_t input_rank) { - std::vector dims; +ttnn::SmallVector get_dim( + const std::optional>> &dim, uint32_t input_rank) { + ttnn::SmallVector dims; if (!dim.has_value()) { initialize_dims_with_range(dims, input_rank); } else if (std::holds_alternative(dim.value())) { auto d = std::get(dim.value()); dims.push_back(d); } else { - dims = std::get>(dim.value()); + dims = std::get>(dim.value()); if (dims.empty()) { initialize_dims_with_range(dims, input_rank); } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.hpp index 384b9097f4b..17d5a36d01a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.hpp @@ -293,16 +293,16 @@ uint32_t compute_inner(tt::tt_metal::LegacyShape shape, uint32_t dim); uint32_t compute_outer(tt::tt_metal::LegacyShape shape, uint32_t dim); -void expand_to_max_dim(std::vector &dim, const ttnn::SimpleShape &shape); +void expand_to_max_dim(ttnn::SmallVector &dim, const ttnn::SimpleShape &shape); void validate_input_with_dim(const Tensor &input, const int64_t &dim); void validate_output_with_keepdim(const Tensor &input, const Tensor &output, const int64_t &dim, const bool &keepdim); -void initialize_dims_with_range(std::vector &dims, uint32_t input_rank); +void initialize_dims_with_range(ttnn::SmallVector &dims, uint32_t input_rank); -std::vector get_dim( - const std::optional>> &dim, uint32_t input_rank); +ttnn::SmallVector get_dim( + const std::optional>> &dim, uint32_t input_rank); std::tuple extract_spatial_dims(const ttnn::SimpleShape& shape); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.cpp index 103b178b76f..7f1a68c57d8 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_device_operation.cpp @@ -89,8 +89,8 @@ MorehLayerNormOperation::shape_return_value_t MorehLayerNormOperation::compute_o auto input_rank = input_shape.rank(); auto output_rank = input_rank - normalized_dims; - std::vector output_size_vec; - auto dimensions_pads = std::vector(); + ttnn::SmallVector output_size_vec; + ttnn::SmallVector dimensions_pads; if (output_rank == 1) { output_size_vec.push_back(32); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.cpp index 44653a944bc..08450c6d0ed 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward.cpp @@ -18,7 +18,7 @@ std::tuple MorehLinearBackward::get_required_outputs(const std return {are_required_outputs[0], are_required_outputs[1], are_required_outputs[2]}; } -void get_tensor_dim(std::vector& dim, const tt::tt_metal::LegacyShape& shape) { +void get_tensor_dim(ttnn::SmallVector& dim, const tt::tt_metal::LegacyShape& shape) { const auto rank = shape.rank(); for (auto i = 0; i < rank; ++i) { auto idx = rank - 1 - i; @@ -66,15 +66,15 @@ inline void moreh_linear_backward_validate( } } -std::vector find_reduce_dim( +ttnn::SmallVector find_reduce_dim( const tt::tt_metal::LegacyShape& a_shape, const tt::tt_metal::LegacyShape& b_shape) { - std::vector a_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector b_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector a_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector b_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_tensor_dim(a_dim, a_shape); get_tensor_dim(b_dim, b_shape); int32_t rank = std::max(a_shape.rank(), b_shape.rank()); log_debug(tt::LogOp, "find_reduce_dim :{} rank {} a {} b {}", __LINE__, rank, a_shape.rank(), b_shape.rank()); - std::vector dims; + ttnn::SmallVector dims; // batch dims for (int i = 0; i < rank - 2; ++i) { int idx = rank - 1 - i; @@ -91,8 +91,8 @@ bool is_same_batch_dim(const Tensor& tensor_a, const Tensor& tensor_b) { // check batch dims const auto& a_shape = tensor_a.get_shape().value; const auto& b_shape = tensor_b.get_shape().value; - std::vector a_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector b_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector a_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector b_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_tensor_dim(a_dim, a_shape); get_tensor_dim(b_dim, b_shape); for (auto i = 2; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { @@ -179,7 +179,7 @@ std::vector> MorehLinearBackward::invoke( weight_grad_memory_config, compute_kernel_config); TT_FATAL(weight_grad.has_value(), "weight_grad tensor should not be std::nullopt"); - std::vector dims = + ttnn::SmallVector dims = find_reduce_dim(temp_weight_grad.get_legacy_shape(), weight_grad.value().get_legacy_shape()); ttnn::moreh_sum( temp_weight_grad, dims, true, weight_grad.value(), weight_grad_memory_config, compute_kernel_config); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp index 9595da2a375..d366c571117 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp @@ -41,8 +41,8 @@ void MorehMatmulOperation::validate_inputs( TT_FATAL(input_k == other_k, "k must be the same. input_k {}, other_k {}", input_k, other_k); // check batch dims - std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_tensor_dim(input_dim, input_shape); get_tensor_dim(other_dim, other_shape); for (auto i = 2; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { @@ -65,7 +65,7 @@ void MorehMatmulOperation::validate_inputs( TT_FATAL(input_m == output_m, "m must be the same. input_m {}, output_m {}", input_m, output_m); TT_FATAL(other_n == output_n, "n must be the same. other_n {}, output_n {}", other_n, output_n); - std::vector output_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector output_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_tensor_dim(output_dim, output_shape); for (auto i = 2; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { @@ -119,8 +119,8 @@ MorehMatmulOperation::shape_return_value_t compute_output_shapes( auto h_wo_padding = (transpose_input) ? (input_shape_wo_padding[-1]) : (input_shape_wo_padding[-2]); auto w_wo_padding = (transpose_other) ? (other_shape_wo_padding[-2]) : (other_shape_wo_padding[-1]); - std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_tensor_dim(input_dim, input_shape); get_tensor_dim(other_dim, other_shape); @@ -187,8 +187,8 @@ MorehMatmulOperation::shape_return_value_t MorehMatmulOperation::compute_output_ auto h_wo_padding = (transpose_input) ? (input_shape_wo_padding[-1]) : (input_shape_wo_padding[-2]); auto w_wo_padding = (transpose_other) ? (other_shape_wo_padding[-2]) : (other_shape_wo_padding[-1]); - std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_tensor_dim(input_dim, input_shape); get_tensor_dim(other_dim, other_shape); @@ -202,7 +202,7 @@ MorehMatmulOperation::shape_return_value_t MorehMatmulOperation::compute_output_ other_shape.rank(), output_rank); - std::vector output_dim(output_rank); + ttnn::SmallVector output_dim(output_rank); // batch dims for (int i = 0; i < output_rank - 2; ++i) { int idx = output_rank - 1 - i; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp index 7e3734fafe5..1ec9b4ddd74 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp @@ -72,8 +72,8 @@ struct MorehMatmulOperation { const std::optional& compute_kernel_config); }; -void get_tensor_dim(std::vector& dim, const tt::tt_metal::LegacyShape& shape); -std::vector find_reduce_dim( +void get_tensor_dim(ttnn::SmallVector& dim, const tt::tt_metal::LegacyShape& shape); +ttnn::SmallVector find_reduce_dim( const tt::tt_metal::LegacyShape& a_shape, const tt::tt_metal::LegacyShape& b_shape); bool is_same_batch_dim(const Tensor& tensor_a, const Tensor& tensor_b); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp index a1f27385d23..883d27409dc 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp @@ -9,7 +9,7 @@ namespace ttnn::operations::moreh::moreh_matmul { -void get_tensor_dim(std::vector &dim, const tt::tt_metal::LegacyShape &shape) { +void get_tensor_dim(ttnn::SmallVector &dim, const tt::tt_metal::LegacyShape &shape) { const auto rank = shape.rank(); for (auto i = 0; i < rank; ++i) { auto idx = rank - 1 - i; @@ -28,15 +28,15 @@ void get_tensor_dim(std::vector &dim, const tt::tt_metal::LegacyShape } } -std::vector find_reduce_dim( +ttnn::SmallVector find_reduce_dim( const tt::tt_metal::LegacyShape &a_shape, const tt::tt_metal::LegacyShape &b_shape) { - std::vector a_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector b_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector a_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector b_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_tensor_dim(a_dim, a_shape); get_tensor_dim(b_dim, b_shape); int32_t rank = std::max(a_shape.rank(), b_shape.rank()); log_debug(tt::LogOp, "find_reduce_dim :{} rank {} a {} b {}", __LINE__, rank, a_shape.rank(), b_shape.rank()); - std::vector dims; + ttnn::SmallVector dims; // batch dims for (int i = 0; i < rank - 2; ++i) { int idx = rank - 1 - i; @@ -53,8 +53,8 @@ bool is_same_batch_dim(const Tensor &tensor_a, const Tensor &tensor_b) { // check batch dims const auto &a_shape = tensor_a.get_shape().value; const auto &b_shape = tensor_b.get_shape().value; - std::vector a_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector b_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector a_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector b_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_tensor_dim(a_dim, a_shape); get_tensor_dim(b_dim, b_shape); for (auto i = 2; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { @@ -67,7 +67,7 @@ bool is_same_batch_dim(const Tensor &tensor_a, const Tensor &tensor_b) { return true; } -void get_tensor_stride(std::vector &stride, std::vector &dim) { +void get_tensor_stride(ttnn::SmallVector &stride, ttnn::SmallVector &dim) { stride[0] = 1; for (auto i = 1; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { stride[i] = stride[i - 1] * dim[i - 1]; @@ -79,10 +79,10 @@ void get_tensor_stride(std::vector &stride, std::vector &dim } void get_not_bcast( - std::vector &input_not_bcast, - std::vector &input_dim, - std::vector &other_not_bcast, - std::vector &other_dim) { + ttnn::SmallVector &input_not_bcast, + ttnn::SmallVector &input_dim, + ttnn::SmallVector &other_not_bcast, + ttnn::SmallVector &other_dim) { // first 2-dims are M,K and K,N // TODO: refaactoring for (auto i = 2; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) { @@ -140,11 +140,11 @@ MorehMatmulOperation::MultiCoreProgramFactory::cached_program_t MorehMatmulOpera const auto &input_shape_wo_padding = input_shape.without_padding(); const auto input_rank = input_shape.rank(); log_debug(tt::LogOp, "input dim"); - std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_tensor_dim(input_dim, input_shape); log_debug(tt::LogOp, "input stride"); - std::vector input_stride(tt::tt_metal::MAX_NUM_DIMENSIONS); + ttnn::SmallVector input_stride(tt::tt_metal::MAX_NUM_DIMENSIONS); get_tensor_stride(input_stride, input_dim); // other tensor @@ -152,16 +152,16 @@ MorehMatmulOperation::MultiCoreProgramFactory::cached_program_t MorehMatmulOpera const auto &other_shape_wo_padding = other_shape.without_padding(); const auto other_rank = other_shape.rank(); log_debug(tt::LogOp, "other dim"); - std::vector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_tensor_dim(other_dim, other_shape); log_debug(tt::LogOp, "other stride"); - std::vector other_stride(tt::tt_metal::MAX_NUM_DIMENSIONS); + ttnn::SmallVector other_stride(tt::tt_metal::MAX_NUM_DIMENSIONS); get_tensor_stride(other_stride, other_dim); log_debug(tt::LogOp, "not bcast"); - std::vector input_not_bcast(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector other_not_bcast(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector input_not_bcast(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector other_not_bcast(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_not_bcast(input_not_bcast, input_dim, other_not_bcast, other_dim); // output tensor @@ -169,11 +169,11 @@ MorehMatmulOperation::MultiCoreProgramFactory::cached_program_t MorehMatmulOpera const auto &output_shape_wo_padding = output_shape.without_padding(); const auto output_rank = output_shape.rank(); log_debug(tt::LogOp, "output dim"); - std::vector output_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector output_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); get_tensor_dim(output_dim, output_shape); log_debug(tt::LogOp, "output stride"); - std::vector output_stride(tt::tt_metal::MAX_NUM_DIMENSIONS); + ttnn::SmallVector output_stride(tt::tt_metal::MAX_NUM_DIMENSIONS); get_tensor_stride(output_stride, output_dim); // matrix shape diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_device_operation.cpp index c73c8eacf19..1c919f21ec1 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_device_operation.cpp @@ -77,8 +77,8 @@ MorehMeanOperation::shape_return_value_t MorehMeanOperation::compute_output_shap return Shape(tt::tt_metal::LegacyShape(output_shape.value, padding)); } - std::vector shape; - std::vector pad_dimensions; + ttnn::SmallVector shape; + ttnn::SmallVector pad_dimensions; const bool is_tile_dim = (dim == input_rank - 1 || dim == input_rank - 2); const std::size_t output_rank = (is_tile_dim) ? (input_rank) : (input_rank - 1); auto input_padding = input_shape.value.padding(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/moreh_mean.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/moreh_mean.cpp index 21266c403d1..7b94384eab6 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/moreh_mean.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/moreh_mean.cpp @@ -10,13 +10,13 @@ namespace ttnn::operations::moreh::moreh_mean { Tensor MorehMean::invoke( const Tensor& input, - const std::optional>> dim, + const std::optional>> dim, const bool keepdim, const std::optional& divisor, const std::optional& output, const std::optional& memory_config, const std::optional& compute_kernel_config) { - std::vector dims = tt::operations::primary::get_dim(dim, input.get_shape().rank()); + ttnn::SmallVector dims = tt::operations::primary::get_dim(dim, input.get_shape().rank()); std::sort(dims.begin(), dims.end()); auto temp_input = input; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/moreh_mean.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/moreh_mean.hpp index 5b26e96d029..794a1d6e2e4 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/moreh_mean.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/moreh_mean.hpp @@ -9,7 +9,7 @@ namespace ttnn::operations::moreh::moreh_mean { struct MorehMean { static Tensor invoke( const Tensor& input, - const std::optional>> dims, + const std::optional>> dims, const bool keepdim, const std::optional& divisor, const std::optional& output, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_device_operation.cpp index b84577801c4..c989b3675de 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_device_operation.cpp @@ -41,8 +41,8 @@ MorehMeanBackwardOperation::shape_return_value_t MorehMeanBackwardOperation::com auto input_grad_shape = operation_attributes.input_grad_shape.value(); auto rank = input_grad_shape.rank(); - std::vector shape; - std::vector dimensions_pads; + ttnn::SmallVector shape; + ttnn::SmallVector dimensions_pads; for (uint32_t dim = 0; dim < rank; dim++) { if (tt::operations::primary::is_hw_dim(dim, rank)) { @@ -81,7 +81,7 @@ MorehMeanBackwardOperation::tensor_return_value_t MorehMeanBackwardOperation::cr std::tuple MorehMeanBackwardOperation::invoke( const Tensor& output_grad, - const std::vector dims, + const ttnn::SmallVector dims, const bool keepdim, const std::optional& input_grad_shape, const std::optional& input_grad, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_device_operation.hpp index 5a40a249ef1..376bd1759e6 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_device_operation.hpp @@ -15,7 +15,7 @@ namespace ttnn::operations::moreh::moreh_mean_backward { struct MorehMeanBackwardOperation { struct operation_attributes_t { - const std::vector dims; + const ttnn::SmallVector dims; const bool keepdim; const std::optional input_grad_shape; const MemoryConfig memory_config; @@ -61,7 +61,7 @@ struct MorehMeanBackwardOperation { static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static std::tuple invoke( const Tensor& output_grad, - const std::vector dims, + const ttnn::SmallVector dims, const bool keepdim, const std::optional& input_grad_shape, const std::optional& input_grad, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp index d6e26dd5b27..0a8f85f70bf 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp @@ -12,7 +12,7 @@ #include "ttnn/operations/reduction/generic/device/common.hpp" #include "ttnn/operations/reduction/generic/device/reduce_op.hpp" -void get_tensor_dim(std::vector &dim, const tt::tt_metal::LegacyShape &shape) { +void get_tensor_dim(ttnn::SmallVector &dim, const tt::tt_metal::LegacyShape &shape) { const auto rank = shape.rank(); for (auto i = 0; i < rank; ++i) { auto idx = rank - 1 - i; @@ -27,7 +27,7 @@ void get_tensor_dim(std::vector &dim, const tt::tt_metal::LegacyShape } tt::tt_metal::LegacyShape get_output_grad_shape( - const Tensor &output_grad, const Tensor &input_grad, const std::vector &dims, const bool &keepdim) { + const Tensor &output_grad, const Tensor &input_grad, const ttnn::SmallVector &dims, const bool &keepdim) { if (keepdim) { return output_grad.get_shape().value; } @@ -78,15 +78,15 @@ MorehMeanBackwardOperation::MorehMeanBackwardFactory::create( const auto &input_grad_shape_wo_padding = input_grad_shape.without_padding(); const uint32_t input_grad_rank = input_grad_shape.rank(); - std::vector input_grad_dim(input_grad_rank, 1); + ttnn::SmallVector input_grad_dim(input_grad_rank, 1); get_tensor_dim(input_grad_dim, input_grad_shape); const auto &output_grad_shape = get_output_grad_shape(output_grad, input_grad, dims, keepdim); const auto &output_grad_shape_wo_padding = output_grad_shape.without_padding(); - std::vector output_grad_dim(input_grad_rank, 1); + ttnn::SmallVector output_grad_dim(input_grad_rank, 1); get_tensor_dim(output_grad_dim, output_grad_shape); - std::vector need_bcast_dim(input_grad_rank, 0); + ttnn::SmallVector need_bcast_dim(input_grad_rank, 0); for (auto i = 0; i < input_grad_rank; ++i) { auto idx = input_grad_rank - 1 - i; bool is_tile_dim = (idx == input_grad_rank - 1 || idx == input_grad_rank - 2); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward.cpp index 3e825da41cd..88cb5946c56 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward.cpp @@ -10,7 +10,7 @@ namespace ttnn::operations::moreh::moreh_mean_backward { Tensor MorehMeanBackward::invoke( const Tensor& output_grad, - std::optional>> dim, + std::optional>> dim, const bool keepdim, const std::optional& input_grad_shape, const std::optional& input_grad, @@ -24,11 +24,11 @@ Tensor MorehMeanBackward::invoke( } else if (std::holds_alternative(dim.value())) { input_grad_rank += 1; } else { - auto dims = std::get>(dim.value()); + auto dims = std::get>(dim.value()); input_grad_rank += dims.size(); } } - std::vector dims = tt::operations::primary::get_dim(dim, input_grad_rank); + ttnn::SmallVector dims = tt::operations::primary::get_dim(dim, input_grad_rank); return ttnn::prim::moreh_mean_backward( output_grad, dims, keepdim, input_grad_shape, input_grad, memory_config, compute_kernel_config); } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward.hpp index 10c6239f5d6..53508360ad1 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward.hpp @@ -9,7 +9,7 @@ namespace ttnn::operations::moreh::moreh_mean_backward { struct MorehMeanBackward { static Tensor invoke( const Tensor& output_grad, - std::optional>> dim, + std::optional>> dim, const bool keepdim, const std::optional& input_grad_shape, const std::optional& input_grad, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/moreh_nll_loss_step2_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/moreh_nll_loss_step2_device_operation.cpp index 246d2100f36..61657a2dd7a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/moreh_nll_loss_step2_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/moreh_nll_loss_step2_device_operation.cpp @@ -72,8 +72,8 @@ MorehNllLossStep2DeviceOperation::shape_return_value_t MorehNllLossStep2DeviceOp auto C = input_shape[1]; - auto dimensions_pads = std::vector(); - std::vector output_shape_vec; + ttnn::SmallVector dimensions_pads; + ttnn::SmallVector output_shape_vec; // Need extend 1d output to 2d, because TT not support 1d tensor if (input_rank == 2) { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_device_operation.cpp index d6772580dc7..f79e228fb3e 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_device_operation.cpp @@ -46,10 +46,10 @@ inline void validate_output_tensor_with_keepdim(const Tensor& input, const Tenso adjusted_input_shape[dim] = (is_tile_dim) ? tt::constants::TILE_HEIGHT : 1; adjusted_input_shape_wo_padding[dim] = 1; - std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector output_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector input_dim_wo_padding(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); - std::vector output_dim_wo_padding(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector output_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector input_dim_wo_padding(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); + ttnn::SmallVector output_dim_wo_padding(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); tt::operations::primary::expand_to_max_dim(input_dim, adjusted_input_shape); tt::operations::primary::expand_to_max_dim(output_dim, output_shape); @@ -66,8 +66,8 @@ inline void validate_output_tensor_with_keepdim(const Tensor& input, const Tenso } else { TT_FATAL(!is_tile_dim, "Dimension {} should not be a tile dimension when keepdim is false.", dim); - std::vector expected_output_shape; - std::vector expected_output_shape_wo_padding; + ttnn::SmallVector expected_output_shape; + ttnn::SmallVector expected_output_shape_wo_padding; for (int i = 0; i < output_rank; ++i) { if (i == dim && !is_tile_dim) { expected_output_shape.push_back(1); @@ -144,8 +144,8 @@ MorehNormOperation::shape_return_value_t MorehNormOperation::compute_output_shap return Shape{tt::tt_metal::LegacyShape(shape, padding)}; } - std::vector shape; - std::vector pad_dimensions; + ttnn::SmallVector shape; + ttnn::SmallVector pad_dimensions; const std::size_t output_rank = is_tile_dim ? input_rank : input_rank - 1; auto input_padding = input_shape.padding(); for (int i = 0; i < input_rank; ++i) { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/moreh_norm.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/moreh_norm.cpp index 489709711d5..19f40d6dc71 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/moreh_norm.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/moreh_norm.cpp @@ -10,22 +10,22 @@ namespace ttnn::operations::moreh::moreh_norm { Tensor MorehNorm::invoke( const Tensor& input, float p, - std::optional>> dim, + std::optional>> dim, bool keepdim, const std::optional& output, const std::optional& memory_config, const std::optional& compute_kernel_config) { if (!dim.has_value()) { - std::vector dims(input.get_legacy_shape().rank()); + ttnn::SmallVector dims(input.get_legacy_shape().rank()); std::iota(dims.begin(), dims.end(), 0); dim = std::make_optional(dims); } if (auto single_dim = std::get_if(&dim.value())) return ttnn::prim::moreh_norm(input, p, *single_dim, keepdim, output, memory_config, compute_kernel_config); - auto dims = std::get>(dim.value()); + auto dims = std::get>(dim.value()); if (dims.empty()) { - std::vector all_dims(input.get_legacy_shape().rank()); + ttnn::SmallVector all_dims(input.get_legacy_shape().rank()); std::iota(all_dims.begin(), all_dims.end(), 0); dims = all_dims; } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/moreh_norm.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/moreh_norm.hpp index f70ef8ad158..4535e89f58f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/moreh_norm.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/moreh_norm.hpp @@ -12,7 +12,7 @@ struct MorehNorm { static Tensor invoke( const Tensor& input, float p, - std::optional>> dim, + std::optional>> dim, bool keepdim, const std::optional& output, const std::optional& memory_config, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_device_operation.cpp index b568f876bb1..6013b1a0b08 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_device_operation.cpp @@ -56,12 +56,12 @@ MorehNormBackwardOperation::invoke( const Tensor& output, const Tensor& output_grad, float p, - std::optional>> dim, + std::optional>> dim, bool keepdim, const std::optional& input_grad, const std::optional& memory_config, const std::optional& compute_kernel_config) { - std::vector dims = tt::operations::primary::get_dim(dim, input.get_legacy_shape().rank()); + ttnn::SmallVector dims = tt::operations::primary::get_dim(dim, input.get_legacy_shape().rank()); std::sort(dims.begin(), dims.end()); return { operation_attributes_t{ diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_device_operation.hpp index 880122370cb..a955818f626 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_device_operation.hpp @@ -28,14 +28,14 @@ namespace ttnn::operations::moreh::moreh_norm_backward { std::tuple get_floored_p_and_decimal_and_p_is_negative(float p); -void get_tensor_dim(std::vector& dim, const Shape& shape); +void get_tensor_dim(ttnn::SmallVector& dim, const Shape& shape); tt::tt_metal::LegacyShape get_output_grad_shape( - const Tensor& output_grad, const Tensor& input_grad, const std::vector& dims, const bool& keepdim); + const Tensor& output_grad, const Tensor& input_grad, const ttnn::SmallVector& dims, const bool& keepdim); struct MorehNormBackwardOperation { struct operation_attributes_t { float p; - std::vector dims; + ttnn::SmallVector dims; bool keepdim; const MemoryConfig memory_config; const DeviceComputeKernelConfig compute_kernel_config; @@ -67,7 +67,7 @@ struct MorehNormBackwardOperation { const Tensor& output, const Tensor& output_grad, float p, - std::optional>> dim, + std::optional>> dim, bool keepdim, const std::optional& input_grad, const std::optional& memory_config, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_program_factory.cpp index b432eb45ad3..79413615e68 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_program_factory.cpp @@ -17,7 +17,7 @@ std::tuple get_floored_p_and_decimal_and_p_is_negative(fl return std::make_tuple(static_cast(floored_p), decimal, p_is_negative); } -void get_tensor_dim(std::vector& dim, const tt::tt_metal::LegacyShape& shape) { +void get_tensor_dim(ttnn::SmallVector& dim, const tt::tt_metal::LegacyShape& shape) { const auto rank = shape.rank(); for (auto i = 0; i < rank; ++i) { auto idx = rank - 1 - i; @@ -29,7 +29,7 @@ void get_tensor_dim(std::vector& dim, const tt::tt_metal::LegacyShape& } tt::tt_metal::LegacyShape get_output_grad_shape( - const Tensor& output_grad, const Tensor& input_grad, const std::vector& dims, const bool& keepdim) { + const Tensor& output_grad, const Tensor& input_grad, const ttnn::SmallVector& dims, const bool& keepdim) { if (keepdim) return output_grad.get_legacy_shape(); @@ -69,16 +69,16 @@ MorehNormBackwardOperation::ProgramFactory::cached_program_t MorehNormBackwardOp const auto& input_grad_shape_wo_padding = input_grad_shape.without_padding(); const auto input_grad_rank = input_grad_shape.rank(); - std::vector input_grad_dim(input_grad_rank, 1); + ttnn::SmallVector input_grad_dim(input_grad_rank, 1); get_tensor_dim(input_grad_dim, input_grad_shape); tt::tt_metal::LegacyShape output_grad_shape = get_output_grad_shape(output_grad, input_grad, operation_attributes.dims, operation_attributes.keepdim); const auto output_grad_shape_wo_padding = output_grad_shape.without_padding(); - std::vector output_grad_dim(input_grad_rank, 1); + ttnn::SmallVector output_grad_dim(input_grad_rank, 1); get_tensor_dim(output_grad_dim, output_grad_shape); - std::vector need_bcast_dim(input_grad_rank, 0); + ttnn::SmallVector need_bcast_dim(input_grad_rank, 0); for (auto i = 0; i < input_grad_rank; ++i) { auto idx = input_grad_rank - 1 - i; bool is_tile_dim = (idx == input_grad_rank - 1 || idx == input_grad_rank - 2); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/moreh_norm_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/moreh_norm_backward.cpp index f1c30f98d97..25410baebb1 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/moreh_norm_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/moreh_norm_backward.cpp @@ -12,7 +12,7 @@ Tensor MorehNormBackward::invoke( const Tensor& output, const Tensor& output_grad, float p, - std::optional>> dim, + std::optional>> dim, bool keepdim, const std::optional& input_grad, const std::optional& memory_config, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/moreh_norm_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/moreh_norm_backward.hpp index 53af8033ce4..312cbd15bcb 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/moreh_norm_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/moreh_norm_backward.hpp @@ -14,7 +14,7 @@ struct MorehNormBackward { const Tensor& output, const Tensor& output_grad, float p, - std::optional>> dim, + std::optional>> dim, bool keepdim, const std::optional& input_grad, const std::optional& memory_config, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.cpp index 6e549fe3124..f66e99ca63b 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.cpp @@ -93,8 +93,8 @@ MorehSumOperation::shape_return_value_t MorehSumOperation::compute_output_shapes output_shape = ttnn::Shape{tt::tt_metal::LegacyShape(shape, padding)}; } else { - std::vector shape; - std::vector pad_dimensions; + ttnn::SmallVector shape; + ttnn::SmallVector pad_dimensions; const std::size_t output_rank = (is_tile_dim) ? (input_rank) : (input_rank - 1); auto input_padding = input_shape.value.padding(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/moreh_sum.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/moreh_sum.cpp index 0209d5ab184..d5ce16b7cb4 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/moreh_sum.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/moreh_sum.cpp @@ -10,12 +10,12 @@ namespace ttnn::operations::moreh::moreh_sum { Tensor MorehSum::invoke( const Tensor& input, - std::optional>> dim, + std::optional>> dim, const bool keepdim, const std::optional& output, const std::optional& memory_config, const std::optional& compute_kernel_config) { - std::vector dims = tt::operations::primary::get_dim(dim, input.get_legacy_shape().rank()); + ttnn::SmallVector dims = tt::operations::primary::get_dim(dim, input.get_legacy_shape().rank()); std::sort(dims.begin(), dims.end()); auto temp_input = input; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/moreh_sum.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/moreh_sum.hpp index 94c64244a29..bf18e912e46 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/moreh_sum.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/moreh_sum.hpp @@ -9,7 +9,7 @@ namespace ttnn::operations::moreh::moreh_sum { struct MorehSum { static Tensor invoke( const Tensor& input, - std::optional>> dims, + std::optional>> dims, const bool keepdim, const std::optional& output, const std::optional& memory_config, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.cpp index 16137e23817..93f4bcc7c3f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.cpp @@ -120,14 +120,14 @@ std::tuple& input, - const std::vector& dims, + std::span dims, bool keepdim, const std::optional& input_grad, const std::optional& memory_config, const std::optional& compute_kernel_config) { return { operation_attributes_t{ - dims, + ttnn::SmallVector(dims.begin(), dims.end()), keepdim, memory_config.value_or(output_grad.memory_config()), init_device_compute_kernel_config( diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.hpp index a6d76c561be..93fc4ccfa67 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.hpp @@ -9,7 +9,7 @@ namespace ttnn::operations::moreh::moreh_sum_backward { struct MorehSumBackwardOperation { struct operation_attributes_t { - const std::vector dims; + const ttnn::SmallVector dims; const bool keepdim; const MemoryConfig memory_config; const DeviceComputeKernelConfig compute_kernel_config; @@ -57,7 +57,7 @@ struct MorehSumBackwardOperation { static std::tuple invoke( const Tensor& output_grad, const std::optional& input, - const std::vector& dims, + std::span dims, bool keepdim, const std::optional& input_grad, const std::optional& memory_config, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_program_factory.cpp index 7a44381b5f0..670bfa0301a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_program_factory.cpp @@ -9,7 +9,7 @@ namespace ttnn::operations::moreh::moreh_sum_backward { -void get_tensor_dim(std::vector &dim, const Shape &shape) { +void get_tensor_dim(ttnn::SmallVector &dim, const Shape &shape) { const auto rank = shape.rank(); for (auto i = 0; i < rank; ++i) { auto idx = rank - 1 - i; @@ -29,7 +29,7 @@ void get_tensor_dim(std::vector &dim, const Shape &shape) { } Shape get_output_grad_shape( - const Tensor &output_grad, const Tensor &input_grad, const std::vector &dims, const bool &keepdim) { + const Tensor &output_grad, const Tensor &input_grad, const ttnn::SmallVector &dims, const bool &keepdim) { if (keepdim) { return output_grad.get_shape(); } @@ -79,17 +79,17 @@ MorehSumBackwardOperation::ProgramFactory::cached_program_t MorehSumBackwardOper const auto &input_grad_shape_wo_padding = input_grad_shape.value.without_padding(); const uint32_t input_grad_rank = input_grad_shape.rank(); - std::vector input_grad_dim(input_grad_rank, 1); + ttnn::SmallVector input_grad_dim(input_grad_rank, 1); log_debug(tt::LogOp, "input_grad"); get_tensor_dim(input_grad_dim, input_grad_shape); const auto &output_grad_shape = get_output_grad_shape(output_grad, input_grad, dims, keepdim); const auto &output_grad_shape_wo_padding = output_grad_shape.value.without_padding(); - std::vector output_grad_dim(input_grad_rank, 1); + ttnn::SmallVector output_grad_dim(input_grad_rank, 1); log_debug(tt::LogOp, "output_grad"); get_tensor_dim(output_grad_dim, output_grad_shape); - std::vector need_bcast_dim(input_grad_rank, 0); + ttnn::SmallVector need_bcast_dim(input_grad_rank, 0); for (auto i = 0; i < input_grad_rank; ++i) { auto idx = input_grad_rank - 1 - i; bool is_tile_dim = (idx == input_grad_rank - 1 || idx == input_grad_rank - 2); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/moreh_sum_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/moreh_sum_backward.cpp index bd911bbb918..1ea51a27cc5 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/moreh_sum_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/moreh_sum_backward.cpp @@ -11,14 +11,14 @@ namespace ttnn::operations::moreh::moreh_sum_backward { Tensor MorehSumBackward::invoke( const Tensor& output_grad, const std::optional& input, - std::optional>> dim, + std::optional>> dim, bool keepdim, const std::optional& input_grad, const std::optional& memory_config, const std::optional& compute_kernel_config) { TT_FATAL((input.has_value() || input_grad.has_value()), "either input or input_grad must have a value"); uint32_t rank = input.has_value() ? input->get_shape().value.rank() : input_grad->get_shape().value.rank(); - std::vector dims = tt::operations::primary::get_dim(dim, rank); + ttnn::SmallVector dims = tt::operations::primary::get_dim(dim, rank); std::sort(dims.begin(), dims.end()); return ttnn::prim::moreh_sum_backward( output_grad, input, dims, keepdim, input_grad, memory_config, compute_kernel_config); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/moreh_sum_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/moreh_sum_backward.hpp index 048288cf1b2..06e3d8adf8e 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/moreh_sum_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/moreh_sum_backward.hpp @@ -10,7 +10,7 @@ struct MorehSumBackward { static Tensor invoke( const Tensor& output_grad, const std::optional& input, - std::optional>> dim, + std::optional>> dim, bool keepdim, const std::optional& input_grad, const std::optional& memory_config, diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp index fe49e797a50..56c7465e713 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_device_op.cpp @@ -71,7 +71,7 @@ MaxPool2D::shape_return_value_t MaxPool2D::compute_output_shapes(const operation uint32_t out_nhw_padded = tt::round_up(out_nhw, (is_out_tiled ? tt::constants::TILE_HEIGHT : 1) * sliding_window_config.num_cores_nhw); // {1, 1, N * H * W, C} - const auto out_dims = std::vector({1, 1, out_nhw_padded, out_c_padded}); + const ttnn::SmallVector out_dims({1, 1, out_nhw_padded, out_c_padded}); const auto padding = Padding( {{0, 0}, {0, 0}, {0, out_nhw_padded - out_nhw}, {0, out_c_padded - out_c}}, Padding::PadValue::NegativeInfinity); diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp index d43270ca96a..3095fc96aab 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp @@ -45,7 +45,7 @@ std::vector UpSample::compute_output_shapes(const std uint32_t out_h = input_shape[1] * scale_factor_h_; uint32_t out_w = input_shape[2] * scale_factor_w_; uint32_t out_c = input_shape[3]; - const auto out_dims = std::vector({ out_n, out_h, out_w, out_c }); //in the NHWC format + const ttnn::SmallVector out_dims({ out_n, out_h, out_w, out_c }); //in the NHWC format return {tt::tt_metal::LegacyShape{out_dims}}; } diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp index 5cce6fa248b..80becef2894 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp @@ -14,7 +14,7 @@ namespace operations::reduction { template static Tensor reduce_impl( const Tensor& input_tensor_arg, - const std::optional>>& dim_arg, + const std::optional>>& dim_arg, const bool keepdim, const std::optional& memory_config_arg, const std::optional& compute_kernel_config, @@ -29,16 +29,16 @@ static Tensor reduce_impl( auto rank = input_shape.size(); auto memory_config = memory_config_arg.value_or(input_tensor_arg.memory_config()); - std::vector dim{}; + ttnn::SmallVector dim{}; if (dim_arg.has_value()) { - if (not std::holds_alternative>(dim_arg.value())) { + if (not std::holds_alternative>(dim_arg.value())) { auto dim_as_int = std::get(dim_arg.value()); - dim = std::vector({dim_as_int}); + dim = ttnn::SmallVector({dim_as_int}); } else { - dim = std::get>(dim_arg.value()); + dim = std::get>(dim_arg.value()); } } else { - dim = std::vector(rank); + dim = ttnn::SmallVector(rank); for (int i = 0; i < rank; i++) { dim[i] = i; } @@ -93,8 +93,8 @@ static Tensor reduce_impl( } std::sort(dim.begin(), dim.end()); - std::vector output_shape; - std::vector padded_output_shape; + ttnn::SmallVector output_shape; + ttnn::SmallVector padded_output_shape; for (int axis = 0; axis < input_shape.size(); axis++) { if (std::find(dim.begin(), dim.end(), axis) != dim.end()) { if (keepdim) { @@ -182,7 +182,7 @@ static Tensor reduce_impl( template Tensor Reduce::invoke( const Tensor& input_tensor_arg, - const std::optional>>& dim_arg, + const std::optional>>& dim_arg, const bool keepdim, const std::optional& memory_config_arg, const std::optional& compute_kernel_config, diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp index 1c6eb5e7566..695e15407ca 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp @@ -26,7 +26,7 @@ template struct Reduce { static Tensor invoke( const Tensor& input_tensor_arg, - const std::optional>>& dim_arg = std::nullopt, + const std::optional>>& dim_arg = std::nullopt, const bool keepdim = true, const std::optional& memory_config_arg = std::nullopt, const std::optional& compute_kernel_config = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp index 8d08d4f28ac..207eda4bbd5 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp @@ -96,14 +96,14 @@ Tensor prod_(const Tensor& input, const int64_t& dim, const MemoryConfig& mem_co Tensor prod_nc( const Tensor& input, const Tensor& output, - std::vector& dims, + ttnn::SmallVector& dims, const MemoryConfig& output_mem_config) { // reduce for all dims if (dims.empty()) { dims = {0, 1, 2, 3}; } - std::vector sorted_dims = dims; + ttnn::SmallVector sorted_dims = dims; std::sort(sorted_dims.begin(), sorted_dims.end()); auto temp_input = input; diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp index 5552f120a4d..cd5de8832e9 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp @@ -40,7 +40,7 @@ Tensor prod_( Tensor prod_nc( const Tensor &input, const Tensor &output, - std::vector &dims, + ttnn::SmallVector &dims, const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); } // namespace primary diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp index e4f56968aa9..0cd4d8d81e4 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/prod.cpp @@ -58,7 +58,7 @@ inline Tensor prod_nc(const Tensor& temp, int64_t dim, const MemoryConfig& outpu } } // Apply prod - std::vector dimension = {(dim == 1 || dim == -3) ? 1 : 0}; + ttnn::SmallVector dimension = {(dim == 1 || dim == -3) ? 1 : 0}; tt::tt_metal::LegacyShape input_shape = formatted_input_tensor.get_legacy_shape(); std::array required = { ((dim == 1 || dim == -3) ? input_shape[0] : 1), @@ -91,32 +91,32 @@ Tensor ProdOperation::invoke(const Tensor& input_a, bool all_dimensions, int64_t Tensor temp = input_a; // Permute for dim 2,3 if (dim == 2 || dim == -2) { - std::vector permute_dims = {2, 0, 1, 3}; + ttnn::SmallVector permute_dims = {2, 0, 1, 3}; temp = ttnn::permute(input_a, permute_dims, output_mem_config); } else if (dim == 3 || dim == -1) { - std::vector permute_dims = {3, 0, 1, 2}; + ttnn::SmallVector permute_dims = {3, 0, 1, 2}; temp = ttnn::permute(input_a, permute_dims, output_mem_config); } Tensor result = prod_nc(temp, dim, output_mem_config); // Permute and unpad result for dim 2,3 - auto step = std::vector({1, 1, 1, 1}); + auto step = ttnn::SmallVector({1, 1, 1, 1}); if (dim == 0 || dim == 1 || dim == -4 || dim == -3) { return result; } else if (dim == 2 || dim == -2) { - std::vector after_permute_dims = {1, 2, 0, 3}; + ttnn::SmallVector after_permute_dims = {1, 2, 0, 3}; Tensor required = ttnn::permute(result, after_permute_dims, output_mem_config); tt::tt_metal::LegacyShape input_shape = input_a.get_legacy_shape(); - std::vector start_index = {0, 0, 0, 0}; - std::vector end_index = {input_shape[0], input_shape[1], 1, input_shape[3]}; + ttnn::SmallVector start_index = {0, 0, 0, 0}; + ttnn::SmallVector end_index = {input_shape[0], input_shape[1], 1, input_shape[3]}; return ttnn::slice(DefaultQueueId, required, start_index, end_index, step, std::nullopt); } else { // dim 3 // permute - std::vector after_permute_dims = {1, 2, 0, 3}; + ttnn::SmallVector after_permute_dims = {1, 2, 0, 3}; Tensor required = ttnn::permute(result, after_permute_dims, output_mem_config); // unpad tt::tt_metal::LegacyShape input_shape = input_a.get_legacy_shape(); - std::vector start_index = {0, 0, 0, 0}; - std::vector end_index = {input_shape[0], input_shape[1], 1, input_shape[2]}; + ttnn::SmallVector start_index = {0, 0, 0, 0}; + ttnn::SmallVector end_index = {input_shape[0], input_shape[1], 1, input_shape[2]}; Tensor new_unpad_tensor = ttnn::slice(DefaultQueueId, required, start_index, end_index, step, std::nullopt); // permute back after_permute_dims = {0, 1, 3, 2}; @@ -125,7 +125,7 @@ Tensor ProdOperation::invoke(const Tensor& input_a, bool all_dimensions, int64_t } } -Tensor ProdOperation::invoke(const Tensor &input, const Tensor &output, std::vector &dims, const std::optional& memory_config) { +Tensor ProdOperation::invoke(const Tensor &input, const Tensor &output, ttnn::SmallVector &dims, const std::optional& memory_config) { auto mem_cfg = memory_config.value_or(input.memory_config()); return tt::operations::primary::prod_nc(input, output, dims, mem_cfg); } diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/prod.hpp b/ttnn/cpp/ttnn/operations/reduction/prod/prod.hpp index c5da8f2259d..bc869658d64 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/prod.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/prod.hpp @@ -23,7 +23,7 @@ namespace operations::reduction { static Tensor invoke( const Tensor& input, const Tensor& output, - std::vector &dims, + ttnn::SmallVector &dims, const std::optional& memory_config = std::nullopt); }; diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/prod_pybind.hpp b/ttnn/cpp/ttnn/operations/reduction/prod/prod_pybind.hpp index 66156a64f02..9ff3c6420cd 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/prod_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/prod_pybind.hpp @@ -64,13 +64,13 @@ void bind_reduction_prod_operation(py::module& module, const unary_operation_t& [](const unary_operation_t& self, const Tensor& input_tensor, const Tensor& output_tensor, - std::vector &dims, + ttnn::SmallVector& dims, const std::optional& memory_config) { return self(input_tensor, output_tensor, dims, memory_config); }, py::arg("input_tensor"), py::arg("output_tensor"), py::kw_only(), - py::arg("dims") = std::vector(), + py::arg("dims") = ttnn::SmallVector(), py::arg("memory_config") = std::nullopt} ); } diff --git a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp index 8286a278ddf..5ab2376eaff 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp @@ -14,7 +14,7 @@ std::size_t SlidingWindowConfig::get_hash() const { * Return the input shape (excluding depth) */ Shape SlidingWindowConfig::get_input_shape() const { - return Shape(std::vector{batch_size, std::get<0>(input_hw), std::get<1>(input_hw)}); + return Shape({batch_size, std::get<0>(input_hw), std::get<1>(input_hw)}); } bool SlidingWindowConfig::has_parallel_config() const { @@ -33,7 +33,7 @@ Shape SlidingWindowConfig::get_output_shape() const { output_w = input_hw.second; } log_debug(tt::LogOp, "output_size: {} {} {}", batch_size, output_h, output_w); - return Shape( std::vector{batch_size, output_h, output_w, 0}); + return Shape({batch_size, output_h, output_w, 0}); } /** @@ -433,7 +433,7 @@ Tensor construct_on_host_config_tensor(const std::vector>& // we need the last dim of tensors to be multiple of 2, pad if needed uint32_t extend_with_zeroes = config[0].size() % 2; extend_with_zeroes = extend_with_zeroes > 0 ? 2 - extend_with_zeroes : 0; - Shape config_shape = Shape(std::vector{(uint32_t) config.size(), (uint32_t) config[0].size() + extend_with_zeroes}); + Shape config_shape = Shape({(uint32_t) config.size(), (uint32_t) config[0].size() + extend_with_zeroes}); std::vector config_vector = flatten(config, extend_with_zeroes); if (p_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { auto config_buffer = owned_buffer::create(std::move(config_vector)); @@ -446,7 +446,7 @@ Tensor construct_on_host_config_tensor(const std::vector>& repeat_config.insert(repeat_config.end(), config_vector.begin(), config_vector.end()); } auto config_buffer = owned_buffer::create(std::move(repeat_config)); - config_shape = Shape(std::vector{config_shape[0] * repeat_factor, config_shape[1]}); + config_shape = Shape({config_shape[0] * repeat_factor, config_shape[1]}); return Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); } else if (p_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) { TT_ASSERT(p_config.grid.ranges().size() == 1, "BLOCK_SHARDED should have just a single core range"); @@ -468,7 +468,7 @@ Tensor construct_on_host_config_tensor(const std::vector>& repeat_config.insert(repeat_config.end(), config_vector.begin(), config_vector.end()); } auto config_buffer = owned_buffer::create(std::move(repeat_config)); - config_shape = Shape(std::vector{config_shape[0] * repeat_factor, config_shape[1]}); + config_shape = Shape({config_shape[0] * repeat_factor, config_shape[1]}); return Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); } else { TT_ASSERT(false, "Unsupported shard scheme"); diff --git a/ttnn/cpp/ttnn/run_operation.cpp b/ttnn/cpp/ttnn/run_operation.cpp index 3941ce32e08..b1d9043407b 100644 --- a/ttnn/cpp/ttnn/run_operation.cpp +++ b/ttnn/cpp/ttnn/run_operation.cpp @@ -306,7 +306,7 @@ std::vector extract_legacy_shapes( legacy_shapes.reserve(simple_shapes.size()); for (size_t idx = 0; idx < simple_shapes.size(); idx++) { auto [data_type, layout] = layout_provider(idx); - legacy_shapes.emplace_back(simple_shapes[idx].as_vector(), get_physical_shape(simple_shapes[idx], data_type, layout).as_vector()); + legacy_shapes.emplace_back(simple_shapes[idx].view(), get_physical_shape(simple_shapes[idx], data_type, layout).view()); } return legacy_shapes; } diff --git a/ttnn/cpp/ttnn/tensor/CMakeLists.txt b/ttnn/cpp/ttnn/tensor/CMakeLists.txt index cef2d5ffdb7..1506fa9de29 100644 --- a/ttnn/cpp/ttnn/tensor/CMakeLists.txt +++ b/ttnn/cpp/ttnn/tensor/CMakeLists.txt @@ -5,6 +5,7 @@ set(TENSOR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/types.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/vector_base.cpp CACHE INTERNAL "Tensor sources to reuse in ttnn build" ) diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index d689237d5e9..f9d96ea891f 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -226,7 +226,7 @@ Tensor::~Tensor() { tensor_attributes.reset(); } -Tensor::Tensor(const Storage storage, const ttnn::SimpleShape& shape, DataType dtype, Layout layout, const std::optional& tile) : Tensor(storage, ttnn::Shape(shape.as_vector()), dtype, layout, tile) {} +Tensor::Tensor(const Storage storage, const ttnn::SimpleShape& shape, DataType dtype, Layout layout, const std::optional& tile) : Tensor(storage, ttnn::Shape(shape.view()), dtype, layout, tile) {} void Tensor::deallocate(bool force) { ZoneScopedN("TensorDeallocate"); @@ -682,7 +682,7 @@ Tensor create_device_tensor( auto device_buffer = tensor_impl::allocate_buffer_on_device( packed_size_in_bytes, device, padded_shape, data_type, layout, memory_config, shard_spec_buffer, tile); - auto output = Tensor(DeviceStorage{device_buffer}, ttnn::Shape(logical_shape.as_vector(), padded_shape.as_vector()), data_type, layout, tile); + auto output = Tensor(DeviceStorage{device_buffer}, ttnn::Shape(logical_shape.view(), padded_shape.view()), data_type, layout, tile); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; @@ -691,7 +691,7 @@ Tensor create_device_tensor( tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(padded_shape, data_type)); auto device_buffer = tensor_impl::allocate_buffer_on_device( packed_size_in_bytes, device, padded_shape, data_type, layout, memory_config, std::nullopt, tile); - auto output = Tensor(DeviceStorage{device_buffer}, ttnn::Shape(logical_shape.as_vector(), padded_shape.as_vector()), data_type, layout, tile); + auto output = Tensor(DeviceStorage{device_buffer}, ttnn::Shape(logical_shape.view(), padded_shape.view()), data_type, layout, tile); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index bea72879c2d..0ee8c0f7e02 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -1185,10 +1185,10 @@ Tensor pad(const Tensor& tensor, const tt::tt_metal::LegacyShape& output_shape, return stride; }; - std::vector> pad_size{}; - std::vector input_strides{}; - std::vector output_strides{}; - std::vector input_indices(input_shape.rank(), 0); + ttnn::SmallVector> pad_size{}; + ttnn::SmallVector input_strides{}; + ttnn::SmallVector output_strides{}; + ttnn::SmallVector input_indices(input_shape.rank(), 0); for (auto index = 0; index < output_shape.rank(); index++) { // Check if input tensor fits in output tensor given the input tensor start indices @@ -1283,7 +1283,7 @@ Tensor unpad(const Tensor& tensor, const ttnn::SimpleShape& output_tensor_start, const auto input_strides = tensor.strides(); // Validate inputs and compute output shape - std::vector output_shape{}; + ttnn::SmallVector output_shape; for (auto i = 0; i < input_shape.rank(); i++) { // Check if tensor start and end indices are within input tensor shape TT_ASSERT(output_tensor_start[i] < input_shape[i]); @@ -1296,7 +1296,7 @@ Tensor unpad(const Tensor& tensor, const ttnn::SimpleShape& output_tensor_start, auto unpad = [&input_shape, &input_strides, &output_shape, &output_tensor_start, &output_tensor_end]( const auto& input_buffer) { - std::vector input_indices(input_shape.rank(), 0); + ttnn::SmallVector input_indices(input_shape.rank(), 0); auto flat_output_index = 0; auto output_buffer = owned_buffer::create(compute_volume(output_shape)); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index a27b2e0bfdb..55b4303913d 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -187,7 +187,7 @@ constexpr inline size_t packed_buffer_size_bytes(size_t volume_unpack // Layout converters // ====================================================================================== namespace detail { -static std::vector to_4D_shape(const tt::tt_metal::LegacyShape& shape) { +static ttnn::SmallVector to_4D_shape(const tt::tt_metal::LegacyShape& shape) { if (shape.rank() == 1) { return {1, 1, 1, shape[-1]}; } else if (shape.rank() == 2) { @@ -201,14 +201,6 @@ static std::vector to_4D_shape(const tt::tt_metal::LegacyShape& shape) } } -static std::vector to_vector(const tt::tt_metal::LegacyShape& shape) { - std::vector shape_vec; - for (int i = 0; i < shape.rank(); i++) { - shape_vec.push_back(shape[i]); - } - return shape_vec; -} - } // namespace detail template typename BufferType> @@ -220,7 +212,7 @@ inline std::vector convert_layout_row_major_to_tile(const tt::tt_metal::Legac auto tile_shape = std::vector{ tile.get_tile_shape()[0], tile.get_tile_shape()[1] }; auto face_shape = std::vector{ tile.get_face_shape()[0], tile.get_face_shape()[1] }; return convert_layout( - data_to_convert, detail::to_vector(shape), TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES, tile_shape, face_shape); + data_to_convert, std::vector(shape.begin(), shape.end()), TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES, tile_shape, face_shape); } template typename BufferType> @@ -228,7 +220,7 @@ inline std::vector convert_layout_tile_to_row_major(const tt::tt_metal::Legac auto tile_shape = std::vector{ tile.get_tile_shape()[0], tile.get_tile_shape()[1] }; auto face_shape = std::vector{ tile.get_face_shape()[0], tile.get_face_shape()[1] }; return convert_layout( - data_to_convert, detail::to_vector(shape), TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR, tile_shape, face_shape); + data_to_convert, std::vector(shape.begin(), shape.end()), TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR, tile_shape, face_shape); } // ====================================================================================== diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index e63a37fb791..c133a1aff71 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -296,9 +296,9 @@ Tensor tensor_pad_to_tile(const Tensor& input_tensor, float pad_value) { uint32_t padded_height = round_up(height, constants::TILE_HEIGHT); uint32_t padded_width = round_up(width, constants::TILE_WIDTH); - std::vector shape; - std::vector padded_shape; - std::vector input_tensor_start; + ttnn::SmallVector shape; + ttnn::SmallVector padded_shape; + ttnn::SmallVector input_tensor_start; for (auto index = 0; index < input_tensor.get_legacy_shape().rank() - 2; index++) { shape.push_back(input_tensor.get_legacy_shape().without_padding()[index]); @@ -335,8 +335,8 @@ Tensor tensor_unpad_from_tile(const Tensor& input_tensor, const ttnn::SimpleShap input_tensor.get_legacy_shape()[-2] - constants::TILE_HEIGHT < output_tensor_shape[-2] && input_tensor.get_legacy_shape()[-1] - constants::TILE_WIDTH < output_tensor_shape[-1], "Last 2 dims of output must be within range to have been padded to input"); - std::vector output_tensor_start{}; - std::vector output_tensor_end{}; + ttnn::SmallVector output_tensor_start{}; + ttnn::SmallVector output_tensor_end{}; for (auto index = 0; index < input_tensor.get_legacy_shape().rank(); index++) { output_tensor_start.push_back(0); output_tensor_end.push_back(output_tensor_shape[index]); @@ -427,7 +427,7 @@ Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) } Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::SimpleShape& new_shape) { - return tensor_reshape(input_tensor, ttnn::Shape(new_shape.as_vector())); + return tensor_reshape(input_tensor, ttnn::Shape(new_shape.view())); } } diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index 3ff6eb4452b..2299abefd18 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -266,13 +266,13 @@ static Tensor conv_group_weight_zero_pad_helper( for (int m = 0; m < original_weight_shape[3]; m++) { // Get value from original weight tensor auto value_flat_input_index = - compute_flat_indices({curr_batch_idx, j, k, m}, compute_strides(original_weight_shape)); + compute_flat_indices(ttnn::SmallVector{curr_batch_idx, j, k, m}, compute_strides(original_weight_shape)); auto value = conv_weight_tensor_buffer[value_flat_input_index]; // Copy value to output tensor at the adjusted position auto new_channel_idx = new_channel_start_idx + j; auto output_flat_input_index = compute_flat_indices( - {new_batch_idx, new_channel_idx, k, m}, compute_strides(output_weight_shape)); + ttnn::SmallVector{new_batch_idx, new_channel_idx, k, m}, compute_strides(output_weight_shape)); output_buffer[output_flat_input_index] = value; } } @@ -301,9 +301,9 @@ static Tensor conv_depthwise_weight_bcast_helper( for (int k = 0; k < output_weight_shape[2]; k++) { for (int l = 0; l < output_weight_shape[3]; l++) { auto value_flat_input_index = - compute_flat_indices({i, 0, k, l}, compute_strides(original_weight_shape)); + compute_flat_indices(ttnn::SmallVector{i, 0, k, l}, compute_strides(original_weight_shape)); auto value = conv_weight_tensor_buffer[value_flat_input_index]; - auto output_flat_input_index = compute_flat_indices({i, j, k, l}, compute_strides(output_weight_shape)); + auto output_flat_input_index = compute_flat_indices(ttnn::SmallVector{i, j, k, l}, compute_strides(output_weight_shape)); output_buffer[output_flat_input_index] = value; } } @@ -456,7 +456,7 @@ Tensor convert_conv_weight_tensor_to_depthwise_layout( TT_THROW("Unsupported weight data type given when trying to add zero padding to weight tensor"); } -const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, const std::vector& shape) { +const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, std::span shape) { int64_t old_volume = tensor.get_logical_volume(); int64_t new_volume = 1; int64_t index_of_negative_1 = -1; @@ -477,7 +477,7 @@ const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, const std:: } } - std::vector new_shape(shape.size()); + ttnn::SmallVector new_shape(shape.size()); std::copy(shape.begin(), shape.end(), new_shape.begin()); if (index_of_negative_1 == -1) { TT_FATAL(new_volume == old_volume, "Invalid arguments to reshape"); diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp index 692e5b361fa..82f51fea057 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp @@ -34,7 +34,7 @@ Tensor convert_conv_weight_tensor_to_grouped_layout(Tensor conv_weight_tensor, u // Converts convolution weights to depthwise layout with broadcasted weights Tensor convert_conv_weight_tensor_to_depthwise_layout(Tensor conv_weight_tensor, uint32_t act_block_h_ntiles, DataType output_dtype); -const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, const std::vector& shape); +const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, std::span shape); // TODO: Remove this once we switch to SimpleShape .volume() static std::size_t compute_volume(const tt::tt_metal::LegacyShape& shape) { @@ -45,12 +45,12 @@ static std::size_t compute_volume(const tt::tt_metal::LegacyShape& shape) { return volume; } -static std::vector compute_strides(const ttnn::SimpleShape& shape) { +static ttnn::SmallVector compute_strides(const ttnn::SimpleShape& shape) { if (shape.rank() == 0) return {}; auto num_elements = shape.volume(); - std::vector strides; + ttnn::SmallVector strides; for (std::int32_t index = 0; index < shape.rank(); index++) { if (shape[index] == 0) { // Insert 0 to indicate no memory access for this dimension @@ -64,7 +64,7 @@ static std::vector compute_strides(const ttnn::SimpleShape& shape) { return strides; } -static int compute_flat_indices(const vector& indices, const vector strides) { +static int compute_flat_indices(std::span indices, std::span strides) { int flat_index = 0; for (auto i = 0; i < indices.size(); i++) { flat_index += indices[i] * strides[i]; diff --git a/ttnn/cpp/ttnn/tensor/types.cpp b/ttnn/cpp/ttnn/tensor/types.cpp index 14d89db8a9b..7109afb5154 100644 --- a/ttnn/cpp/ttnn/tensor/types.cpp +++ b/ttnn/cpp/ttnn/tensor/types.cpp @@ -42,8 +42,8 @@ const Shape Shape::to_rank(size_t new_rank) const { auto padded_shape = value; auto shape = value.without_padding(); - std::vector new_shape(new_rank, 1); - std::vector new_padded_shape(new_rank, 1); + SmallVector new_shape(new_rank, 1); + SmallVector new_padded_shape(new_rank, 1); int cur_idx = static_cast(rank()) - 1; int new_idx = static_cast(new_rank) - 1; @@ -116,7 +116,7 @@ Padding::Padding(const std::initializer_list pad_dimensions, PadVa std::copy(std::begin(pad_dimensions), std::end(pad_dimensions), std::begin(this->pad_dimensions_)); } -Padding::Padding(const std::vector& pad_dimensions, PadValue pad_value) : +Padding::Padding(std::span pad_dimensions, PadValue pad_value) : rank_(pad_dimensions.size()), pad_dimensions_{}, pad_value_(pad_value) { std::copy(std::begin(pad_dimensions), std::end(pad_dimensions), std::begin(this->pad_dimensions_)); } @@ -166,7 +166,7 @@ LegacyShape::LegacyShape(const std::initializer_list dimensions) : rank_(dimensions.size()), dimensions_{}, padding_(dimensions.size()) { std::copy(std::begin(dimensions), std::end(dimensions), std::begin(this->dimensions_)); } -LegacyShape::LegacyShape(const std::vector& dimensions) : +LegacyShape::LegacyShape(std::span dimensions) : rank_(dimensions.size()), dimensions_{}, padding_(dimensions.size()) { std::copy(std::begin(dimensions), std::end(dimensions), std::begin(this->dimensions_)); } @@ -176,7 +176,7 @@ LegacyShape::LegacyShape(const std::initializer_list dimensions, const TT_ASSERT(this->padding_.rank_ == this->rank_); std::copy(std::begin(dimensions), std::end(dimensions), std::begin(this->dimensions_)); } -LegacyShape::LegacyShape(const std::vector& dimensions, const Padding& padding) : +LegacyShape::LegacyShape(std::span dimensions, const Padding& padding) : rank_(dimensions.size()), dimensions_{}, padding_(padding) { TT_ASSERT(this->padding_.rank_ == this->rank_); std::copy(std::begin(dimensions), std::end(dimensions), std::begin(this->dimensions_)); @@ -208,7 +208,7 @@ const Padding& LegacyShape::padding() const { const LegacyShape LegacyShape::without_padding() const { auto padding = this->padding_; - std::vector shape_without_padding; + ttnn::SmallVector shape_without_padding; for (auto index = 0; index < this->rank(); index++) { const auto dimension = this->operator[](index); auto&& [front_pad, back_pad] = padding.pad_dimensions_[index]; @@ -222,7 +222,7 @@ ttnn::SimpleShape LegacyShape::logical_shape() const { const LegacyShape logical = without_padding(); - std::vector values(rank()); + ttnn::SmallVector values(rank()); for (size_t i = 0; i < values.size(); i++) { values[i] = logical[i]; } @@ -391,25 +391,25 @@ int32_t normalized_index(int32_t index, size_t container_size) { } bool SimpleShape::operator==(const SimpleShape &other) const { - return this->value == other.value; + return this->m_value == other.m_value; } -bool SimpleShape::operator==(const std::vector &other) const { - return this->value == other; +bool SimpleShape::operator==(const SmallVector &other) const { + return this->m_value == other; } uint32_t SimpleShape::operator[](int32_t index) const { - auto norm_index = normalized_index(index, value.size()); - return value[norm_index]; + auto norm_index = normalized_index(index, m_value.size()); + return m_value[norm_index]; } uint32_t& SimpleShape::operator[](int32_t index) { - auto norm_index = normalized_index(index, value.size()); - return value[norm_index]; + auto norm_index = normalized_index(index, m_value.size()); + return m_value[norm_index]; } uint64_t SimpleShape::volume() const { - return std::accumulate(this->value.begin(), this->value.end(), + return std::accumulate(this->m_value.cbegin(), this->m_value.cend(), uint64_t{1}, std::multiplies()); } diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index 4364414bc76..de6bd3909c2 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -18,6 +18,7 @@ #include "tt_metal/tt_stl/concepts.hpp" #include "tt_metal/tt_stl/reflection.hpp" #include "ttnn/tensor/host_buffer/types.hpp" +#include "ttnn/tensor/vector_base.hpp" #include "ttnn/cpp/ttnn/tensor/enum_types.hpp" namespace ttnn { @@ -32,42 +33,41 @@ SimpleShape is a temporary measure aimed at making a clear distinction between S We will clearly see where full shape is used vs logical or physical shape is used. Need to split .hpp and .cpp **/ -class SimpleShape { +class SimpleShape final { public: - explicit SimpleShape(const std::vector& shape) : value(shape) {} - explicit SimpleShape(std::vector&& shape) : value(std::move(shape)) {} - explicit SimpleShape(std::initializer_list ilist) : value(ilist) {} + explicit SimpleShape(const SmallVector& shape) : m_value(shape) {} + explicit SimpleShape(SmallVector&& shape) : m_value(std::move(shape)) {} + explicit SimpleShape(std::initializer_list ilist) : m_value(ilist) {} template - explicit SimpleShape(const std::array& arr) : value(arr.begin(), arr.end()) {} + explicit SimpleShape(const std::array& arr) : m_value(arr) {} template bool operator==(const std::array &other) const { - bool sameSize = value.size() == N; - return sameSize && std::equal(value.begin(), value.end(), other.begin()); + return m_value == other; } bool operator==(const SimpleShape &other) const; - bool operator==(const std::vector &other) const; + bool operator==(const SmallVector &other) const; uint32_t operator[](int32_t index) const; uint32_t &operator[](int32_t index); - size_t rank() const { return this->value.size(); } + size_t rank() const { return m_value.size(); } uint64_t volume() const; - auto cbegin() const { return this->value.cbegin(); } - auto cend() const { return this->value.cend(); } + auto cbegin() const { return m_value.cbegin(); } + auto cend() const { return m_value.cend(); } - const std::vector& as_vector() const { return this->value; } + std::span view() const { return m_value.view(); } // Needed for reflect / fmt static constexpr auto attribute_names = std::forward_as_tuple("value"); - auto attribute_values() const { return std::forward_as_tuple(this->value); } + auto attribute_values() const { return std::forward_as_tuple(m_value); } friend std::ostream &operator<<(std::ostream &os, const SimpleShape &shape); private: - std::vector value; + VectorBase m_value; }; inline std::ostream &operator<<(std::ostream &os, const ttnn::SimpleShape &shape) { @@ -169,7 +169,7 @@ struct Padding { Padding(const std::size_t rank); Padding(const std::initializer_list pad_dimensions, PadValue pad_value); - Padding(const std::vector &pad_dimensions, PadValue pad_value); + Padding(std::span pad_dimensions, PadValue pad_value); template Padding(const std::array, Rank> pad_dimensions, PadValue pad_value) : @@ -241,9 +241,11 @@ class LegacyShape { ~LegacyShape() = default; LegacyShape(const std::initializer_list); - LegacyShape(const std::vector &); + LegacyShape(std::span); + LegacyShape(const ttnn::SmallVector& vec) : LegacyShape(std::span(vec)) {}; LegacyShape(const std::initializer_list, const Padding &); - LegacyShape(const std::vector &, const Padding &); + LegacyShape(std::span, const Padding &); + LegacyShape(const ttnn::SmallVector& vec, const Padding &padding) : LegacyShape(std::span(vec), padding) {}; explicit LegacyShape(const LegacyShape &, const Padding &); @@ -269,7 +271,7 @@ class LegacyShape { this->padding_[index] = {.front = 0, .back = padded_dimension - shape[index]}; } } - explicit LegacyShape(const std::vector &shape, const std::vector &shape_with_tile_padding) : + explicit LegacyShape(std::span shape, std::span shape_with_tile_padding) : rank_(shape.size()), dimensions_{}, padding_{shape.size()} { TT_ASSERT( shape.size() == shape_with_tile_padding.size(), @@ -280,6 +282,10 @@ class LegacyShape { this->padding_[index] = {.front = 0, .back = padded_dimension - shape[index]}; } } + explicit LegacyShape(const ttnn::SmallVector& shape, const ttnn::SmallVector& shape_with_tile_padding) + : LegacyShape(std::span(shape), std::span(shape_with_tile_padding)) {} + explicit LegacyShape(const std::initializer_list shape, const std::initializer_list shape_with_tile_padding) + : LegacyShape(ttnn::SmallVector(shape), ttnn::SmallVector(shape_with_tile_padding)) {} std::size_t rank() const; std::size_t size() const; @@ -480,7 +486,7 @@ static tt::tt_metal::LegacyShape compute_ttl_shape( for (auto index = 0; index < Rank; index++) { ttl_shape[index] = shape[index] + padding[index][0] + padding[index][1]; } - return tt::tt_metal::LegacyShape{ttl_shape, tt::tt_metal::Padding{padding, tt::tt_metal::Padding::PadValue::Any}}; + return tt::tt_metal::LegacyShape{std::span(ttl_shape), tt::tt_metal::Padding{padding, tt::tt_metal::Padding::PadValue::Any}}; } } // namespace detail @@ -507,18 +513,23 @@ struct Shape { const std::array &shape, const std::array, Rank> &tile_padding) : value{detail::compute_ttl_shape(shape, tile_padding)} {} - Shape(const std::vector &shape) : value{tt::tt_metal::LegacyShape{shape}} {} + Shape(std::span shape) : value{tt::tt_metal::LegacyShape{shape}} {} - explicit Shape(const std::vector &shape, const std::vector &shape_with_tile_padding) : + Shape(const SmallVector& shape) : value{tt::tt_metal::LegacyShape{shape}} {} + + explicit Shape(std::span shape, std::span shape_with_tile_padding) : + value{tt::tt_metal::LegacyShape{shape, shape_with_tile_padding}} {} + + explicit Shape(const std::initializer_list shape, const std::initializer_list shape_with_tile_padding) : value{tt::tt_metal::LegacyShape{shape, shape_with_tile_padding}} {} - explicit Shape(const std::vector &shape, const Padding &padding) : + explicit Shape(std::span shape, const Padding &padding) : value{tt::tt_metal::LegacyShape{shape, padding}} {} explicit Shape(const Shape &shape, const Padding &padding) : value{tt::tt_metal::LegacyShape{shape.value, padding}} {} - Shape(const SimpleShape& shape): value{shape.as_vector()} {} + Shape(const SimpleShape& shape): value{shape.view()} {} const auto rank() const { return this->value.rank(); } @@ -535,7 +546,7 @@ struct Shape { } SimpleShape padded_shape() const { - std::vector values(rank()); + SmallVector values(rank()); for (size_t i = 0; i < values.size(); i++) { values[i] = this->value[i]; // value stored LegacyShape, its operator[] returns padded value } @@ -544,7 +555,7 @@ struct Shape { // Returns the shape without padding, padding information is stripped SimpleShape logical_shape() const { - std::vector values(this->rank()); + SmallVector values(this->rank()); for (size_t i = 0; i < values.size(); i++) { values[i] = this->operator[](i); // operator[] returns the shape without padding } diff --git a/ttnn/cpp/ttnn/tensor/vector_base.cpp b/ttnn/cpp/ttnn/tensor/vector_base.cpp new file mode 100644 index 00000000000..afe4fa3360a --- /dev/null +++ b/ttnn/cpp/ttnn/tensor/vector_base.cpp @@ -0,0 +1,83 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "vector_base.hpp" +#include +#include "fmt/color.h" +#include "tt_metal/common/assert.hpp" + +namespace ttnn { + +namespace { + +constexpr size_t MIN_INTERNAL_SIZE = 4; + +int32_t normalized_index(int32_t index, size_t original_size, size_t container_size) { + int32_t orig_size = static_cast(original_size); + int32_t full_size = static_cast(container_size); + + int fixed_index = index; + if (fixed_index < 0) { + fixed_index += full_size; + } else { + fixed_index += full_size - orig_size; + } + + if (fixed_index < 0 || fixed_index >= full_size) { + TT_THROW("VectorBase[] index out of range. {} not in [{}, {})", index, -full_size, full_size); + } + + return fixed_index; +} +} + +void VectorBase::init() { + m_original_size = m_value.size(); + + if(m_original_size < MIN_INTERNAL_SIZE) { + m_value.resize(MIN_INTERNAL_SIZE); + size_t shift = MIN_INTERNAL_SIZE - m_original_size; + for (size_t idx = MIN_INTERNAL_SIZE - 1; idx >= shift; idx--) { + m_value[idx] = m_value[idx - shift]; + } + for(size_t idx = 0; idx < shift; idx++) { + m_value[idx] = 1; + } + } +} + +size_t VectorBase::size() const { + return m_original_size; +} + +std::span VectorBase::view() const { + return std::span(cbegin(), cend()); +} + +bool VectorBase::operator==(const VectorBase &other) const = default; + +bool VectorBase::operator==(const Container &other) const { + auto original_view = view(); + return std::equal(original_view.begin(), original_view.end(), other.cbegin(), other.cend()); +} + +uint32_t VectorBase::operator[](int32_t index) const { + auto norm_index = normalized_index(index, m_original_size, m_value.size()); + return m_value[norm_index]; +} + +uint32_t& VectorBase::operator[](int32_t index) { + auto norm_index = normalized_index(index, m_original_size, m_value.size()); + return m_value[norm_index]; +} + +VectorBase::Container::const_iterator VectorBase::cbegin() const { + return this->m_value.cbegin() + (m_value.size() - m_original_size); +} + +VectorBase::Container::const_iterator VectorBase::cend() const { + return this->m_value.cend(); +} + +} diff --git a/ttnn/cpp/ttnn/tensor/vector_base.hpp b/ttnn/cpp/ttnn/tensor/vector_base.hpp new file mode 100644 index 00000000000..678c6b599b3 --- /dev/null +++ b/ttnn/cpp/ttnn/tensor/vector_base.hpp @@ -0,0 +1,99 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include + +#include "tt_metal/tt_stl/reflection.hpp" + +namespace ttnn { + +static constexpr size_t SMALL_VECTOR_SIZE = 8; + +template +struct SmallVector: public boost::container::small_vector { + using boost::container::small_vector::small_vector; +}; + +template +std::ostream &operator<<(std::ostream &os, const SmallVector &vec) { + os << "SmallVector(["; + for (auto i = 0; i < vec.size(); ++i) { + if (i > 0) { + os << ", "; + } + os << vec[i]; + } + os << "])"; + return os; +} + +// Container wrapper that allows negative indexing +class VectorBase final { +public: + using Container = SmallVector; + + VectorBase() = default; + explicit VectorBase(const Container& shape) : m_value(shape) { init(); } + explicit VectorBase(Container&& shape) : m_value(std::move(shape)) { init(); } + explicit VectorBase(std::initializer_list ilist) : m_value(ilist) { init(); } + template + explicit VectorBase(const std::array& arr) : m_value(arr.begin(), arr.end()) { init(); } + + size_t size() const; + + template + bool operator==(const std::array &other) const { + return m_value.size() == N && std::equal(m_value.begin(), m_value.end(), other.begin()); + } + + bool operator==(const VectorBase &other) const; + bool operator==(const Container &other) const; + + uint32_t operator[](int32_t index) const; + uint32_t &operator[](int32_t index); + + Container::const_iterator cbegin() const; + Container::const_iterator cend() const; + + std::span view() const; + + static constexpr auto attribute_names = std::forward_as_tuple("value", "original_size"); + const auto attribute_values() const { return std::forward_as_tuple(this->m_value, this->m_original_size); } + +private: + void init(); + + Container m_value; + size_t m_original_size = 0; +}; + +} + +template +struct std::hash> { + size_t operator()(const ttnn::SmallVector& vec) const noexcept { + size_t hash = 0; + for (const auto& element : vec) { + hash = tt::stl::hash::detail::hash_objects(hash, element); + } + return hash; + } +}; + +template +struct fmt::formatter> { + constexpr auto parse(format_parse_context& ctx) -> format_parse_context::iterator { return ctx.end(); } + + auto format(const ttnn::SmallVector& vector, format_context& ctx) const -> format_context::iterator { + std::stringstream ss; + ss << vector; + return fmt::format_to(ctx.out(), "{}", ss.str()); + } +}; From e35a340194e710ee5610e174be4ad23059c1777b Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Fri, 25 Oct 2024 13:12:03 +0000 Subject: [PATCH 2/9] #0: Try to fix python --- ttnn/cpp/ttnn/tensor/vector_base.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ttnn/cpp/ttnn/tensor/vector_base.hpp b/ttnn/cpp/ttnn/tensor/vector_base.hpp index 678c6b599b3..1ab931d07d6 100644 --- a/ttnn/cpp/ttnn/tensor/vector_base.hpp +++ b/ttnn/cpp/ttnn/tensor/vector_base.hpp @@ -9,6 +9,7 @@ #include #include +#include #include "tt_metal/tt_stl/reflection.hpp" @@ -97,3 +98,8 @@ struct fmt::formatter> { return fmt::format_to(ctx.out(), "{}", ss.str()); } }; + +namespace PYBIND11_NAMESPACE { namespace detail { + template + struct type_caster> : list_caster, T> {}; +}} From 56f17ee3f69b7e379a7fd4e57138756ada1fbd7b Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Fri, 25 Oct 2024 14:26:15 +0000 Subject: [PATCH 3/9] #0: Fixup --- .../data_movement/reshape_view/reshape_pybind.cpp | 2 +- .../operations/data_movement/slice/slice_pybind.hpp | 13 +++++-------- .../fast_reduce_nc/fast_reduce_nc_pybind.cpp | 4 ++-- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp index b57eddcc3f0..5df62a9620e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape_pybind.cpp @@ -34,7 +34,7 @@ void bind_reshape_view(pybind11::module& module, const data_movement_operation_t ttnn::pybind_overload_t{ [](const data_movement_operation_t& self, const ttnn::Tensor& input_tensor, - std::span shape + const ttnn::SmallVector shape ) -> ttnn::Tensor { return self(input_tensor, shape); }, diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp index 963b235f9a8..9fc0e2c0066 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice_pybind.hpp @@ -52,17 +52,14 @@ void bind_slice(py::module& module) { ttnn::pybind_overload_t{ [] (const OperationType& self, const ttnn::Tensor& input_tensor, - std::span slice_start, - std::span slice_end, - std::optional> step, + const ttnn::SmallVector& slice_start, + const ttnn::SmallVector& slice_end, + const std::optional>& step, const std::optional& memory_config, const std::optional& optional_output_tensor, uint8_t queue_id) { - if (step.has_value()) { - return self(queue_id, input_tensor, slice_start, slice_end, step.value(), memory_config, optional_output_tensor); - } else { - return self(queue_id, input_tensor, slice_start, slice_end, ttnn::SmallVector(slice_end.size(), 1), memory_config, optional_output_tensor); - } + const auto step_value = step.value_or(ttnn::SmallVector(slice_end.size(), 1)); + return self(queue_id, input_tensor, slice_start, slice_end, step_value, memory_config, optional_output_tensor); }, py::arg("input_tensor"), py::arg("slice_start"), diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc_pybind.cpp index 6174d6bae0c..a379faa9ef1 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc_pybind.cpp @@ -22,7 +22,7 @@ void bind_fast_reduce_nc(pybind11::module& module) { ttnn::pybind_overload_t{ [] (const OperationType& self, const ttnn::Tensor& input, - std::span dims, + const ttnn::SmallVector& dims, const std::optional output, const ttnn::MemoryConfig memory_config, std::optional compute_kernel_config, @@ -31,7 +31,7 @@ void bind_fast_reduce_nc(pybind11::module& module) { }, pybind11::arg("input").noconvert(), pybind11::kw_only(), - pybind11::arg("dims").noconvert() = std::span(), + pybind11::arg("dims").noconvert() = ttnn::SmallVector(), pybind11::arg("output").noconvert() = std::nullopt, pybind11::arg("memory_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, pybind11::arg("compute_kernel_config").noconvert() = std::nullopt, From f1cea1edebab6559cd6ff6e9b9775f427e4abd5d Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Fri, 25 Oct 2024 16:15:34 +0000 Subject: [PATCH 4/9] #0: Remove VectorBase, cleanup --- tt_metal/common/test_tiles.hpp | 19 ++--- ttnn/cpp/ttnn/tensor/CMakeLists.txt | 1 - ttnn/cpp/ttnn/tensor/small_vector.hpp | 61 +++++++++++++++ ttnn/cpp/ttnn/tensor/tensor_impl.hpp | 12 +-- ttnn/cpp/ttnn/tensor/types.cpp | 14 ++-- ttnn/cpp/ttnn/tensor/types.hpp | 25 +++--- ttnn/cpp/ttnn/tensor/vector_base.cpp | 83 -------------------- ttnn/cpp/ttnn/tensor/vector_base.hpp | 105 -------------------------- 8 files changed, 97 insertions(+), 223 deletions(-) create mode 100644 ttnn/cpp/ttnn/tensor/small_vector.hpp delete mode 100644 ttnn/cpp/ttnn/tensor/vector_base.cpp delete mode 100644 ttnn/cpp/ttnn/tensor/vector_base.hpp diff --git a/tt_metal/common/test_tiles.hpp b/tt_metal/common/test_tiles.hpp index 083093cbed8..9c8b471d544 100644 --- a/tt_metal/common/test_tiles.hpp +++ b/tt_metal/common/test_tiles.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include "tt_metal/common/constants.hpp" #include "tt_metal/common/assert.hpp" #include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" @@ -25,8 +26,8 @@ enum TensorLayout { template typename BufferType> std::vector convert_to_tile_layout( const BufferType& data, - const std::optional>& tile_shape = std::nullopt, - const std::optional>& face_shape = std::nullopt) { + std::optional> tile_shape = std::nullopt, + std::optional> face_shape = std::nullopt) { ZoneScoped; std::vector result; result.reserve(data.size()); @@ -79,8 +80,8 @@ std::vector convert_to_tile_layout( template typename BufferTyp> std::vector convert_to_flat_layout( const BufferTyp& data, - const std::optional>& tile_shape = std::nullopt, - const std::optional>& face_shape = std::nullopt) { + std::optional> tile_shape = std::nullopt, + std::optional> face_shape = std::nullopt) { ZoneScoped; std::vector result; result.reserve(data.size()); @@ -115,7 +116,7 @@ std::vector convert_to_flat_layout( // Converts a 32-swizzled tilized row-major tensor to a linear 32-zero-padded row-major tensor template typename BufferType> -inline std::vector untilize_nchw(const BufferType& in, const std::vector& shape, const std::optional>& tile_shape = std::nullopt) { +inline std::vector untilize_nchw(const BufferType& in, std::span shape, std::optional> tile_shape = std::nullopt) { ZoneScoped; auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT; auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH; @@ -159,7 +160,7 @@ inline std::uint32_t round_up_to_tile(int val, int tile_val) { return (val + til // Converts a linear non-zero-padded row-major tensor to zero-padded-32 32-swizzled tilized row-major tensor template typename BufferType> -inline std::vector tilize_nchw(const BufferType& in_rowmajor, const std::vector& shape, const std::optional>& tile_shape = std::nullopt) { +inline std::vector tilize_nchw(const BufferType& in_rowmajor, std::span shape, std::optional> tile_shape = std::nullopt) { ZoneScoped; int H = shape[shape.size() - 2], W = shape[shape.size() - 1]; auto batch_size = 1; @@ -221,11 +222,11 @@ struct TensAddr { template typename BufferType> inline std::vector convert_layout( const BufferType& inp, - const std::vector& shape, + std::span shape, TensorLayout inL, TensorLayout outL, - const std::optional>& tile_shape = std::nullopt, - const std::optional>& face_shape = std::nullopt) { + std::optional> tile_shape = std::nullopt, + std::optional> face_shape = std::nullopt) { ZoneScoped; switch (inL) { case TILED_SWIZZLED: diff --git a/ttnn/cpp/ttnn/tensor/CMakeLists.txt b/ttnn/cpp/ttnn/tensor/CMakeLists.txt index 1506fa9de29..cef2d5ffdb7 100644 --- a/ttnn/cpp/ttnn/tensor/CMakeLists.txt +++ b/ttnn/cpp/ttnn/tensor/CMakeLists.txt @@ -5,7 +5,6 @@ set(TENSOR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/types.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/vector_base.cpp CACHE INTERNAL "Tensor sources to reuse in ttnn build" ) diff --git a/ttnn/cpp/ttnn/tensor/small_vector.hpp b/ttnn/cpp/ttnn/tensor/small_vector.hpp new file mode 100644 index 00000000000..6fae1d164b3 --- /dev/null +++ b/ttnn/cpp/ttnn/tensor/small_vector.hpp @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "tt_metal/tt_stl/reflection.hpp" + +namespace ttnn { + +static constexpr size_t SMALL_VECTOR_SIZE = 8; + +template +struct SmallVector: public boost::container::small_vector { + using boost::container::small_vector::small_vector; +}; + +template +std::ostream &operator<<(std::ostream &os, const SmallVector &vec) { + os << "SmallVector(["; + for (auto i = 0; i < vec.size(); ++i) { + if (i > 0) { + os << ", "; + } + os << vec[i]; + } + os << "])"; + return os; +} + +} + +template +struct std::hash> { + size_t operator()(const ttnn::SmallVector& vec) const noexcept { + size_t hash = 0; + for (const auto& element : vec) { + hash = tt::stl::hash::detail::hash_objects(hash, element); + } + return hash; + } +}; + +template +struct fmt::formatter> { + constexpr auto parse(format_parse_context& ctx) -> format_parse_context::iterator { return ctx.end(); } + + auto format(const ttnn::SmallVector& vector, format_context& ctx) const -> format_context::iterator { + std::stringstream ss; + ss << vector; + return fmt::format_to(ctx.out(), "{}", ss.str()); + } +}; + +namespace PYBIND11_NAMESPACE { namespace detail { + template + struct type_caster> : list_caster, T> {}; +}} diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index 55b4303913d..94a676a4fe3 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -209,18 +209,18 @@ inline std::vector convert_layout_row_major_to_tile(const tt::tt_metal::Legac (shape[-2] % tile.get_tile_shape()[0] == 0 && shape[-1] % tile.get_tile_shape()[1] == 0), "Unsupported shape for tensor conversion from row-major to tile layout. The tensor shape height and width must be a multiple of tile height ({}) and width ({}), but the provided shape is {}", tile.get_tile_shape()[0], tile.get_tile_shape()[1], shape); - auto tile_shape = std::vector{ tile.get_tile_shape()[0], tile.get_tile_shape()[1] }; - auto face_shape = std::vector{ tile.get_face_shape()[0], tile.get_face_shape()[1] }; + auto tile_shape = ttnn::SmallVector{ tile.get_tile_shape()[0], tile.get_tile_shape()[1] }; + auto face_shape = ttnn::SmallVector{ tile.get_face_shape()[0], tile.get_face_shape()[1] }; return convert_layout( - data_to_convert, std::vector(shape.begin(), shape.end()), TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES, tile_shape, face_shape); + data_to_convert, std::span(shape.begin(), shape.end()), TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES, tile_shape, face_shape); } template typename BufferType> inline std::vector convert_layout_tile_to_row_major(const tt::tt_metal::LegacyShape& shape, const Tile& tile, const BufferType& data_to_convert) { - auto tile_shape = std::vector{ tile.get_tile_shape()[0], tile.get_tile_shape()[1] }; - auto face_shape = std::vector{ tile.get_face_shape()[0], tile.get_face_shape()[1] }; + auto tile_shape = ttnn::SmallVector{ tile.get_tile_shape()[0], tile.get_tile_shape()[1] }; + auto face_shape = ttnn::SmallVector{ tile.get_face_shape()[0], tile.get_face_shape()[1] }; return convert_layout( - data_to_convert, std::vector(shape.begin(), shape.end()), TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR, tile_shape, face_shape); + data_to_convert, std::span(shape.begin(), shape.end()), TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR, tile_shape, face_shape); } // ====================================================================================== diff --git a/ttnn/cpp/ttnn/tensor/types.cpp b/ttnn/cpp/ttnn/tensor/types.cpp index 7109afb5154..79d772718b3 100644 --- a/ttnn/cpp/ttnn/tensor/types.cpp +++ b/ttnn/cpp/ttnn/tensor/types.cpp @@ -391,25 +391,25 @@ int32_t normalized_index(int32_t index, size_t container_size) { } bool SimpleShape::operator==(const SimpleShape &other) const { - return this->m_value == other.m_value; + return this->value == other.value; } bool SimpleShape::operator==(const SmallVector &other) const { - return this->m_value == other; + return this->value == other; } uint32_t SimpleShape::operator[](int32_t index) const { - auto norm_index = normalized_index(index, m_value.size()); - return m_value[norm_index]; + auto norm_index = normalized_index(index, value.size()); + return value[norm_index]; } uint32_t& SimpleShape::operator[](int32_t index) { - auto norm_index = normalized_index(index, m_value.size()); - return m_value[norm_index]; + auto norm_index = normalized_index(index, value.size()); + return value[norm_index]; } uint64_t SimpleShape::volume() const { - return std::accumulate(this->m_value.cbegin(), this->m_value.cend(), + return std::accumulate(this->value.cbegin(), this->value.cend(), uint64_t{1}, std::multiplies()); } diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index de6bd3909c2..1d11463f00f 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "common/bfloat16.hpp" #include "tt_metal/common/core_coord.hpp" @@ -18,7 +19,7 @@ #include "tt_metal/tt_stl/concepts.hpp" #include "tt_metal/tt_stl/reflection.hpp" #include "ttnn/tensor/host_buffer/types.hpp" -#include "ttnn/tensor/vector_base.hpp" +#include "ttnn/tensor/small_vector.hpp" #include "ttnn/cpp/ttnn/tensor/enum_types.hpp" namespace ttnn { @@ -35,15 +36,15 @@ SimpleShape is a temporary measure aimed at making a clear distinction between S **/ class SimpleShape final { public: - explicit SimpleShape(const SmallVector& shape) : m_value(shape) {} - explicit SimpleShape(SmallVector&& shape) : m_value(std::move(shape)) {} - explicit SimpleShape(std::initializer_list ilist) : m_value(ilist) {} + explicit SimpleShape(const SmallVector& shape) : value(shape) {} + explicit SimpleShape(SmallVector&& shape) : value(std::move(shape)) {} + explicit SimpleShape(std::initializer_list ilist) : value(ilist) {} template - explicit SimpleShape(const std::array& arr) : m_value(arr) {} + explicit SimpleShape(const std::array& arr) : value(arr.begin(), arr.end()) {} template bool operator==(const std::array &other) const { - return m_value == other; + return value.size() == other.size() && std::equal(value.cbegin(), value.cend(), other.cbegin()); } bool operator==(const SimpleShape &other) const; @@ -52,22 +53,22 @@ class SimpleShape final { uint32_t operator[](int32_t index) const; uint32_t &operator[](int32_t index); - size_t rank() const { return m_value.size(); } + size_t rank() const { return value.size(); } uint64_t volume() const; - auto cbegin() const { return m_value.cbegin(); } - auto cend() const { return m_value.cend(); } + auto cbegin() const { return value.cbegin(); } + auto cend() const { return value.cend(); } - std::span view() const { return m_value.view(); } + std::span view() const { return value; } // Needed for reflect / fmt static constexpr auto attribute_names = std::forward_as_tuple("value"); - auto attribute_values() const { return std::forward_as_tuple(m_value); } + auto attribute_values() const { return std::forward_as_tuple(value); } friend std::ostream &operator<<(std::ostream &os, const SimpleShape &shape); private: - VectorBase m_value; + SmallVector value; }; inline std::ostream &operator<<(std::ostream &os, const ttnn::SimpleShape &shape) { diff --git a/ttnn/cpp/ttnn/tensor/vector_base.cpp b/ttnn/cpp/ttnn/tensor/vector_base.cpp deleted file mode 100644 index afe4fa3360a..00000000000 --- a/ttnn/cpp/ttnn/tensor/vector_base.cpp +++ /dev/null @@ -1,83 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC -// -// SPDX-License-Identifier: Apache-2.0 - -#include "vector_base.hpp" -#include -#include "fmt/color.h" -#include "tt_metal/common/assert.hpp" - -namespace ttnn { - -namespace { - -constexpr size_t MIN_INTERNAL_SIZE = 4; - -int32_t normalized_index(int32_t index, size_t original_size, size_t container_size) { - int32_t orig_size = static_cast(original_size); - int32_t full_size = static_cast(container_size); - - int fixed_index = index; - if (fixed_index < 0) { - fixed_index += full_size; - } else { - fixed_index += full_size - orig_size; - } - - if (fixed_index < 0 || fixed_index >= full_size) { - TT_THROW("VectorBase[] index out of range. {} not in [{}, {})", index, -full_size, full_size); - } - - return fixed_index; -} -} - -void VectorBase::init() { - m_original_size = m_value.size(); - - if(m_original_size < MIN_INTERNAL_SIZE) { - m_value.resize(MIN_INTERNAL_SIZE); - size_t shift = MIN_INTERNAL_SIZE - m_original_size; - for (size_t idx = MIN_INTERNAL_SIZE - 1; idx >= shift; idx--) { - m_value[idx] = m_value[idx - shift]; - } - for(size_t idx = 0; idx < shift; idx++) { - m_value[idx] = 1; - } - } -} - -size_t VectorBase::size() const { - return m_original_size; -} - -std::span VectorBase::view() const { - return std::span(cbegin(), cend()); -} - -bool VectorBase::operator==(const VectorBase &other) const = default; - -bool VectorBase::operator==(const Container &other) const { - auto original_view = view(); - return std::equal(original_view.begin(), original_view.end(), other.cbegin(), other.cend()); -} - -uint32_t VectorBase::operator[](int32_t index) const { - auto norm_index = normalized_index(index, m_original_size, m_value.size()); - return m_value[norm_index]; -} - -uint32_t& VectorBase::operator[](int32_t index) { - auto norm_index = normalized_index(index, m_original_size, m_value.size()); - return m_value[norm_index]; -} - -VectorBase::Container::const_iterator VectorBase::cbegin() const { - return this->m_value.cbegin() + (m_value.size() - m_original_size); -} - -VectorBase::Container::const_iterator VectorBase::cend() const { - return this->m_value.cend(); -} - -} diff --git a/ttnn/cpp/ttnn/tensor/vector_base.hpp b/ttnn/cpp/ttnn/tensor/vector_base.hpp deleted file mode 100644 index 1ab931d07d6..00000000000 --- a/ttnn/cpp/ttnn/tensor/vector_base.hpp +++ /dev/null @@ -1,105 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include -#include - -#include -#include - -#include "tt_metal/tt_stl/reflection.hpp" - -namespace ttnn { - -static constexpr size_t SMALL_VECTOR_SIZE = 8; - -template -struct SmallVector: public boost::container::small_vector { - using boost::container::small_vector::small_vector; -}; - -template -std::ostream &operator<<(std::ostream &os, const SmallVector &vec) { - os << "SmallVector(["; - for (auto i = 0; i < vec.size(); ++i) { - if (i > 0) { - os << ", "; - } - os << vec[i]; - } - os << "])"; - return os; -} - -// Container wrapper that allows negative indexing -class VectorBase final { -public: - using Container = SmallVector; - - VectorBase() = default; - explicit VectorBase(const Container& shape) : m_value(shape) { init(); } - explicit VectorBase(Container&& shape) : m_value(std::move(shape)) { init(); } - explicit VectorBase(std::initializer_list ilist) : m_value(ilist) { init(); } - template - explicit VectorBase(const std::array& arr) : m_value(arr.begin(), arr.end()) { init(); } - - size_t size() const; - - template - bool operator==(const std::array &other) const { - return m_value.size() == N && std::equal(m_value.begin(), m_value.end(), other.begin()); - } - - bool operator==(const VectorBase &other) const; - bool operator==(const Container &other) const; - - uint32_t operator[](int32_t index) const; - uint32_t &operator[](int32_t index); - - Container::const_iterator cbegin() const; - Container::const_iterator cend() const; - - std::span view() const; - - static constexpr auto attribute_names = std::forward_as_tuple("value", "original_size"); - const auto attribute_values() const { return std::forward_as_tuple(this->m_value, this->m_original_size); } - -private: - void init(); - - Container m_value; - size_t m_original_size = 0; -}; - -} - -template -struct std::hash> { - size_t operator()(const ttnn::SmallVector& vec) const noexcept { - size_t hash = 0; - for (const auto& element : vec) { - hash = tt::stl::hash::detail::hash_objects(hash, element); - } - return hash; - } -}; - -template -struct fmt::formatter> { - constexpr auto parse(format_parse_context& ctx) -> format_parse_context::iterator { return ctx.end(); } - - auto format(const ttnn::SmallVector& vector, format_context& ctx) const -> format_context::iterator { - std::stringstream ss; - ss << vector; - return fmt::format_to(ctx.out(), "{}", ss.str()); - } -}; - -namespace PYBIND11_NAMESPACE { namespace detail { - template - struct type_caster> : list_caster, T> {}; -}} From ccaa821d8eeb38d275ff1ae5b18b2bc27b19d7c8 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Fri, 25 Oct 2024 17:01:50 +0000 Subject: [PATCH 5/9] #0: Rebase fix --- ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp index 9a77bba6f2d..a97e009c665 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp @@ -99,7 +99,8 @@ ttnn::Tensor slice_operation_invoke_impl( padded_ends[adjusted_rank - 2] = std::max(tt::round_up(padded_ends[adjusted_rank - 2], tt::constants::TILE_HEIGHT), tt::constants::TILE_HEIGHT); padded_ends[adjusted_rank - 1] = std::max(tt::round_up(padded_ends[adjusted_rank - 1], tt::constants::TILE_WIDTH), tt::constants::TILE_WIDTH); } - SmallVector actual_shape, padded_shape; + + SmallVector actual_shape, final_padded_shape; actual_shape.reserve(input_rank); final_padded_shape.reserve(input_rank); bool empty = false; @@ -220,5 +221,5 @@ ttnn::Tensor SliceOperation::invoke( const std::optional& optional_output_tensor) { return slice_operation_invoke_impl(ttnn::DefaultQueueId, input_tensor, begins, ends, step, memory_config_arg, optional_output_tensor); } -slice.cpp + } // namespace operations From 6ee79873bd912d57d98a607b0e9d02bd0c060699 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Fri, 25 Oct 2024 17:59:18 +0000 Subject: [PATCH 6/9] #0: Rebase fixes --- .../data_movement/repeat/repeat.cpp | 2 +- .../operations/data_movement/slice/slice.cpp | 247 +++++++++++++++--- .../operations/data_movement/slice/slice.hpp | 62 +++-- .../binary_backward/binary_backward.cpp | 12 +- 4 files changed, 245 insertions(+), 78 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp index 436f2fb3482..0f60cf61b9e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp @@ -62,7 +62,7 @@ ttnn::Tensor RepeatOperation::invoke( if (input_tensor.get_layout() != Layout::ROW_MAJOR && logical_input_shape != padded_input_shape) { auto zero_indices = ttnn::SmallVector(input_rank, 0); - auto end_indices = repeated_logical_shape.view(); + auto end_indices = ttnn::SmallVector(repeated_logical_shape.cbegin(), repeated_logical_shape.cend()); auto step = ttnn::SmallVector(input_rank, 1); if (repeated_logical_shape.volume() % tt::constants::TILE_HW != 0) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp index a97e009c665..b1068c9b04c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp @@ -15,9 +15,8 @@ namespace ttnn::operations::data_movement { -namespace { template -ttnn::Tensor slice_operation_invoke_impl( +ttnn::Tensor SliceOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, std::span begins, @@ -69,9 +68,9 @@ ttnn::Tensor slice_operation_invoke_impl( size_t adjusted_rank = padded_shape.rank(); // Now adjusted to 4 after unsqueeze // Create modified vectors with wrapped indices and adjust them to match the tensor's rank - std::vector modified_begins(adjusted_rank, 0); - std::vector modified_ends = padded_shape.as_vector(); - std::vector modified_step(adjusted_rank, 1); + ttnn::SmallVector modified_begins(adjusted_rank, 0); + ttnn::SmallVector modified_ends(padded_shape.cbegin(), padded_shape.cend()); + ttnn::SmallVector modified_step(adjusted_rank, 1); size_t rank_diff = adjusted_rank - input_rank; @@ -90,17 +89,17 @@ ttnn::Tensor slice_operation_invoke_impl( } } - auto output_dim_i = [&modified_begins, &modified_step](size_t i, const std::vector &modified_ends) { + auto output_dim_i = [&modified_begins, &modified_step](size_t i, const ttnn::SmallVector &modified_ends) { return (modified_ends[i] - modified_begins[i] + modified_step[i] - 1) / modified_step[i]; }; - std::vector padded_ends = modified_ends; + ttnn::SmallVector padded_ends = modified_ends; if (input.layout() == Layout::TILE) { padded_ends[adjusted_rank - 2] = std::max(tt::round_up(padded_ends[adjusted_rank - 2], tt::constants::TILE_HEIGHT), tt::constants::TILE_HEIGHT); padded_ends[adjusted_rank - 1] = std::max(tt::round_up(padded_ends[adjusted_rank - 1], tt::constants::TILE_WIDTH), tt::constants::TILE_WIDTH); } - SmallVector actual_shape, final_padded_shape; + ttnn::SmallVector actual_shape, final_padded_shape; actual_shape.reserve(input_rank); final_padded_shape.reserve(input_rank); bool empty = false; @@ -178,48 +177,210 @@ ttnn::Tensor slice_operation_invoke_impl( return rm_only ? ttnn::to_layout(res, input_tensor.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr) : res; } } -} +template ttnn::Tensor SliceOperation::invoke( - uint8_t queue_id, - const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor) { - return slice_operation_invoke_impl(queue_id, input_tensor, begins, ends, step, memory_config_arg, optional_output_tensor); -} + const ttnn::Tensor& input_tensor, + std::span begins, + std::span ends, + std::span step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor) { + return SliceOperation::invoke(ttnn::DefaultQueueId, input_tensor, begins, ends, step, memory_config_arg); + } -ttnn::Tensor SliceOperation::invoke( - const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor) { - return slice_operation_invoke_impl(ttnn::DefaultQueueId, input_tensor, begins, ends, step, memory_config_arg, optional_output_tensor); +// Specialization for uint32_t and N=4 +template<> +ttnn::Tensor SliceOperation::invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + const std::array &begins, + const std::array &ends, + const std::array &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor) { + + const auto& padded_input_shape = input_tensor.get_padded_shape(); + TT_FATAL(padded_input_shape.rank() == 4, "Input tensor must have rank 4"); + + bool no_step = step[0] == 1 && step[1] == 1 && step[2] == 1 && step[3] == 1; + bool starts_zero = begins[0]==0 && begins[1]==0 && begins[2]==0 && begins[3]==0; + bool ends_max = ends[0]==padded_input_shape[0] && ends[1]==padded_input_shape[1] && ends[2]==padded_input_shape[2] && ends[3]==padded_input_shape[3]; + + if (no_step && starts_zero && ends_max) { + if (input_tensor.storage_type() == StorageType::DEVICE) { + auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input_tensor.memory_config()); + return ttnn::to_memory_config(input_tensor, memory_config, std::nullopt); + } + return input_tensor; + } + bool rm_only = !no_step && input_tensor.get_layout() == Layout::TILE; + ttnn::Tensor input = input_tensor; + if (rm_only) { + input = ttnn::to_layout(input_tensor, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); + } + + const bool tiled = input.get_layout() == Layout::TILE; + bool on_device = input.storage_type() == StorageType::DEVICE; + + std::array actual_shape; + std::array padded_shape; + const std::array padded_ends = tiled ? std::array({ends[0], ends[1], std::max(tt::round_up(ends[2], tt::constants::TILE_HEIGHT), tt::constants::TILE_HEIGHT), std::max(tt::round_up(ends[3], tt::constants::TILE_WIDTH), tt::constants::TILE_WIDTH)}) : ends; + bool empty = false; + for (int i = 0; i < 4; ++i) { + TT_FATAL(ends[i] >= begins[i], "End {} must be greater than or equal to start {}", ends[i], begins[i]); + uint32_t offset = step[i] - begins[i] - 1; + uint32_t dim_size = (ends[i] + offset) / step[i]; + empty |= dim_size == 0; + actual_shape[i] = dim_size; + padded_shape[i]= std::max((padded_ends[i] + offset) / step[i], 1u); + } + + ttnn::Shape output_shape(actual_shape, padded_shape); + + if (empty) { + TT_FATAL(on_device, "Host tensor slice cannot return a scalar or empty tensor"); + auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); + return ttnn::empty(output_shape, input.dtype(), input_tensor.layout(), + input.device(), memory_config); + } + + // Early exit if slice is a no-op + if (ttnn::Shape(padded_shape) == padded_input_shape && no_step) { + if (input.storage_type() == StorageType::DEVICE) { + auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); + auto res = ttnn::to_memory_config(input, memory_config, std::nullopt); + return ttnn::reshape(res, output_shape); + } + return ttnn::reshape(input, output_shape); // change to view + } + + if (on_device) { + auto memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config_arg.value_or(input.memory_config()); + + // Check for in-place unpad optimization + if (input.is_sharded() && input.memory_config() == memory_config && padded_input_shape.rank() > 1) { + TT_FATAL(no_step, "Sharded tensor slice implementation does not support striding"); + bool in_place_unpad = true; + for (int i = 0; i < 2; ++i) { + in_place_unpad &= begins[i] == 0 && ends[i] == 1 && padded_input_shape[i] == 1; + } + in_place_unpad &= begins[2] == 0 && + tt::div_up(ends[2], input.shard_spec().value().shape[0]) == + tt::div_up(padded_input_shape[2], input.shard_spec().value().shape[0]); + in_place_unpad &= begins[3] == 0 && ends[3] == padded_input_shape[3]; + if (in_place_unpad) { + return ttnn::reshape(input, output_shape); + } + } + + input = operation::run( + SliceDeviceOperation{ + begins, + padded_ends, + step, + memory_config}, + {input}, {}, {optional_output_tensor}, queue_id)[0]; + input = ttnn::reshape(input, output_shape); + return rm_only ? ttnn::to_layout(input, input.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr) : input; + } + + TT_FATAL(no_step, "Host tensor slice does not support strides"); + + if (input.get_padded_shape() == actual_shape) { + return input; + } else { + auto input_4d_rm = ttnn::to_layout(input, Layout::ROW_MAJOR, std::nullopt, std::nullopt, (Device *)nullptr); + auto output_4d = input_4d_rm.unpad(ttnn::SimpleShape(begins), ttnn::SimpleShape(ends)); + auto output_4d_rm = ttnn::to_layout(output_4d, input.get_layout(), std::nullopt, std::nullopt, (Device *)nullptr); + return ttnn::reshape(output_4d_rm, output_shape); + } } +template ttnn::Tensor SliceOperation::invoke( - uint8_t queue_id, - const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor) { - return slice_operation_invoke_impl(queue_id, input_tensor, begins, ends, step, memory_config_arg, optional_output_tensor); -} + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor) { + std::span start(output_tensor_start.begin(), output_tensor_start.end()); + std::span end(output_tensor_end.begin(), output_tensor_end.end()); + std::span step_vec(step.begin(), step.end()); + return SliceOperation::invoke(queue_id, input_tensor, start, end, step_vec, memory_config_arg); + } +template ttnn::Tensor SliceOperation::invoke( - const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, - const std::optional& memory_config_arg, - const std::optional& optional_output_tensor) { - return slice_operation_invoke_impl(ttnn::DefaultQueueId, input_tensor, begins, ends, step, memory_config_arg, optional_output_tensor); -} + const ttnn::Tensor& input_tensor, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor) { + return SliceOperation::invoke(ttnn::DefaultQueueId, input_tensor, output_tensor_start, output_tensor_end, step, memory_config_arg); + } + +template ttnn::Tensor SliceOperation::invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + std::span begins, + std::span ends, + std::span step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + +template ttnn::Tensor SliceOperation::invoke( + const ttnn::Tensor& input_tensor, + std::span begins, + std::span ends, + std::span step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + + +template ttnn::Tensor SliceOperation::invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + std::span begins, + std::span ends, + std::span step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + +template ttnn::Tensor SliceOperation::invoke( + const ttnn::Tensor& input_tensor, + std::span begins, + std::span ends, + std::span step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + +template ttnn::Tensor SliceOperation::invoke( + const ttnn::Tensor& input_tensor, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + +template ttnn::Tensor SliceOperation::invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); + +template ttnn::Tensor SliceOperation::invoke( + const ttnn::Tensor& input_tensor, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, + const std::optional& memory_config_arg, + const std::optional& optional_output_tensor); } // namespace operations diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp index 0206e0dd4ea..b4c3b951509 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp @@ -11,62 +11,68 @@ namespace operations { namespace data_movement { struct SliceOperation { + template static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, + std::span begins, + std::span ends, + std::span step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); + template static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, + std::span output_tensor_start, + std::span output_tensor_end, + std::span step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); + template static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, + const ttnn::SmallVector& begins, + const ttnn::SmallVector& ends, + const ttnn::SmallVector& step, const std::optional& memory_config_arg = std::nullopt, - const std::optional& optional_output_tensor = std::nullopt); + const std::optional& optional_output_tensor = std::nullopt) { + return invoke(queue_id, input_tensor, std::span(begins.begin(), begins.end()), std::span(ends.begin(), ends.end()), std::span(step.begin(), step.end()), memory_config_arg, optional_output_tensor); + } + template static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, + const ttnn::SmallVector& begins, + const ttnn::SmallVector& ends, + const ttnn::SmallVector& step, const std::optional& memory_config_arg = std::nullopt, - const std::optional& optional_output_tensor = std::nullopt); + const std::optional& optional_output_tensor = std::nullopt) { + return invoke(input_tensor, std::span(begins.begin(), begins.end()), std::span(ends.begin(), ends.end()), std::span(step.begin(), step.end()), memory_config_arg, optional_output_tensor); + } - template + template static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - const std::array& begins, - const std::array& ends, - const std::array& step, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, const std::optional& memory_config_arg = std::nullopt, - const std::optional& optional_output_tensor = std::nullopt) { - return invoke(queue_id, input_tensor, std::span(begins), std::span(ends), std::span(step), memory_config_arg, optional_output_tensor); - } + const std::optional& optional_output_tensor = std::nullopt); - template + template static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const std::array& begins, - const std::array& ends, - const std::array& step, + const std::array &output_tensor_start, + const std::array &output_tensor_end, + const std::array &step, const std::optional& memory_config_arg = std::nullopt, - const std::optional& optional_output_tensor = std::nullopt) { - return invoke(ttnn::DefaultQueueId, input_tensor, std::span(begins), std::span(ends), std::span(step), memory_config_arg, optional_output_tensor); - } + const std::optional& optional_output_tensor = std::nullopt); + + }; } // namespace data_movement diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp index ae9a5655ac6..a47fd0b0312 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp @@ -445,19 +445,19 @@ std::vector> ExecuteBackwardConcat::invoke( preallocated_tensors_check(input_grad, other_grad, input, other, {are_required_outputs[0], are_required_outputs[1]}); if(are_required_outputs[0]){ - std::vector start_index = {0, 0, 0, 0}; - std::vector end_index = { + ttnn::SmallVector start_index = {0, 0, 0, 0}; + ttnn::SmallVector end_index = { input.get_legacy_shape()[0], input.get_legacy_shape()[1], input.get_legacy_shape()[2], input.get_legacy_shape()[3]}; - std::vector step = std::vector({1, 1, 1, 1}); + ttnn::SmallVector step = {1, 1, 1, 1}; ttnn::slice(queue_id, grad, start_index, end_index, step, std::nullopt, input_grad); grad_tensor[0] = input_grad; } if(are_required_outputs[1]){ - std::vector start_index_2 = {0, 0, 0, 0}; + ttnn::SmallVector start_index_2 = {0, 0, 0, 0}; if (dim == 0) { start_index_2 = {input.get_legacy_shape()[0], 0, 0, 0}; } else if (dim == 1) { @@ -468,12 +468,12 @@ std::vector> ExecuteBackwardConcat::invoke( } else if (dim == 3) { start_index_2 = {0, 0, 0, input.get_legacy_shape()[3]}; } - std::vector end_index_2 = { + ttnn::SmallVector end_index_2 = { grad.get_legacy_shape()[0], grad.get_legacy_shape()[1], grad.get_legacy_shape()[2], grad.get_legacy_shape()[3]}; - std::vector step_2 = std::vector({1, 1, 1, 1}); + ttnn::SmallVector step_2 = {1, 1, 1, 1}; ttnn::slice(queue_id, grad, start_index_2, end_index_2, step_2, std::nullopt, other_grad); grad_tensor[1] = other_grad; } From a669f5190eccee21f91b325fa84da92f8895ce57 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Fri, 25 Oct 2024 20:51:26 +0000 Subject: [PATCH 7/9] #0: Fix UB - access out of bounds in SqueezeOperation --- .../data_movement/squeeze/squeeze.cpp | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp b/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp index e622da4918d..88c79d6ca8d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp @@ -9,8 +9,8 @@ namespace ttnn::operations::data_movement { ttnn::Tensor SqueezeOperation::invoke(const ttnn::Tensor& input_tensor, const int dim) { - const auto original_logical_shape = input_tensor.get_shape(); - const auto padded_shape = input_tensor.get_shape().with_tile_padding(); + const auto original_logical_shape = input_tensor.get_logical_shape(); + const auto padded_shape = input_tensor.get_padded_shape(); const auto input_tensor_rank = original_logical_shape.rank(); int normal_dim = dim; @@ -19,22 +19,16 @@ ttnn::Tensor SqueezeOperation::invoke(const ttnn::Tensor& input_tensor, const in normal_dim += input_tensor_rank; } - SmallVector original_logical_shape_vector(input_tensor_rank - 1); - SmallVector padded_shape_vector(input_tensor_rank - 1); - uint32_t vector_id = 0; - for(int i=0; i< input_tensor_rank; i++) { - if(i != normal_dim or original_logical_shape[i] != 1) { - original_logical_shape_vector[vector_id] = original_logical_shape[i]; - padded_shape_vector[vector_id] = padded_shape[i]; - vector_id++; - } - } - // If dim is out of range or original dimension was not of size 1, include all dimensions - if (normal_dim >= static_cast(original_logical_shape.size()) || original_logical_shape[normal_dim] != 1) { + if (normal_dim < 0 || normal_dim >= original_logical_shape.rank() || original_logical_shape[normal_dim] != 1) { return input_tensor; } + SmallVector original_logical_shape_vector(original_logical_shape.cbegin(), original_logical_shape.cend()); + SmallVector padded_shape_vector(padded_shape.cbegin(), padded_shape.cend()); + original_logical_shape_vector.erase(original_logical_shape_vector.begin() + normal_dim); + padded_shape_vector.erase(padded_shape_vector.begin() + normal_dim); + return ttnn::reshape(input_tensor, ttnn::Shape(original_logical_shape_vector, padded_shape_vector)); } From b06ec4d1b33d4b512bdc80f815cf49f5a5372026 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Fri, 25 Oct 2024 22:17:34 +0000 Subject: [PATCH 8/9] #0: Review fixes --- ttnn/CMakeLists.txt | 7 ++++++- .../device/moreh_getitem_device_operation.cpp | 2 +- ttnn/cpp/ttnn/tensor/small_vector.hpp | 7 ++++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 873296beece..cd9d099b94d 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -371,7 +371,6 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/device/uniform_program_factory.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam_pybind.cpp @@ -619,6 +618,12 @@ target_compile_options( -fno-var-tracking ) +if(WITH_PYTHON_BINDINGS) + target_compile_definitions(ttnn PUBLIC TTNN_WITH_PYTHON_BINDINGS=1) +else() + target_compile_definitions(ttnn PUBLIC TTNN_WITH_PYTHON_BINDINGS=0) +endif() + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") target_compile_definitions(ttnn PUBLIC DISABLE_NAMESPACE_STATIC_ASSERT) endif() diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp index 8ed30508837..ac71f8c1601 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp @@ -55,7 +55,7 @@ void MorehGetItemOperation::validate_inputs( TT_FATAL( dim_start + i == dim, "The value of index_dims={} must be consecutive integers.", - std::vector(operation_attributes.index_dims.begin(), operation_attributes.index_dims.end())); + operation_attributes.index_dims); i++; } if (!output_tensor.has_value()) { diff --git a/ttnn/cpp/ttnn/tensor/small_vector.hpp b/ttnn/cpp/ttnn/tensor/small_vector.hpp index 6fae1d164b3..0680b3ad476 100644 --- a/ttnn/cpp/ttnn/tensor/small_vector.hpp +++ b/ttnn/cpp/ttnn/tensor/small_vector.hpp @@ -5,10 +5,13 @@ #pragma once #include -#include #include "tt_metal/tt_stl/reflection.hpp" +#if TTNN_WITH_PYTHON_BINDINGS +#include +#endif + namespace ttnn { static constexpr size_t SMALL_VECTOR_SIZE = 8; @@ -55,7 +58,9 @@ struct fmt::formatter> { } }; +#if TTNN_WITH_PYTHON_BINDINGS namespace PYBIND11_NAMESPACE { namespace detail { template struct type_caster> : list_caster, T> {}; }} +#endif From 68195e002fe807c0a4d0eaeedbea2fba3662c659 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Sat, 26 Oct 2024 00:43:34 +0000 Subject: [PATCH 9/9] #0: Use tt::stl::Span --- best_practices.md | 14 +++---- ttnn/cpp/pybind11/pytensor.cpp | 2 +- .../ttnn/operations/data_movement/pad/pad.cpp | 2 +- .../ttnn/operations/data_movement/pad/pad.hpp | 2 +- .../data_movement/permute/permute.cpp | 12 +++--- .../data_movement/permute/permute.hpp | 6 +-- .../reshape_on_device/reshape.cpp | 6 +-- .../reshape_on_device/reshape.hpp | 6 +-- .../data_movement/reshape_view/reshape.cpp | 2 +- .../data_movement/reshape_view/reshape.hpp | 2 +- .../operations/data_movement/slice/slice.cpp | 42 +++++++++---------- .../operations/data_movement/slice/slice.hpp | 16 +++---- .../fast_reduce_nc_device_operation.cpp | 2 +- .../fast_reduce_nc_device_operation.hpp | 2 +- .../fast_reduce_nc/fast_reduce_nc.cpp | 4 +- .../fast_reduce_nc/fast_reduce_nc.hpp | 4 +- .../moreh_sum_backward_device_operation.cpp | 2 +- .../moreh_sum_backward_device_operation.hpp | 2 +- ttnn/cpp/ttnn/tensor/tensor_impl.hpp | 4 +- ttnn/cpp/ttnn/tensor/tensor_utils.cpp | 2 +- ttnn/cpp/ttnn/tensor/tensor_utils.hpp | 4 +- ttnn/cpp/ttnn/tensor/types.cpp | 6 +-- ttnn/cpp/ttnn/tensor/types.hpp | 26 ++++++------ 23 files changed, 85 insertions(+), 85 deletions(-) diff --git a/best_practices.md b/best_practices.md index 86060de6406..d3be5fd3713 100644 --- a/best_practices.md +++ b/best_practices.md @@ -18,13 +18,13 @@ void write_buffer(queue_id cq_id, Tensor& dst, std::vector void write_buffer(queue_id cq_id, Tensor& dst, const std::vector>& src, const std::optional& transfer_size = std::nullopt); // Right! ``` -## 2. Use `std::span` for Input Parameters +## 2. Use `tt::stl::Span` for Input Parameters ### Practice -Consider using `std::span` as input instead of `std::vector`. This allows `std::array` to be used as an argument as well. +Consider using `tt::stl::Span` as input instead of `std::vector`. This allows `std::array` to be used as an argument as well. ### Explanation -`std::span` is a lightweight view over a contiguous sequence of objects, such as arrays and vectors. It provides a safe and flexible way to handle array-like data structures without copying them. +`tt::stl::Spann` is a lightweight view over a contiguous sequence of objects, such as arrays and vectors. It provides a safe and flexible way to handle array-like data structures without copying them. ### Motivation - **Flexibility**: Enables functions to accept both `std::vector` and `std::array`. @@ -33,7 +33,7 @@ Consider using `std::span` as input instead of `std::vector`. This allows `std:: ### Example ``` template -void print_elements(std::span data) { +void print_elements(tt::stl::Span data) { for (const auto& element : data) { std::cout << element << " "; } @@ -217,7 +217,7 @@ Use the Copy-and-Swap idiom to avoid duplicating code between different construc ### Explanation The Copy-and-Swap idiom is a robust and elegant method to implement copy assignment operators. It leverages the copy constructor and the swap method to provide strong exception safety and reduce code duplication. -### Example +### Example https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom @@ -279,7 +279,7 @@ Prefer: enum class ThreadingOption { SingleCore, MultiCore }; tensor = tt::tt_metal::tilize_with_val_padding(tensor, output_shape, 0, output_memory_config, dtype, ThreadingOption::MultiCore); ``` -Also consider giving enums power-of-2 values to pass them all as a single argument, e.g. +Also consider giving enums power-of-2 values to pass them all as a single argument, e.g. ```cpp Options::FOO | Options::BAR ``` @@ -343,7 +343,7 @@ void doSomething(...) { Prefer: ```cpp void doSomething(...) { - if (!contractCheck) + if (!contractCheck) return; // Do a lot of things diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 6fa13b17842..2c02c12b406 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -64,7 +64,7 @@ void log_external_operation( #endif template -Tensor create_owned_tensor(T* data_ptr, size_t num_elements, std::span shape, DataType data_type, Layout layout, const std::optional& optional_tile = std::nullopt) +Tensor create_owned_tensor(T* data_ptr, size_t num_elements, tt::stl::Span shape, DataType data_type, Layout layout, const std::optional& optional_tile = std::nullopt) { auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp index a1bae5e6f95..21d2c93cbe8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp @@ -116,7 +116,7 @@ static ttnn::Tensor pad_impl( ttnn::Tensor ExecutePad::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - std::span> padding, + tt::stl::Span> padding, const float value, const bool use_multicore, const std::optional& memory_config_arg) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp index 5c006212cbf..a6370abc6fe 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp @@ -32,7 +32,7 @@ struct ExecutePad { // Any rank tensor supported static ttnn::Tensor invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, - std::span> padding, + tt::stl::Span> padding, const float value, const bool use_multicore, const std::optional& memory_config_arg); diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp index 9efdcf72c38..29dd73506df 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp @@ -128,7 +128,7 @@ ttnn::Tensor permute_impl(const ttnn::Tensor &a, const SmallVector& di return output; } -ttnn::Tensor permute_launch(const ttnn::Tensor &a, std::span dims, const MemoryConfig& output_mem_config) { +ttnn::Tensor permute_launch(const ttnn::Tensor &a, tt::stl::Span dims, const MemoryConfig& output_mem_config) { std::vector output_tensors = {ttnn::Tensor(operation::get_workers_for_op_output({a}))}; operation::launch_with_autoformat( [dims, output_mem_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { @@ -147,7 +147,7 @@ ttnn::Tensor permute_launch(const ttnn::Tensor &a, std::span dims Tensor composite_invoke( const ttnn::Tensor& input_tensor, - std::span dims, + tt::stl::Span dims, const std::optional& memory_config) { auto output_tensor = permute_launch(input_tensor, dims, memory_config.value_or(input_tensor.memory_config())); @@ -159,7 +159,7 @@ Tensor composite_invoke( ttnn::Tensor ExecutePermute::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - std::span dims, + tt::stl::Span dims, const std::optional& memory_config, bool composite) { @@ -175,7 +175,7 @@ ttnn::Tensor ExecutePermute::invoke( input_rank == dims.size(), "The number of dimensions in the tensor input does not match the length of the desired ordering"); - auto adjust_order = [](std::span dims) { + auto adjust_order = [](tt::stl::Span dims) { ttnn::SmallVector new_order; TT_FATAL(dims.size() <= 4, "Error"); int additional_ranks = 4 - dims.size(); @@ -218,12 +218,12 @@ ttnn::Tensor ExecutePermute::invoke( ttnn::Tensor ExecutePermute::invoke( const ttnn::Tensor& input_tensor, - std::span dims, + tt::stl::Span dims, const std::optional& memory_config) { return invoke(DefaultQueueId, input_tensor, dims, memory_config); } -ttnn::Tensor ExecutePermute::invoke(const ttnn::Tensor& input_tensor, std::span dims) { +ttnn::Tensor ExecutePermute::invoke(const ttnn::Tensor& input_tensor, tt::stl::Span dims) { return invoke(input_tensor, dims, std::nullopt); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp index a1b37e4994c..04f5231956b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp @@ -13,16 +13,16 @@ struct ExecutePermute { static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - std::span dims, + tt::stl::Span dims, const std::optional& memory_config, bool composite = true); static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - std::span dims, + tt::stl::Span dims, const std::optional& memory_config); - static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, std::span dims); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, tt::stl::Span dims); }; } // namespace operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp index 6b479264c73..c35ea040c4e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp @@ -94,15 +94,15 @@ ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, const tt return invoke(DefaultQueueId, input_tensor, shape, std::nullopt); } -ttnn::Tensor ReshapeOperation::invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, std::span shape_vector, const std::optional& memory_config_arg) { +ttnn::Tensor ReshapeOperation::invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, tt::stl::Span shape_vector, const std::optional& memory_config_arg) { return invoke(queue_id, input_tensor, ttnn::Shape(infer_dims_for_reshape(input_tensor, shape_vector).view()), memory_config_arg); } -ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, std::span shape_vector, const std::optional& memory_config_arg) { +ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, tt::stl::Span shape_vector, const std::optional& memory_config_arg) { return invoke(DefaultQueueId, input_tensor, shape_vector, memory_config_arg); } -ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, std::span shape_vector) { +ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, tt::stl::Span shape_vector) { return invoke(input_tensor, shape_vector, std::nullopt); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp index cdd34ac34df..983af7ff34d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp @@ -24,9 +24,9 @@ struct ReshapeOperation { static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape); - static ttnn::Tensor invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, std::span shape_vector, const std::optional& memory_config_arg); - static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, std::span shape_vector, const std::optional& memory_config_arg); - static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, std::span shape_vector); + static ttnn::Tensor invoke(uint8_t queue_id, const ttnn::Tensor& input_tensor, tt::stl::Span shape_vector, const std::optional& memory_config_arg); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, tt::stl::Span shape_vector, const std::optional& memory_config_arg); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, tt::stl::Span shape_vector); }; diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index ebbe216841d..5dd40c16db7 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -131,7 +131,7 @@ ttnn::Tensor ReshapeViewOperation::invoke(const ttnn::Tensor& tensor, const ttnn ttnn::Tensor ReshapeViewOperation::invoke( const ttnn::Tensor& tensor, - std::span shape_vector + tt::stl::Span shape_vector ) { return invoke(tensor, tt::tt_metal::infer_dims_for_reshape(tensor, shape_vector)); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp index 38c49b5ff32..ad1a1970804 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp @@ -13,7 +13,7 @@ namespace operations::data_movement { struct ReshapeViewOperation { static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape); static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& logical_shape); - static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, std::span shape_vector); + static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, tt::stl::Span shape_vector); }; diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp index b1068c9b04c..234f2f51c82 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.cpp @@ -19,9 +19,9 @@ template ttnn::Tensor SliceOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, + tt::stl::Span begins, + tt::stl::Span ends, + tt::stl::Span step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor) { @@ -181,9 +181,9 @@ ttnn::Tensor SliceOperation::invoke( template ttnn::Tensor SliceOperation::invoke( const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, + tt::stl::Span begins, + tt::stl::Span ends, + tt::stl::Span step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor) { return SliceOperation::invoke(ttnn::DefaultQueueId, input_tensor, begins, ends, step, memory_config_arg); @@ -306,9 +306,9 @@ ttnn::Tensor SliceOperation::invoke( const std::array &step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor) { - std::span start(output_tensor_start.begin(), output_tensor_start.end()); - std::span end(output_tensor_end.begin(), output_tensor_end.end()); - std::span step_vec(step.begin(), step.end()); + tt::stl::Span start(output_tensor_start.begin(), output_tensor_start.end()); + tt::stl::Span end(output_tensor_end.begin(), output_tensor_end.end()); + tt::stl::Span step_vec(step.begin(), step.end()); return SliceOperation::invoke(queue_id, input_tensor, start, end, step_vec, memory_config_arg); } @@ -326,17 +326,17 @@ ttnn::Tensor SliceOperation::invoke( template ttnn::Tensor SliceOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, + tt::stl::Span begins, + tt::stl::Span ends, + tt::stl::Span step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor); template ttnn::Tensor SliceOperation::invoke( const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, + tt::stl::Span begins, + tt::stl::Span ends, + tt::stl::Span step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor); @@ -344,17 +344,17 @@ template ttnn::Tensor SliceOperation::invoke( template ttnn::Tensor SliceOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, + tt::stl::Span begins, + tt::stl::Span ends, + tt::stl::Span step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor); template ttnn::Tensor SliceOperation::invoke( const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, + tt::stl::Span begins, + tt::stl::Span ends, + tt::stl::Span step, const std::optional& memory_config_arg, const std::optional& optional_output_tensor); diff --git a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp index b4c3b951509..afb0ed5b0b8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp @@ -15,18 +15,18 @@ struct SliceOperation { static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - std::span begins, - std::span ends, - std::span step, + tt::stl::Span begins, + tt::stl::Span ends, + tt::stl::Span step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); template static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - std::span output_tensor_start, - std::span output_tensor_end, - std::span step, + tt::stl::Span output_tensor_start, + tt::stl::Span output_tensor_end, + tt::stl::Span step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt); @@ -39,7 +39,7 @@ struct SliceOperation { const ttnn::SmallVector& step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt) { - return invoke(queue_id, input_tensor, std::span(begins.begin(), begins.end()), std::span(ends.begin(), ends.end()), std::span(step.begin(), step.end()), memory_config_arg, optional_output_tensor); + return invoke(queue_id, input_tensor, tt::stl::Span(begins), tt::stl::Span(ends), tt::stl::Span(step), memory_config_arg, optional_output_tensor); } template @@ -50,7 +50,7 @@ struct SliceOperation { const ttnn::SmallVector& step, const std::optional& memory_config_arg = std::nullopt, const std::optional& optional_output_tensor = std::nullopt) { - return invoke(input_tensor, std::span(begins.begin(), begins.end()), std::span(ends.begin(), ends.end()), std::span(step.begin(), step.end()), memory_config_arg, optional_output_tensor); + return invoke(input_tensor, tt::stl::Span(begins), tt::stl::Span(ends), tt::stl::Span(step), memory_config_arg, optional_output_tensor); } template diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.cpp b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.cpp index d794d3adda1..e4ddeaf7d9e 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.cpp @@ -97,7 +97,7 @@ operation::ProgramWithCallbacks FastReduceNCDeviceOperation::create_program( Tensor fast_reduce_nc( uint8_t queue_id, const ttnn::Tensor& input, - std::span dims, + tt::stl::Span dims, const std::optional output, const MemoryConfig& output_mem_config, std::optional compute_kernel_config) { diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.hpp b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.hpp index b19e85abd86..deb11acd213 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_device_operation.hpp @@ -29,7 +29,7 @@ struct FastReduceNCDeviceOperation { Tensor fast_reduce_nc( uint8_t queue_id, const ttnn::Tensor &input, - std::span dims, + tt::stl::Span dims, const std::optional output = std::nullopt, const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::optional compute_kernel_config = std::nullopt); diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.cpp b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.cpp index 8b7560cd494..e9eb7ecd802 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.cpp @@ -14,7 +14,7 @@ namespace operations::experimental::reduction{ ttnn::Tensor FastReduceNCOperation::invoke( uint8_t queue_id, const ttnn::Tensor& input, - std::span dims, + tt::stl::Span dims, const std::optional output, const ttnn::MemoryConfig memory_config, std::optional compute_kernel_config) { @@ -23,7 +23,7 @@ ttnn::Tensor FastReduceNCOperation::invoke( ttnn::Tensor FastReduceNCOperation::invoke( const ttnn::Tensor& input, - std::span dims, + tt::stl::Span dims, const std::optional output, const ttnn::MemoryConfig memory_config, std::optional compute_kernel_config) { diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.hpp b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.hpp index 0a731a40838..42decb2cac6 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/fast_reduce_nc.hpp @@ -16,14 +16,14 @@ struct FastReduceNCOperation { static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input, - std::span dims, + tt::stl::Span dims, const std::optional output, const ttnn::MemoryConfig memory_config, std::optional compute_kernel_config); static ttnn::Tensor invoke( const ttnn::Tensor& input, - std::span dims, + tt::stl::Span dims, const std::optional output, const ttnn::MemoryConfig memory_config, std::optional compute_kernel_config); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.cpp index 93f4bcc7c3f..a07da8910c3 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.cpp @@ -120,7 +120,7 @@ std::tuple& input, - std::span dims, + tt::stl::Span dims, bool keepdim, const std::optional& input_grad, const std::optional& memory_config, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.hpp index 93fc4ccfa67..e848e61fa2b 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_device_operation.hpp @@ -57,7 +57,7 @@ struct MorehSumBackwardOperation { static std::tuple invoke( const Tensor& output_grad, const std::optional& input, - std::span dims, + tt::stl::Span dims, bool keepdim, const std::optional& input_grad, const std::optional& memory_config, diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index 94a676a4fe3..4e50cf9c29d 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -212,7 +212,7 @@ inline std::vector convert_layout_row_major_to_tile(const tt::tt_metal::Legac auto tile_shape = ttnn::SmallVector{ tile.get_tile_shape()[0], tile.get_tile_shape()[1] }; auto face_shape = ttnn::SmallVector{ tile.get_face_shape()[0], tile.get_face_shape()[1] }; return convert_layout( - data_to_convert, std::span(shape.begin(), shape.end()), TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES, tile_shape, face_shape); + data_to_convert, tt::stl::Span(shape.begin(), shape.end()), TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES, tile_shape, face_shape); } template typename BufferType> @@ -220,7 +220,7 @@ inline std::vector convert_layout_tile_to_row_major(const tt::tt_metal::Legac auto tile_shape = ttnn::SmallVector{ tile.get_tile_shape()[0], tile.get_tile_shape()[1] }; auto face_shape = ttnn::SmallVector{ tile.get_face_shape()[0], tile.get_face_shape()[1] }; return convert_layout( - data_to_convert, std::span(shape.begin(), shape.end()), TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR, tile_shape, face_shape); + data_to_convert, tt::stl::Span(shape.begin(), shape.end()), TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR, tile_shape, face_shape); } // ====================================================================================== diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index 2299abefd18..0c4b05e4307 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -456,7 +456,7 @@ Tensor convert_conv_weight_tensor_to_depthwise_layout( TT_THROW("Unsupported weight data type given when trying to add zero padding to weight tensor"); } -const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, std::span shape) { +const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, tt::stl::Span shape) { int64_t old_volume = tensor.get_logical_volume(); int64_t new_volume = 1; int64_t index_of_negative_1 = -1; diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp index 82f51fea057..733a764ce99 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp @@ -34,7 +34,7 @@ Tensor convert_conv_weight_tensor_to_grouped_layout(Tensor conv_weight_tensor, u // Converts convolution weights to depthwise layout with broadcasted weights Tensor convert_conv_weight_tensor_to_depthwise_layout(Tensor conv_weight_tensor, uint32_t act_block_h_ntiles, DataType output_dtype); -const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, std::span shape); +const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, tt::stl::Span shape); // TODO: Remove this once we switch to SimpleShape .volume() static std::size_t compute_volume(const tt::tt_metal::LegacyShape& shape) { @@ -64,7 +64,7 @@ static ttnn::SmallVector compute_strides(const ttnn::SimpleShape& shap return strides; } -static int compute_flat_indices(std::span indices, std::span strides) { +static int compute_flat_indices(tt::stl::Span indices, tt::stl::Span strides) { int flat_index = 0; for (auto i = 0; i < indices.size(); i++) { flat_index += indices[i] * strides[i]; diff --git a/ttnn/cpp/ttnn/tensor/types.cpp b/ttnn/cpp/ttnn/tensor/types.cpp index 79d772718b3..729593cfd72 100644 --- a/ttnn/cpp/ttnn/tensor/types.cpp +++ b/ttnn/cpp/ttnn/tensor/types.cpp @@ -116,7 +116,7 @@ Padding::Padding(const std::initializer_list pad_dimensions, PadVa std::copy(std::begin(pad_dimensions), std::end(pad_dimensions), std::begin(this->pad_dimensions_)); } -Padding::Padding(std::span pad_dimensions, PadValue pad_value) : +Padding::Padding(tt::stl::Span pad_dimensions, PadValue pad_value) : rank_(pad_dimensions.size()), pad_dimensions_{}, pad_value_(pad_value) { std::copy(std::begin(pad_dimensions), std::end(pad_dimensions), std::begin(this->pad_dimensions_)); } @@ -166,7 +166,7 @@ LegacyShape::LegacyShape(const std::initializer_list dimensions) : rank_(dimensions.size()), dimensions_{}, padding_(dimensions.size()) { std::copy(std::begin(dimensions), std::end(dimensions), std::begin(this->dimensions_)); } -LegacyShape::LegacyShape(std::span dimensions) : +LegacyShape::LegacyShape(tt::stl::Span dimensions) : rank_(dimensions.size()), dimensions_{}, padding_(dimensions.size()) { std::copy(std::begin(dimensions), std::end(dimensions), std::begin(this->dimensions_)); } @@ -176,7 +176,7 @@ LegacyShape::LegacyShape(const std::initializer_list dimensions, const TT_ASSERT(this->padding_.rank_ == this->rank_); std::copy(std::begin(dimensions), std::end(dimensions), std::begin(this->dimensions_)); } -LegacyShape::LegacyShape(std::span dimensions, const Padding& padding) : +LegacyShape::LegacyShape(tt::stl::Span dimensions, const Padding& padding) : rank_(dimensions.size()), dimensions_{}, padding_(padding) { TT_ASSERT(this->padding_.rank_ == this->rank_); std::copy(std::begin(dimensions), std::end(dimensions), std::begin(this->dimensions_)); diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index 1d11463f00f..617f6eeba37 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -10,7 +10,6 @@ #include #include #include -#include #include "common/bfloat16.hpp" #include "tt_metal/common/core_coord.hpp" @@ -18,6 +17,7 @@ #include "tt_metal/impl/device/device.hpp" #include "tt_metal/tt_stl/concepts.hpp" #include "tt_metal/tt_stl/reflection.hpp" +#include "tt_metal/tt_stl/span.hpp" #include "ttnn/tensor/host_buffer/types.hpp" #include "ttnn/tensor/small_vector.hpp" #include "ttnn/cpp/ttnn/tensor/enum_types.hpp" @@ -59,7 +59,7 @@ class SimpleShape final { auto cbegin() const { return value.cbegin(); } auto cend() const { return value.cend(); } - std::span view() const { return value; } + tt::stl::Span view() const { return value; } // Needed for reflect / fmt static constexpr auto attribute_names = std::forward_as_tuple("value"); @@ -170,7 +170,7 @@ struct Padding { Padding(const std::size_t rank); Padding(const std::initializer_list pad_dimensions, PadValue pad_value); - Padding(std::span pad_dimensions, PadValue pad_value); + Padding(tt::stl::Span pad_dimensions, PadValue pad_value); template Padding(const std::array, Rank> pad_dimensions, PadValue pad_value) : @@ -242,11 +242,11 @@ class LegacyShape { ~LegacyShape() = default; LegacyShape(const std::initializer_list); - LegacyShape(std::span); - LegacyShape(const ttnn::SmallVector& vec) : LegacyShape(std::span(vec)) {}; + LegacyShape(tt::stl::Span); + LegacyShape(const ttnn::SmallVector& vec) : LegacyShape(tt::stl::Span(vec)) {}; LegacyShape(const std::initializer_list, const Padding &); - LegacyShape(std::span, const Padding &); - LegacyShape(const ttnn::SmallVector& vec, const Padding &padding) : LegacyShape(std::span(vec), padding) {}; + LegacyShape(tt::stl::Span, const Padding &); + LegacyShape(const ttnn::SmallVector& vec, const Padding &padding) : LegacyShape(tt::stl::Span(vec), padding) {}; explicit LegacyShape(const LegacyShape &, const Padding &); @@ -272,7 +272,7 @@ class LegacyShape { this->padding_[index] = {.front = 0, .back = padded_dimension - shape[index]}; } } - explicit LegacyShape(std::span shape, std::span shape_with_tile_padding) : + explicit LegacyShape(tt::stl::Span shape, tt::stl::Span shape_with_tile_padding) : rank_(shape.size()), dimensions_{}, padding_{shape.size()} { TT_ASSERT( shape.size() == shape_with_tile_padding.size(), @@ -284,7 +284,7 @@ class LegacyShape { } } explicit LegacyShape(const ttnn::SmallVector& shape, const ttnn::SmallVector& shape_with_tile_padding) - : LegacyShape(std::span(shape), std::span(shape_with_tile_padding)) {} + : LegacyShape(tt::stl::Span(shape), tt::stl::Span(shape_with_tile_padding)) {} explicit LegacyShape(const std::initializer_list shape, const std::initializer_list shape_with_tile_padding) : LegacyShape(ttnn::SmallVector(shape), ttnn::SmallVector(shape_with_tile_padding)) {} @@ -487,7 +487,7 @@ static tt::tt_metal::LegacyShape compute_ttl_shape( for (auto index = 0; index < Rank; index++) { ttl_shape[index] = shape[index] + padding[index][0] + padding[index][1]; } - return tt::tt_metal::LegacyShape{std::span(ttl_shape), tt::tt_metal::Padding{padding, tt::tt_metal::Padding::PadValue::Any}}; + return tt::tt_metal::LegacyShape{tt::stl::Span(ttl_shape), tt::tt_metal::Padding{padding, tt::tt_metal::Padding::PadValue::Any}}; } } // namespace detail @@ -514,17 +514,17 @@ struct Shape { const std::array &shape, const std::array, Rank> &tile_padding) : value{detail::compute_ttl_shape(shape, tile_padding)} {} - Shape(std::span shape) : value{tt::tt_metal::LegacyShape{shape}} {} + Shape(tt::stl::Span shape) : value{tt::tt_metal::LegacyShape{shape}} {} Shape(const SmallVector& shape) : value{tt::tt_metal::LegacyShape{shape}} {} - explicit Shape(std::span shape, std::span shape_with_tile_padding) : + explicit Shape(tt::stl::Span shape, tt::stl::Span shape_with_tile_padding) : value{tt::tt_metal::LegacyShape{shape, shape_with_tile_padding}} {} explicit Shape(const std::initializer_list shape, const std::initializer_list shape_with_tile_padding) : value{tt::tt_metal::LegacyShape{shape, shape_with_tile_padding}} {} - explicit Shape(std::span shape, const Padding &padding) : + explicit Shape(tt::stl::Span shape, const Padding &padding) : value{tt::tt_metal::LegacyShape{shape, padding}} {} explicit Shape(const Shape &shape, const Padding &padding) :