Skip to content

Commit

Permalink
Add a config field to GraphNode to allow for custom configuration of …
Browse files Browse the repository at this point in the history
…nodes

Users can set `pinToGroupTop` for a node to be pinned to the top of a layer.

PiperOrigin-RevId: 677915322
  • Loading branch information
yijie-yang authored and copybara-github committed Sep 23, 2024
1 parent 706dab5 commit 154b960
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/builtin-adapter/formats/schema_structs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ llvm::json::Object GraphEdge::Json() {
return json_edge;
}

const char GraphNodeConfig::kPinToGroupTop[] = "pinToGroupTop";

llvm::json::Object GraphNodeConfig::Json() {
llvm::json::Object json_config;
json_config[kPinToGroupTop] = pin_to_group_top;
return json_config;
}

const char GraphNode::kNodeId[] = "id";
const char GraphNode::kNodeLabel[] = "label";
const char GraphNode::kNodeName[] = "namespace";
Expand All @@ -72,6 +80,7 @@ const char GraphNode::kNodeAttrs[] = "attrs";
const char GraphNode::kIncomingEdges[] = "incomingEdges";
const char GraphNode::kInputsMetadata[] = "inputsMetadata";
const char GraphNode::kOutputsMetadata[] = "outputsMetadata";
const char GraphNode::kConfig[] = "config";

llvm::json::Object GraphNode::Json() {
llvm::json::Object json_node;
Expand Down Expand Up @@ -105,6 +114,11 @@ llvm::json::Object GraphNode::Json() {
for (Metadata& metadata : outputs_metadata) {
json_outputs_metadata->push_back(metadata.Json());
}

if (config.has_value()) { // Only add config if it exists
json_node[kConfig] = config->Json();
}

return json_node;
}

Expand Down
14 changes: 14 additions & 0 deletions src/builtin-adapter/formats/schema_structs.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_GOOGLE_TOOLING_FORMATS_SCHEMA_STRUCTS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_GOOGLE_TOOLING_FORMATS_SCHEMA_STRUCTS_H_

#include <optional>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -64,6 +65,17 @@ struct GraphEdge {
static const char kEdgeMetadata[];
};

// Configuration for a graph node.
struct GraphNodeConfig {
// Whether to pin the node to the top of the group it belongs to.
bool pin_to_group_top = false;

llvm::json::Object Json();

private:
static const char kPinToGroupTop[];
};

struct GraphNode {
std::string node_id;
std::string node_label;
Expand All @@ -73,6 +85,7 @@ struct GraphNode {
std::vector<GraphEdge> incoming_edges;
std::vector<Metadata> inputs_metadata;
std::vector<Metadata> outputs_metadata;
std::optional<GraphNodeConfig> config;

llvm::json::Object Json();

Expand All @@ -85,6 +98,7 @@ struct GraphNode {
static const char kIncomingEdges[];
static const char kInputsMetadata[];
static const char kOutputsMetadata[];
static const char kConfig[];
};

struct Subgraph {
Expand Down
8 changes: 8 additions & 0 deletions src/builtin-adapter/graphnode_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,13 @@ void GraphNodeBuilder::AppendAttrToMetadata(const EdgeType edge_type,
}
}

void GraphNodeBuilder::SetPinToGroupTop(bool pin_to_group_top) {
if (node_.config.has_value()) {
node_.config->pin_to_group_top = pin_to_group_top;
} else {
node_.config = GraphNodeConfig{.pin_to_group_top = pin_to_group_top};
}
}

} // namespace visualization_client
} // namespace tooling
5 changes: 5 additions & 0 deletions src/builtin-adapter/graphnode_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class GraphNodeBuilder {
absl::string_view attr_key,
absl::string_view attr_value);

// Sets the node to be pinned to the top of the group it belongs to.
// User needs to ensure the node's namespace is indeed a layer for the pinning
// to work.
void SetPinToGroupTop(bool pin_to_group_top);

// Returns the node that has been created by this class.
GraphNode Build() && { return std::move(node_); }

Expand Down

0 comments on commit 154b960

Please sign in to comment.