diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java index 2952548e851bd5..cfc09237b415b0 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/Module.java @@ -242,6 +242,8 @@ public Builder addExtensionUsage(ModuleExtensionUsage value) { return this; } + abstract ModuleKey getKey(); + abstract String getName(); abstract Optional getRepoName(); diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java index 3b3f42c0ba0709..da7ba7318cb49e 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileGlobals.java @@ -61,6 +61,7 @@ public class ModuleFileGlobals { Pattern.compile("(>|<|-|<=|>=)(\\d+\\.){2}\\d+"); private boolean moduleCalled = false; + private boolean hadNonModuleCall = false; private final boolean ignoreDevDeps; private final Module.Builder module; private final Map deps = new LinkedHashMap<>(); @@ -208,6 +209,9 @@ public void module( if (moduleCalled) { throw Starlark.errorf("the module() directive can only be called once"); } + if (hadNonModuleCall) { + throw Starlark.errorf("if module() is called, it must be called before any other functions"); + } moduleCalled = true; if (!name.isEmpty()) { validateModuleName(name); @@ -298,6 +302,7 @@ private static ImmutableList checkAllCompatibilityVersions( public void bazelDep( String name, String version, String repoName, boolean devDependency, StarlarkThread thread) throws EvalException { + hadNonModuleCall = true; if (repoName.isEmpty()) { repoName = name; } @@ -330,6 +335,7 @@ public void bazelDep( allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)}, doc = "The labels of the platforms to register.")) public void registerExecutionPlatforms(Sequence platformLabels) throws EvalException { + hadNonModuleCall = true; module.addExecutionPlatformsToRegister( checkAllAbsolutePatterns(platformLabels, "register_execution_platforms")); } @@ -347,6 +353,7 @@ public void registerExecutionPlatforms(Sequence platformLabels) throws EvalEx allowedTypes = {@ParamType(type = Sequence.class, generic1 = String.class)}, doc = "The labels of the toolchains to register.")) public void registerToolchains(Sequence toolchainLabels) throws EvalException { + hadNonModuleCall = true; module.addToolchainsToRegister( checkAllAbsolutePatterns(toolchainLabels, "register_toolchains")); } @@ -376,7 +383,14 @@ public void registerToolchains(Sequence toolchainLabels) throws EvalException }, useStarlarkThread = true) public ModuleExtensionProxy useExtension( - String extensionBzlFile, String extensionName, boolean devDependency, StarlarkThread thread) { + String rawExtensionBzlFile, + String extensionName, + boolean devDependency, + StarlarkThread thread) { + hadNonModuleCall = true; + + String extensionBzlFile = normalizeLabelString(rawExtensionBzlFile); + ModuleExtensionUsageBuilder newUsageBuilder = new ModuleExtensionUsageBuilder( extensionBzlFile, extensionName, thread.getCallerLocation()); @@ -399,6 +413,22 @@ public ModuleExtensionProxy useExtension( return newUsageBuilder.getProxy(devDependency); } + private String normalizeLabelString(String rawExtensionBzlFile) { + // Normalize the label by adding the current module's repo_name if the label doesn't specify a + // repository name. This is necessary as ModuleExtensionUsages are grouped by the string value + // of this label, but later mapped to their Label representation. If multiple strings map to the + // same Label, this would result in a crash. + // ownName can't change anymore as calling module() after this results in an error. + String ownName = module.getRepoName().orElse(module.getName()); + if (module.getKey().equals(ModuleKey.ROOT) && rawExtensionBzlFile.startsWith("@//")) { + return "@" + ownName + rawExtensionBzlFile.substring(1); + } else if (rawExtensionBzlFile.startsWith("//")) { + return "@" + ownName + rawExtensionBzlFile; + } else { + return rawExtensionBzlFile; + } + } + class ModuleExtensionUsageBuilder { private final String extensionBzlFile; private final String extensionName; @@ -516,6 +546,7 @@ public void useRepo( Dict kwargs, StarlarkThread thread) throws EvalException { + hadNonModuleCall = true; Location location = thread.getCallerLocation(); for (String arg : Sequence.cast(args, String.class, "args")) { extensionProxy.addImport(arg, arg, location); @@ -598,6 +629,7 @@ public void singleVersionOverride( Iterable patchCmds, StarlarkInt patchStrip) throws EvalException { + hadNonModuleCall = true; Version parsedVersion; try { parsedVersion = Version.parse(version); @@ -652,6 +684,7 @@ public void singleVersionOverride( }) public void multipleVersionOverride(String moduleName, Iterable versions, String registry) throws EvalException { + hadNonModuleCall = true; ImmutableList.Builder parsedVersionsBuilder = new ImmutableList.Builder<>(); try { for (String version : Sequence.cast(versions, String.class, "versions").getImmutableList()) { @@ -735,6 +768,7 @@ public void archiveOverride( Iterable patchCmds, StarlarkInt patchStrip) throws EvalException { + hadNonModuleCall = true; ImmutableList urlList = urls instanceof String ? ImmutableList.of((String) urls) @@ -806,6 +840,7 @@ public void gitOverride( Iterable patchCmds, StarlarkInt patchStrip) throws EvalException { + hadNonModuleCall = true; addOverride( moduleName, GitOverride.create( @@ -835,6 +870,7 @@ public void gitOverride( positional = false), }) public void localPathOverride(String moduleName, String path) throws EvalException { + hadNonModuleCall = true; addOverride(moduleName, LocalPathOverride.create(path)); } diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java index a6f3127a8d9b24..2b488324e115b0 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java @@ -329,6 +329,88 @@ public void simpleExtension() throws Exception { assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba"); } + @Test + public void simpleExtension_nonCanonicalLabel() throws Exception { + scratch.file( + workspaceRoot.getRelative("MODULE.bazel").getPathString(), + "module(name='my_module', version = '1.0')", + "bazel_dep(name='data_repo', version='1.0')", + "ext1 = use_extension('//:defs.bzl', 'ext')", + "ext1.tag(name='foo', data='fu')", + "use_repo(ext1, 'foo')", + "ext2 = use_extension('@my_module//:defs.bzl', 'ext')", + "ext2.tag(name='bar', data='ba')", + "use_repo(ext2, 'bar')", + "ext3 = use_extension('@//:defs.bzl', 'ext')", + "ext3.tag(name='quz', data='qu')", + "use_repo(ext3, 'quz')"); + scratch.file( + workspaceRoot.getRelative("defs.bzl").getPathString(), + "load('@data_repo//:defs.bzl','data_repo')", + "tag = tag_class(attrs = {'name':attr.string(),'data':attr.string()})", + "def _ext_impl(ctx):", + " for mod in ctx.modules:", + " for tag in mod.tags.tag:", + " data_repo(name=tag.name,data=tag.data)", + "ext = module_extension(implementation=_ext_impl, tag_classes={'tag':tag})"); + scratch.file(workspaceRoot.getRelative("BUILD").getPathString()); + scratch.file( + workspaceRoot.getRelative("data.bzl").getPathString(), + "load('@foo//:data.bzl', foo_data='data')", + "load('@bar//:data.bzl', bar_data='data')", + "load('@quz//:data.bzl', quz_data='data')", + "data = 'foo:'+foo_data+' bar:'+bar_data+' quz:'+quz_data"); + + SkyKey skyKey = BzlLoadValue.keyForBuild(Label.parseCanonical("//:data.bzl")); + EvaluationResult result = + evaluator.evaluate(ImmutableList.of(skyKey), evaluationContext); + if (result.hasError()) { + throw result.getError().getException(); + } + assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba quz:qu"); + } + + @Test + public void simpleExtension_nonCanonicalLabel_repoName() throws Exception { + scratch.file( + workspaceRoot.getRelative("MODULE.bazel").getPathString(), + "module(name='my_module', version = '1.0', repo_name='my_name')", + "bazel_dep(name='data_repo', version='1.0')", + "ext1 = use_extension('//:defs.bzl', 'ext')", + "ext1.tag(name='foo', data='fu')", + "use_repo(ext1, 'foo')", + "ext2 = use_extension('@my_name//:defs.bzl', 'ext')", + "ext2.tag(name='bar', data='ba')", + "use_repo(ext2, 'bar')", + "ext3 = use_extension('@//:defs.bzl', 'ext')", + "ext3.tag(name='quz', data='qu')", + "use_repo(ext3, 'quz')"); + scratch.file( + workspaceRoot.getRelative("defs.bzl").getPathString(), + "load('@data_repo//:defs.bzl','data_repo')", + "tag = tag_class(attrs = {'name':attr.string(),'data':attr.string()})", + "def _ext_impl(ctx):", + " for mod in ctx.modules:", + " for tag in mod.tags.tag:", + " data_repo(name=tag.name,data=tag.data)", + "ext = module_extension(implementation=_ext_impl, tag_classes={'tag':tag})"); + scratch.file(workspaceRoot.getRelative("BUILD").getPathString()); + scratch.file( + workspaceRoot.getRelative("data.bzl").getPathString(), + "load('@foo//:data.bzl', foo_data='data')", + "load('@bar//:data.bzl', bar_data='data')", + "load('@quz//:data.bzl', quz_data='data')", + "data = 'foo:'+foo_data+' bar:'+bar_data+' quz:'+quz_data"); + + SkyKey skyKey = BzlLoadValue.keyForBuild(Label.parseCanonical("//:data.bzl")); + EvaluationResult result = + evaluator.evaluate(ImmutableList.of(skyKey), evaluationContext); + if (result.hasError()) { + throw result.getError().getException(); + } + assertThat(result.get(skyKey).getModule().getGlobal("data")).isEqualTo("foo:fu bar:ba quz:qu"); + } + @Test public void multipleModules() throws Exception { scratch.file( diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java index 2df3f7af45a1e2..1b8e52ac9cc8fe 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java @@ -473,7 +473,7 @@ public void testModuleExtensions_good() throws Exception { .setRegistry(registry) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@mymod//:defs.bzl") .setExtensionName("myext1") .setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 2, 23)) .setImports(ImmutableBiMap.of("repo1", "repo1")) @@ -491,7 +491,7 @@ public void testModuleExtensions_good() throws Exception { .build()) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@mymod//:defs.bzl") .setExtensionName("myext2") .setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 5, 23)) .setImports(ImmutableBiMap.of("other_repo1", "repo1", "repo2", "repo2")) @@ -582,7 +582,7 @@ public void testModuleExtensions_duplicateProxy_asRoot() throws Exception { .setKey(ModuleKey.ROOT) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@//:defs.bzl") .setExtensionName("myext") .setLocation(Location.fromFileLineColumn("/MODULE.bazel", 1, 23)) .setImports( @@ -672,7 +672,7 @@ public void testModuleExtensions_duplicateProxy_asDep() throws Exception { .setRegistry(registry) .addExtensionUsage( ModuleExtensionUsage.builder() - .setExtensionBzlFile("//:defs.bzl") + .setExtensionBzlFile("@mymod//:defs.bzl") .setExtensionName("myext") .setLocation(Location.fromFileLineColumn("mymod@1.0/MODULE.bazel", 5, 23)) .setImports(ImmutableBiMap.of("beta", "beta", "delta", "delta")) @@ -956,4 +956,34 @@ public void moduleRepoName_conflict() throws Exception { assertContainsEvent("The repo name 'bbb' is already being used as the module's own repo name"); } + + @Test + public void module_calledTwice() throws Exception { + scratch.file( + rootDirectory.getRelative("MODULE.bazel").getPathString(), + "module(name='aaa',version='0.1',repo_name='bbb')", + "module(name='aaa',version='0.1',repo_name='bbb')"); + FakeRegistry registry = registryFactory.newFakeRegistry("/foo"); + ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl())); + + reporter.removeHandler(failFastHandler); // expect failures + evaluator.evaluate(ImmutableList.of(ModuleFileValue.KEY_FOR_ROOT_MODULE), evaluationContext); + + assertContainsEvent("the module() directive can only be called once"); + } + + @Test + public void module_calledLate() throws Exception { + scratch.file( + rootDirectory.getRelative("MODULE.bazel").getPathString(), + "use_extension('//:extensions.bzl', 'my_ext')", + "module(name='aaa',version='0.1',repo_name='bbb')"); + FakeRegistry registry = registryFactory.newFakeRegistry("/foo"); + ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl())); + + reporter.removeHandler(failFastHandler); // expect failures + evaluator.evaluate(ImmutableList.of(ModuleFileValue.KEY_FOR_ROOT_MODULE), evaluationContext); + + assertContainsEvent("if module() is called, it must be called before any other functions"); + } }