Skip to content

Commit

Permalink
Added support to expose imports for Python to Starlark
Browse files Browse the repository at this point in the history
Fixes #2617.

Basically just re-packaging the code and tests from @irengrig's work at srikalyan#1.

Closes #6423.

PiperOrigin-RevId: 218508231
  • Loading branch information
asafflesch authored and Copybara-Service committed Oct 24, 2018
1 parent 0f69a74 commit aab45c4
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ public Collection<Artifact> precompiledPythonFiles(
}

@Override
public List<PathFragment> getImports(RuleContext ruleContext) {
List<PathFragment> result = new ArrayList<>();
public List<String> getImports(RuleContext ruleContext) {
List<String> result = new ArrayList<>();
PathFragment packageFragment = ruleContext.getLabel().getPackageIdentifier().getRunfilesPath();
// Python scripts start with x.runfiles/ as the module space, so everything must be manually
// adjusted to be relative to the workspace name.
Expand All @@ -119,7 +119,7 @@ public List<PathFragment> getImports(RuleContext ruleContext) {
ruleContext.attributeError("imports",
"Path " + importsAttr + " references a path above the execution root");
}
result.add(importsPath);
result.add(importsPath.getPathString());
}
return result;
}
Expand All @@ -134,7 +134,7 @@ public Artifact createExecutable(
RuleContext ruleContext,
PyCommon common,
CcLinkingInfo ccLinkingInfo,
NestedSet<PathFragment> imports)
NestedSet<String> imports)
throws InterruptedException {
String main = common.determineMainExecutableSource(/*withWorkspaceName=*/ true);
Artifact executable = common.getExecutable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import com.google.devtools.build.lib.rules.cpp.CcCommon.CcFlagsSupplier;
import com.google.devtools.build.lib.rules.cpp.CcLinkingInfo;
import com.google.devtools.build.lib.syntax.Type;
import com.google.devtools.build.lib.vfs.PathFragment;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -72,7 +71,7 @@ static RuleConfiguredTargetBuilder init(
return null;
}

NestedSet<PathFragment> imports = common.collectImports(ruleContext, semantics);
NestedSet<String> imports = common.collectImports(ruleContext, semantics);
if (ruleContext.hasErrors()) {
return null;
}
Expand Down Expand Up @@ -118,7 +117,7 @@ static RuleConfiguredTargetBuilder init(

RuleConfiguredTargetBuilder builder =
new RuleConfiguredTargetBuilder(ruleContext);
common.addCommonTransitiveInfoProviders(builder, semantics, common.getFilesToBuild());
common.addCommonTransitiveInfoProviders(builder, semantics, common.getFilesToBuild(), imports);

semantics.postInitBinary(ruleContext, runfilesSupport, common);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public final class PyCommon {
public static final String PYTHON_SKYLARK_PROVIDER_NAME = "py";
public static final String TRANSITIVE_PYTHON_SRCS = "transitive_sources";
public static final String IS_USING_SHARED_LIBRARY = "uses_shared_libraries";
public static final String IMPORTS = "imports";

private static final LocalMetadataCollector METADATA_COLLECTOR = new LocalMetadataCollector() {
@Override
Expand Down Expand Up @@ -156,8 +157,11 @@ public Artifact getPythonLauncherArtifact(Artifact executable) {
return ruleContext.getRelatedArtifact(executable.getRootRelativePath(), "");
}

public void addCommonTransitiveInfoProviders(RuleConfiguredTargetBuilder builder,
PythonSemantics semantics, NestedSet<Artifact> filesToBuild) {
public void addCommonTransitiveInfoProviders(
RuleConfiguredTargetBuilder builder,
PythonSemantics semantics,
NestedSet<Artifact> filesToBuild,
NestedSet<String> imports) {

builder
.add(
Expand All @@ -169,7 +173,7 @@ public void addCommonTransitiveInfoProviders(RuleConfiguredTargetBuilder builder
filesToBuild))
.addSkylarkTransitiveInfo(
PYTHON_SKYLARK_PROVIDER_NAME,
createSourceProvider(this.transitivePythonSources, usesSharedLibraries()))
createSourceProvider(this.transitivePythonSources, usesSharedLibraries(), imports))
// Python targets are not really compilable. The best we can do is make sure that all
// generated source files are ready.
.addOutputGroup(OutputGroupInfo.FILES_TO_COMPILE, transitivePythonSources)
Expand All @@ -182,13 +186,17 @@ public void addCommonTransitiveInfoProviders(RuleConfiguredTargetBuilder builder
* <p>addSkylarkTransitiveInfo(PYTHON_SKYLARK_PROVIDER_NAME, createSourceProvider(...))
*/
public static StructImpl createSourceProvider(
NestedSet<Artifact> transitivePythonSources, boolean isUsingSharedLibrary) {
NestedSet<Artifact> transitivePythonSources,
boolean isUsingSharedLibrary,
NestedSet<String> imports) {
return StructProvider.STRUCT.create(
ImmutableMap.<String, Object>of(
TRANSITIVE_PYTHON_SRCS,
SkylarkNestedSet.of(Artifact.class, transitivePythonSources),
IS_USING_SHARED_LIBRARY,
isUsingSharedLibrary),
isUsingSharedLibrary,
IMPORTS,
SkylarkNestedSet.of(String.class, imports)),
"No such attribute '%s'");
}

Expand Down Expand Up @@ -384,15 +392,14 @@ public NestedSet<Artifact> collectTransitivePythonSourcesWithoutLocal() {
return builder.build();
}

public NestedSet<PathFragment> collectImports(
RuleContext ruleContext, PythonSemantics semantics) {
NestedSetBuilder<PathFragment> builder = NestedSetBuilder.compileOrder();
public NestedSet<String> collectImports(RuleContext ruleContext, PythonSemantics semantics) {
NestedSetBuilder<String> builder = NestedSetBuilder.compileOrder();
builder.addAll(semantics.getImports(ruleContext));
collectTransitivePythonImports(builder);
return builder.build();
}

private void collectTransitivePythonImports(NestedSetBuilder<PathFragment> builder) {
private void collectTransitivePythonImports(NestedSetBuilder<String> builder) {
for (TransitiveInfoCollection dep : getTargetDeps()) {
if (dep.getProvider(PythonImportsProvider.class) != null) {
PythonImportsProvider provider = dep.getProvider(PythonImportsProvider.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import com.google.devtools.build.lib.collect.nestedset.NestedSet;
import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
import com.google.devtools.build.lib.collect.nestedset.Order;
import com.google.devtools.build.lib.vfs.PathFragment;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -60,7 +59,7 @@ public ConfiguredTarget create(final RuleContext ruleContext)
NestedSetBuilder.wrap(Order.STABLE_ORDER, allOutputs);
common.addPyExtraActionPseudoAction();

NestedSet<PathFragment> imports = common.collectImports(ruleContext, semantics);
NestedSet<String> imports = common.collectImports(ruleContext, semantics);
if (ruleContext.hasErrors()) {
return null;
}
Expand All @@ -76,7 +75,7 @@ public ConfiguredTarget create(final RuleContext ruleContext)
runfilesBuilder.addRunfiles(ruleContext, RunfilesProvider.DEFAULT_RUNFILES);

RuleConfiguredTargetBuilder builder = new RuleConfiguredTargetBuilder(ruleContext);
common.addCommonTransitiveInfoProviders(builder, semantics, filesToBuild);
common.addCommonTransitiveInfoProviders(builder, semantics, filesToBuild, imports);

return builder
.setFilesToBuild(filesToBuild)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,19 @@
import com.google.devtools.build.lib.collect.nestedset.NestedSet;
import com.google.devtools.build.lib.concurrent.ThreadSafety.Immutable;
import com.google.devtools.build.lib.skyframe.serialization.autocodec.AutoCodec;
import com.google.devtools.build.lib.vfs.PathFragment;

/** A {@link TransitiveInfoProvider} that supplies import directories for Python dependencies. */
@Immutable
@AutoCodec
public final class PythonImportsProvider implements TransitiveInfoProvider {

private final NestedSet<PathFragment> transitivePythonImports;
private final NestedSet<String> transitivePythonImports;

public PythonImportsProvider(NestedSet<PathFragment> transitivePythonImports) {
public PythonImportsProvider(NestedSet<String> transitivePythonImports) {
this.transitivePythonImports = transitivePythonImports;
}

public NestedSet<PathFragment> getTransitivePythonImports() {
public NestedSet<String> getTransitivePythonImports() {
return transitivePythonImports;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.google.devtools.build.lib.collect.nestedset.NestedSet;
import com.google.devtools.build.lib.packages.RuleClass.ConfiguredTargetFactory.RuleErrorException;
import com.google.devtools.build.lib.rules.cpp.CcLinkingInfo;
import com.google.devtools.build.lib.vfs.PathFragment;
import java.util.Collection;
import java.util.List;

Expand Down Expand Up @@ -64,10 +63,8 @@ void collectDefaultRunfilesForBinary(RuleContext ruleContext, Runfiles.Builder b
Collection<Artifact> precompiledPythonFiles(
RuleContext ruleContext, Collection<Artifact> sources, PyCommon common);

/**
* Returns a list of PathFragments for the import paths specified in the imports attribute.
*/
List<PathFragment> getImports(RuleContext ruleContext);
/** Returns a list of PathFragments for the import paths specified in the imports attribute. */
List<String> getImports(RuleContext ruleContext);

/**
* Create the actual executable artifact.
Expand All @@ -78,7 +75,7 @@ Artifact createExecutable(
RuleContext ruleContext,
PyCommon common,
CcLinkingInfo ccLinkingInfo,
NestedSet<PathFragment> imports)
NestedSet<String> imports)
throws InterruptedException, RuleErrorException;

/**
Expand Down
1 change: 1 addition & 0 deletions src/test/java/com/google/devtools/build/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ java_test(
"//src/main/java/com/google/devtools/build/lib:events",
"//src/main/java/com/google/devtools/build/lib:java-compilation",
"//src/main/java/com/google/devtools/build/lib:packages-internal",
"//src/main/java/com/google/devtools/build/lib:python-rules",
"//src/main/java/com/google/devtools/build/lib:syntax",
"//src/main/java/com/google/devtools/build/lib:util",
"//src/main/java/com/google/devtools/build/lib/actions",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright 2018 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.devtools.build.lib.bazel.rules.python;

import static com.google.common.truth.Truth.assertThat;
import static com.google.devtools.build.lib.actions.util.ActionsTestUtil.prettyArtifactNames;

import com.google.devtools.build.lib.actions.Artifact;
import com.google.devtools.build.lib.analysis.ConfiguredTarget;
import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
import com.google.devtools.build.lib.packages.SkylarkInfo;
import com.google.devtools.build.lib.rules.python.PyCommon;
import com.google.devtools.build.lib.syntax.SkylarkNestedSet;
import java.io.IOException;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Python Starlark APi test */
@RunWith(JUnit4.class)
public class BazelPythonStarlarkApiTest extends BuildViewTestCase {
@Test
public void pythonProviderWithFields() throws Exception {
simpleSources();
assertNoEvents();

ConfiguredTarget helloTarget = getConfiguredTarget("//py:hello");
SkylarkInfo provider = getPythonSkylarkProvider(helloTarget);

assertThat(provider.hasField(PyCommon.TRANSITIVE_PYTHON_SRCS)).isTrue();
assertThat(provider.hasField(PyCommon.IS_USING_SHARED_LIBRARY)).isTrue();
assertThat(provider.hasField(PyCommon.IMPORTS)).isTrue();
assertThat(provider.hasField("srcs")).isFalse();
}

@Test
public void simpleFieldsValues() throws Exception {
simpleSources();
assertNoEvents();

ConfiguredTarget helloTarget = getConfiguredTarget("//py:hello");
SkylarkInfo provider = getPythonSkylarkProvider(helloTarget);

SkylarkNestedSet sources =
provider.getValue(PyCommon.TRANSITIVE_PYTHON_SRCS, SkylarkNestedSet.class);
assertThat(prettyArtifactNames(sources.getSet(Artifact.class))).containsExactly("py/hello.py");

assertThat(provider.getValue(PyCommon.IS_USING_SHARED_LIBRARY, Boolean.class)).isFalse();

SkylarkNestedSet imports = provider.getValue(PyCommon.IMPORTS, SkylarkNestedSet.class);
assertThat(imports.getSet(String.class)).containsExactly("__main__/py");
}

@Test
public void transitiveFieldsValues() throws Exception {
simpleSources();
assertNoEvents();

ConfiguredTarget helloTarget = getConfiguredTarget("//py:sayHello");
SkylarkInfo provider = getPythonSkylarkProvider(helloTarget);

SkylarkNestedSet sources =
provider.getValue(PyCommon.TRANSITIVE_PYTHON_SRCS, SkylarkNestedSet.class);
assertThat(prettyArtifactNames(sources.getSet(Artifact.class)))
.containsExactly("py/hello.py", "py/sayHello.py");

assertThat(provider.getValue(PyCommon.IS_USING_SHARED_LIBRARY, Boolean.class)).isFalse();

SkylarkNestedSet imports = provider.getValue(PyCommon.IMPORTS, SkylarkNestedSet.class);
assertThat(imports.getSet(String.class)).containsExactly("__main__/py");
}

private void simpleSources() throws IOException {
scratch.file(
"py/hello.py",
"import os",
"def Hello():",
" print(\"Hello, World!\")",
" print(\"Hello, \" + os.getcwd() + \"!\")");
scratch.file("py/sayHello.py", "from py import hello", "hello.Hello()");
scratch.file(
"py/BUILD",
"py_binary(name=\"sayHello\", srcs=[\"sayHello.py\"], deps=[\":hello\"])",
"py_library(name=\"hello\", srcs=[\"hello.py\"], imports= [\".\"])");
}

private SkylarkInfo getPythonSkylarkProvider(ConfiguredTarget target) {
Object object = target.get(PyCommon.PYTHON_SKYLARK_PROVIDER_NAME);
assertThat(object).isInstanceOf(SkylarkInfo.class);
SkylarkInfo provider = (SkylarkInfo) object;

assertThat(provider).isNotNull();
return provider;
}
}

0 comments on commit aab45c4

Please sign in to comment.