forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#21 from Superjomn/fea/init-tensor
init tensor
- Loading branch information
Showing
16 changed files
with
212 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,4 +39,4 @@ TEST(Tensor, func) { | |
} | ||
|
||
} // namespace ir | ||
} // namespace cinn | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,8 @@ | ||
cc_library(lang SRCS buffer.cc tensor.cc DEPS common ir) | ||
cc_library(lang SRCS | ||
buffer.cc | ||
tensor.cc | ||
compute.cc | ||
DEPS ir poly) | ||
|
||
cc_test(test_compute SRCS compute_test.cc DEPS lang) | ||
cc_test(test_tensor2 SRCS tensor_test.cc DEPS lang) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#include "cinn/lang/compute.h" | ||
#include "cinn/poly/dim.h" | ||
#include "cinn/utils/functional.h" | ||
|
||
namespace cinn { | ||
namespace lang { | ||
|
||
using ir::Expr; | ||
|
||
template <> | ||
Tensor Compute<compute_handle_1_t>(const std::vector<int>& dims, compute_handle_1_t handle) { | ||
CHECK_EQ(dims.size(), 1); | ||
|
||
poly::Dim dim("i", 0, dims[0] - 1); | ||
|
||
Var i("i", Int(32)); | ||
auto expr = handle(i); | ||
std::vector<Expr> shape; | ||
for (int v : dims) shape.emplace_back(v); | ||
|
||
Tensor tensor(shape, {i}, expr.type(), expr); | ||
return std::move(tensor); | ||
} | ||
|
||
} // namespace lang | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#pragma once | ||
#include <functional> | ||
#include <utility> | ||
#include <vector> | ||
#include "cinn/ir/ir.h" | ||
#include "cinn/lang/tensor.h" | ||
|
||
namespace cinn { | ||
namespace lang { | ||
|
||
using ir::Var; | ||
using compute_handle_1_t = std::function<ir::Expr(Var i)>; | ||
using compute_handle_2_t = std::function<ir::Expr(Var i0, Var i1)>; | ||
using compute_handle_3_t = std::function<ir::Expr(Var i0, Var i1, Var i2)>; | ||
using compute_handle_4_t = std::function<ir::Expr(Var i0, Var i1, Var i2, Var i3)>; | ||
|
||
/** | ||
* Compute a Tensor. | ||
* @param dims Dimensions. | ||
* @param iterators | ||
* @param handle | ||
*/ | ||
template <typename Fn> | ||
Tensor Compute(const std::vector<int>& dims, Fn handle); | ||
|
||
} // namespace lang | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#include "cinn/lang/compute.h" | ||
#include <gtest/gtest.h> | ||
#include "cinn/lang/tensor.h" | ||
|
||
namespace cinn { | ||
namespace lang {} // namespace lang | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,21 @@ | ||
#include "cinn/lang/tensor.h" | ||
#include "cinn/ir/ir.h" | ||
#include "cinn/ir/ir_visitor.h" | ||
|
||
namespace cinn { | ||
namespace lang { | ||
|
||
Tensor::Tensor(const std::vector<Expr> &shape, const std::vector<Var> &iterators, Type dtype, ir::Expr expr) | ||
: n_(ir::_Tensor_::Make(shape, iterators, dtype, expr)) {} | ||
|
||
size_t Tensor::ndims() const { return tensor_p()->shape.size(); } | ||
|
||
Expr Tensor::operator()(const std::vector<Expr> &indices) const { | ||
CHECK_EQ(indices.size(), ndims()) << "number of indices not match the dimension"; | ||
auto n = ir::Call::Make(tensor_p()->type().ElementOf(), "cinn_buffer_get_element", indices, ir::Call::Halide); | ||
n->set_type(tensor_p()->type()); | ||
return n; | ||
} | ||
|
||
} // namespace lang | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,63 @@ | ||
#pragma once | ||
|
||
#include <cinn/ir/node.h> | ||
#include "cinn/common/common.h" | ||
#include "cinn/ir/ir.h" | ||
#include "cinn/ir/node.h" | ||
#include "cinn/poly/element.h" | ||
|
||
namespace cinn { | ||
namespace lang { | ||
|
||
using ir::Expr; | ||
using ir::IrNodeRef; | ||
using ir::Type; | ||
using ir::Var; | ||
|
||
class _Tensor_; | ||
namespace detail { | ||
constexpr bool LE(int a, int b) { return a <= b; } | ||
constexpr bool GE(int a, int b) { return a >= b; } | ||
} // namespace detail | ||
|
||
/** | ||
* @brief Tensor representing a possible input, or intermediate computation result. | ||
*/ | ||
class Tensor : public ir::IrNode { | ||
class Tensor : public ir::IrNodeRef { | ||
public: | ||
Tensor() = default; | ||
explicit Tensor(IrNode* n) : n_(n) {} | ||
explicit Tensor(ir::IrNode* n) : n_(n) {} | ||
Tensor(const std::vector<Expr>& shape, const std::vector<Var>& iterators, Type dtype, ir::Expr expr); | ||
|
||
//! Get number of dimensions. | ||
inline size_t ndims() const; | ||
|
||
inline const ir::_Tensor_* operator->() const { return tensor_p(); } | ||
inline ir::_Tensor_* operator->() { return tensor_p(); } | ||
|
||
/** | ||
* Take elements from the tensor. | ||
* This take one or multiple expression as indices. | ||
*/ | ||
// @{ | ||
Expr operator()(const Expr& a) const { return operator()({a}); } | ||
template <typename... Args> | ||
inline typename std::enable_if<detail::GE(sizeof...(Args), 2), Expr>::type operator()(Args... args) const { | ||
return operator()({std::forward<Args>(args)...}); | ||
} | ||
// @} | ||
|
||
/** | ||
* Take elements from the tensor. | ||
* @param indices The indices. | ||
* @return The result expression representing a tensor read. | ||
*/ | ||
Expr operator()(const std::vector<Expr>& indices) const; | ||
|
||
ir::_Tensor_* tensor_p() { return n_->As<ir::_Tensor_>(); } | ||
const ir::_Tensor_* tensor_p() const { return n_->As<ir::_Tensor_>(); } | ||
|
||
private: | ||
ir::IrNodeRef n_; | ||
}; | ||
|
||
class _Tensor_ : public ir::IrNodeRef { | ||
public: | ||
std::vector<Expr> shape; | ||
Type dtype; | ||
}; | ||
|
||
} // namespace lang | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#include "cinn/lang/tensor.h" | ||
#include <gtest/gtest.h> | ||
#include "cinn/ir/ir.h" | ||
|
||
namespace cinn { | ||
namespace lang { | ||
|
||
TEST(Tensor, basic) { | ||
Expr M(100); | ||
Expr N(20); | ||
|
||
Expr x(100); | ||
|
||
Var i("i"), j("j"); | ||
|
||
Tensor tensor({M, N}, x, {i, j}, Float(32)); | ||
} | ||
|
||
} // namespace lang | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters