Skip to content

Commit

Permalink
Batching of lint and fmt invokes (#14186)
Browse files Browse the repository at this point in the history
As described in #13462, there are correctness concerns around not breaking large batches of files into smaller batches in `lint` and `fmt`. But there are other reasons to batch, including improving the performance of linters which don't support internal parallelism (by breaking them into multiple processes which _can_ be parallelized).

This change adds a function to sequentially partition a list of items into stable batches, and then uses it to create batches by default in `lint` and `fmt`. Sequential partitioning was chosen rather than bucketing by hash, because it was easier to reason about in the presence of minimum and maximum bucket sizes.

Additionally, this implementation is at the level of the `lint` and `fmt` goals themselves (rather than within individual `lint`/`fmt` `@rule` sets, as originally suggested [on the ticket](#13462 (comment))) because that reduces the effort of implementing a linter or formatter, and would likely ease doing further "automatic"/declarative partitioning in those goals (by `Field` values, for example).

`./pants --no-pantsd --no-local-cache --no-remote-cache-read fmt lint ::` runs about ~4% faster than on main.

Fixes #13462.

[ci skip-build-wheels]
  • Loading branch information
stuhood authored Jan 19, 2022
1 parent 2d9e7b5 commit 9c1eb9f
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 71 deletions.
3 changes: 2 additions & 1 deletion src/python/pants/backend/terraform/lint/tffmt/tffmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import textwrap

from pants.backend.terraform.style import StyleSetup, StyleSetupRequest
from pants.backend.terraform.target_types import TerraformFieldSet
from pants.backend.terraform.tool import TerraformProcess
from pants.backend.terraform.tool import rules as tool_rules
from pants.core.goals.fmt import FmtRequest, FmtResult
Expand Down Expand Up @@ -39,7 +40,7 @@ def register_options(cls, register):


class TffmtRequest(FmtRequest):
pass
field_set_type = TerraformFieldSet


@rule(desc="Format with `terraform fmt`")
Expand Down
28 changes: 26 additions & 2 deletions src/python/pants/core/goals/fmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass
from typing import TypeVar, cast

from pants.core.goals.style_request import StyleRequest
from pants.core.goals.style_request import StyleRequest, style_batch_size_help
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.engine.console import Console
from pants.engine.engine_aware import EngineAwareReturnType
Expand All @@ -18,6 +18,7 @@
from pants.engine.rules import Get, MultiGet, collect_rules, goal_rule, rule
from pants.engine.target import SourcesField, Targets
from pants.engine.unions import UnionMembership, union
from pants.util.collections import partition_sequentially
from pants.util.logging import LogLevel
from pants.util.strutil import strip_v2_chroot_path

Expand Down Expand Up @@ -135,6 +136,15 @@ def register_options(cls, register) -> None:
advanced=True,
type=bool,
default=False,
removal_version="2.11.0.dev0",
removal_hint=(
"Formatters are now broken into multiple batches by default using the "
"`--batch-size` argument.\n"
"\n"
"To keep (roughly) this option's behavior, set [fmt].batch_size = 1. However, "
"you'll likely get better performance by using a larger batch size because of "
"reduced overhead launching processes."
),
help=(
"Rather than formatting all files in a single batch, format each file as a "
"separate process.\n\nWhy do this? You'll get many more cache hits. Why not do "
Expand All @@ -145,11 +155,22 @@ def register_options(cls, register) -> None:
"faster than `--no-per-file-caching` for your use case."
),
)
register(
"--batch-size",
advanced=True,
type=int,
default=128,
help=style_batch_size_help(uppercase="Formatter", lowercase="formatter"),
)

@property
def per_file_caching(self) -> bool:
return cast(bool, self.options.per_file_caching)

@property
def batch_size(self) -> int:
return cast(int, self.options.batch_size)


class Fmt(Goal):
subsystem_cls = FmtSubsystem
Expand Down Expand Up @@ -187,9 +208,12 @@ async def fmt(
per_language_results = await MultiGet(
Get(
_LanguageFmtResults,
_LanguageFmtRequest(fmt_requests, Targets(targets)),
_LanguageFmtRequest(fmt_requests, Targets(target_batch)),
)
for fmt_requests, targets in targets_by_fmt_request_order.items()
for target_batch in partition_sequentially(
targets, key=lambda t: t.address.spec, size_min=fmt_subsystem.batch_size
)
)

individual_results = list(
Expand Down
3 changes: 1 addition & 2 deletions src/python/pants/core/goals/fmt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pants.engine.rules import Get, collect_rules, rule
from pants.engine.target import FieldSet, MultipleSourcesField, Target
from pants.engine.unions import UnionRule
from pants.testutil.rule_runner import RuleRunner, logging
from pants.testutil.rule_runner import RuleRunner
from pants.util.logging import LogLevel

FORTRAN_FILE = FileContent("formatted.f98", b"READ INPUT TAPE 5\n")
Expand Down Expand Up @@ -135,7 +135,6 @@ def run_fmt(rule_runner: RuleRunner, *, target_specs: List[str], per_file_cachin
return result.stderr


@logging
@pytest.mark.parametrize("per_file_caching", [True, False])
def test_summary(per_file_caching: bool) -> None:
"""Tests that the final summary is correct.
Expand Down
83 changes: 58 additions & 25 deletions src/python/pants/core/goals/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from dataclasses import dataclass
from typing import Any, Iterable, cast

from pants.core.goals.style_request import StyleRequest, write_reports
from pants.core.goals.style_request import StyleRequest, style_batch_size_help, write_reports
from pants.core.util_rules.distdir import DistDir
from pants.engine.console import Console
from pants.engine.engine_aware import EngineAwareReturnType
from pants.engine.fs import EMPTY_DIGEST, Digest, Workspace
from pants.engine.goal import Goal, GoalSubsystem
from pants.engine.process import FallibleProcessResult
from pants.engine.rules import Get, MultiGet, collect_rules, goal_rule
from pants.engine.target import Targets
from pants.engine.target import FieldSet, Targets
from pants.engine.unions import UnionMembership, union
from pants.util.collections import partition_sequentially
from pants.util.logging import LogLevel
from pants.util.memo import memoized_property
from pants.util.meta import frozen_after_init
Expand Down Expand Up @@ -153,6 +154,15 @@ def register_options(cls, register) -> None:
advanced=True,
type=bool,
default=False,
removal_version="2.11.0.dev0",
removal_hint=(
"Linters are now broken into multiple batches by default using the "
"`--batch-size` argument.\n"
"\n"
"To keep (roughly) this option's behavior, set [lint].batch_size = 1. However, "
"you'll likely get better performance by using a larger batch size because of "
"reduced overhead launching processes."
),
help=(
"Rather than linting all files in a single batch, lint each file as a "
"separate process.\n\nWhy do this? You'll get many more cache hits. Why not do "
Expand All @@ -163,11 +173,22 @@ def register_options(cls, register) -> None:
"faster than `--no-per-file-caching` for your use case."
),
)
register(
"--batch-size",
advanced=True,
type=int,
default=128,
help=style_batch_size_help(uppercase="Linter", lowercase="linter"),
)

@property
def per_file_caching(self) -> bool:
return cast(bool, self.options.per_file_caching)

@property
def batch_size(self) -> int:
return cast(int, self.options.batch_size)


class Lint(Goal):
subsystem_cls = LintSubsystem
Expand All @@ -182,7 +203,7 @@ async def lint(
union_membership: UnionMembership,
dist_dir: DistDir,
) -> Lint:
request_types = cast("Iterable[type[StyleRequest]]", union_membership[LintRequest])
request_types = cast("Iterable[type[LintRequest]]", union_membership[LintRequest])
requests = tuple(
request_type(
request_type.field_set_type.create(target)
Expand All @@ -193,36 +214,48 @@ async def lint(
)

if lint_subsystem.per_file_caching:
all_per_file_results = await MultiGet(
all_batch_results = await MultiGet(
Get(LintResults, LintRequest, request.__class__([field_set]))
for request in requests
for field_set in request.field_sets
if request.field_sets
for field_set in request.field_sets
)
else:

def key_fn(results: LintResults):
return results.linter_name

# NB: We must pre-sort the data for itertools.groupby() to work properly.
sorted_all_per_files_results = sorted(all_per_file_results, key=key_fn)
# We consolidate all results for each linter into a single `LintResults`.
all_results = tuple(
LintResults(
itertools.chain.from_iterable(
per_file_results.results for per_file_results in all_linter_results
),
linter_name=linter_name,
)
for linter_name, all_linter_results in itertools.groupby(
sorted_all_per_files_results, key=key_fn
def address_str(fs: FieldSet) -> str:
return fs.address.spec

all_batch_results = await MultiGet(
Get(LintResults, LintRequest, request.__class__(field_set_batch))
for request in requests
if request.field_sets
for field_set_batch in partition_sequentially(
request.field_sets, key=address_str, size_min=lint_subsystem.batch_size
)
)
else:
all_results = await MultiGet(
Get(LintResults, LintRequest, request) for request in requests if request.field_sets
)

all_results = tuple(sorted(all_results, key=lambda results: results.linter_name))
def key_fn(results: LintResults):
return results.linter_name

# NB: We must pre-sort the data for itertools.groupby() to work properly.
sorted_all_batch_results = sorted(all_batch_results, key=key_fn)
# We consolidate all results for each linter into a single `LintResults`.
all_results = tuple(
sorted(
(
LintResults(
itertools.chain.from_iterable(
per_file_results.results for per_file_results in all_linter_results
),
linter_name=linter_name,
)
for linter_name, all_linter_results in itertools.groupby(
sorted_all_batch_results, key=key_fn
)
),
key=key_fn,
)
)

def get_tool_name(res: LintResults) -> str:
return res.linter_name
Expand Down
102 changes: 62 additions & 40 deletions src/python/pants/core/goals/lint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def run_lint_rule(
*,
lint_request_types: List[Type[LintRequest]],
targets: List[Target],
per_file_caching: bool,
per_file_caching: bool = False,
batch_size: int = 128,
) -> Tuple[int, str]:
with mock_console(rule_runner.options_bootstrapper) as (console, stdio_reader):
union_membership = UnionMembership({LintRequest: lint_request_types})
Expand All @@ -123,7 +124,9 @@ def run_lint_rule(
Workspace(rule_runner.scheduler, _enforce_effects=False),
Targets(targets),
create_goal_subsystem(
LintSubsystem, per_file_caching=per_file_caching, per_target_caching=False
LintSubsystem,
per_file_caching=per_file_caching,
batch_size=128,
),
union_membership,
DistDir(relpath=Path("dist")),
Expand All @@ -141,22 +144,20 @@ def run_lint_rule(
return result.exit_code, stdio_reader.get_stderr()


def test_invalid_target_noops(rule_runner: RuleRunner) -> None:
def assert_noops(per_file_caching: bool) -> None:
exit_code, stderr = run_lint_rule(
rule_runner,
lint_request_types=[InvalidRequest],
targets=[make_target()],
per_file_caching=per_file_caching,
)
assert exit_code == 0
assert stderr == ""

assert_noops(per_file_caching=False)
assert_noops(per_file_caching=True)
@pytest.mark.parametrize("per_file_caching", [True, False])
def test_invalid_target_noops(rule_runner: RuleRunner, per_file_caching: bool) -> None:
exit_code, stderr = run_lint_rule(
rule_runner,
lint_request_types=[InvalidRequest],
targets=[make_target()],
per_file_caching=per_file_caching,
)
assert exit_code == 0
assert stderr == ""


def test_summary(rule_runner: RuleRunner) -> None:
@pytest.mark.parametrize("per_file_caching", [True, False])
def test_summary(rule_runner: RuleRunner, per_file_caching: bool) -> None:
"""Test that we render the summary correctly.
This tests that we:
Expand All @@ -166,31 +167,52 @@ def test_summary(rule_runner: RuleRunner) -> None:
good_address = Address("", target_name="good")
bad_address = Address("", target_name="bad")

def assert_expected(*, per_file_caching: bool) -> None:
exit_code, stderr = run_lint_rule(
rule_runner,
lint_request_types=[
ConditionallySucceedsRequest,
FailingRequest,
SkippedRequest,
SuccessfulRequest,
],
targets=[make_target(good_address), make_target(bad_address)],
per_file_caching=per_file_caching,
)
assert exit_code == FailingRequest.exit_code([bad_address])
assert stderr == dedent(
"""\
𐄂 ConditionallySucceedsLinter failed.
𐄂 FailingLinter failed.
- SkippedLinter skipped.
✓ SuccessfulLinter succeeded.
"""
)
exit_code, stderr = run_lint_rule(
rule_runner,
lint_request_types=[
ConditionallySucceedsRequest,
FailingRequest,
SkippedRequest,
SuccessfulRequest,
],
targets=[make_target(good_address), make_target(bad_address)],
per_file_caching=per_file_caching,
)
assert exit_code == FailingRequest.exit_code([bad_address])
assert stderr == dedent(
"""\
assert_expected(per_file_caching=False)
assert_expected(per_file_caching=True)
𐄂 ConditionallySucceedsLinter failed.
𐄂 FailingLinter failed.
- SkippedLinter skipped.
✓ SuccessfulLinter succeeded.
"""
)


@pytest.mark.parametrize("batch_size", [1, 32, 128, 1024])
def test_batched(rule_runner: RuleRunner, batch_size: int) -> None:
exit_code, stderr = run_lint_rule(
rule_runner,
lint_request_types=[
ConditionallySucceedsRequest,
FailingRequest,
SkippedRequest,
SuccessfulRequest,
],
targets=[make_target(Address("", target_name=f"good{i}")) for i in range(0, 512)],
batch_size=batch_size,
)
assert exit_code == FailingRequest.exit_code([])
assert stderr == dedent(
"""\
✓ ConditionallySucceedsLinter succeeded.
𐄂 FailingLinter failed.
- SkippedLinter skipped.
✓ SuccessfulLinter succeeded.
"""
)


def test_streaming_output_skip() -> None:
Expand Down
17 changes: 17 additions & 0 deletions src/python/pants/core/goals/style_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@
_FS = TypeVar("_FS", bound=FieldSet)


def style_batch_size_help(uppercase: str, lowercase: str) -> str:
return (
f"The target minimum number of files that will be included in each {lowercase} batch.\n"
"\n"
f"{uppercase} processes are batched for a few reasons:\n"
"\n"
"1. to avoid OS argument length limits (in processes which don't support argument "
"files)\n"
"2. to support more stable cache keys than would be possible if all files were "
"operated on in a single batch.\n"
f"3. to allow for parallelism in {lowercase} processes which don't have internal "
"parallelism, or -- if they do support internal parallelism -- to improve scheduling "
"behavior when multiple processes are competing for cores and so internal "
"parallelism cannot be used perfectly.\n"
)


@frozen_after_init
@dataclass(unsafe_hash=True)
class StyleRequest(Generic[_FS], EngineAwareParameter, metaclass=ABCMeta):
Expand Down
Loading

0 comments on commit 9c1eb9f

Please sign in to comment.