Skip to content

Commit

Permalink
#8109: ported ttnn.softmax into C++
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed May 8, 2024
1 parent fc84f76 commit 561c4e2
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 302 deletions.
112 changes: 16 additions & 96 deletions tt_eager/tt_dnn/op_library/nlp_tms/nlp_tms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
#pragma once

#include "tensor/tensor.hpp"

#include "tt_dnn/op_library/run_operation.hpp"
#include "ttnn/operations/core.hpp"

namespace tt {

namespace tt_metal {

// input_tensor - qkv tensor if kv_tensor is nullopt, q tensor if kv_tensor is populated
// expectation for both interleaved and sharded implementation is that each q, k and v vector are concatenated across the last dimension (|q_i k_i v_i| is each row of the tensor)
// expectation for both interleaved and sharded implementation is that each q, k and v vector are concatenated across
// the last dimension (|q_i k_i v_i| is each row of the tensor)

// operation::ProgramWithCallbacks multi_core_create_qkv_heads_interleaved(const Tensor &input_tensor_qkv, const uint32_t num_q_heads, const uint32_t num_kv_heads, const uint32_t head_dim, const bool transpose_k_heads, std::vector<Tensor>& output, CoreCoord compute_with_storage_grid_size);
operation::ProgramWithCallbacks multi_core_create_qkv_heads_sharded(const Tensor &input_tensor_qkv, const uint32_t num_q_heads, const uint32_t num_kv_heads, const uint32_t head_dim, const bool transpose_k_heads, std::vector<Tensor>& output, CoreCoord compute_with_storage_grid_size);
Expand Down Expand Up @@ -126,13 +125,20 @@ struct NlpConcatHeadsDecode {
tt::stl::reflection::Attributes attributes() const;
};

inline std::vector<Tensor> nlp_create_qkv_heads_falcon7b(const Tensor &input_tensor_a, const MemoryConfig& mem_config) {
// TODO: hard-coded for falcon-7b; can delete if we switch to the more generic one (but perf may be worse)
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a})), Tensor(operation::get_workers_for_op_output({input_tensor_a})), Tensor(operation::get_workers_for_op_output({input_tensor_a}))};
operation::launch_op(
[mem_config] (std::vector<Tensor> input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors) mutable -> std::vector<Tensor> {
return operation::run(NlpCreateHeadsFalcon7B{mem_config}, input_tensors);
}, {input_tensor_a}, output_tensors);
inline std::vector<Tensor> nlp_create_qkv_heads_falcon7b(const Tensor& input_tensor_a, const MemoryConfig& mem_config) {
// TODO: hard-coded for falcon-7b; can delete if we switch to the more generic one (but perf may be worse)
std::vector<Tensor> output_tensors = {
Tensor(operation::get_workers_for_op_output({input_tensor_a})),
Tensor(operation::get_workers_for_op_output({input_tensor_a})),
Tensor(operation::get_workers_for_op_output({input_tensor_a}))};
operation::launch_op(
[mem_config](
std::vector<Tensor> input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) mutable -> std::vector<Tensor> {
return operation::run(NlpCreateHeadsFalcon7B{mem_config}, input_tensors);
},
{input_tensor_a},
output_tensors);
return output_tensors;
}
inline std::vector<Tensor> nlp_create_qkv_heads_decode(
Expand Down Expand Up @@ -198,89 +204,3 @@ inline Tensor nlp_concat_heads_decode(const Tensor &input_tensor_a, const uint32
} // namespace tt_metal

} // namespace tt

namespace ttnn {
namespace operations {
namespace transformer {

struct ConcatenateHeads : public NlpConcatHeads {
static inline const std::array<TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{
4, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false}};
}

void validate(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
const auto head_size = input_tensor.get_shape()[-1];
const auto padded_head_size = input_tensor.get_legacy_shape()[-1];
TT_FATAL(
head_size % ttnn::types::TILE_SIZE == 0,
fmt::format(
"Head size must be a multiple of {} but was found to be {}! Update matmul that uses the output of this "
"operation to have the "
"padding in the weights!",
ttnn::types::TILE_SIZE,
head_size));
TT_FATAL(padded_head_size - head_size == 0, "Head size cannot have tile padding!");

NlpConcatHeads::validate(input_tensors);
}

std::vector<tt::tt_metal::Shape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
std::vector<tt::tt_metal::Shape> output_shape_vec;
const auto& input_tensor = input_tensors.at(0);
const ttnn::types::Shape input_shape = input_tensor.get_shape();
const ttnn::types::Shape padded_input_shape = input_shape.with_tile_padding();

auto batch_size = input_shape[0];
auto num_heads = input_shape[1];
auto sequence_size = input_shape[2];
auto padded_sequence_size = padded_input_shape[2];
auto head_size = input_shape[3];
auto padded_head_size = padded_input_shape[3];

std::array<uint32_t, 3> intended_output_shape = {batch_size, sequence_size, num_heads * head_size};
std::array<uint32_t, 3> padded_output_shape = {batch_size, padded_sequence_size, num_heads * padded_head_size};
return {ttnn::types::Shape(intended_output_shape, padded_output_shape).value()};
}

std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
if (this->output_mem_config.is_sharded()) {
ShardSpec shard_spec = input_tensor.shard_spec().value();
uint32_t num_cores = shard_spec.num_cores();
uint32_t heads_per_shard = shard_spec.shape[0] / input_tensor.get_legacy_shape()[-2];
shard_spec.shape = {shard_spec.shape[0] / heads_per_shard, shard_spec.shape[1] * heads_per_shard};
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {create_sharded_device_tensor(
this->compute_output_shapes(input_tensors).at(0),
input_tensor.get_dtype(),
Layout::TILE,
input_tensor.device(),
mem_config)};
} else {
return operation::generic_create_output_tensors(
*this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config);
}
}
};

inline Tensor concatenate_heads(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config) {
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};
operation::launch_op(
[memory_config](
std::vector<Tensor> input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) mutable -> std::vector<Tensor> {
auto& input_tensor = input_tensors.at(0);
return operation::run(
ConcatenateHeads{memory_config.value_or(input_tensor.memory_config())}, {input_tensor});
},
{input_tensor},
output_tensors);
return output_tensors.at(0);
}

} // namespace transformer
} // namespace operations
} // namespace ttnn
86 changes: 0 additions & 86 deletions tt_eager/tt_dnn/op_library/softmax/softmax_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "tt_dnn/op_library/operation.hpp"
#include "tt_dnn/op_library/run_operation.hpp"
#include "tt_dnn/op_library/compute_kernel_config.hpp"
#include "ttnn/types.hpp"

namespace tt {
namespace operations {
Expand Down Expand Up @@ -127,88 +126,3 @@ Tensor scale_mask_softmax(const Tensor& input_tensor, std::optional<float> scale
} // namespace tt_metal

} // namespace tt

namespace ttnn {

namespace operations {

namespace transformer {

struct Softmax : public tt::operations::primary::Softmax {
static inline const std::vector<TensorSchema> input_tensor_schemas() {
return {
ttnn::TensorSchema{4, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, false},
ttnn::TensorSchema{4, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT}, true, false, false, true}};
}

void validate(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
const auto& shape = input_tensor.get_shape();

TT_FATAL(
shape.rank() == 4,
fmt::format("Input Tensor must have strictly 4 dimensions however it currently has {}!", shape.rank()));

TT_FATAL(input_tensor.get_layout() == ttnn::TILE_LAYOUT, "Input Tensor must be in a TILE_LAYOUT!");
tt::operations::primary::Softmax::validate(input_tensors, optional_input_tensors);
}
};

inline ttnn::Tensor attention_softmax_(
const ttnn::Tensor& input_tensor,
const std::optional<int> head_size = std::nullopt,
const std::optional<const ttnn::Tensor>& attention_mask = std::nullopt,
const tt::operations::primary::transformers::SoftmaxProgramConfig& program_config =
tt::operations::primary::transformers::SoftmaxDefaultProgramConfig{},
const std::optional<bool> causal_mask = false,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt) {
TT_FATAL(attention_mask.has_value(), "Cannot apply divide by sqrt(head_size) using in-place version!");

// std::vector<Tensor> output_tensors = {
// Tensor(operation::get_workers_for_op_output({input_tensor}, {attention_mask}))};
// std::cout << "Launching attention_softmax_ seems to hang unfortunately." << std::endl;
// operation::launch_op(
// [&input_tensor, &head_size, &attention_mask, &program_config, &causal_mask, &memory_config](
// std::vector<Tensor> input_tensors,
// const std::vector<std::optional<const Tensor>>& optional_input_tensors) mutable -> std::vector<Tensor> {
// std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt;
// auto kernel_config_val = init_device_compute_kernel_config(
// input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false);
// return operation::run(
// Softmax{
// head_size.has_value() ? 1.0 / sqrt(head_size.value()) : 1.0,
// true,
// memory_config.value_or(input_tensor.memory_config()),
// program_config,
// causal_mask.value(),
// kernel_config_val,
// false},
// {input_tensor},
// {attention_mask});
// },
// {input_tensor},
// output_tensors,
// {attention_mask});
// return output_tensors.at(0);

std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt;
auto kernel_config_val = init_device_compute_kernel_config(
input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false);
operation::run(
Softmax{
head_size.has_value() ? 1.0 / sqrt(head_size.value()) : 1.0,
true,
memory_config.value_or(input_tensor.memory_config()),
program_config,
causal_mask.value(),
kernel_config_val,
false},
{input_tensor},
{attention_mask});
return input_tensor;
}
} // namespace transformer
} // namespace operations
} // namespace ttnn
28 changes: 26 additions & 2 deletions ttnn/cpp/pybind11/operations/normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,34 @@ namespace ttnn {
namespace operations {
namespace normalization {
void py_module(py::module& module) {
ttnn::bind_registered_operation(
module,
ttnn::softmax,
R"doc(softmax(input_tensor: ttnn.Tensor, dim: int, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Compute softmax over :attr:`input_tensor` along :attr:`dim`.
Args:
* :attr:`input_tensor`: the input tensor
* :attr:`dim`: the dimension along which to compute softmax.
Keyword Args:
* :attr:`memory_config`: the memory configuration for the output tensor. If not provided, the memory configuration of the input tensor is used.
Example::
>>> tensor = ttnn.to_device(ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16)), device)
>>> output = ttnn.softmax(tensor, -1)
>>> print(output[0, 0, 0, :3])
ttnn.Tensor([ 0.0310059, 0.0310059, 0.0310059], dtype=bfloat16 )
)doc",
ttnn::pybind_arguments_t{
py::arg("input_tensor"), py::arg("dim"), py::kw_only(), py::arg("memory_config") = std::nullopt});

ttnn::bind_registered_operation(
module,
ttnn::layer_norm,
R"doc(rms_norm(input_tensor: ttnn.Tensor, epsilon: float = 1e-12, weight: ttnn.Tensor, bias: ttnn.Tensor, residual_input_tensor: ttnn.Tensor, memory_config: ttnn.MemoryConfig, program_config: ttnn.LayerNormProgramConfig) -> ttnn.Tensor
R"doc(rms_norm(input_tensor: ttnn.Tensor, epsilon: float = 1e-12, weight: Optional[ttnn.Tensor] = None, bias: Optional[ttnn.Tensor] = None, residual_input_tensor: Optional[ttnn.Tensor] = None, memory_config: Optional[ttnn.MemoryConfig] = None, program_config: Optional[ttnn.ProgramConfig] = None) -> ttnn.Tensor
Compute layer_norm over :attr:`input_tensor`.
)doc",
ttnn::pybind_arguments_t{
Expand All @@ -39,7 +63,7 @@ void py_module(py::module& module) {
ttnn::bind_registered_operation(
module,
ttnn::rms_norm,
R"doc(rms_norm(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, epsilon: float = 1e-12, memory_config: ttnn.MemoryConfig) -> ttnn.Tensor
R"doc(rms_norm(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, epsilon: float = 1e-12, Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Compute rms_norm over :attr:`input_tensor`.
)doc",
ttnn::pybind_arguments_t{
Expand Down
63 changes: 37 additions & 26 deletions ttnn/cpp/pybind11/operations/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "../decorators.hpp"
#include "ttnn/operations/transformer.hpp"

namespace py = pybind11;
Expand All @@ -16,32 +17,42 @@ namespace operations {
namespace transformer {

void py_module(py::module& module) {
module.def(
"concatenate_heads",
[](const ttnn::Tensor& input_tensor, const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt)
-> ttnn::Tensor { return ttnn::operations::transformer::concatenate_heads(input_tensor, memory_config); },
py::arg("input_tensor"),
py::kw_only(),
py::arg("memory_config") = std::nullopt);

module.def(
"attention_softmax_",
[](const ttnn::Tensor& tensor,
const std::optional<int> head_size,
const std::optional<const ttnn::Tensor>& attention_mask,
const tt::operations::primary::transformers::SoftmaxProgramConfig& program_config,
const std::optional<bool> causal_mask,
const std::optional<ttnn::MemoryConfig>& memory_config) -> ttnn::Tensor {
return ttnn::operations::transformer::attention_softmax_(
tensor, head_size, attention_mask, program_config, causal_mask, memory_config);
},
py::arg("tensor"),
py::kw_only(),
py::arg("head_size") = std::nullopt,
py::arg("attention_mask") = std::nullopt,
py::arg("program_config").noconvert() = tt::operations::primary::transformers::SoftmaxDefaultProgramConfig{},
py::arg("causal_mask") = false,
py::arg("memory_config") = std::nullopt);
ttnn::bind_registered_operation(
module,
ttnn::transformer::concatenate_heads,
R"doc(concatenate_heads(input_tensor: ttnn.Tensor, *, memory_config: Optional[MemoryConfig] = None) -> ttnn.Tensor
Takes in a tensor of shape ``[batch_size, num_heads, sequence_size, head_size]``, concatenates heads back along the width dimension and returns the tensor of shape ``[batch_size, sequence_size, num_heads * head_size]``
Args:
* :attr:`input_tensor`: Input Tensor
* :attr:`memory_config`: Memory Config of the output tensor, if None then it gets set to input_tensor.memory_config()
)doc",
ttnn::pybind_arguments_t{py::arg("input_tensor"), py::kw_only(), py::arg("memory_config") = std::nullopt});

ttnn::bind_registered_operation(
module,
ttnn::transformer::attention_softmax_,
R"doc(attention_softmax_(tensor: ttnn.Tensor, *, head_size: Optional[int] = None, attention_mask: Optional[ttnn.Tensor] = None, program_config: Optional[SoftmaxProgramConfig] = SoftmaxDefaultProgramConfig(), causal_mask: bool = False, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
In-Place divides :attr:`tensor` by the square root of :attr:`head_size`, adds :attr:`attention_mask` (optionally) and computes softmax.
Args:
* :attr:`tensor`: Input Tensor
* :attr:`head_size`: Number of heads
* :attr:`attention_mask`: Attention Mask
* :attr:`program_config`: Program Config of the output tensor
* :attr:`memory_config`: Memory Config of the output tensor, defaults to input_tensor.memory_config()
)doc",
ttnn::pybind_arguments_t{
py::arg("tensor"),
py::kw_only(),
py::arg("head_size") = std::nullopt,
py::arg("attention_mask") = std::nullopt,
py::arg("program_config").noconvert() =
tt::operations::primary::transformers::SoftmaxDefaultProgramConfig{},
py::arg("causal_mask") = false,
py::arg("memory_config") = std::nullopt});

module.def("split_query_key_value_and_split_heads",
[](const Tensor &input_tensor, const std::optional<Tensor> &input_tensor_kv,
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/op_library/binary/binary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ struct Binary {
template <typename... Args>
static auto input_tensors_to_validate(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) {
return std::make_tuple(input_tensor_a, input_tensor_b);
};
}

static Tensor execute(
const Tensor &input_tensor_a,
Expand Down Expand Up @@ -139,7 +139,7 @@ struct Binary {
template <typename... Args>
static auto input_tensors_to_validate(const Tensor &input_tensor_a, const float input_tensor_b, Args &&...args) {
return std::make_tuple(input_tensor_a, input_tensor_b);
};
}

// TODO: this case should use BinaryWithScalarProgramConfig and there should be a custom kernel to run this
// Currently, this is exactly how tt::tt_metal::add_unary works
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace operations {
namespace core {

struct ToLayout {
static inline const std::vector<TensorSchema> input_tensor_schemas() {
static inline const std::array<TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{
1,
4,
Expand All @@ -40,7 +40,7 @@ struct ToLayout {
template <typename... Args>
static auto input_tensors_to_validate(const Tensor& tensor_arg, Args&&... args) {
return std::make_tuple(tensor_arg);
};
}

static Tensor execute(
const ttnn::Tensor& tensor_arg,
Expand Down
Loading

0 comments on commit 561c4e2

Please sign in to comment.