Skip to content

Commit

Permalink
add placements attribute and dist_reshape api in pir
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuzyc committed Nov 12, 2024
1 parent 54f40d2 commit ec0d61f
Show file tree
Hide file tree
Showing 15 changed files with 563 additions and 246 deletions.
30 changes: 25 additions & 5 deletions paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ class PlacementsAttrStorage : public pir::AttributeStorage {
}

static std::string to_string(const ParamKey& key) {
std::string s = "(";
std::string s = "[";
for (const auto& p : key) {
s += p->to_string() + ", ";
}
s.pop_back();
s.pop_back();
return s + ")";
return s + "]";
}

///
Expand All @@ -102,7 +102,20 @@ class PlacementsAttrStorage : public pir::AttributeStorage {
///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey& key) const { return placements == key; }
bool operator==(const ParamKey& key) const {
bool equal = true;
if (placements.size() != key.size()) {
return false;
}
size_t len = key.size();
for (size_t i = 0; i < len; i++) {
if (*placements[i] != *key[i]) {
equal = false;
break;
}
}
return equal;
}

ParamKey placements;
};
Expand All @@ -114,12 +127,14 @@ class TensorDistAttrStorage : public pir::AttributeStorage {
///
using ParamKey = std::tuple<ProcessMeshAttribute,
std::vector<int64_t>,
flat_hash_map<int64_t, phi::ReduceType>>;
flat_hash_map<int64_t, phi::ReduceType>,
std::optional<PlacementsAttribute>>;

TensorDistAttrStorage(ParamKey&& param) // NOLINT
: mesh_attr(std::get<0>(param)),
dims_mapping(std::move(std::get<1>(param))),
partial_status(std::move(std::get<2>(param))) {}
partial_status(std::move(std::get<2>(param))),
placements_(std::move(std::get<3>(param))) {}
///
/// \brief Each derived TypeStorage must define a Construct method, which
/// StorageManager uses to construct a derived TypeStorage.
Expand All @@ -142,6 +157,10 @@ class TensorDistAttrStorage : public pir::AttributeStorage {
}
partial_status_str += "]";
auto combine_hash = pir::detail::hash_combine(mesh_hash, dims_map_hash);
if (std::get<3>(key).has_value()) {
combine_hash =
pir::detail::hash_combine(combine_hash, std::get<3>(key)->hash());
}
return pir::detail::hash_combine(
combine_hash, std::hash<std::string>()(partial_status_str));
}
Expand All @@ -160,6 +179,7 @@ class TensorDistAttrStorage : public pir::AttributeStorage {
// iterate operation (copy and comparison) would more frequency than random
// element access. <key: dim on mesh, value: reduce type>
flat_hash_map<int64_t, phi::ReduceType> partial_status;
std::optional<PlacementsAttribute> placements_;
};

class OperationDistAttrStorage : public pir::AttributeStorage {
Expand Down
23 changes: 23 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,27 @@ pir::Value moe_global_mesh_tensor(
return op.result(0);
}

pir::Value dist_reshape(
const pir::Value& x,
const phi::distributed::Placements& x_placements,
const std::vector<int64_t>& global_shape,
const std::vector<int64_t>& local_shape,
const phi::distributed::ProcessMesh& mesh,
const phi::distributed::Placements& placements,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status) {
pir::IrContext* ctx = pir::IrContext::Instance();
common::DDim global_dims = common::make_ddim(global_shape);
common::DDim local_dims = common::make_ddim(local_shape);
PlacementsAttribute x_placements_attr =
PlacementsAttribute::get(ctx, x_placements);
PlacementsAttribute placements_attr =
PlacementsAttribute::get(ctx, placements);
TensorDistAttribute out_dist_attr = TensorDistAttribute::get(
ctx, mesh, dims_mapping, partial_status, placements_attr);
auto op = ApiBuilder::Instance().GetBuilder()->Build<DistReshapeOp>(
x, x_placements_attr, global_dims, local_dims, out_dist_attr);
return op.result(0);
}

} // namespace paddle::dialect
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,15 @@ pir::Value moe_global_mesh_tensor(
const flat_hash_map<int64_t, phi::ReduceType>& global_partial_status,
const std::vector<int64_t>& global_shape);

pir::Value dist_reshape(
const pir::Value& x,
const phi::distributed::Placements& x_placements,
const std::vector<int64_t>& global_shape,
const std::vector<int64_t>& local_shape,
const phi::distributed::ProcessMesh& mesh,
const phi::distributed::Placements& placements,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status);

} // namespace dialect
} // namespace paddle
16 changes: 7 additions & 9 deletions paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ PlacementsAttribute PlacementsAttribute::get(

std::string PlacementsAttribute::to_string() const {
return PlacementsAttrStorage::to_string(placements());
// std::string s = "(";
// for (const auto& p : placements()) {
// s += p->to_string() + ", ";
// }
// s.pop_back();
// s.pop_back();
// return s + ")";
}

size_t PlacementsAttribute::hash() const {
Expand All @@ -69,6 +62,10 @@ ProcessMeshAttribute TensorDistAttribute::process_mesh_attr() const {
const std::vector<int64_t>& TensorDistAttribute::dims_mapping() const {
return storage()->dims_mapping;
}
std::optional<PlacementsAttribute> TensorDistAttribute::placements_attr()
const {
return storage()->placements_;
}

std::set<int64_t> TensorDistAttribute::partial_dims() const {
auto& partial = partial_status();
Expand Down Expand Up @@ -120,12 +117,13 @@ TensorDistAttribute TensorDistAttribute::get(
pir::IrContext* ctx,
ProcessMeshAttribute mesh,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status) {
const flat_hash_map<int64_t, phi::ReduceType>& partial_status,
const std::optional<PlacementsAttribute>& placements) {
PADDLE_ENFORCE_NOT_NULL(mesh,
common::errors::PreconditionNotMet(
"Building tensor_dist_attr through a nullptr "
"mesh attribute is currently not supported."));
return Base::get(ctx, mesh, dims_mapping, partial_status);
return Base::get(ctx, mesh, dims_mapping, partial_status, placements);
}

///
Expand Down
10 changes: 7 additions & 3 deletions paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class TensorDistAttribute : public pir::AttrBase<TensorDistAttribute,
using Base::Base;
ProcessMeshAttribute process_mesh_attr() const;
const std::vector<int64_t>& dims_mapping() const;
std::optional<PlacementsAttribute> placements_attr() const;

// return vector of mesh dims on which the this tensor is partial on
std::set<int64_t> partial_dims() const;
Expand All @@ -106,16 +107,19 @@ class TensorDistAttribute : public pir::AttrBase<TensorDistAttribute,
pir::IrContext* ctx,
ProcessMeshAttribute mesh,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {});
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {},
const std::optional<PlacementsAttribute>& placements = std::nullopt);
static TensorDistAttribute get(
pir::IrContext* ctx,
const phi::distributed::ProcessMesh& mesh,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {}) {
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {},
const std::optional<PlacementsAttribute>& placements = std::nullopt) {
return get(ctx,
ProcessMeshAttribute::get(ctx, mesh),
dims_mapping,
partial_status);
partial_status,
placements);
}

static std::string name() { return "a_tensor_dist"; }
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ DistDialect::DistDialect(pir::IrContext *context)

void DistDialect::initialize() {
RegisterAttributes<ProcessMeshAttribute,
PlacementsAttribute,
TensorDistAttribute,
OperationDistAttribute>();
RegisterTypes<DistDenseTensorType>();
RegisterOps<ShardTensorOp,
ReshardOp,
MoESubMeshTensorsOp,
MoEGlobalMeshTensorOp>();
MoEGlobalMeshTensorOp,
DistReshapeOp>();
}

void DistDialect::PrintType(pir::Type type, std::ostream &os) const {
Expand Down Expand Up @@ -115,6 +117,8 @@ void DistDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const {
}
os << ",chunk_id:" << op_dist_attr.chunk_id();
os << "}";
} else if (auto placements_attr = attr.dyn_cast<PlacementsAttribute>()) {
os << placements_attr.to_string();
} else {
os << "error_attribute_type";
}
Expand Down
Loading

0 comments on commit ec0d61f

Please sign in to comment.