Skip to content

Commit

Permalink
add placements attribute in pir
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuzyc committed Nov 12, 2024
1 parent e2bdc5d commit 54f40d2
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 0 deletions.
47 changes: 47 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,53 @@ class ProcessMeshAttrStorage : public pir::AttributeStorage {
ParamKey process_mesh;
};

// NOTE (pkuzyc): now the Placements is only used for the ``local_reshape``
// op, for the case when one tensor dimension is sharded by multiple mesh
// dimensions. In the future, the dims_mapping in TensorDistAttr will be
// replaced by this Placements.
class PlacementsAttrStorage : public pir::AttributeStorage {
public:
///
/// \brief Declare ParamKey according to parameter type.
///
using ParamKey = phi::distributed::Placements;

PlacementsAttrStorage(ParamKey&& placements) // NOLINT
: placements(std::move(placements)) {}

///
/// \brief Each derived TypeStorage must define a Construct method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static PlacementsAttrStorage* Construct(ParamKey&& key) {
return new PlacementsAttrStorage(std::move(key));
}

static std::string to_string(const ParamKey& key) {
std::string s = "(";
for (const auto& p : key) {
s += p->to_string() + ", ";
}
s.pop_back();
s.pop_back();
return s + ")";
}

///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey& key) {
return std::hash<std::string>()(to_string(key));
}

///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey& key) const { return placements == key; }

ParamKey placements;
};

class TensorDistAttrStorage : public pir::AttributeStorage {
public:
///
Expand Down
25 changes: 25 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,30 @@ ProcessMeshAttribute ProcessMeshAttribute::get(
return Base::get(ctx, shape, process_ids, dim_names);
}

const phi::distributed::Placements& PlacementsAttribute::placements() const {
return storage()->placements;
}

PlacementsAttribute PlacementsAttribute::get(
pir::IrContext* ctx, const phi::distributed::Placements& placements) {
return Base::get(ctx, placements);
}

std::string PlacementsAttribute::to_string() const {
return PlacementsAttrStorage::to_string(placements());
// std::string s = "(";
// for (const auto& p : placements()) {
// s += p->to_string() + ", ";
// }
// s.pop_back();
// s.pop_back();
// return s + ")";
}

size_t PlacementsAttribute::hash() const {
return std::hash<std::string>()(to_string());
}

///
/// \brief TensorDistAttribute interface.
///
Expand Down Expand Up @@ -166,5 +190,6 @@ OperationDistAttribute OperationDistAttribute::get(
} // namespace dialect
} // namespace paddle
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ProcessMeshAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PlacementsAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::TensorDistAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OperationDistAttribute)
18 changes: 18 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace dialect {
class ProcessMeshAttrStorage;
class TensorDistAttrStorage;
class OperationDistAttrStorage;
class PlacementsAttrStorage;

class ProcessMeshAttribute : public pir::AttrBase<ProcessMeshAttribute,
pir::Attribute,
Expand Down Expand Up @@ -65,6 +66,22 @@ class ProcessMeshAttribute : public pir::AttrBase<ProcessMeshAttribute,
static std::string name() { return "a_process_mesh"; }
};

class PlacementsAttribute : public pir::AttrBase<PlacementsAttribute,
pir::Attribute,
PlacementsAttrStorage> {
public:
using Base::Base;
const phi::distributed::Placements& placements() const;

size_t hash() const;
std::string to_string() const;

static std::string name() { return "a_placements"; }

static PlacementsAttribute get(
pir::IrContext* ctx, const phi::distributed::Placements& placements);
};

class TensorDistAttribute : public pir::AttrBase<TensorDistAttribute,
pir::Attribute,
TensorDistAttrStorage> {
Expand Down Expand Up @@ -145,5 +162,6 @@ class OperationDistAttribute : public pir::AttrBase<OperationDistAttribute,
} // namespace paddle

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ProcessMeshAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PlacementsAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::TensorDistAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OperationDistAttribute)
24 changes: 24 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,30 @@ class MoEGlobalMeshTensorOp
std::vector<pir::Value> results() { return operation()->results(); }
};

class DistReshapeOp : public pir::Op<DistReshapeOp, VjpInterface> {
public:
using Op::Op;
static const char* name() { return "dist_op.dist_reshape"; }
static const char* attributes_name[1];
static constexpr uint32_t attributes_num = 1;
TEST_API static void Build(pir::Builder& builder, // NOLINT
pir::OperationArgument& argument, // NOLINT
pir::Value input,
const phi::DDim& global_dims,
const phi::DDim& local_dims);

static OpInfoTuple GetOpInfo();
static std::vector<std::vector<pir::Value>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs_,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients);

void VerifySig();
std::vector<pir::Value> results() { return operation()->results(); }
};

} // namespace dialect
} // namespace paddle

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/dist_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct type_caster<paddle::flat_hash_map<Key, Value, Hash, Equal, Alloc>>

using paddle::dialect::DistTypeInterface;
using paddle::dialect::OperationDistAttribute;
using paddle::dialect::PlacementsAttribute;
using paddle::dialect::ProcessMeshAttribute;
using paddle::dialect::TensorDistAttribute;
using pir::ArrayAttribute;
Expand Down

0 comments on commit 54f40d2

Please sign in to comment.