Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dist Dialect] Add placements and dist_reshape api in pir #69262

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,81 @@ class ProcessMeshAttrStorage : public pir::AttributeStorage {
ParamKey process_mesh;
};

// NOTE (pkuzyc): now the Placements is only used for the ``local_reshape``
// op, for the case when one tensor dimension is sharded by multiple mesh
// dimensions. In the future, the dims_mapping in TensorDistAttr will be
// replaced by this Placements.
class PlacementsAttrStorage : public pir::AttributeStorage {
public:
///
/// \brief Declare ParamKey according to parameter type.
///
using ParamKey = phi::distributed::Placements;

PlacementsAttrStorage(ParamKey&& placements) // NOLINT
: placements(std::move(placements)) {}

///
/// \brief Each derived TypeStorage must define a Construct method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static PlacementsAttrStorage* Construct(ParamKey&& key) {
return new PlacementsAttrStorage(std::move(key));
}

static std::string to_string(const ParamKey& key) {
std::string s = "[";
for (const auto& p : key) {
s += p->to_string() + ", ";
}
s.pop_back();
s.pop_back();
return s + "]";
}

///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey& key) {
return std::hash<std::string>()(to_string(key));
}

///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey& key) const {
bool equal = true;
if (placements.size() != key.size()) {
return false;
}
size_t len = key.size();
for (size_t i = 0; i < len; i++) {
if (*placements[i] != *key[i]) {
equal = false;
break;
}
}
return equal;
}

ParamKey placements;
};

class TensorDistAttrStorage : public pir::AttributeStorage {
public:
///
/// \brief Declare ParamKey according to parameter type.
///
using ParamKey = std::tuple<ProcessMeshAttribute,
std::vector<int64_t>,
flat_hash_map<int64_t, phi::ReduceType>>;
flat_hash_map<int64_t, phi::ReduceType>,
std::optional<PlacementsAttribute>>;

TensorDistAttrStorage(ParamKey&& param) // NOLINT
: mesh_attr(std::get<0>(param)),
dims_mapping(std::move(std::get<1>(param))),
partial_status(std::move(std::get<2>(param))) {}
partial_status(std::move(std::get<2>(param))),
placements_(std::move(std::get<3>(param))) {}
///
/// \brief Each derived TypeStorage must define a Construct method, which
/// StorageManager uses to construct a derived TypeStorage.
Expand All @@ -95,6 +157,10 @@ class TensorDistAttrStorage : public pir::AttributeStorage {
}
partial_status_str += "]";
auto combine_hash = pir::detail::hash_combine(mesh_hash, dims_map_hash);
if (std::get<3>(key).has_value()) {
combine_hash =
pir::detail::hash_combine(combine_hash, std::get<3>(key)->hash());
}
return pir::detail::hash_combine(
combine_hash, std::hash<std::string>()(partial_status_str));
}
Expand All @@ -113,6 +179,7 @@ class TensorDistAttrStorage : public pir::AttributeStorage {
// iterate operation (copy and comparison) would more frequency than random
// element access. <key: dim on mesh, value: reduce type>
flat_hash_map<int64_t, phi::ReduceType> partial_status;
std::optional<PlacementsAttribute> placements_;
};

class OperationDistAttrStorage : public pir::AttributeStorage {
Expand Down
23 changes: 23 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,27 @@ pir::Value moe_global_mesh_tensor(
return op.result(0);
}

pir::Value dist_reshape(
const pir::Value& x,
const phi::distributed::Placements& x_placements,
const std::vector<int64_t>& global_shape,
const std::vector<int64_t>& local_shape,
const phi::distributed::ProcessMesh& mesh,
const phi::distributed::Placements& placements,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status) {
pir::IrContext* ctx = pir::IrContext::Instance();
common::DDim global_dims = common::make_ddim(global_shape);
common::DDim local_dims = common::make_ddim(local_shape);
PlacementsAttribute x_placements_attr =
PlacementsAttribute::get(ctx, x_placements);
PlacementsAttribute placements_attr =
PlacementsAttribute::get(ctx, placements);
TensorDistAttribute out_dist_attr = TensorDistAttribute::get(
ctx, mesh, dims_mapping, partial_status, placements_attr);
auto op = ApiBuilder::Instance().GetBuilder()->Build<DistReshapeOp>(
x, x_placements_attr, global_dims, local_dims, out_dist_attr);
return op.result(0);
}

} // namespace paddle::dialect
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,15 @@ pir::Value moe_global_mesh_tensor(
const flat_hash_map<int64_t, phi::ReduceType>& global_partial_status,
const std::vector<int64_t>& global_shape);

pir::Value dist_reshape(
const pir::Value& x,
const phi::distributed::Placements& x_placements,
const std::vector<int64_t>& global_shape,
const std::vector<int64_t>& local_shape,
const phi::distributed::ProcessMesh& mesh,
const phi::distributed::Placements& placements,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status);

} // namespace dialect
} // namespace paddle
27 changes: 25 additions & 2 deletions paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ ProcessMeshAttribute ProcessMeshAttribute::get(
return Base::get(ctx, shape, process_ids, dim_names);
}

const phi::distributed::Placements& PlacementsAttribute::placements() const {
return storage()->placements;
}

PlacementsAttribute PlacementsAttribute::get(
pir::IrContext* ctx, const phi::distributed::Placements& placements) {
return Base::get(ctx, placements);
}

std::string PlacementsAttribute::to_string() const {
return PlacementsAttrStorage::to_string(placements());
}

size_t PlacementsAttribute::hash() const {
return std::hash<std::string>()(to_string());
}

///
/// \brief TensorDistAttribute interface.
///
Expand All @@ -45,6 +62,10 @@ ProcessMeshAttribute TensorDistAttribute::process_mesh_attr() const {
const std::vector<int64_t>& TensorDistAttribute::dims_mapping() const {
return storage()->dims_mapping;
}
std::optional<PlacementsAttribute> TensorDistAttribute::placements_attr()
const {
return storage()->placements_;
}

std::set<int64_t> TensorDistAttribute::partial_dims() const {
auto& partial = partial_status();
Expand Down Expand Up @@ -96,12 +117,13 @@ TensorDistAttribute TensorDistAttribute::get(
pir::IrContext* ctx,
ProcessMeshAttribute mesh,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status) {
const flat_hash_map<int64_t, phi::ReduceType>& partial_status,
const std::optional<PlacementsAttribute>& placements) {
PADDLE_ENFORCE_NOT_NULL(mesh,
common::errors::PreconditionNotMet(
"Building tensor_dist_attr through a nullptr "
"mesh attribute is currently not supported."));
return Base::get(ctx, mesh, dims_mapping, partial_status);
return Base::get(ctx, mesh, dims_mapping, partial_status, placements);
}

///
Expand Down Expand Up @@ -166,5 +188,6 @@ OperationDistAttribute OperationDistAttribute::get(
} // namespace dialect
} // namespace paddle
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ProcessMeshAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PlacementsAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::TensorDistAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OperationDistAttribute)
28 changes: 25 additions & 3 deletions paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace dialect {
class ProcessMeshAttrStorage;
class TensorDistAttrStorage;
class OperationDistAttrStorage;
class PlacementsAttrStorage;

class ProcessMeshAttribute : public pir::AttrBase<ProcessMeshAttribute,
pir::Attribute,
Expand Down Expand Up @@ -65,13 +66,30 @@ class ProcessMeshAttribute : public pir::AttrBase<ProcessMeshAttribute,
static std::string name() { return "a_process_mesh"; }
};

class PlacementsAttribute : public pir::AttrBase<PlacementsAttribute,
pir::Attribute,
PlacementsAttrStorage> {
public:
using Base::Base;
const phi::distributed::Placements& placements() const;

size_t hash() const;
std::string to_string() const;

static std::string name() { return "a_placements"; }

static PlacementsAttribute get(
pir::IrContext* ctx, const phi::distributed::Placements& placements);
};

class TensorDistAttribute : public pir::AttrBase<TensorDistAttribute,
pir::Attribute,
TensorDistAttrStorage> {
public:
using Base::Base;
ProcessMeshAttribute process_mesh_attr() const;
const std::vector<int64_t>& dims_mapping() const;
std::optional<PlacementsAttribute> placements_attr() const;

// return vector of mesh dims on which the this tensor is partial on
std::set<int64_t> partial_dims() const;
Expand All @@ -89,16 +107,19 @@ class TensorDistAttribute : public pir::AttrBase<TensorDistAttribute,
pir::IrContext* ctx,
ProcessMeshAttribute mesh,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {});
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {},
const std::optional<PlacementsAttribute>& placements = std::nullopt);
static TensorDistAttribute get(
pir::IrContext* ctx,
const phi::distributed::ProcessMesh& mesh,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {}) {
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {},
const std::optional<PlacementsAttribute>& placements = std::nullopt) {
return get(ctx,
ProcessMeshAttribute::get(ctx, mesh),
dims_mapping,
partial_status);
partial_status,
placements);
}

static std::string name() { return "a_tensor_dist"; }
Expand Down Expand Up @@ -145,5 +166,6 @@ class OperationDistAttribute : public pir::AttrBase<OperationDistAttribute,
} // namespace paddle

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ProcessMeshAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PlacementsAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::TensorDistAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OperationDistAttribute)
6 changes: 5 additions & 1 deletion paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ DistDialect::DistDialect(pir::IrContext *context)

void DistDialect::initialize() {
RegisterAttributes<ProcessMeshAttribute,
PlacementsAttribute,
TensorDistAttribute,
OperationDistAttribute>();
RegisterTypes<DistDenseTensorType>();
RegisterOps<ShardTensorOp,
ReshardOp,
MoESubMeshTensorsOp,
MoEGlobalMeshTensorOp>();
MoEGlobalMeshTensorOp,
DistReshapeOp>();
}

void DistDialect::PrintType(pir::Type type, std::ostream &os) const {
Expand Down Expand Up @@ -115,6 +117,8 @@ void DistDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const {
}
os << ",chunk_id:" << op_dist_attr.chunk_id();
os << "}";
} else if (auto placements_attr = attr.dyn_cast<PlacementsAttribute>()) {
os << placements_attr.to_string();
} else {
os << "error_attribute_type";
}
Expand Down
Loading