Skip to content

Commit

Permalink
#8118: ported ttnn::embedding to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed May 21, 2024
1 parent e68f0ce commit d716aab
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 100 deletions.
4 changes: 4 additions & 0 deletions ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "normalization.hpp"
#include "kv_cache.hpp"
#include "pool.hpp"
#include "embedding.hpp"

namespace py = pybind11;

Expand All @@ -34,6 +35,9 @@ void py_module(py::module& module) {
auto m_core = module.def_submodule("core", "core operations");
core::py_module(m_core);

auto m_embedding = module.def_submodule("embedding", "embedding operations");
embedding::py_module(m_embedding);

auto m_matmul = module.def_submodule("matmul", "matmul operations");
matmul::py_module(m_matmul);

Expand Down
62 changes: 62 additions & 0 deletions ttnn/cpp/pybind11/operations/embedding.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/embedding.hpp"

namespace py = pybind11;

namespace ttnn {
namespace operations {
namespace embedding {

void py_module(py::module& module) {
bind_registered_operation(
module,
ttnn::embedding,
R"doc(
embedding(inxput_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, pad_token: Optional[int] = None, layout: ttnn.Layout = ttnn.ROW_MAJOR_LAYOUT, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Retrieves word embeddings using input_tensor. The input_tensor is a list of indices, and the embedding matrix, and the output is the corresponding word embeddings.
Args:
* :attr:`input_tensor`: the indices ttnn.Tensor
* :attr:`weight`: the embeddings ttnn.Tensor that correspond to the indices ttnn.Tensor
Keyword Args:
* :attr:`pad_token`: the padding token. Default is None.
* :attr:`layout`: the layout of the input and output tensors. Default is ttnn.ROW_MAJOR_LAYOUT.
* :attr:`memory_config`: the memory configuration of the output tensor. Default is ttnn.DRAM_MEMORY_CONFIG.
Example::
>>> device_id = 0
>>> device = ttnn.open_device(device_id=device_id)
>>> input_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]), dtype=ttnn.uint32), device)
>>> # an embedding matrix containing 10 tensors of size 4
>>> weight = ttnn.to_device(ttnn.from_torch(torch.rand(10, 4), dtype=ttnn.bfloat16), device)
>>> ttnn.embedding(input_tensor, weight)
ttnn.Tensor([ [[1, 0.106445, 0.988281, 0.59375],
[0.212891, 0.964844, 0.199219, 0.996094],
[3.78362e-38, 0, 7.89785e-39, 0],
[8.04479e-38, 0, 1.25815e-38, 0]],
[[2.71833e-38, 0, 3.59995e-38, 0],
[7.60398e-38, 0, 1.83671e-38, 0],
[2.22242e-38, 0, 1.88263e-38, 0],
[1.35917e-38, 0, 4.49994e-39, 0]]], dtype=bfloat16 ))doc",
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::arg("weight"),
py::arg("pad_token") = std::nullopt,
py::arg("layout") = ttnn::ROW_MAJOR_LAYOUT,
py::arg("memory_config") = std::nullopt});
}

} // namespace embedding
} // namespace operations
} // namespace ttnn
72 changes: 72 additions & 0 deletions ttnn/cpp/ttnn/operations/embedding.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "tt_eager/tt_dnn/op_library/embeddings/embeddings_op.hpp"
#include "tt_eager/tt_dnn/op_library/run_operation.hpp"
#include "ttnn/decorators.hpp"
#include "ttnn/operations/core.hpp"
#include "ttnn/validation.hpp"

namespace ttnn {

namespace operations {

namespace embedding {

using EmbeddingsType = tt::tt_metal::EmbeddingsType;

struct Embedding {
static const std::array<ttnn::TensorSchema, 2> input_tensor_schemas() {
return {
ttnn::TensorSchema{2, 2, {ttnn::uint32}, {ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false},
ttnn::TensorSchema{2, 4, {ttnn::bfloat16}, {ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}};
}

template <typename... Args>
static auto input_tensors_to_validate(const Tensor& input_tensor, const Tensor& weight, Args&&... args) {
return std::make_tuple(input_tensor, weight);
}

static Tensor execute(
const Tensor& input_tensor_arg,
const Tensor& weight_arg,
const std::optional<int>& pad_token = std::nullopt,
const Layout& layout = ttnn::ROW_MAJOR_LAYOUT,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
auto embeddings_type = EmbeddingsType::GENERIC;
if (pad_token.has_value()) {
embeddings_type = EmbeddingsType::PADDED;
}

auto hidden_embedding_dim = weight_arg.get_shape()[-1];
auto padded_hidden_embedding_dim = weight_arg.get_shape().with_tile_padding()[-1];
auto weight = ttnn::unsqueeze_to_4D(weight_arg);

auto batch_size = input_tensor_arg.get_shape()[0];
auto sentence_size = input_tensor_arg.get_shape()[-1];
auto input_tensor = ttnn::reshape(input_tensor_arg, ttnn::Shape{{batch_size, 1, 1, sentence_size}});

bool tilized = layout == ttnn::TILE_LAYOUT;
auto embeddings = operation::run(
tt::tt_metal::Embeddings{
.output_mem_config = memory_config.value_or(input_tensor.memory_config()),
.tilized = tilized,
.embeddings_type = embeddings_type,
.pad_token = pad_token,
.output_dtype = weight.get_dtype()},
{input_tensor, weight})
.at(0);
embeddings = ttnn::reshape(embeddings, ttnn::Shape{{batch_size, sentence_size, hidden_embedding_dim}});
return embeddings;
}
};

} // namespace embedding
} // namespace operations

constexpr auto embedding = ttnn::register_operation<ttnn::operations::embedding::Embedding>("ttnn::embedding");

} // namespace ttnn
5 changes: 2 additions & 3 deletions ttnn/cpp/ttnn/operations/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,7 @@ struct RotaryEmbedding : public tt::tt_metal::RotaryEmbedding {
uint32_t B = input_tensor.get_legacy_shape()[0];
uint32_t X = input_tensor.get_legacy_shape()[-1];

auto arch = input_tensor.storage_type() == StorageType::DEVICE ? input_tensor.device()->arch()
: AutoFormat::GetDefaultDevice()->arch();
auto arch = input_tensor.device()->arch();
auto kernel_config_val =
init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4, true, false, false);

Expand Down Expand Up @@ -363,7 +362,7 @@ constexpr auto attention_softmax =
constexpr auto attention_softmax_ =
ttnn::register_operation<ttnn::operations::transformer::ExecuteAttentionSoftmax<true>>(
"ttnn::transfomer::attention_softmax_");

} // namespace transformer

} // namespace ttnn
98 changes: 1 addition & 97 deletions ttnn/ttnn/operations/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@

# SPDX-License-Identifier: Apache-2.0

from typing import Tuple, Union, Optional


from loguru import logger

import tt_lib as ttl

import ttnn


Expand All @@ -19,96 +12,7 @@ def _golden_function(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, **_):
return output_tensor


def _embedding_validate_input_tensors(operation_name, input_tensor, weight, *args, **kwargs):
ttnn.validate_input_tensor(
operation_name,
input_tensor,
ranks=(2, 3, 4),
dtypes=(ttnn.uint32, ttnn.bfloat16),
layouts=(ttnn.ROW_MAJOR_LAYOUT,),
can_be_on_device=True,
can_be_on_cpu=False,
)
ttnn.validate_input_tensor(
operation_name,
weight,
ranks=(2, 3, 4),
dtypes=(ttnn.bfloat16,),
layouts=(ttnn.ROW_MAJOR_LAYOUT,),
can_be_on_device=True,
can_be_on_cpu=False,
)


@ttnn.register_operation(
name="ttnn.embedding",
validate_input_tensors=_embedding_validate_input_tensors,
golden_function=_golden_function,
)
def embedding(
input_tensor: ttnn.Tensor,
weight: ttnn.Tensor,
*,
pad_token: Optional[int] = None,
layout: ttnn.Layout = ttnn.ROW_MAJOR_LAYOUT,
memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG,
):
r"""
embedding(inxput_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, pad_token: Optional[int] = None, layout: ttnn.Layout = ttnn.ROW_MAJOR_LAYOUT, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Retrieves word embeddings using input_tensor. The input_tensor is a list of indices, and the embedding matrix, and the output is the corresponding word embeddings.
Args:
* :attr:`input_tensor`: the indices ttnn.Tensor
* :attr:`weight`: the embeddings ttnn.Tensor that correspond to the indices ttnn.Tensor
Keyword Args:
* :attr:`pad_token`: the padding token. Default is None.
* :attr:`layout`: the layout of the input and output tensors. Default is ttnn.ROW_MAJOR_LAYOUT.
* :attr:`memory_config`: the memory configuration of the output tensor. Default is ttnn.DRAM_MEMORY_CONFIG.
Example::
>>> device_id = 0
>>> device = ttnn.open_device(device_id=device_id)
>>> input_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]), dtype=ttnn.uint32), device)
>>> # an embedding matrix containing 10 tensors of size 4
>>> weight = ttnn.to_device(ttnn.from_torch(torch.rand(10, 4), dtype=ttnn.bfloat16), device)
>>> ttnn.embedding(input_tensor, weight)
ttnn.Tensor([ [[1, 0.106445, 0.988281, 0.59375],
[0.212891, 0.964844, 0.199219, 0.996094],
[3.78362e-38, 0, 7.89785e-39, 0],
[8.04479e-38, 0, 1.25815e-38, 0]],
[[2.71833e-38, 0, 3.59995e-38, 0],
[7.60398e-38, 0, 1.83671e-38, 0],
[2.22242e-38, 0, 1.88263e-38, 0],
[1.35917e-38, 0, 4.49994e-39, 0]]], dtype=bfloat16 )
"""

if pad_token is not None:
embeddings_type = ttl.tensor.EmbeddingsType.PADDED
else:
embeddings_type = ttl.tensor.EmbeddingsType.GENERIC

*_, hidden_embedding_dim = weight.shape
*_, padded_hidden_embedding_dim = weight.shape.with_tile_padding()
weight = ttnn.unsqueeze_to_4D(weight)

batch_size, sentence_size = input_tensor.shape
input_tensor = ttnn.reshape(input_tensor, shape=(batch_size, 1, 1, sentence_size))

tilized = layout == ttnn.TILE_LAYOUT
embeddings = ttl.tensor.embeddings(
input_tensor,
weight,
tilized,
embeddings_type=embeddings_type,
pad_token=pad_token,
output_mem_config=memory_config,
)
embeddings = ttnn.reshape(embeddings, shape=(batch_size, sentence_size, hidden_embedding_dim))

return embeddings
embedding = ttnn.register_operation(golden_function=_golden_function)(ttnn._ttnn.operations.embedding.embedding)


__all__ = []

0 comments on commit d716aab

Please sign in to comment.