Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#21 from Superjomn/fea/init-tensor
Browse files Browse the repository at this point in the history
init tensor
  • Loading branch information
Superjomn authored Feb 7, 2020
2 parents b20f50f + e60dcf8 commit 7b15a98
Show file tree
Hide file tree
Showing 16 changed files with 212 additions and 22 deletions.
5 changes: 4 additions & 1 deletion cinn/common/shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ struct Shared {
inline bool operator<(const Shared& other) const { return p_ < other.p_; }
inline bool operator==(const Shared& other) const { return p_ == other.p_; }

~Shared() { DesRef(p_); }
~Shared() {
DesRef(p_);
p_ = nullptr;
}

private:
//! Increase the share count.
Expand Down
6 changes: 3 additions & 3 deletions cinn/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ cc_library(ir SRCS
function_definition.cc
ir_operators.cc
buffer.cc
tensor.cc
#tensor.cc
function_base.cc
operation.cc
#operation.cc
DEPS common boost
)

cc_test(test_ir SRCS ir_test.cc DEPS ir)
cc_test(test_ir_printer SRCS ir_printer_test.cc DEPS ir)
cc_test(test_ir_operators SRCS ir_operators_test.cc DEPS ir)
cc_test(test_tensor SRCS tensor_test.cc DEPS ir)
#cc_test(test_tensor SRCS tensor_test.cc DEPS ir)
cc_test(test_function SRCS function_test.cc DEPS ir)
18 changes: 17 additions & 1 deletion cinn/ir/ir.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "cinn/ir/ir.h"
#include "cinn/common/pod_value.h"
#include "cinn/ir/ir_visitor.h"
#include "cinn/lang/tensor.h"

namespace cinn {

Expand Down Expand Up @@ -219,6 +220,22 @@ Expr Call::Make(Type type,
node->set_type(type);
return Expr(node);
}

void _Tensor_::Accept(IrVisitor *v) const { v->Visit(this); }

lang::Tensor _Tensor_::Make(const std::vector<Expr> &shape,
const std::vector<Var> &iterators,
Type dtype,
ir::Expr expr) {
CHECK_EQ(shape.size(), iterators.size()) << "dimension of the shape and the iterators should match";
auto n = common::make_shared<_Tensor_>();
n->dtype = dtype;
n->shape = shape;
n->expr = expr;
n->iterators = iterators;
return lang::Tensor(n);
}

} // namespace ir

namespace common {
Expand Down Expand Up @@ -247,5 +264,4 @@ Value ToValue<ir::Var>(ir::Var v) {
}

} // namespace common

} // namespace cinn
30 changes: 30 additions & 0 deletions cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
#include "cinn/ir/node.h"

namespace cinn {

namespace poly {
class Element;
} // namespace poly

namespace lang {
class Tensor;
} // namespace lang

namespace ir {

using common::Object;
Expand Down Expand Up @@ -505,6 +514,27 @@ class _IterVar_ : public IrNode {
static const IrNodeTy _node_type_ = IrNodeTy::_Range_;
};

class _Tensor_ : public IrNode {
public:
//! Shape of this tensor.
std::vector<Expr> shape;
//! Data type of this tensor.
Type dtype;
//! The expression that generate this tensor.
ir::Expr expr;
//! The iterators, we store the iterators to name the dimensions for better readability.
std::vector<Var> iterators;
//! Polyhedral element for analysis and schedule.
std::unique_ptr<poly::Element> poly_element;

static lang::Tensor Make(const std::vector<Expr>& shape,
const std::vector<Var>& iterators,
Type dtype,
ir::Expr expr);

void Accept(ir::IrVisitor* v) const override;
};

static IterVar thread_axis(Range dom, const std::string& tag) {
return _IterVar_::Make(dom, Var(tag), IterVarType::kThreadIndex, tag);
}
Expand Down
2 changes: 2 additions & 0 deletions cinn/ir/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class Operation : public FunctionRef {
std::string name;
};

/*
class _Tensor_ : public IrNode {
public:
//! The shape of the tensor.
Expand All @@ -121,6 +122,7 @@ class _Tensor_ : public IrNode {
static const IrNodeTy _node_type_ = IrNodeTy::_Tensor_;
};
*/

class _Operation_ : public ir::FunctionBase {
public:
Expand Down
2 changes: 1 addition & 1 deletion cinn/ir/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ TEST(Tensor, func) {
}

} // namespace ir
} // namespace cinn
} // namespace cinn
9 changes: 8 additions & 1 deletion cinn/lang/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
cc_library(lang SRCS buffer.cc tensor.cc DEPS common ir)
cc_library(lang SRCS
buffer.cc
tensor.cc
compute.cc
DEPS ir poly)

cc_test(test_compute SRCS compute_test.cc DEPS lang)
cc_test(test_tensor2 SRCS tensor_test.cc DEPS lang)
26 changes: 26 additions & 0 deletions cinn/lang/compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "cinn/lang/compute.h"
#include "cinn/poly/dim.h"
#include "cinn/utils/functional.h"

namespace cinn {
namespace lang {

using ir::Expr;

template <>
Tensor Compute<compute_handle_1_t>(const std::vector<int>& dims, compute_handle_1_t handle) {
CHECK_EQ(dims.size(), 1);

poly::Dim dim("i", 0, dims[0] - 1);

Var i("i", Int(32));
auto expr = handle(i);
std::vector<Expr> shape;
for (int v : dims) shape.emplace_back(v);

Tensor tensor(shape, {i}, expr.type(), expr);
return std::move(tensor);
}

} // namespace lang
} // namespace cinn
27 changes: 27 additions & 0 deletions cinn/lang/compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once
#include <functional>
#include <utility>
#include <vector>
#include "cinn/ir/ir.h"
#include "cinn/lang/tensor.h"

namespace cinn {
namespace lang {

using ir::Var;
using compute_handle_1_t = std::function<ir::Expr(Var i)>;
using compute_handle_2_t = std::function<ir::Expr(Var i0, Var i1)>;
using compute_handle_3_t = std::function<ir::Expr(Var i0, Var i1, Var i2)>;
using compute_handle_4_t = std::function<ir::Expr(Var i0, Var i1, Var i2, Var i3)>;

/**
* Compute a Tensor.
* @param dims Dimensions.
* @param iterators
* @param handle
*/
template <typename Fn>
Tensor Compute(const std::vector<int>& dims, Fn handle);

} // namespace lang
} // namespace cinn
7 changes: 7 additions & 0 deletions cinn/lang/compute_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include "cinn/lang/compute.h"
#include <gtest/gtest.h>
#include "cinn/lang/tensor.h"

namespace cinn {
namespace lang {} // namespace lang
} // namespace cinn
20 changes: 20 additions & 0 deletions cinn/lang/tensor.cc
Original file line number Diff line number Diff line change
@@ -1 +1,21 @@
#include "cinn/lang/tensor.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_visitor.h"

namespace cinn {
namespace lang {

Tensor::Tensor(const std::vector<Expr> &shape, const std::vector<Var> &iterators, Type dtype, ir::Expr expr)
: n_(ir::_Tensor_::Make(shape, iterators, dtype, expr)) {}

size_t Tensor::ndims() const { return tensor_p()->shape.size(); }

Expr Tensor::operator()(const std::vector<Expr> &indices) const {
CHECK_EQ(indices.size(), ndims()) << "number of indices not match the dimension";
auto n = ir::Call::Make(tensor_p()->type().ElementOf(), "cinn_buffer_get_element", indices, ir::Call::Halide);
n->set_type(tensor_p()->type());
return n;
}

} // namespace lang
} // namespace cinn
52 changes: 42 additions & 10 deletions cinn/lang/tensor.h
Original file line number Diff line number Diff line change
@@ -1,31 +1,63 @@
#pragma once

#include <cinn/ir/node.h>
#include "cinn/common/common.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/node.h"
#include "cinn/poly/element.h"

namespace cinn {
namespace lang {

using ir::Expr;
using ir::IrNodeRef;
using ir::Type;
using ir::Var;

class _Tensor_;
namespace detail {
constexpr bool LE(int a, int b) { return a <= b; }
constexpr bool GE(int a, int b) { return a >= b; }
} // namespace detail

/**
* @brief Tensor representing a possible input, or intermediate computation result.
*/
class Tensor : public ir::IrNode {
class Tensor : public ir::IrNodeRef {
public:
Tensor() = default;
explicit Tensor(IrNode* n) : n_(n) {}
explicit Tensor(ir::IrNode* n) : n_(n) {}
Tensor(const std::vector<Expr>& shape, const std::vector<Var>& iterators, Type dtype, ir::Expr expr);

//! Get number of dimensions.
inline size_t ndims() const;

inline const ir::_Tensor_* operator->() const { return tensor_p(); }
inline ir::_Tensor_* operator->() { return tensor_p(); }

/**
* Take elements from the tensor.
* This take one or multiple expression as indices.
*/
// @{
Expr operator()(const Expr& a) const { return operator()({a}); }
template <typename... Args>
inline typename std::enable_if<detail::GE(sizeof...(Args), 2), Expr>::type operator()(Args... args) const {
return operator()({std::forward<Args>(args)...});
}
// @}

/**
* Take elements from the tensor.
* @param indices The indices.
* @return The result expression representing a tensor read.
*/
Expr operator()(const std::vector<Expr>& indices) const;

ir::_Tensor_* tensor_p() { return n_->As<ir::_Tensor_>(); }
const ir::_Tensor_* tensor_p() const { return n_->As<ir::_Tensor_>(); }

private:
ir::IrNodeRef n_;
};

class _Tensor_ : public ir::IrNodeRef {
public:
std::vector<Expr> shape;
Type dtype;
};

} // namespace lang
} // namespace cinn
20 changes: 20 additions & 0 deletions cinn/lang/tensor_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "cinn/lang/tensor.h"
#include <gtest/gtest.h>
#include "cinn/ir/ir.h"

namespace cinn {
namespace lang {

TEST(Tensor, basic) {
Expr M(100);
Expr N(20);

Expr x(100);

Var i("i"), j("j");

Tensor tensor({M, N}, x, {i, j}, Float(32));
}

} // namespace lang
} // namespace cinn
2 changes: 1 addition & 1 deletion cinn/poly/element.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void Element::InitSchedule() {
}
}

Element::Element(isl::set domain) : domain_(domain) {
Element::Element(const isl::set &domain) : domain_(domain) {
CHECK(!domain_.is_null());
CHECK(!domain_.is_empty());

Expand Down
2 changes: 1 addition & 1 deletion cinn/poly/element.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace poly {
*/
class Element {
public:
explicit Element(isl::set domain);
explicit Element(const isl::set& domain);

/**
* The id of this element, should be unique across the schedule.
Expand Down
6 changes: 3 additions & 3 deletions docs/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ The layers are as follows
Var i("i"), j("j"), k("k");
Constant N("N"), M("M"), K("K");

PlaceHolder<float> x("x");
PlaceHolder<float> y("x");
PlaceHolder<float> x("x", {M, K});
PlaceHolder<float> y("x", {K, N});

Tensor C = compute({M, N}/*dims*/, [&](Var i, Var j, Var k){
Tensor C = compute({M, N, K}/*dims*/, [&](Var i, Var j, Var k){
return x(i,k) * y(k,j);
});

Expand Down

0 comments on commit 7b15a98

Please sign in to comment.