Skip to content

Commit

Permalink
Go back to using GenerateSourcesRequest so export-codegen works
Browse files Browse the repository at this point in the history
# Rust tests and lints will be skipped. Delete if not intended.
[ci skip-rust]

# Building wheels and fs_util will be skipped. Delete if not intended.
[ci skip-build-wheels]
  • Loading branch information
Eric-Arellano committed Mar 4, 2022
1 parent f1d769b commit f2d3f4d
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 49 deletions.
57 changes: 32 additions & 25 deletions src/python/pants/backend/codegen/protobuf/go/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from dataclasses import dataclass

from pants.backend.codegen.protobuf.protoc import Protoc
from pants.backend.codegen.protobuf.target_types import ProtobufSourceField
from pants.backend.codegen.protobuf.target_types import ProtobufGrpcToggleField, ProtobufSourceField
from pants.backend.go.target_types import GoPackageSourcesField
from pants.backend.go.util_rules.build_pkg import BuildGoPackageRequest
from pants.backend.go.util_rules.build_pkg_target import GoCodegenRequest
from pants.backend.go.util_rules.build_pkg_target import GoCodegenBuildRequest
from pants.backend.go.util_rules.sdk import GoSdkProcess
from pants.core.util_rules.external_tool import DownloadedExternalTool, ExternalToolRequest
from pants.core.util_rules.source_files import SourceFilesRequest
Expand All @@ -21,31 +22,31 @@
FileContent,
MergeDigests,
RemovePrefix,
Snapshot,
)
from pants.engine.internals.native_engine import EMPTY_DIGEST
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.platform import Platform
from pants.engine.process import Process, ProcessResult
from pants.engine.rules import collect_rules, rule
from pants.engine.target import TransitiveTargets, TransitiveTargetsRequest
from pants.engine.target import (
GeneratedSources,
GenerateSourcesRequest,
TransitiveTargets,
TransitiveTargetsRequest,
)
from pants.engine.unions import UnionRule
from pants.source.source_root import SourceRoot, SourceRootRequest
from pants.util.logging import LogLevel


class GoCodegenProtobufRequest(GoCodegenRequest):
class GoCodegenBuildProtobufRequest(GoCodegenBuildRequest):
generate_from = ProtobufSourceField


@dataclass(frozen=True)
class _GeneratedGoFilesRequest:
source: ProtobufSourceField
grpc: bool


@dataclass(frozen=True)
class _GeneratedGoFiles:
digest: Digest
class GenerateGoFromProtobufRequest(GenerateSourcesRequest):
input = ProtobufSourceField
output = GoPackageSourcesField


@dataclass(frozen=True)
Expand All @@ -55,25 +56,25 @@ class _SetupGoProtocPlugin:

@rule
async def setup_build_go_package_request_for_protobuf(
_: GoCodegenProtobufRequest,
_: GoCodegenBuildProtobufRequest,
) -> BuildGoPackageRequest:
raise NotImplementedError()


@rule(desc="Generate Go source files from Protobuf", level=LogLevel.DEBUG)
async def generate_go_from_protobuf(
request: _GeneratedGoFilesRequest,
request: GenerateGoFromProtobufRequest,
protoc: Protoc,
go_protoc_plugin: _SetupGoProtocPlugin,
) -> _GeneratedGoFiles:
) -> GeneratedSources:
output_dir = "_generated_files"
protoc_relpath = "__protoc"
protoc_go_plugin_relpath = "__protoc_gen_go"

downloaded_protoc_binary, empty_output_dir, transitive_targets = await MultiGet(
Get(DownloadedExternalTool, ExternalToolRequest, protoc.get_request(Platform.current)),
Get(Digest, CreateDigest([Directory(output_dir)])),
Get(TransitiveTargets, TransitiveTargetsRequest([request.source.address])),
Get(TransitiveTargets, TransitiveTargetsRequest([request.protocol_target.address])),
)

# NB: By stripping the source roots, we avoid having to set the value `--proto_path`
Expand All @@ -87,15 +88,17 @@ async def generate_go_from_protobuf(
if tgt.has_field(ProtobufSourceField)
),
),
Get(StrippedSourceFiles, SourceFilesRequest([request.source])),
Get(
StrippedSourceFiles, SourceFilesRequest([request.protocol_target[ProtobufSourceField]])
),
)

input_digest = await Get(
Digest, MergeDigests([all_sources_stripped.snapshot.digest, empty_output_dir])
)

maybe_grpc_plugin_args = []
if request.grpc:
if request.protocol_target.get(ProtobufGrpcToggleField).value:
maybe_grpc_plugin_args = [
f"--go-grpc_out={output_dir}",
"--go-grpc_opt=paths=source_relative",
Expand All @@ -120,23 +123,23 @@ async def generate_go_from_protobuf(
protoc_relpath: downloaded_protoc_binary.digest,
protoc_go_plugin_relpath: go_protoc_plugin.digest,
},
description=f"Generating Go sources from {request.source.address}.",
description=f"Generating Go sources from {request.protocol_target.address}.",
level=LogLevel.DEBUG,
output_directories=(output_dir,),
),
)

normalized_digest, source_root = await MultiGet(
Get(Digest, RemovePrefix(result.output_digest, output_dir)),
Get(SourceRoot, SourceRootRequest, SourceRootRequest.for_address(request.source.address)),
Get(SourceRoot, SourceRootRequest, SourceRootRequest.for_target(request.protocol_target)),
)

source_root_restored = (
await Get(Digest, AddPrefix(normalized_digest, source_root.path))
await Get(Snapshot, AddPrefix(normalized_digest, source_root.path))
if source_root.path != "."
else normalized_digest
else await Get(Snapshot, Digest, normalized_digest)
)
return _GeneratedGoFiles(source_root_restored)
return GeneratedSources(source_root_restored)


# Note: The versions of the Go protoc and gRPC plugins are hard coded in the following go.mod. To update,
Expand Down Expand Up @@ -249,4 +252,8 @@ async def setup_go_protoc_plugin(platform: Platform) -> _SetupGoProtocPlugin:


def rules():
return (*collect_rules(), UnionRule(GoCodegenRequest, GoCodegenProtobufRequest))
return (
*collect_rules(),
UnionRule(GenerateSourcesRequest, GenerateGoFromProtobufRequest),
UnionRule(GoCodegenBuildRequest, GoCodegenBuildProtobufRequest),
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

import pytest

from pants.backend.codegen.protobuf.go.rules import _GeneratedGoFiles, _GeneratedGoFilesRequest
from pants.backend.codegen.protobuf.go.rules import GenerateGoFromProtobufRequest
from pants.backend.codegen.protobuf.go.rules import rules as go_protobuf_rules
from pants.backend.codegen.protobuf.target_types import (
ProtobufGrpcToggleField,
ProtobufSourceField,
ProtobufSourcesGeneratorTarget,
)
Expand All @@ -20,9 +19,9 @@
from pants.build_graph.address import Address
from pants.core.util_rules import config_files, source_files, stripped_source_files
from pants.core.util_rules.external_tool import rules as external_tool_rules
from pants.engine.fs import Digest, DigestContents, Snapshot
from pants.engine.fs import Digest, DigestContents
from pants.engine.rules import QueryRule
from pants.engine.target import HydratedSources, HydrateSourcesRequest
from pants.engine.target import GeneratedSources, HydratedSources, HydrateSourcesRequest
from pants.jvm.jdk_rules import rules as jdk_rules
from pants.testutil.rule_runner import PYTHON_BOOTSTRAP_ENV, RuleRunner

Expand All @@ -40,7 +39,7 @@ def rule_runner() -> RuleRunner:
*go_protobuf_rules(),
*sdk.rules(),
QueryRule(HydratedSources, [HydrateSourcesRequest]),
QueryRule(_GeneratedGoFiles, [_GeneratedGoFilesRequest]),
QueryRule(GeneratedSources, [GenerateGoFromProtobufRequest]),
QueryRule(DigestContents, (Digest,)),
],
target_types=[
Expand All @@ -66,16 +65,13 @@ def assert_files_generated(
args = [f"--source-root-patterns={repr(source_roots)}", *extra_args]
rule_runner.set_options(args, env_inherit=PYTHON_BOOTSTRAP_ENV)
tgt = rule_runner.get_target(address)
protocol_sources = rule_runner.request(
HydratedSources, [HydrateSourcesRequest(tgt[ProtobufSourceField])]
)
generated_sources = rule_runner.request(
_GeneratedGoFiles,
[
_GeneratedGoFilesRequest(
tgt[ProtobufSourceField], grpc=tgt[ProtobufGrpcToggleField].value
)
],
GeneratedSources, [GenerateGoFromProtobufRequest(protocol_sources.snapshot, tgt)]
)
snapshot = rule_runner.request(Snapshot, [generated_sources.digest])
assert set(snapshot.files) == set(expected_files)
assert set(generated_sources.snapshot.files) == set(expected_files)


def test_generates_go(rule_runner: RuleRunner) -> None:
Expand Down
19 changes: 12 additions & 7 deletions src/python/pants/backend/go/util_rules/build_pkg_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,19 @@ def debug_hint(self) -> str:

@union
@dataclass(frozen=True)
class GoCodegenRequest:
"""The plugin hook to generate Go code.
class GoCodegenBuildRequest:
"""The plugin hook to build/compile Go code.
Note that you should still use the normal `GenerateSourcesRequest` plugin hook from
`pants.engine.target` too, which is necessary for integrations like the `export-codegen` goal.
However, that is only helpful to generate the raw `.go` files; you also need to use this
plugin hook so that Pants knows how to compile those generated `.go` files.
Subclass this and set the class property `generate_from`. Define a rule that goes from your
subclass to `BuildGoPackageRequest` - the request must result in valid compilation, which you
should test for by using `rule_runner.request(BuiltGoPackage, BuildGoPackageRequest)` in your
tests. For example, make sure to set up any third-party packages needed by the generated code.
Finally, register `UnionRule(GoCodegenRequest, MySubclass)`.
Finally, register `UnionRule(GoCodegenBuildRequest, MySubclass)`.
"""

target: Target
Expand All @@ -79,11 +84,11 @@ class GoCodegenRequest:

def maybe_get_codegen_request_type(
tgt: Target, union_membership: UnionMembership
) -> GoCodegenRequest | None:
) -> GoCodegenBuildRequest | None:
if not tgt.has_field(SourcesField):
return None
generate_request_types = cast(
FrozenOrderedSet[type[GoCodegenRequest]], union_membership.get(GoCodegenRequest)
FrozenOrderedSet[type[GoCodegenBuildRequest]], union_membership.get(GoCodegenBuildRequest)
)
sources_field = tgt[SourcesField]
relevant_requests = [
Expand All @@ -92,7 +97,7 @@ def maybe_get_codegen_request_type(
if len(relevant_requests) > 1:
generate_from_sources = relevant_requests[0].generate_from.__name__
raise AmbiguousCodegenImplementationsException(
f"Multiple of the registered code generators from {GoCodegenRequest.__name__} can "
f"Multiple of the registered code generators from {GoCodegenBuildRequest.__name__} can "
f"generate from {generate_from_sources}. It is ambiguous which implementation to "
f"use.\n\n"
f"Possible implementations:\n\n"
Expand All @@ -112,7 +117,7 @@ async def setup_build_go_package_target_request(

codegen_request = maybe_get_codegen_request_type(target, union_membership)
if codegen_request:
codegen_result = await Get(BuildGoPackageRequest, GoCodegenRequest, codegen_request)
codegen_result = await Get(BuildGoPackageRequest, GoCodegenBuildRequest, codegen_request)
return FallibleBuildGoPackageRequest(codegen_result, codegen_result.import_path)

embed_config: EmbedConfig | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from pants.backend.go.util_rules.build_pkg_target import (
BuildGoPackageTargetRequest,
GoCodegenRequest,
GoCodegenBuildRequest,
)
from pants.core.target_types import FileSourceField, FileTarget
from pants.engine.addresses import Address
Expand All @@ -41,12 +41,12 @@


# Set up a trivial codegen plugin.
class GoCodegenFilesRequest(GoCodegenRequest):
class GoCodegenBuildFilesRequest(GoCodegenBuildRequest):
generate_from = FileSourceField


@rule
async def generate_from_file(_: GoCodegenFilesRequest) -> BuildGoPackageRequest:
async def generate_from_file(_: GoCodegenBuildFilesRequest) -> BuildGoPackageRequest:
content = dedent(
"""\
package gen
Expand Down Expand Up @@ -89,7 +89,7 @@ def rule_runner() -> RuleRunner:
QueryRule(FallibleBuiltGoPackage, [BuildGoPackageRequest]),
QueryRule(BuildGoPackageRequest, [BuildGoPackageTargetRequest]),
QueryRule(FallibleBuildGoPackageRequest, [BuildGoPackageTargetRequest]),
UnionRule(GoCodegenRequest, GoCodegenFilesRequest),
UnionRule(GoCodegenBuildRequest, GoCodegenBuildFilesRequest),
],
target_types=[GoModTarget, GoPackageTarget, FileTarget],
)
Expand Down

0 comments on commit f2d3f4d

Please sign in to comment.