diff --git a/.buildkite/windows/install/reqs.txt b/.buildkite/windows/install/reqs.txt index d4bff5c46644..76a574868235 100644 --- a/.buildkite/windows/install/reqs.txt +++ b/.buildkite/windows/install/reqs.txt @@ -41,8 +41,8 @@ pytest-tornasync pytest-trio pytest-twisted werkzeug -git+git://github.com/ray-project/tune-sklearn@master#tune-sklearn -git+git://github.com/ray-project/xgboost_ray@master#egg=xgboost_ray +git+https://github.com/ray-project/tune-sklearn@master#tune-sklearn +git+https://github.com/ray-project/xgboost_ray@master#egg=xgboost_ray scikit-optimize tensorflow gym diff --git a/doc/requirements-doc.txt b/doc/requirements-doc.txt index 0471886e0159..829cd4c6d4c7 100644 --- a/doc/requirements-doc.txt +++ b/doc/requirements-doc.txt @@ -31,9 +31,9 @@ starlette tabulate uvicorn werkzeug -git+git://github.com/ray-project/tune-sklearn@master#tune-sklearn -git+git://github.com/ray-project/xgboost_ray@master#egg=xgboost_ray -git+git://github.com/ray-project/lightgbm_ray@main#lightgbm_ray +git+https://github.com/ray-project/tune-sklearn@master#tune-sklearn +git+https://github.com/ray-project/xgboost_ray@master#egg=xgboost_ray +git+https://github.com/ray-project/lightgbm_ray@main#lightgbm_ray git+https://github.com/ray-project/ray_lightning#ray_lightning scikit-optimize sphinx-sitemap==2.2.0 diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index 1d8d60e1048f..af90153a506b 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -106,7 +106,8 @@ public void start() { JobConfig.newBuilder() .setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess) .addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker) - .addAllCodeSearchPath(rayConfig.codeSearchPath); + .addAllCodeSearchPath(rayConfig.codeSearchPath) + .setRayNamespace(rayConfig.namespace); RuntimeEnv.Builder runtimeEnvBuilder = RuntimeEnv.newBuilder(); if (!rayConfig.workerEnv.isEmpty()) { // TODO(SongGuyang): Suppport complete runtime env interface for users. diff --git a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java index 511489b0d2f9..d651b08f6607 100644 --- a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java @@ -69,6 +69,8 @@ public LoggerConf(String loggerName, String fileName, String pattern) { public final int numWorkersPerProcess; + public final String namespace; + public final List jvmOptionsForJavaWorker; public final Map workerEnv; @@ -118,6 +120,9 @@ public RayConfig(Config config) { this.jobId = JobId.NIL; } + // Namespace of this job. + namespace = config.getString("ray.job.namespace"); + // jvm options for java workers of this job. jvmOptionsForJavaWorker = config.getStringList("ray.job.jvm-options"); diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index 2c865530e5ce..0d94d72e30b3 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -36,6 +36,10 @@ ray { // key1 : "value1" // key2 : "value2" } + /// The namespace of this job. It's used for isolation between jobs. + /// Jobs in different namespaces cannot access each other. + /// If it's not specified, a randomized value will be used instead. + namespace: "" } // Configurations about raylet diff --git a/java/serve/src/main/java/io/ray/serve/BackendConfig.java b/java/serve/src/main/java/io/ray/serve/DeploymentConfig.java similarity index 65% rename from java/serve/src/main/java/io/ray/serve/BackendConfig.java rename to java/serve/src/main/java/io/ray/serve/DeploymentConfig.java index 6e6f4c3693f0..9f39f4453cb9 100644 --- a/java/serve/src/main/java/io/ray/serve/BackendConfig.java +++ b/java/serve/src/main/java/io/ray/serve/DeploymentConfig.java @@ -3,7 +3,7 @@ import com.google.common.base.Preconditions; import java.io.Serializable; -public class BackendConfig implements Serializable { +public class DeploymentConfig implements Serializable { private static final long serialVersionUID = 4037621960087621036L; @@ -19,13 +19,13 @@ public class BackendConfig implements Serializable { private boolean isCrossLanguage; - private int backendLanguage = 1; + private int deploymentLanguage = 1; public int getNumReplicas() { return numReplicas; } - public BackendConfig setNumReplicas(int numReplicas) { + public DeploymentConfig setNumReplicas(int numReplicas) { this.numReplicas = numReplicas; return this; } @@ -34,7 +34,7 @@ public int getMaxConcurrentQueries() { return maxConcurrentQueries; } - public BackendConfig setMaxConcurrentQueries(int maxConcurrentQueries) { + public DeploymentConfig setMaxConcurrentQueries(int maxConcurrentQueries) { Preconditions.checkArgument(maxConcurrentQueries >= 0, "max_concurrent_queries must be >= 0"); this.maxConcurrentQueries = maxConcurrentQueries; return this; @@ -44,7 +44,7 @@ public Object getUserConfig() { return userConfig; } - public BackendConfig setUserConfig(Object userConfig) { + public DeploymentConfig setUserConfig(Object userConfig) { this.userConfig = userConfig; return this; } @@ -53,7 +53,7 @@ public double getGracefulShutdownWaitLoopS() { return gracefulShutdownWaitLoopS; } - public BackendConfig setGracefulShutdownWaitLoopS(double gracefulShutdownWaitLoopS) { + public DeploymentConfig setGracefulShutdownWaitLoopS(double gracefulShutdownWaitLoopS) { this.gracefulShutdownWaitLoopS = gracefulShutdownWaitLoopS; return this; } @@ -62,7 +62,7 @@ public double getGracefulShutdownTimeoutS() { return gracefulShutdownTimeoutS; } - public BackendConfig setGracefulShutdownTimeoutS(double gracefulShutdownTimeoutS) { + public DeploymentConfig setGracefulShutdownTimeoutS(double gracefulShutdownTimeoutS) { this.gracefulShutdownTimeoutS = gracefulShutdownTimeoutS; return this; } @@ -71,17 +71,17 @@ public boolean isCrossLanguage() { return isCrossLanguage; } - public BackendConfig setCrossLanguage(boolean isCrossLanguage) { + public DeploymentConfig setCrossLanguage(boolean isCrossLanguage) { this.isCrossLanguage = isCrossLanguage; return this; } - public int getBackendLanguage() { - return backendLanguage; + public int getDeploymentLanguage() { + return deploymentLanguage; } - public BackendConfig setBackendLanguage(int backendLanguage) { - this.backendLanguage = backendLanguage; + public DeploymentConfig setDeploymentLanguage(int deploymentLanguage) { + this.deploymentLanguage = deploymentLanguage; return this; } } diff --git a/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java b/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java index 02ed2e510410..65e65460004b 100644 --- a/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java +++ b/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java @@ -13,7 +13,7 @@ public class DeploymentInfo implements Serializable { private Object[] initArgs; - private BackendConfig backendConfig; + private DeploymentConfig deploymentConfig; private DeploymentVersion deploymentVersion; @@ -46,12 +46,12 @@ public DeploymentInfo setInitArgs(Object[] initArgs) { return this; } - public BackendConfig getBackendConfig() { - return backendConfig; + public DeploymentConfig getDeploymentConfig() { + return deploymentConfig; } - public DeploymentInfo setBackendConfig(BackendConfig backendConfig) { - this.backendConfig = backendConfig; + public DeploymentInfo setDeploymentConfig(DeploymentConfig deploymentConfig) { + this.deploymentConfig = deploymentConfig; return this; } diff --git a/java/serve/src/main/java/io/ray/serve/RayServeReplicaImpl.java b/java/serve/src/main/java/io/ray/serve/RayServeReplicaImpl.java index 0c6a955b58c3..0ffdd16f02b4 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeReplicaImpl.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeReplicaImpl.java @@ -32,7 +32,7 @@ public class RayServeReplicaImpl implements RayServeReplica { private String replicaTag; - private BackendConfig config; + private DeploymentConfig config; private AtomicInteger numOngoingRequests = new AtomicInteger(); @@ -60,13 +60,13 @@ public class RayServeReplicaImpl implements RayServeReplica { public RayServeReplicaImpl( Object callable, - BackendConfig backendConfig, + DeploymentConfig deploymentConfig, DeploymentVersion version, BaseActorHandle actorHandle) { this.backendTag = Serve.getReplicaContext().getBackendTag(); this.replicaTag = Serve.getReplicaContext().getReplicaTag(); this.callable = callable; - this.config = backendConfig; + this.config = deploymentConfig; this.version = version; this.checkHealthMethod = getRunnerMethod(Constants.CHECK_HEALTH_METHOD, null, true); @@ -75,7 +75,7 @@ public RayServeReplicaImpl( Map keyListeners = new HashMap<>(); keyListeners.put( new KeyType(LongPollNamespace.BACKEND_CONFIGS, backendTag), - newConfig -> updateBackendConfigs(newConfig)); + newConfig -> updateDeploymentConfigs(newConfig)); this.longPollClient = new LongPollClient(actorHandle, keyListeners); this.longPollClient.start(); registerMetrics(); @@ -348,8 +348,8 @@ public DeploymentVersion reconfigure(Object userConfig) { * * @param newConfig the new configuration of backend */ - private void updateBackendConfigs(Object newConfig) { - config = (BackendConfig) newConfig; + private void updateDeploymentConfigs(Object newConfig) { + config = (DeploymentConfig) newConfig; } public DeploymentVersion getVersion() { diff --git a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java index ed33e315e2b2..34e112ba8283 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java @@ -30,17 +30,17 @@ public RayServeWrappedReplica( String replicaTag, String backendDef, byte[] initArgsbytes, - byte[] backendConfigBytes, + byte[] deploymentConfigBytes, byte[] deploymentVersionBytes, String controllerName) { - // Parse BackendConfig. - BackendConfig backendConfig = ServeProtoUtil.parseBackendConfig(backendConfigBytes); + // Parse DeploymentConfig. + DeploymentConfig deploymentConfig = ServeProtoUtil.parseDeploymentConfig(deploymentConfigBytes); // Parse init args. Object[] initArgs = null; try { - initArgs = parseInitArgs(initArgsbytes, backendConfig); + initArgs = parseInitArgs(initArgsbytes, deploymentConfig); } catch (IOException e) { String errMsg = LogUtil.format( @@ -51,10 +51,11 @@ public RayServeWrappedReplica( throw new RayServeException(errMsg, e); } + // Init replica. init( new DeploymentInfo() .setName(backendTag) - .setBackendConfig(backendConfig) + .setDeploymentConfig(deploymentConfig) .setDeploymentVersion(ServeProtoUtil.parseDeploymentVersion(deploymentVersionBytes)) .setBackendDef(backendDef) .setInitArgs(initArgs), @@ -103,7 +104,7 @@ private void init( this.backend = new RayServeReplicaImpl( callable, - deploymentInfo.getBackendConfig(), + deploymentInfo.getDeploymentConfig(), deploymentInfo.getDeploymentVersion(), optional.get()); this.deploymentInfo = deploymentInfo; @@ -131,14 +132,14 @@ private void enableMetrics(Map config) { }); } - private Object[] parseInitArgs(byte[] initArgsbytes, BackendConfig backendConfig) + private Object[] parseInitArgs(byte[] initArgsbytes, DeploymentConfig deploymentConfig) throws IOException { if (initArgsbytes == null || initArgsbytes.length == 0) { return new Object[0]; } - if (backendConfig.isCrossLanguage()) { + if (deploymentConfig.isCrossLanguage()) { // For other language like Python API, not support Array type. return new Object[] {MessagePackSerializer.decode(initArgsbytes, Object.class)}; } else { @@ -199,10 +200,10 @@ public boolean prepareForShutdown() { public Object reconfigure(Object userConfig) { DeploymentVersion deploymentVersion = backend.reconfigure( - deploymentInfo.getBackendConfig().isCrossLanguage() && userConfig != null + deploymentInfo.getDeploymentConfig().isCrossLanguage() && userConfig != null ? MessagePackSerializer.decode((byte[]) userConfig, Object.class) : userConfig); - return deploymentInfo.getBackendConfig().isCrossLanguage() + return deploymentInfo.getDeploymentConfig().isCrossLanguage() ? ServeProtoUtil.toProtobuf(deploymentVersion).toByteArray() : deploymentVersion; } @@ -215,7 +216,7 @@ public Object reconfigure(Object userConfig) { */ public Object getVersion() { DeploymentVersion deploymentVersion = backend.getVersion(); - return deploymentInfo.getBackendConfig().isCrossLanguage() + return deploymentInfo.getDeploymentConfig().isCrossLanguage() ? ServeProtoUtil.toProtobuf(deploymentVersion).toByteArray() : deploymentVersion; } diff --git a/java/serve/src/main/java/io/ray/serve/ReplicaSet.java b/java/serve/src/main/java/io/ray/serve/ReplicaSet.java index 4e15a64a056e..fb6e9ec8625a 100644 --- a/java/serve/src/main/java/io/ray/serve/ReplicaSet.java +++ b/java/serve/src/main/java/io/ray/serve/ReplicaSet.java @@ -9,7 +9,7 @@ import io.ray.runtime.metric.Metrics; import io.ray.runtime.metric.TagKey; import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.DeploymentConfig; import io.ray.serve.util.CollectionUtil; import java.util.ArrayList; import java.util.HashSet; @@ -48,8 +48,8 @@ public ReplicaSet(String backendTag) { .register()); } - public void setMaxConcurrentQueries(Object backendConfig) { - int newValue = ((BackendConfig) backendConfig).getMaxConcurrentQueries(); + public void setMaxConcurrentQueries(Object deploymentConfig) { + int newValue = ((DeploymentConfig) deploymentConfig).getMaxConcurrentQueries(); if (newValue != this.maxConcurrentQueries) { this.maxConcurrentQueries = newValue; LOGGER.info("ReplicaSet: changing max_concurrent_queries to {}", newValue); diff --git a/java/serve/src/main/java/io/ray/serve/Router.java b/java/serve/src/main/java/io/ray/serve/Router.java index 5ef339d77767..0b744c835a09 100644 --- a/java/serve/src/main/java/io/ray/serve/Router.java +++ b/java/serve/src/main/java/io/ray/serve/Router.java @@ -38,7 +38,7 @@ public Router(BaseActorHandle controllerHandle, String backendTag) { Map keyListeners = new HashMap<>(); keyListeners.put( new KeyType(LongPollNamespace.BACKEND_CONFIGS, backendTag), - backendConfig -> replicaSet.setMaxConcurrentQueries(backendConfig)); // cross language + deploymentConfig -> replicaSet.setMaxConcurrentQueries(deploymentConfig)); // cross language keyListeners.put( new KeyType(LongPollNamespace.REPLICA_HANDLES, backendTag), workerReplicas -> replicaSet.updateWorkerReplicas(workerReplicas)); // cross language diff --git a/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java b/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java index 308391254e10..54d107b717aa 100644 --- a/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java +++ b/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java @@ -47,7 +47,7 @@ public class LongPollClient { static { DESERIALIZERS.put( - LongPollNamespace.BACKEND_CONFIGS, body -> ServeProtoUtil.parseBackendConfig(body)); + LongPollNamespace.BACKEND_CONFIGS, body -> ServeProtoUtil.parseDeploymentConfig(body)); DESERIALIZERS.put( LongPollNamespace.REPLICA_HANDLES, body -> ServeProtoUtil.parseEndpointSet(body)); DESERIALIZERS.put( diff --git a/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java index 61d94ba817bb..7727e0e3255e 100644 --- a/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java +++ b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java @@ -5,11 +5,11 @@ import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import io.ray.runtime.serializer.MessagePackSerializer; -import io.ray.serve.BackendConfig; import io.ray.serve.Constants; +import io.ray.serve.DeploymentConfig; import io.ray.serve.DeploymentVersion; import io.ray.serve.RayServeException; -import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.DeploymentLanguage; import io.ray.serve.generated.EndpointInfo; import io.ray.serve.generated.EndpointSet; import io.ray.serve.generated.LongPollResult; @@ -25,51 +25,54 @@ public class ServeProtoUtil { private static final Gson GSON = new Gson(); - public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) { + public static DeploymentConfig parseDeploymentConfig(byte[] deploymentConfigBytes) { - BackendConfig backendConfig = new BackendConfig(); - if (backendConfigBytes == null) { - return backendConfig; + DeploymentConfig deploymentConfig = new DeploymentConfig(); + if (deploymentConfigBytes == null) { + return deploymentConfig; } - io.ray.serve.generated.BackendConfig pbBackendConfig = null; + io.ray.serve.generated.DeploymentConfig pbDeploymentConfig = null; try { - pbBackendConfig = io.ray.serve.generated.BackendConfig.parseFrom(backendConfigBytes); + pbDeploymentConfig = io.ray.serve.generated.DeploymentConfig.parseFrom(deploymentConfigBytes); } catch (InvalidProtocolBufferException e) { - throw new RayServeException("Failed to parse BackendConfig from protobuf bytes.", e); + throw new RayServeException("Failed to parse DeploymentConfig from protobuf bytes.", e); } - if (pbBackendConfig == null) { - return backendConfig; + if (pbDeploymentConfig == null) { + return deploymentConfig; } - if (pbBackendConfig.getNumReplicas() != 0) { - backendConfig.setNumReplicas(pbBackendConfig.getNumReplicas()); + if (pbDeploymentConfig.getNumReplicas() != 0) { + deploymentConfig.setNumReplicas(pbDeploymentConfig.getNumReplicas()); } - if (pbBackendConfig.getMaxConcurrentQueries() != 0) { - backendConfig.setMaxConcurrentQueries(pbBackendConfig.getMaxConcurrentQueries()); + if (pbDeploymentConfig.getMaxConcurrentQueries() != 0) { + deploymentConfig.setMaxConcurrentQueries(pbDeploymentConfig.getMaxConcurrentQueries()); } - if (pbBackendConfig.getGracefulShutdownWaitLoopS() != 0) { - backendConfig.setGracefulShutdownWaitLoopS(pbBackendConfig.getGracefulShutdownWaitLoopS()); + if (pbDeploymentConfig.getGracefulShutdownWaitLoopS() != 0) { + deploymentConfig.setGracefulShutdownWaitLoopS( + pbDeploymentConfig.getGracefulShutdownWaitLoopS()); } - if (pbBackendConfig.getGracefulShutdownTimeoutS() != 0) { - backendConfig.setGracefulShutdownTimeoutS(pbBackendConfig.getGracefulShutdownTimeoutS()); + if (pbDeploymentConfig.getGracefulShutdownTimeoutS() != 0) { + deploymentConfig.setGracefulShutdownTimeoutS( + pbDeploymentConfig.getGracefulShutdownTimeoutS()); } - backendConfig.setCrossLanguage(pbBackendConfig.getIsCrossLanguage()); - if (pbBackendConfig.getBackendLanguage() == BackendLanguage.UNRECOGNIZED) { + deploymentConfig.setCrossLanguage(pbDeploymentConfig.getIsCrossLanguage()); + if (pbDeploymentConfig.getDeploymentLanguage() == DeploymentLanguage.UNRECOGNIZED) { throw new RayServeException( LogUtil.format( - "Unrecognized backend language {}. Backend language must be in {}.", - pbBackendConfig.getBackendLanguageValue(), - Lists.newArrayList(BackendLanguage.values()))); - } - backendConfig.setBackendLanguage(pbBackendConfig.getBackendLanguageValue()); - if (pbBackendConfig.getUserConfig() != null && pbBackendConfig.getUserConfig().size() != 0) { - backendConfig.setUserConfig( + "Unrecognized deployment language {}. Deployment language must be in {}.", + pbDeploymentConfig.getDeploymentLanguage(), + Lists.newArrayList(DeploymentLanguage.values()))); + } + deploymentConfig.setDeploymentLanguage(pbDeploymentConfig.getDeploymentLanguageValue()); + if (pbDeploymentConfig.getUserConfig() != null + && pbDeploymentConfig.getUserConfig().size() != 0) { + deploymentConfig.setUserConfig( MessagePackSerializer.decode( - pbBackendConfig.getUserConfig().toByteArray(), Object.class)); + pbDeploymentConfig.getUserConfig().toByteArray(), Object.class)); } - return backendConfig; + return deploymentConfig; } public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes) { diff --git a/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java b/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java index b9266762e515..c8330445bc21 100644 --- a/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java +++ b/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java @@ -5,7 +5,7 @@ import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.DeploymentLanguage; import io.ray.serve.generated.EndpointInfo; import io.ray.serve.util.CommonUtil; import java.io.IOException; @@ -52,8 +52,8 @@ public void test() throws IOException { DeploymentInfo deploymentInfo = new DeploymentInfo() .setName(deploymentName) - .setBackendConfig( - new BackendConfig().setBackendLanguage(BackendLanguage.JAVA.getNumber())) + .setDeploymentConfig( + new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber())) .setDeploymentVersion(new DeploymentVersion(version)) .setBackendDef(DummyBackendReplica.class.getName()); diff --git a/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java b/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java index 5d62fe786fb1..edcd1dee683c 100644 --- a/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java +++ b/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java @@ -5,7 +5,7 @@ import io.ray.api.Ray; import io.ray.serve.api.Serve; import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.DeploymentLanguage; import org.testng.Assert; import org.testng.annotations.Test; @@ -28,15 +28,15 @@ public void test() { Ray.actor(DummyServeController::new).setName(controllerName).remote(); // Replica - BackendConfig backendConfig = - new BackendConfig().setBackendLanguage(BackendLanguage.JAVA.getNumber()); + DeploymentConfig deploymentConfig = + new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber()); Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; DeploymentInfo deploymentInfo = new DeploymentInfo() .setName(backendTag) - .setBackendConfig(backendConfig) + .setDeploymentConfig(deploymentConfig) .setDeploymentVersion(new DeploymentVersion(version)) .setBackendDef("io.ray.serve.ReplicaContext") .setInitArgs(initArgs); diff --git a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java index 94b38d960ef8..49cf94def8a2 100644 --- a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java +++ b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java @@ -5,7 +5,7 @@ import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.serve.api.Serve; -import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.DeploymentLanguage; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.generated.RequestWrapper; import java.io.IOException; @@ -30,12 +30,12 @@ public void test() throws IOException { ActorHandle controllerHandle = Ray.actor(DummyServeController::new).setName(controllerName).remote(); - BackendConfig backendConfig = - new BackendConfig().setBackendLanguage(BackendLanguage.JAVA.getNumber()); + DeploymentConfig deploymentConfig = + new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber()); DeploymentInfo deploymentInfo = new DeploymentInfo() .setName(backendTag) - .setBackendConfig(backendConfig) + .setDeploymentConfig(deploymentConfig) .setDeploymentVersion(new DeploymentVersion(version)) .setBackendDef(DummyBackendReplica.class.getName()); diff --git a/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java b/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java index 93fa8ba6580a..6d8c60eb9724 100644 --- a/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java +++ b/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java @@ -5,7 +5,7 @@ import io.ray.api.Ray; import io.ray.serve.api.Serve; import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.DeploymentLanguage; import io.ray.serve.generated.RequestMetadata; import java.util.Map; import java.util.Set; @@ -20,8 +20,8 @@ public class ReplicaSetTest { @Test public void setMaxConcurrentQueriesTest() { ReplicaSet replicaSet = new ReplicaSet(backendTag); - io.ray.serve.generated.BackendConfig.Builder builder = - io.ray.serve.generated.BackendConfig.newBuilder(); + io.ray.serve.generated.DeploymentConfig.Builder builder = + io.ray.serve.generated.DeploymentConfig.newBuilder(); builder.setMaxConcurrentQueries(200); replicaSet.setMaxConcurrentQueries(builder.build()); @@ -56,15 +56,15 @@ public void assignReplicaTest() { Ray.actor(DummyServeController::new).setName(controllerName).remote(); // Replica - BackendConfig backendConfig = - new BackendConfig().setBackendLanguage(BackendLanguage.JAVA.getNumber()); + DeploymentConfig deploymentConfig = + new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber()); Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; DeploymentInfo deploymentInfo = new DeploymentInfo() .setName(backendTag) - .setBackendConfig(backendConfig) + .setDeploymentConfig(deploymentConfig) .setDeploymentVersion(new DeploymentVersion(version)) .setBackendDef("io.ray.serve.ReplicaContext") .setInitArgs(initArgs); diff --git a/java/serve/src/test/java/io/ray/serve/RouterTest.java b/java/serve/src/test/java/io/ray/serve/RouterTest.java index 2bd59f00f791..de8faf8d9852 100644 --- a/java/serve/src/test/java/io/ray/serve/RouterTest.java +++ b/java/serve/src/test/java/io/ray/serve/RouterTest.java @@ -5,7 +5,7 @@ import io.ray.api.Ray; import io.ray.serve.api.Serve; import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.DeploymentLanguage; import io.ray.serve.generated.RequestMetadata; import org.apache.commons.lang3.RandomStringUtils; import org.testng.Assert; @@ -30,15 +30,15 @@ public void test() { Ray.actor(DummyServeController::new).setName(controllerName).remote(); // Replica - BackendConfig backendConfig = - new BackendConfig().setBackendLanguage(BackendLanguage.JAVA.getNumber()); + DeploymentConfig deploymentConfig = + new DeploymentConfig().setDeploymentLanguage(DeploymentLanguage.JAVA.getNumber()); Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; DeploymentInfo deploymentInfo = new DeploymentInfo() .setName(backendTag) - .setBackendConfig(backendConfig) + .setDeploymentConfig(deploymentConfig) .setDeploymentVersion(new DeploymentVersion(version)) .setBackendDef("io.ray.serve.ReplicaContext") .setInitArgs(initArgs); diff --git a/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java b/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java index 7141af418e1a..4347e191ed0f 100644 --- a/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java +++ b/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java @@ -1,7 +1,7 @@ package io.ray.serve.poll; import com.google.protobuf.ByteString; -import io.ray.serve.BackendConfig; +import io.ray.serve.DeploymentConfig; import io.ray.serve.generated.UpdatedObject; import java.util.HashMap; import java.util.Map; @@ -19,19 +19,19 @@ public void test() throws Throwable { KeyType keyType = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "backendTag"); Map keyListeners = new HashMap<>(); keyListeners.put( - keyType, (object) -> a[0] = String.valueOf(((BackendConfig) object).getNumReplicas())); + keyType, (object) -> a[0] = String.valueOf(((DeploymentConfig) object).getNumReplicas())); // Initialize LongPollClient. LongPollClient longPollClient = new LongPollClient(null, keyListeners); // Construct updated object. - io.ray.serve.generated.BackendConfig.Builder backendConfig = - io.ray.serve.generated.BackendConfig.newBuilder(); - backendConfig.setNumReplicas(20); + io.ray.serve.generated.DeploymentConfig.Builder deploymentConfig = + io.ray.serve.generated.DeploymentConfig.newBuilder(); + deploymentConfig.setNumReplicas(20); int snapshotId = 10; UpdatedObject.Builder updatedObject = UpdatedObject.newBuilder(); updatedObject.setSnapshotId(snapshotId); - updatedObject.setObjectSnapshot(ByteString.copyFrom(backendConfig.build().toByteArray())); + updatedObject.setObjectSnapshot(ByteString.copyFrom(deploymentConfig.build().toByteArray())); // Process update. Map updates = new HashMap<>(); @@ -41,8 +41,8 @@ public void test() throws Throwable { // Validation. Assert.assertEquals(longPollClient.getSnapshotIds().get(keyType).intValue(), snapshotId); Assert.assertEquals( - ((BackendConfig) longPollClient.getObjectSnapshots().get(keyType)).getNumReplicas(), - backendConfig.getNumReplicas()); - Assert.assertEquals(a[0], String.valueOf(backendConfig.getNumReplicas())); + ((DeploymentConfig) longPollClient.getObjectSnapshots().get(keyType)).getNumReplicas(), + deploymentConfig.getNumReplicas()); + Assert.assertEquals(a[0], String.valueOf(deploymentConfig.getNumReplicas())); } } diff --git a/java/serve/src/test/java/io/ray/serve/util/ServeProtoUtilTest.java b/java/serve/src/test/java/io/ray/serve/util/ServeProtoUtilTest.java index fecf8cda8105..207d55e41e74 100644 --- a/java/serve/src/test/java/io/ray/serve/util/ServeProtoUtilTest.java +++ b/java/serve/src/test/java/io/ray/serve/util/ServeProtoUtilTest.java @@ -1,7 +1,7 @@ package io.ray.serve.util; import com.google.protobuf.ByteString; -import io.ray.serve.BackendConfig; +import io.ray.serve.DeploymentConfig; import io.ray.serve.DeploymentVersion; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.generated.RequestWrapper; @@ -12,20 +12,21 @@ public class ServeProtoUtilTest { @Test - public void parseBackendConfigTest() { + public void parseDeploymentConfigTest() { int numReplicas = 10; - io.ray.serve.generated.BackendConfig pbBackendConfig = - io.ray.serve.generated.BackendConfig.newBuilder().setNumReplicas(numReplicas).build(); + io.ray.serve.generated.DeploymentConfig pbDeploymentConfig = + io.ray.serve.generated.DeploymentConfig.newBuilder().setNumReplicas(numReplicas).build(); - BackendConfig backendConfig = ServeProtoUtil.parseBackendConfig(pbBackendConfig.toByteArray()); - Assert.assertNotNull(backendConfig); - Assert.assertEquals(backendConfig.getNumReplicas(), numReplicas); - Assert.assertEquals(backendConfig.getBackendLanguage(), 0); - Assert.assertEquals(backendConfig.getGracefulShutdownTimeoutS(), 20); - Assert.assertEquals(backendConfig.getGracefulShutdownWaitLoopS(), 2); - Assert.assertEquals(backendConfig.getMaxConcurrentQueries(), 100); - Assert.assertNull(backendConfig.getUserConfig()); - Assert.assertEquals(backendConfig.isCrossLanguage(), false); + DeploymentConfig deploymentConfig = + ServeProtoUtil.parseDeploymentConfig(pbDeploymentConfig.toByteArray()); + Assert.assertNotNull(deploymentConfig); + Assert.assertEquals(deploymentConfig.getNumReplicas(), numReplicas); + Assert.assertEquals(deploymentConfig.getDeploymentLanguage(), 0); + Assert.assertEquals(deploymentConfig.getGracefulShutdownTimeoutS(), 20); + Assert.assertEquals(deploymentConfig.getGracefulShutdownWaitLoopS(), 2); + Assert.assertEquals(deploymentConfig.getMaxConcurrentQueries(), 100); + Assert.assertNull(deploymentConfig.getUserConfig()); + Assert.assertEquals(deploymentConfig.isCrossLanguage(), false); } @Test diff --git a/java/test/src/main/java/io/ray/test/MultiDriverTest.java b/java/test/src/main/java/io/ray/test/MultiDriverTest.java index cd2057279467..99425989e323 100644 --- a/java/test/src/main/java/io/ray/test/MultiDriverTest.java +++ b/java/test/src/main/java/io/ray/test/MultiDriverTest.java @@ -3,7 +3,6 @@ import io.ray.api.ActorHandle; import io.ray.api.ObjectRef; import io.ray.api.Ray; -import io.ray.runtime.config.RayConfig; import io.ray.runtime.util.SystemUtil; import java.io.BufferedReader; import java.io.IOException; @@ -102,7 +101,6 @@ public void testMultiDrivers() throws InterruptedException, IOException { } private Process startDriver() throws IOException { - RayConfig rayConfig = TestUtils.getRuntime().getRayConfig(); ProcessBuilder builder = TestUtils.buildDriver(MultiDriverTest.class, null); builder.redirectError(Redirect.INHERIT); return builder.start(); diff --git a/java/test/src/main/java/io/ray/test/NamespaceTest.java b/java/test/src/main/java/io/ray/test/NamespaceTest.java new file mode 100644 index 000000000000..1aa0ca6f5e52 --- /dev/null +++ b/java/test/src/main/java/io/ray/test/NamespaceTest.java @@ -0,0 +1,71 @@ +package io.ray.test; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import java.io.IOException; +import java.util.NoSuchElementException; +import java.util.concurrent.TimeUnit; +import org.testng.Assert; +import org.testng.annotations.Test; + +@Test(groups = "cluster") +public class NamespaceTest { + + private static class A { + public String hello() { + return "hello"; + } + } + + /// This case tests that actor cannot be accessed in different namespaces. + public void testIsolationBetweenNamespaces() throws IOException, InterruptedException { + System.setProperty("ray.job.namespace", "test2"); + testIsolation( + () -> + Assert.assertThrows( + NoSuchElementException.class, + () -> { + Ray.getGlobalActor("a").get(); + })); + } + + /// This case tests that actor can be accessed between different jobs but in the same namespace. + public void testIsolationInTheSameNamespaces() throws IOException, InterruptedException { + System.setProperty("ray.job.namespace", "test1"); + testIsolation( + () -> { + ActorHandle a = (ActorHandle) Ray.getGlobalActor("a").get(); + Assert.assertEquals("hello", a.task(A::hello).remote().get()); + }); + } + + public static void main(String[] args) throws IOException, InterruptedException { + System.setProperty("ray.job.namespace", "test1"); + Ray.init(); + ActorHandle a = Ray.actor(A::new).setGlobalName("a").remote(); + Assert.assertEquals("hello", a.task(A::hello).remote().get()); + /// Because we don't support long running job yet, so sleep to don't destroy + /// it for a while. Otherwise the actor created in this job will be destroyed + /// as well. + TimeUnit.SECONDS.sleep(10); + Ray.shutdown(); + } + + private void testIsolation(Runnable runnable) throws IOException, InterruptedException { + Process driver = null; + try { + Ray.init(); + ProcessBuilder builder = TestUtils.buildDriver(NamespaceTest.class, null); + builder.redirectError(ProcessBuilder.Redirect.INHERIT); + driver = builder.start(); + // Wait for driver to start. + TimeUnit.SECONDS.sleep(3); + runnable.run(); + } finally { + if (driver != null) { + driver.waitFor(1, TimeUnit.SECONDS); + } + Ray.shutdown(); + } + } +} diff --git a/python/build-wheel-macos-arm64.sh b/python/build-wheel-macos-arm64.sh new file mode 100644 index 000000000000..097834570469 --- /dev/null +++ b/python/build-wheel-macos-arm64.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +# Cause the script to exit if a single command fails. +set -e + +# Show explicitly which commands are currently running. +set -x + +DOWNLOAD_DIR=python_downloads + +NODE_VERSION="14" +PY_VERSIONS=("3.8.2" + "3.9.1") +PY_MMS=("3.8" + "3.9") + + +if [[ -n "${SKIP_DEP_RES}" ]]; then + ./ci/travis/install-bazel.sh + + curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash + curl -o- https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh | bash + source ~/.bash_profile + conda init bash + source ~/.bash_profile + + # Use the latest version of Node.js in order to build the dashboard. + source "$HOME"/.nvm/nvm.sh + nvm install $NODE_VERSION + nvm use $NODE_VERSION +fi + +# Build the dashboard so its static assets can be included in the wheel. +pushd python/ray/dashboard/client + npm ci + npm run build +popd + +mkdir -p .whl + +for ((i=0; i<${#PY_VERSIONS[@]}; ++i)); do + PY_MM=${PY_MMS[i]} + CONDA_ENV_NAME="p$PY_MM" + + # The -f flag is passed twice to also run git clean in the arrow subdirectory. + # The -d flag removes directories. The -x flag ignores the .gitignore file, + # and the -e flag ensures that we don't remove the .whl directory. + git clean -f -f -x -d -e .whl -e $DOWNLOAD_DIR -e python/ray/dashboard/client -e dashboard/client + + + # Install python using conda. This should be easier to produce consistent results in buildkite and locally. + source ~/.bash_profile + conda create -y -n "$CONDA_ENV_NAME" + conda activate "$CONDA_ENV_NAME" + conda remove -y python || true + conda install -y python="$PY_MM" + + # NOTE: We expect conda to set the PATH properly. + PIP_CMD=pip + PYTHON_EXE=python + + $PIP_CMD install --upgrade pip + + if [ -z "${TRAVIS_COMMIT}" ]; then + TRAVIS_COMMIT=${BUILDKITE_COMMIT} + fi + + pushd python + # Setuptools on CentOS is too old to install arrow 0.9.0, therefore we upgrade. + $PIP_CMD install --upgrade setuptools + $PIP_CMD install -q cython==0.29.15 + # Install wheel to avoid the error "invalid command 'bdist_wheel'". + $PIP_CMD install -q wheel + # Set the commit SHA in __init__.py. + if [ -n "$TRAVIS_COMMIT" ]; then + echo "TRAVIS_COMMIT variable detected. ray.__commit__ will be set to $TRAVIS_COMMIT" + else + echo "TRAVIS_COMMIT variable is not set, getting the current commit from git." + TRAVIS_COMMIT=$(git rev-parse HEAD) + fi + + sed -i .bak "s/{{RAY_COMMIT_SHA}}/$TRAVIS_COMMIT/g" ray/__init__.py && rm ray/__init__.py.bak + + # Add the correct Python to the path and build the wheel. This is only + # needed so that the installation finds the cython executable. + # build ray wheel + $PYTHON_EXE setup.py bdist_wheel + # build ray-cpp wheel + RAY_INSTALL_CPP=1 $PYTHON_EXE setup.py bdist_wheel + mv dist/*.whl ../.whl/ + popd +done diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 06bacdf2b2d1..952e9aee57fa 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -17,7 +17,7 @@ from ray.actor import ActorHandle from ray.serve.common import BackendInfo, GoalId, ReplicaTag -from ray.serve.config import (AutoscalingConfig, BackendConfig, HTTPOptions, +from ray.serve.config import (AutoscalingConfig, DeploymentConfig, HTTPOptions, ReplicaConfig) from ray.serve.constants import (DEFAULT_CHECKPOINT_PATH, HTTP_PROXY_TIMEOUT, SERVE_CONTROLLER_NAME, MAX_CACHED_HANDLES, @@ -187,18 +187,19 @@ def _wait_for_goal(self, return False @_ensure_connected - def deploy(self, - name: str, - backend_def: Union[Callable, Type[Callable], str], - init_args: Tuple[Any], - init_kwargs: Dict[Any, Any], - ray_actor_options: Optional[Dict] = None, - config: Optional[Union[BackendConfig, Dict[str, Any]]] = None, - version: Optional[str] = None, - prev_version: Optional[str] = None, - route_prefix: Optional[str] = None, - url: str = "", - _blocking: Optional[bool] = True) -> Optional[GoalId]: + def deploy( + self, + name: str, + deployment_def: Union[Callable, Type[Callable], str], + init_args: Tuple[Any], + init_kwargs: Dict[Any, Any], + ray_actor_options: Optional[Dict] = None, + config: Optional[Union[DeploymentConfig, Dict[str, Any]]] = None, + version: Optional[str] = None, + prev_version: Optional[str] = None, + route_prefix: Optional[str] = None, + url: str = "", + _blocking: Optional[bool] = True) -> Optional[GoalId]: if config is None: config = {} if ray_actor_options is None: @@ -212,23 +213,25 @@ def deploy(self, ray_actor_options["runtime_env"] = curr_job_env replica_config = ReplicaConfig( - backend_def, + deployment_def, init_args=init_args, init_kwargs=init_kwargs, ray_actor_options=ray_actor_options) if isinstance(config, dict): - backend_config = BackendConfig.parse_obj(config) - elif isinstance(config, BackendConfig): - backend_config = config + deployment_config = DeploymentConfig.parse_obj(config) + elif isinstance(config, DeploymentConfig): + deployment_config = config else: - raise TypeError("config must be a BackendConfig or a dictionary.") + raise TypeError( + "config must be a DeploymentConfig or a dictionary.") goal_id, updating = ray.get( - self._controller.deploy.remote( - name, backend_config.to_proto_bytes(), replica_config, version, - prev_version, route_prefix, - ray.get_runtime_context().job_id)) + self._controller.deploy.remote(name, + deployment_config.to_proto_bytes(), + replica_config, version, + prev_version, route_prefix, + ray.get_runtime_context().job_id)) tag = f"component=serve deployment={name}" @@ -626,7 +629,7 @@ class Deployment: def __init__(self, func_or_class: Callable, name: str, - config: BackendConfig, + config: DeploymentConfig, version: Optional[str] = None, prev_version: Optional[str] = None, init_args: Optional[Tuple[Any]] = None, @@ -1021,7 +1024,7 @@ class MyDeployment: raise ValueError("Manually setting num_replicas is not allowed when " "_autoscaling_config is provided.") - config = BackendConfig() + config = DeploymentConfig() if num_replicas is not None: config.num_replicas = num_replicas @@ -1085,9 +1088,10 @@ def get_deployment(name: str) -> Deployment: raise KeyError(f"Deployment {name} was not found. " "Did you call Deployment.deploy()?") return Deployment( - cloudpickle.loads(backend_info.replica_config.serialized_backend_def), + cloudpickle.loads( + backend_info.replica_config.serialized_deployment_def), name, - backend_info.backend_config, + backend_info.deployment_config, version=backend_info.version, init_args=backend_info.replica_config.init_args, init_kwargs=backend_info.replica_config.init_kwargs, @@ -1109,9 +1113,9 @@ def list_deployments() -> Dict[str, Deployment]: for name, (backend_info, route_prefix) in infos.items(): deployments[name] = Deployment( cloudpickle.loads( - backend_info.replica_config.serialized_backend_def), + backend_info.replica_config.serialized_deployment_def), name, - backend_info.backend_config, + backend_info.deployment_config, version=backend_info.version, init_args=backend_info.replica_config.init_args, init_kwargs=backend_info.replica_config.init_kwargs, diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 8f23bc3fba75..7b00235f34af 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -11,7 +11,7 @@ from ray.serve.async_goal_manager import AsyncGoalManager from ray.serve.common import (BackendInfo, BackendTag, Duration, GoalId, ReplicaTag, ReplicaName, RunningReplicaInfo) -from ray.serve.config import BackendConfig +from ray.serve.config import DeploymentConfig from ray.serve.constants import ( CONTROLLER_STARTUP_GRACE_PERIOD_S, SERVE_CONTROLLER_NAME, SERVE_PROXY_NAME, MAX_DEPLOYMENT_CONSTRUCTOR_RETRY_COUNT, MAX_NUM_DELETED_DEPLOYMENTS) @@ -156,9 +156,9 @@ def start(self, backend_info: BackendInfo, version: DeploymentVersion): """ self._actor_resources = backend_info.replica_config.resource_dict self._max_concurrent_queries = ( - backend_info.backend_config.max_concurrent_queries) + backend_info.deployment_config.max_concurrent_queries) self._graceful_shutdown_timeout_s = ( - backend_info.backend_config.graceful_shutdown_timeout_s) + backend_info.deployment_config.graceful_shutdown_timeout_s) if USE_PLACEMENT_GROUP: self._placement_group = self.create_placement_group( self._placement_group_name, self._actor_resources) @@ -177,11 +177,11 @@ def start(self, backend_info: BackendInfo, version: DeploymentVersion): self.backend_tag, self.replica_tag, backend_info.replica_config.init_args, backend_info.replica_config.init_kwargs, - backend_info.backend_config.to_proto_bytes(), version, + backend_info.deployment_config.to_proto_bytes(), version, self._controller_name, self._detached) self._ready_obj_ref = self._actor_handle.reconfigure.remote( - backend_info.backend_config.user_config) + backend_info.deployment_config.user_config) def update_user_config(self, user_config: Any): """ @@ -242,11 +242,11 @@ def check_ready( return ReplicaStartupStatus.PENDING, None elif len(ready) > 0: try: - backend_config, version = ray.get(ready)[0] + deployment_config, version = ray.get(ready)[0] self._max_concurrent_queries = ( - backend_config.max_concurrent_queries) + deployment_config.max_concurrent_queries) self._graceful_shutdown_timeout_s = ( - backend_config.graceful_shutdown_timeout_s) + deployment_config.graceful_shutdown_timeout_s) except Exception: return ReplicaStartupStatus.FAILED, None @@ -726,11 +726,11 @@ def _set_backend_goal(self, backend_info: Optional[BackendInfo]) -> None: if backend_info is not None: self._target_info = backend_info - self._target_replicas = backend_info.backend_config.num_replicas + self._target_replicas = backend_info.deployment_config.num_replicas self._target_version = DeploymentVersion( backend_info.version, - user_config=backend_info.backend_config.user_config) + user_config=backend_info.deployment_config.user_config) else: self._target_replicas = 0 @@ -746,7 +746,7 @@ def deploy(self, backend_info: BackendInfo) -> Tuple[Optional[GoalId], bool]: """Deploy the backend. - If the backend already exists with the same version and BackendConfig, + If the backend already exists with the same version and config, this is a no-op and returns the GoalId corresponding to the existing update if there is one. @@ -760,7 +760,8 @@ def deploy(self, # Redeploying should not reset the deployment's start time. backend_info.start_time_ms = existing_info.start_time_ms - if (existing_info.backend_config == backend_info.backend_config + if (existing_info.deployment_config == + backend_info.deployment_config and backend_info.version is not None and existing_info.version == backend_info.version): return self._curr_goal, False @@ -1291,19 +1292,20 @@ def get_running_replica_infos( return replicas - def get_backend_configs(self, - filter_tag: Optional[BackendTag] = None, - include_deleted: Optional[bool] = False - ) -> Dict[BackendTag, BackendConfig]: - configs: Dict[BackendTag, BackendConfig] = {} + def get_deployment_configs(self, + filter_tag: Optional[BackendTag] = None, + include_deleted: Optional[bool] = False + ) -> Dict[BackendTag, DeploymentConfig]: + configs: Dict[BackendTag, DeploymentConfig] = {} for backend_tag, backend_state in self._backend_states.items(): if filter_tag is None or backend_tag == filter_tag: - configs[backend_tag] = backend_state.target_info.backend_config + configs[ + backend_tag] = backend_state.target_info.deployment_config if include_deleted: for backend_tag, info in self._deleted_backend_metadata.items(): if filter_tag is None or backend_tag == filter_tag: - configs[backend_tag] = info.backend_config + configs[backend_tag] = info.deployment_config return configs @@ -1322,7 +1324,7 @@ def deploy_backend(self, backend_tag: BackendTag, backend_info: BackendInfo ) -> Tuple[Optional[GoalId], bool]: """Deploy the backend. - If the backend already exists with the same version and BackendConfig, + If the backend already exists with the same version and config, this is a no-op and returns the GoalId corresponding to the existing update if there is one. diff --git a/python/ray/serve/common.py b/python/ray/serve/common.py index 3c236a668d26..eb0a5dead6ea 100644 --- a/python/ray/serve/common.py +++ b/python/ray/serve/common.py @@ -5,7 +5,7 @@ from uuid import UUID from ray.actor import ActorClass, ActorHandle -from ray.serve.config import BackendConfig, ReplicaConfig +from ray.serve.config import DeploymentConfig, ReplicaConfig from ray.serve.autoscaling_policy import AutoscalingPolicy BackendTag = str @@ -23,7 +23,7 @@ class EndpointInfo: class BackendInfo: def __init__(self, - backend_config: BackendConfig, + deployment_config: DeploymentConfig, replica_config: ReplicaConfig, start_time_ms: int, actor_def: Optional[ActorClass] = None, @@ -31,7 +31,7 @@ def __init__(self, deployer_job_id: "Optional[ray._raylet.JobID]" = None, end_time_ms: Optional[int] = None, autoscaling_policy: Optional[AutoscalingPolicy] = None): - self.backend_config = backend_config + self.deployment_config = deployment_config self.replica_config = replica_config # The time when .deploy() was first called for this deployment. self.start_time_ms = start_time_ms diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 7f010d3ff37e..60a1e1344dd8 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -7,10 +7,10 @@ from google.protobuf.json_format import MessageToDict from pydantic import BaseModel, NonNegativeFloat, PositiveInt, validator from ray.serve.constants import DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT -from ray.serve.generated.serve_pb2 import (BackendConfig as BackendConfigProto, - AutoscalingConfig as - AutoscalingConfigProto) -from ray.serve.generated.serve_pb2 import BackendLanguage +from ray.serve.generated.serve_pb2 import ( + DeploymentConfig as DeploymentConfigProto, AutoscalingConfig as + AutoscalingConfigProto) +from ray.serve.generated.serve_pb2 import DeploymentLanguage from ray import cloudpickle as cloudpickle @@ -56,21 +56,21 @@ class AutoscalingConfig(BaseModel): # TODO(architkulkarni): Add pydantic validation. E.g. max_replicas>=min -class BackendConfig(BaseModel): - """Configuration options for a backend, to be set by the user. +class DeploymentConfig(BaseModel): + """Configuration options for a deployment, to be set by the user. Args: num_replicas (Optional[int]): The number of processes to start up that - will handle requests to this backend. Defaults to 1. + will handle requests to this deployment. Defaults to 1. max_concurrent_queries (Optional[int]): The maximum number of queries - that will be sent to a replica of this backend without receiving a - response. Defaults to 100. + that will be sent to a replica of this deployment without receiving + a response. Defaults to 100. user_config (Optional[Any]): Arguments to pass to the reconfigure - method of the backend. The reconfigure method is called if + method of the deployment. The reconfigure method is called if user_config is not None. graceful_shutdown_wait_loop_s (Optional[float]): Duration - that backend workers will wait until there is no more work to be - done before shutting down. Defaults to 2s. + that deployment replicas will wait until there is no more work to + be done before shutting down. Defaults to 2s. graceful_shutdown_timeout_s (Optional[float]): Controller waits for this duration to forcefully kill the replica for shutdown. Defaults to 20s. @@ -107,15 +107,15 @@ def to_proto_bytes(self): if data.get("autoscaling_config"): data["autoscaling_config"] = AutoscalingConfigProto( **data["autoscaling_config"]) - return BackendConfigProto( + return DeploymentConfigProto( is_cross_language=False, - backend_language=BackendLanguage.PYTHON, + deployment_language=DeploymentLanguage.PYTHON, **data, ).SerializeToString() @classmethod def from_proto_bytes(cls, proto_bytes: bytes): - proto = BackendConfigProto.FromString(proto_bytes) + proto = DeploymentConfigProto.FromString(proto_bytes) data = MessageToDict( proto, including_default_value_fields=True, @@ -131,37 +131,36 @@ def from_proto_bytes(cls, proto_bytes: bytes): # Delete fields which are only used in protobuf, not in Python. del data["is_cross_language"] - del data["backend_language"] + del data["deployment_language"] return cls(**data) class ReplicaConfig: def __init__(self, - backend_def: Callable, + deployment_def: Callable, init_args: Optional[Tuple[Any]] = None, init_kwargs: Optional[Dict[Any, Any]] = None, ray_actor_options=None): - # Validate that backend_def is an import path, function, or class. - if isinstance(backend_def, str): - self.func_or_class_name = backend_def - pass - elif inspect.isfunction(backend_def): - self.func_or_class_name = backend_def.__name__ + # Validate that deployment_def is an import path, function, or class. + if isinstance(deployment_def, str): + self.func_or_class_name = deployment_def + elif inspect.isfunction(deployment_def): + self.func_or_class_name = deployment_def.__name__ if init_args: raise ValueError( - "init_args not supported for function backend.") + "init_args not supported for function deployments.") if init_kwargs: raise ValueError( - "init_kwargs not supported for function backend.") - elif inspect.isclass(backend_def): - self.func_or_class_name = backend_def.__name__ + "init_kwargs not supported for function deployments.") + elif inspect.isclass(deployment_def): + self.func_or_class_name = deployment_def.__name__ else: raise TypeError( - "Backend must be an import path, function or class, it is {}.". - format(type(backend_def))) + "Deployment must be a function or class, it is {}.".format( + type(deployment_def))) - self.serialized_backend_def = cloudpickle.dumps(backend_def) + self.serialized_deployment_def = cloudpickle.dumps(deployment_def) self.init_args = init_args if init_args is not None else () self.init_kwargs = init_kwargs if init_kwargs is not None else {} if ray_actor_options is None: @@ -175,7 +174,7 @@ def __init__(self, def _validate(self): if "placement_group" in self.ray_actor_options: - raise ValueError("Providing placement_group for backend actors " + raise ValueError("Providing placement_group for deployment actors " "is not currently supported.") if not isinstance(self.ray_actor_options, dict): diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index e1dd906a398a..cbebddafa0c4 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -20,7 +20,7 @@ NodeId, RunningReplicaInfo, ) -from ray.serve.config import BackendConfig, HTTPOptions, ReplicaConfig +from ray.serve.config import DeploymentConfig, HTTPOptions, ReplicaConfig from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY from ray.serve.endpoint_state import EndpointState from ray.serve.http_state import HTTPState @@ -142,7 +142,7 @@ def autoscale(self) -> None: """Updates autoscaling deployments with calculated num_replicas.""" for deployment_name, (backend_info, route_prefix) in self.list_deployments().items(): - backend_config = backend_info.backend_config + deployment_config = backend_info.deployment_config autoscaling_policy = backend_info.autoscaling_policy if autoscaling_policy is None: @@ -166,16 +166,16 @@ def autoscale(self) -> None: if len(current_num_ongoing_requests) == 0: continue - new_backend_config = backend_config.copy() + new_deployment_config = deployment_config.copy() decision_num_replicas = ( autoscaling_policy.get_decision_num_replicas( current_num_ongoing_requests=current_num_ongoing_requests, - curr_target_num_replicas=backend_config.num_replicas)) - new_backend_config.num_replicas = decision_num_replicas + curr_target_num_replicas=deployment_config.num_replicas)) + new_deployment_config.num_replicas = decision_num_replicas new_backend_info = copy(backend_info) - new_backend_info.backend_config = new_backend_config + new_backend_info.deployment_config = new_deployment_config goal_id, updating = self.backend_state_manager.deploy_backend( deployment_name, new_backend_info) @@ -275,7 +275,7 @@ async def shutdown(self) -> List[GoalId]: def deploy(self, name: str, - backend_config_proto_bytes: bytes, + deployment_config_proto_bytes: bytes, replica_config: ReplicaConfig, version: Optional[str], prev_version: Optional[str], @@ -285,8 +285,8 @@ def deploy(self, if route_prefix is not None: assert route_prefix.startswith("/") - backend_config = BackendConfig.from_proto_bytes( - backend_config_proto_bytes) + deployment_config = DeploymentConfig.from_proto_bytes( + deployment_config_proto_bytes) if prev_version is not None: existing_backend_info = self.backend_state_manager.get_backend( @@ -301,10 +301,10 @@ def deploy(self, "does not match with the existing " f"version '{existing_backend_info.version}'.") - autoscaling_config = backend_config.autoscaling_config + autoscaling_config = deployment_config.autoscaling_config if autoscaling_config is not None: # TODO: is this the desired behaviour? Should this be a setting? - backend_config.num_replicas = autoscaling_config.min_replicas + deployment_config.num_replicas = autoscaling_config.min_replicas autoscaling_policy = BasicAutoscalingPolicy(autoscaling_config) else: @@ -312,10 +312,10 @@ def deploy(self, backend_info = BackendInfo( actor_def=ray.remote( - create_replica_wrapper(name, - replica_config.serialized_backend_def)), + create_replica_wrapper( + name, replica_config.serialized_deployment_def)), version=version, - backend_config=backend_config, + deployment_config=deployment_config, replica_config=replica_config, deployer_job_id=deployer_job_id, start_time_ms=int(time.time() * 1000), @@ -373,6 +373,6 @@ def list_deployments(self, include_deleted: Optional[bool] = False name: (self.backend_state_manager.get_backend( name, include_deleted=include_deleted), self.endpoint_state.get_endpoint_route(name)) - for name in self.backend_state_manager.get_backend_configs( + for name in self.backend_state_manager.get_deployment_configs( include_deleted=include_deleted) } diff --git a/python/ray/serve/replica.py b/python/ray/serve/replica.py index a85ed390b71c..45db316106a7 100644 --- a/python/ray/serve/replica.py +++ b/python/ray/serve/replica.py @@ -15,7 +15,7 @@ from ray.serve.autoscaling_metrics import start_metrics_pusher from ray.serve.common import BackendTag, ReplicaTag -from ray.serve.config import BackendConfig +from ray.serve.config import DeploymentConfig from ray.serve.http_util import ASGIHTTPSender from ray.serve.utils import parse_request_item, _get_logger from ray.serve.exceptions import RayServeException @@ -31,30 +31,30 @@ logger = _get_logger() -def create_replica_wrapper(name: str, serialized_backend_def: bytes): +def create_replica_wrapper(name: str, serialized_deployment_def: bytes): """Creates a replica class wrapping the provided function or class. This approach is picked over inheritance to avoid conflict between user provided class and the RayServeReplica class. """ - serialized_backend_def = serialized_backend_def + serialized_deployment_def = serialized_deployment_def # TODO(architkulkarni): Add type hints after upgrading cloudpickle class RayServeWrappedReplica(object): async def __init__(self, backend_tag, replica_tag, init_args, - init_kwargs, backend_config_proto_bytes: bytes, + init_kwargs, deployment_config_proto_bytes: bytes, version: DeploymentVersion, controller_name: str, detached: bool): - backend = cloudpickle.loads(serialized_backend_def) - backend_config = BackendConfig.from_proto_bytes( - backend_config_proto_bytes) + backend = cloudpickle.loads(serialized_deployment_def) + deployment_config = DeploymentConfig.from_proto_bytes( + deployment_config_proto_bytes) if inspect.isfunction(backend): is_function = True elif inspect.isclass(backend): is_function = False else: - assert False, ("backend_def must be function, class, or " + assert False, ("deployment_def must be function, class, or " "corresponding import path.") # Set the controller name so that serve.connect() in the user's @@ -85,10 +85,10 @@ async def __init__(self, backend_tag, replica_tag, init_args, detached) controller_handle = ray.get_actor( controller_name, namespace=controller_namespace) - self.backend = RayServeReplica(_callable, backend_tag, replica_tag, - backend_config, - backend_config.user_config, version, - is_function, controller_handle) + self.backend = RayServeReplica( + _callable, backend_tag, replica_tag, deployment_config, + deployment_config.user_config, version, is_function, + controller_handle) # asyncio.Event used to signal that the replica is shutting down. self.shutdown_event = asyncio.Event() @@ -109,14 +109,14 @@ async def handle_request( return await self.backend.handle_request(query) async def reconfigure(self, user_config: Optional[Any] = None - ) -> Tuple[BackendConfig, DeploymentVersion]: + ) -> Tuple[DeploymentConfig, DeploymentVersion]: if user_config is not None: await self.backend.reconfigure(user_config) return self.get_metadata() - def get_metadata(self) -> Tuple[BackendConfig, DeploymentVersion]: - return self.backend.backend_config, self.backend.version + def get_metadata(self) -> Tuple[DeploymentConfig, DeploymentVersion]: + return self.backend.deployment_config, self.backend.version async def prepare_for_shutdown(self): self.shutdown_event.set() @@ -146,10 +146,10 @@ class RayServeReplica: """Handles requests with the provided callable.""" def __init__(self, _callable: Callable, backend_tag: BackendTag, - replica_tag: ReplicaTag, backend_config: BackendConfig, + replica_tag: ReplicaTag, deployment_config: DeploymentConfig, user_config: Any, version: DeploymentVersion, is_function: bool, controller_handle: ActorHandle) -> None: - self.backend_config = backend_config + self.deployment_config = deployment_config self.backend_tag = backend_tag self.replica_tag = replica_tag self.callable = _callable @@ -211,10 +211,10 @@ def __init__(self, _callable: Callable, backend_tag: BackendTag, self.restart_counter.inc() self._shutdown_wait_loop_s = ( - backend_config.graceful_shutdown_wait_loop_s) + deployment_config.graceful_shutdown_wait_loop_s) - if backend_config.autoscaling_config: - config = backend_config.autoscaling_config + if deployment_config.autoscaling_config: + config = deployment_config.autoscaling_config start_metrics_pusher( interval_s=config.metrics_interval_s, collection_callback=self._collect_autoscaling_metrics, @@ -319,7 +319,8 @@ async def reconfigure(self, user_config: Any): self.version = DeploymentVersion( self.version.code_version, user_config=user_config) if self.is_function: - raise ValueError("backend_def must be a class to use user_config") + raise ValueError( + "deployment_def must be a class to use user_config") elif not hasattr(self.callable, BACKEND_RECONFIGURE_METHOD): raise RayServeException("user_config specified but backend " + self.backend_tag + " missing " + diff --git a/python/ray/serve/tests/test_backend_state.py b/python/ray/serve/tests/test_backend_state.py index 74a170a6f1c6..687382945249 100644 --- a/python/ray/serve/tests/test_backend_state.py +++ b/python/ray/serve/tests/test_backend_state.py @@ -8,7 +8,7 @@ from ray.actor import ActorHandle from ray.serve.common import ( - BackendConfig, + DeploymentConfig, BackendInfo, BackendTag, ReplicaConfig, @@ -130,7 +130,7 @@ def available_resources(self) -> Dict[str, float]: def graceful_stop(self) -> None: assert self.started self.stopped = True - return self.backend_info.backend_config.graceful_shutdown_timeout_s + return self.backend_info.deployment_config.graceful_shutdown_timeout_s def check_stopped(self) -> bool: return self.done_stopping @@ -154,7 +154,7 @@ def backend_info(version: Optional[str] = None, actor_def=None, version=version, start_time_ms=0, - backend_config=BackendConfig( + deployment_config=DeploymentConfig( num_replicas=num_replicas, user_config=user_config, **config_opts), replica_config=ReplicaConfig(lambda x: x)) @@ -163,7 +163,8 @@ def backend_info(version: Optional[str] = None, else: code_version = get_random_letters() - version = DeploymentVersion(code_version, info.backend_config.user_config) + version = DeploymentVersion(code_version, + info.deployment_config.user_config) return info, version diff --git a/python/ray/serve/tests/test_config.py b/python/ray/serve/tests/test_config.py index 505147055e34..f8e96487434d 100644 --- a/python/ray/serve/tests/test_config.py +++ b/python/ray/serve/tests/test_config.py @@ -1,29 +1,29 @@ import pytest from pydantic import ValidationError -from ray.serve.config import (BackendConfig, DeploymentMode, HTTPOptions, +from ray.serve.config import (DeploymentConfig, DeploymentMode, HTTPOptions, ReplicaConfig) from ray.serve.config import AutoscalingConfig -def test_backend_config_validation(): +def test_deployment_config_validation(): # Test unknown key. with pytest.raises(ValidationError): - BackendConfig(unknown_key=-1) + DeploymentConfig(unknown_key=-1) # Test num_replicas validation. - BackendConfig(num_replicas=1) + DeploymentConfig(num_replicas=1) with pytest.raises(ValidationError, match="type_error"): - BackendConfig(num_replicas="hello") + DeploymentConfig(num_replicas="hello") with pytest.raises(ValidationError, match="value_error"): - BackendConfig(num_replicas=-1) + DeploymentConfig(num_replicas=-1) # Test dynamic default for max_concurrent_queries. - assert BackendConfig().max_concurrent_queries == 100 + assert DeploymentConfig().max_concurrent_queries == 100 -def test_backend_config_update(): - b = BackendConfig(num_replicas=1, max_concurrent_queries=1) +def test_deployment_config_update(): + b = DeploymentConfig(num_replicas=1, max_concurrent_queries=1) # Test updating a key works. b.num_replicas = 2 @@ -108,18 +108,18 @@ def test_http_options(): def test_with_proto(): # Test roundtrip - config = BackendConfig(num_replicas=100, max_concurrent_queries=16) - assert config == BackendConfig.from_proto_bytes(config.to_proto_bytes()) + config = DeploymentConfig(num_replicas=100, max_concurrent_queries=16) + assert config == DeploymentConfig.from_proto_bytes(config.to_proto_bytes()) # Test user_config object - config = BackendConfig(user_config={"python": ("native", ["objects"])}) - assert config == BackendConfig.from_proto_bytes(config.to_proto_bytes()) + config = DeploymentConfig(user_config={"python": ("native", ["objects"])}) + assert config == DeploymentConfig.from_proto_bytes(config.to_proto_bytes()) def test_zero_default_proto(): # Test that options set to zero (protobuf default value) still retain their # original value after being serialized and deserialized. - config = BackendConfig( + config = DeploymentConfig( autoscaling_config={ "min_replicas": 1, "max_replicas": 2, @@ -127,7 +127,7 @@ def test_zero_default_proto(): "downscale_delay_s": 0 }) serialized_config = config.to_proto_bytes() - deserialized_config = BackendConfig.from_proto_bytes(serialized_config) + deserialized_config = DeploymentConfig.from_proto_bytes(serialized_config) new_delay_s = deserialized_config.autoscaling_config.downscale_delay_s assert new_delay_s == 0 diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index e118fe72b897..ad34bca748db 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -1,4 +1,5 @@ import os +import signal import sys import time @@ -7,8 +8,9 @@ import ray import ray._private.utils +import ray._private.gcs_utils as gcs_utils import ray.ray_constants as ray_constants -from ray.exceptions import RayTaskError +from ray.exceptions import RayTaskError, RayActorError, GetTimeoutError from ray._private.test_utils import (wait_for_condition, SignalActor, init_error_pubsub, get_error_message) @@ -587,6 +589,88 @@ def f(): assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR +@pytest.mark.parametrize( + "ray_start_cluster_head", [{ + "num_cpus": 0, + "_system_config": { + "raylet_death_check_interval_milliseconds": 10 * 1000, + "num_heartbeats_timeout": 10, + "raylet_heartbeat_period_milliseconds": 100, + "timeout_ms_task_wait_for_death_info": 100, + } + }], + indirect=True) +def test_actor_failover_with_bad_network(ray_start_cluster_head): + # The test case is to cover the scenario that when an actor FO happens, + # the caller receives the actor ALIVE notification and connects to the new + # actor instance while there are still some tasks sent to the previous + # actor instance haven't returned. + # + # It's not easy to reproduce this scenario, so we set + # `raylet_death_check_interval_milliseconds` to a large value and add a + # never-return function for the actor to keep the RPC connection alive + # while killing the node to trigger actor failover. Later we send SIGKILL + # to kill the previous actor process to let the task fail. + # + # The expected behavior is that after the actor is alive again and the + # previous RPC connection is broken, tasks sent via the previous RPC + # connection should fail but tasks sent via the new RPC connection should + # succeed. + + cluster = ray_start_cluster_head + node = cluster.add_node(num_cpus=1) + + @ray.remote(max_restarts=1) + class Actor: + def getpid(self): + return os.getpid() + + def never_return(self): + while True: + time.sleep(1) + return 0 + + # The actor should be placed on the non-head node. + actor = Actor.remote() + pid = ray.get(actor.getpid.remote()) + + # Submit a never-return task (task 1) to the actor. The return + # object should be unready. + obj1 = actor.never_return.remote() + with pytest.raises(GetTimeoutError): + ray.get(obj1, timeout=1) + + # Kill the non-head node and start a new one. Now GCS should trigger actor + # FO. Since we changed the interval of worker checking death of Raylet, + # the actor process won't quit in a short time. + cluster.remove_node(node, allow_graceful=False) + cluster.add_node(num_cpus=1) + + # The removed node will be marked as dead by GCS after 1 second and task 1 + # will return with failure after that. + with pytest.raises(RayActorError): + ray.get(obj1, timeout=2) + + # Wait for the actor to be alive again in a new worker process. + def check_actor_restart(): + actors = list(ray.state.actors().values()) + assert len(actors) == 1 + print(actors) + return (actors[0]["State"] == gcs_utils.ActorTableData.ALIVE + and actors[0]["NumRestarts"] == 1) + + wait_for_condition(check_actor_restart) + + # Kill the previous actor process. + os.kill(pid, signal.SIGKILL) + + # Submit another task (task 2) to the actor. + obj2 = actor.getpid.remote() + + # We should be able to get the return value of task 2 without any issue + ray.get(obj2) + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_placement_group_3.py b/python/ray/tests/test_placement_group_3.py index cee2b819c1a6..57009bbd90f4 100644 --- a/python/ray/tests/test_placement_group_3.py +++ b/python/ray/tests/test_placement_group_3.py @@ -646,5 +646,57 @@ def check_bundle_leaks(): wait_for_condition(check_bundle_leaks) +def test_placement_group_local_resource_view(monkeypatch, ray_start_cluster): + """Please refer to https://github.com/ray-project/ray/pull/19911 + for more details. + """ + with monkeypatch.context() as m: + # Increase broadcasting interval so that node resource will arrive + # at raylet after local resource all being allocated. + m.setenv("RAY_raylet_report_resources_period_milliseconds", "2000") + m.setenv("RAY_grpc_based_resource_broadcast", "true") + cluster = ray_start_cluster + + cluster.add_node(num_cpus=16, object_store_memory=1e9) + cluster.wait_for_nodes() + cluster.add_node(num_cpus=16, num_gpus=1) + cluster.wait_for_nodes() + NUM_CPU_BUNDLES = 30 + + @ray.remote(num_cpus=1) + class Worker(object): + def __init__(self, i): + self.i = i + + def work(self): + time.sleep(0.1) + print("work ", self.i) + + @ray.remote(num_cpus=1, num_gpus=1) + class Trainer(object): + def __init__(self, i): + self.i = i + + def train(self): + time.sleep(0.2) + print("train ", self.i) + + ray.init(address="auto") + bundles = [{"CPU": 1, "GPU": 1}] + bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)] + pg = placement_group(bundles, strategy="PACK") + ray.get(pg.ready()) + + # Local resource will be allocated and here we are to ensure + # local view is consistent and node resouce updates are discarded + workers = [ + Worker.options(placement_group=pg).remote(i) + for i in range(NUM_CPU_BUNDLES) + ] + trainer = Trainer.options(placement_group=pg).remote(0) + ray.get([workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)]) + ray.get(trainer.train.remote()) + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/train/examples/tensorflow_mnist_example.py b/python/ray/train/examples/tensorflow_mnist_example.py index 5f71842bf78b..0880f3347cde 100644 --- a/python/ray/train/examples/tensorflow_mnist_example.py +++ b/python/ray/train/examples/tensorflow_mnist_example.py @@ -72,7 +72,7 @@ def train_func(config): return results -def train_tensorflow_mnist(num_workers=2, use_gpu=False): +def train_tensorflow_mnist(num_workers=2, use_gpu=False, epochs=4): trainer = Trainer( backend="tensorflow", num_workers=num_workers, use_gpu=use_gpu) trainer.start() @@ -81,7 +81,7 @@ def train_tensorflow_mnist(num_workers=2, use_gpu=False): config={ "lr": 1e-3, "batch_size": 64, - "epochs": 4 + "epochs": epochs }) trainer.shutdown() print(f"Results: {results[0]}") @@ -105,6 +105,8 @@ def train_tensorflow_mnist(num_workers=2, use_gpu=False): action="store_true", default=False, help="Enables GPU training") + parser.add_argument( + "--epochs", type=int, default=3, help="Number of epochs to train for.") parser.add_argument( "--smoke-test", action="store_true", @@ -117,6 +119,10 @@ def train_tensorflow_mnist(num_workers=2, use_gpu=False): if args.smoke_test: ray.init(num_cpus=2) + train_tensorflow_mnist() else: ray.init(address=args.address) - train_tensorflow_mnist(num_workers=args.num_workers, use_gpu=args.use_gpu) + train_tensorflow_mnist( + num_workers=args.num_workers, + use_gpu=args.use_gpu, + epochs=args.epochs) diff --git a/python/ray/train/examples/train_linear_example.py b/python/ray/train/examples/train_linear_example.py index 50bfbd0fe2aa..2512f022a596 100644 --- a/python/ray/train/examples/train_linear_example.py +++ b/python/ray/train/examples/train_linear_example.py @@ -25,8 +25,9 @@ def __len__(self): return len(self.x) -def train_epoch(dataloader, model, loss_fn, optimizer): +def train_epoch(dataloader, model, loss_fn, optimizer, device): for X, y in dataloader: + X, y = X.to(device), y.to(device) # Compute prediction error pred = model(X) loss = loss_fn(pred, y) @@ -37,12 +38,13 @@ def train_epoch(dataloader, model, loss_fn, optimizer): optimizer.step() -def validate_epoch(dataloader, model, loss_fn): +def validate_epoch(dataloader, model, loss_fn, device): num_batches = len(dataloader) model.eval() loss = 0 with torch.no_grad(): for X, y in dataloader: + X, y = X.to(device), y.to(device) pred = model(X) loss += loss_fn(pred, y).item() loss /= num_batches @@ -58,6 +60,9 @@ def train_func(config): lr = config.get("lr", 1e-2) epochs = config.get("epochs", 3) + device = torch.device(f"cuda:{train.local_rank()}" + if torch.cuda.is_available() else "cpu") + train_dataset = LinearDataset(2, 5, size=data_size) val_dataset = LinearDataset(2, 5, size=val_size) train_loader = torch.utils.data.DataLoader( @@ -70,7 +75,10 @@ def train_func(config): sampler=DistributedSampler(val_dataset)) model = nn.Linear(1, hidden_size) - model = DistributedDataParallel(model) + model.to(device) + model = DistributedDataParallel( + model, + device_ids=[device.index] if torch.cuda.is_available() else None) loss_fn = nn.MSELoss() @@ -79,17 +87,20 @@ def train_func(config): results = [] for _ in range(epochs): - train_epoch(train_loader, model, loss_fn, optimizer) - result = validate_epoch(validation_loader, model, loss_fn) + train_epoch(train_loader, model, loss_fn, optimizer, device) + result = validate_epoch(validation_loader, model, loss_fn, device) train.report(**result) results.append(result) return results -def train_linear(num_workers=2): - trainer = Trainer(TorchConfig(backend="gloo"), num_workers=num_workers) - config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": 3} +def train_linear(num_workers=2, use_gpu=False, epochs=3): + trainer = Trainer( + backend=TorchConfig(backend="gloo"), + num_workers=num_workers, + use_gpu=use_gpu) + config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": epochs} trainer.start() results = trainer.run( train_func, @@ -115,6 +126,12 @@ def train_linear(num_workers=2): type=int, default=2, help="Sets number of workers for training.") + parser.add_argument( + "--use-gpu", + action="store_true", + help="Whether to use GPU for training.") + parser.add_argument( + "--epochs", type=int, default=3, help="Number of epochs to train for.") parser.add_argument( "--smoke-test", action="store_true", @@ -127,7 +144,10 @@ def train_linear(num_workers=2): if args.smoke_test: ray.init(num_cpus=2) + train_linear() else: ray.init(address=args.address) - - train_linear(num_workers=args.num_workers) + train_linear( + num_workers=args.num_workers, + use_gpu=args.use_gpu, + epochs=args.epochs) diff --git a/python/ray/tune/examples/horovod_cifar_pbt_example.py b/python/ray/tune/examples/horovod_cifar_pbt_example.py index 4bd6bd44dd8d..d66c985f4f1e 120000 --- a/python/ray/tune/examples/horovod_cifar_pbt_example.py +++ b/python/ray/tune/examples/horovod_cifar_pbt_example.py @@ -1 +1 @@ -../../../../release/horovod_tests/workloads/horovod_test.py \ No newline at end of file +../../../../release/horovod_tests/workloads/horovod_tune_test.py \ No newline at end of file diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 01618be75caa..964f45fc3e8c 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -85,7 +85,7 @@ class Experiment: max_failures=2) """ - # keys that will be present in `public_spec` dict + # Keys that will be present in `public_spec` dict. PUBLIC_KEYS = {"stop", "num_samples"} def __init__(self, diff --git a/python/ray/util/collective/collective.py b/python/ray/util/collective/collective.py index 4cebe9eefcfd..16d1fea7b249 100644 --- a/python/ray/util/collective/collective.py +++ b/python/ray/util/collective/collective.py @@ -544,7 +544,8 @@ def send(tensor, dst_rank: int, group_name: str = "default"): def send_multigpu(tensor, dst_rank: int, dst_gpu_index: int, - group_name: str = "default"): + group_name: str = "default", + n_elements: int = 0): """Send a tensor to a remote GPU synchronously. The function asssume each process owns >1 GPUs, and the sender @@ -555,6 +556,8 @@ def send_multigpu(tensor, dst_rank (int): the rank of the destination process. dst_gpu_index (int): the destination gpu index. group_name (str): the name of the collective group. + n_elements (int): if specified, send the next n elements + from the starting address of tensor. Returns: None @@ -567,9 +570,13 @@ def send_multigpu(tensor, if dst_rank == g.rank: raise RuntimeError("The dst_rank '{}' is self. Considering " "doing GPU to GPU memcpy instead?".format(dst_rank)) + if n_elements < 0: + raise RuntimeError( + "The n_elements '{}' should >= 0.".format(n_elements)) opts = types.SendOptions() opts.dst_rank = dst_rank opts.dst_gpu_index = dst_gpu_index + opts.n_elements = n_elements g.send([tensor], opts) @@ -598,7 +605,8 @@ def recv(tensor, src_rank: int, group_name: str = "default"): def recv_multigpu(tensor, src_rank: int, src_gpu_index: int, - group_name: str = "default"): + group_name: str = "default", + n_elements: int = 0): """Receive a tensor from a remote GPU synchronously. The function asssume each process owns >1 GPUs, and the sender @@ -621,9 +629,13 @@ def recv_multigpu(tensor, if src_rank == g.rank: raise RuntimeError("The dst_rank '{}' is self. Considering " "doing GPU to GPU memcpy instead?".format(src_rank)) + if n_elements < 0: + raise RuntimeError( + "The n_elements '{}' should be >= 0.".format(n_elements)) opts = types.RecvOptions() opts.src_rank = src_rank opts.src_gpu_index = src_gpu_index + opts.n_elements = n_elements g.recv([tensor], opts) diff --git a/python/ray/util/collective/collective_group/nccl_collective_group.py b/python/ray/util/collective/collective_group/nccl_collective_group.py index a73dc9526e7a..6825ed0813da 100644 --- a/python/ray/util/collective/collective_group/nccl_collective_group.py +++ b/python/ray/util/collective/collective_group/nccl_collective_group.py @@ -348,7 +348,8 @@ def send(self, tensors, send_options=SendOptions()): def p2p_fn(tensor, comm, stream, peer): comm.send( - nccl_util.get_tensor_ptr(tensor), + nccl_util.get_tensor_ptr(tensor), send_options.n_elements + if send_options.n_elements > 0 else nccl_util.get_tensor_n_elements(tensor), nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr) @@ -368,7 +369,8 @@ def recv(self, tensors, recv_options=RecvOptions()): def p2p_fn(tensor, comm, stream, peer): comm.recv( - nccl_util.get_tensor_ptr(tensor), + nccl_util.get_tensor_ptr(tensor), recv_options.n_elements + if recv_options.n_elements > 0 else nccl_util.get_tensor_n_elements(tensor), nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr) diff --git a/python/ray/util/collective/types.py b/python/ray/util/collective/types.py index 6aff71b3a53d..b949177657c7 100644 --- a/python/ray/util/collective/types.py +++ b/python/ray/util/collective/types.py @@ -101,6 +101,7 @@ class ReduceScatterOptions: class SendOptions: dst_rank = 0 dst_gpu_index = 0 + n_elements = 0 timeout_ms = unset_timeout_ms @@ -108,4 +109,5 @@ class SendOptions: class RecvOptions: src_rank = 0 src_gpu_index = 0 + n_elements = 0 unset_timeout_ms = unset_timeout_ms diff --git a/python/ray/util/horovod/horovod_example.py b/python/ray/util/horovod/horovod_example.py index 59aa0850245e..1e285d0128af 100644 --- a/python/ray/util/horovod/horovod_example.py +++ b/python/ray/util/horovod/horovod_example.py @@ -115,11 +115,20 @@ def train_fn(data_dir=None, 100. * batch_idx / len(train_loader), loss.item())) -def main(num_workers, use_gpu, **kwargs): - settings = RayExecutor.create_settings(timeout_s=30) +def main(num_workers, + use_gpu, + timeout_s=30, + placement_group_timeout_s=100, + kwargs=None): + kwargs = kwargs or {} + if use_gpu: + kwargs["use_cuda"] = True + settings = RayExecutor.create_settings( + timeout_s=timeout_s, + placement_group_timeout_s=placement_group_timeout_s) executor = RayExecutor(settings, use_gpu=use_gpu, num_workers=num_workers) executor.start() - executor.run(train_fn, **kwargs) + executor.run(train_fn, kwargs=kwargs) if __name__ == "__main__": diff --git a/python/ray/util/ray_lightning/BUILD b/python/ray/util/ray_lightning/BUILD index 75b45d271927..f4aaac2a00c0 100644 --- a/python/ray/util/ray_lightning/BUILD +++ b/python/ray/util/ray_lightning/BUILD @@ -1,5 +1,5 @@ # -------------------------------------------------------------------- -# Tests from the python/ray/util/xgboost directory. +# Tests from the python/ray/util/ray_lightning directory. # Please keep these sorted alphabetically. # -------------------------------------------------------------------- py_test( diff --git a/python/ray/util/ray_lightning/simple_example.py b/python/ray/util/ray_lightning/simple_example.py index 9b8b728364c8..9064fb0d208a 100644 --- a/python/ray/util/ray_lightning/simple_example.py +++ b/python/ray/util/ray_lightning/simple_example.py @@ -1,3 +1,4 @@ +import argparse import os import torch from torch import nn @@ -38,15 +39,41 @@ def configure_optimizers(self): return optimizer -def main(): +def main(num_workers: int = 2, use_gpu: bool = False, max_steps: int = 10): dataset = MNIST( os.getcwd(), download=True, transform=transforms.ToTensor()) train, val = random_split(dataset, [55000, 5000]) autoencoder = LitAutoEncoder() - trainer = pl.Trainer(plugins=[RayPlugin(num_workers=2)], max_steps=10) + trainer = pl.Trainer( + plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)], + max_steps=max_steps) trainer.fit(autoencoder, DataLoader(train), DataLoader(val)) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser( + description="Ray Lightning Example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--num-workers", + type=int, + default=2, + help="Number of workers to use for training.") + parser.add_argument( + "--max-steps", + type=int, + default=10, + help="Maximum number of steps to run for training.") + parser.add_argument( + "--use-gpu", + action="store_true", + default=False, + help="Whether to enable GPU training.") + + args = parser.parse_args() + + main( + num_workers=args.num_workers, + max_steps=args.max_steps, + use_gpu=args.use_gpu) diff --git a/python/ray/workflow/BUILD b/python/ray/workflow/BUILD index 3b04110be7dc..b995b61d9ef9 100644 --- a/python/ray/workflow/BUILD +++ b/python/ray/workflow/BUILD @@ -13,10 +13,20 @@ SRCS = [] + select({ "//conditions:default": [], }) +LARGE_TESTS = ["tests/test_recovery.py"] + py_test_module_list( - files = glob(["tests/test_*.py", "examples/**/*.py"]), + files = glob(["tests/test_*.py", "examples/**/*.py"], exclude=LARGE_TESTS), size = "medium", extra_srcs = SRCS, tags = ["team:core", "exclusive"], deps = ["//:ray_lib"], ) + +py_test_module_list( + files = LARGE_TESTS, + size = "large", + extra_srcs = SRCS, + tags = ["team:core", "exclusive"], + deps = ["//:ray_lib"], +) diff --git a/python/requirements_ml_docker.txt b/python/requirements_ml_docker.txt index f30397ede599..ede4f9b30ef0 100644 --- a/python/requirements_ml_docker.txt +++ b/python/requirements_ml_docker.txt @@ -1,4 +1,9 @@ ipython + +# Needed for Ray Client error message serialization/deserialization. +tblib + + # In TF >v2, GPU support is included in the base package. tensorflow==2.5.0 tensorflow-probability==0.13.0 diff --git a/release/.buildkite/build_pipeline.py b/release/.buildkite/build_pipeline.py index 5faa766d8190..d3faa2ccc872 100644 --- a/release/.buildkite/build_pipeline.py +++ b/release/.buildkite/build_pipeline.py @@ -20,7 +20,12 @@ class ReleaseTest: - def __init__(self, name: str, smoke_test: bool = False, retry: int = 0): + def __init__( + self, + name: str, + smoke_test: bool = False, + retry: int = 0, + ): self.name = name self.smoke_test = smoke_test self.retry = retry @@ -243,6 +248,19 @@ def __init__(self, ], } +HOROVOD_INSTALL_ENV_VARS = [ + "HOROVOD_WITH_GLOO", "HOROVOD_WITHOUT_MPI", "HOROVOD_WITHOUT_TENSORFLOW", + "HOROVOD_WITHOUT_MXNET", "HOROVOD_WITH_PYTORCH" +] + +HOROVOD_SETUP_COMMANDS = [ + "sudo apt update", "sudo apt -y install build-essential", + "pip install cmake" +] + [ + f"export {horovod_env_var}=1" + for horovod_env_var in HOROVOD_INSTALL_ENV_VARS +] + # This test suite holds "user" tests to test important user workflows # in a particular environment. # All workloads in this test suite should: @@ -250,7 +268,39 @@ def __init__(self, # 2. Use autoscaling/scale up (no wait_cluster.py) # 3. Use GPUs if applicable # 4. Have the `use_connect` flag set. -USER_TESTS = {} +USER_TESTS = { + "~/ray/release/ray_lightning_tests/ray_lightning_tests.yaml": [ + ConnectTest( + "ray_lightning_user_test_latest", + requirements_file="release/ray_lightning_tests" + "/driver_requirements.txt"), + ConnectTest( + "ray_lightning_user_test_master", + requirements_file="release/ray_lightning_tests" + "/driver_requirements.txt") + ], + "~/ray/release/horovod_tests/horovod_tests.yaml": [ + ConnectTest( + "horovod_user_test_latest", + setup_commands=HOROVOD_SETUP_COMMANDS, + requirements_file="release/horovod_tests/driver_requirements.txt"), + ConnectTest( + "horovod_user_test_master", + setup_commands=HOROVOD_SETUP_COMMANDS, + requirements_file="release/horovod_tests" + "/driver_requirements_master.txt") + ], + "~/ray/release/train_tests/train_tests.yaml": [ + ConnectTest( + "train_tensorflow_mnist_test", + requirements_file="release/train_tests" + "/driver_requirements.txt"), + ConnectTest( + "train_torch_linear_test", + requirements_file="release/train_tests" + "/driver_requirements.txt") + ], +} SUITES = { "core-nightly": CORE_NIGHTLY_TESTS, @@ -473,22 +523,21 @@ def create_test_step( }] } - step_conf["commands"] = [ - "pip install -q -r release/requirements.txt", - "pip install -U boto3 botocore", - f"git clone -b {ray_test_branch} {ray_test_repo} ~/ray", cmd, - "sudo cp -rf /tmp/artifacts/* /tmp/ray_release_test_artifacts " - "|| true" - ] - if isinstance(test_name, ConnectTest): # Add driver side setup commands to the step. pip_requirements_command = [f"pip install -U -r " f"{test_name.requirements_file}"] if \ test_name.requirements_file else [] step_conf["commands"] = test_name.setup_commands \ - + pip_requirements_command \ - + step_conf["commands"] + + pip_requirements_command + + step_conf["commands"] += [ + "pip install -q -r release/requirements.txt", + "pip install -U boto3 botocore", + f"git clone -b {ray_test_branch} {ray_test_repo} ~/ray", cmd, + "sudo cp -rf /tmp/artifacts/* /tmp/ray_release_test_artifacts " + "|| true" + ] step_conf["label"] = ( f"{test_name} " diff --git a/release/horovod_tests/app_config.yaml b/release/horovod_tests/app_config.yaml index 15dd1051603e..6678b4beb922 100644 --- a/release/horovod_tests/app_config.yaml +++ b/release/horovod_tests/app_config.yaml @@ -14,7 +14,7 @@ post_build_cmds: - sudo rm -rf /home/ray/anaconda3/lib/python3.7/site-packages/numpy - pip3 install numpy || true - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} - - pip3 install 'ray[rllib]' + - pip3 install 'ray[tune]' - pip3 install torch torchvision - - HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip3 install -U git+https://github.com/horovod/horovod.git + - HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip3 install -U horovod - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/horovod_tests/app_config_master.yaml b/release/horovod_tests/app_config_master.yaml new file mode 100644 index 000000000000..c53c0e981fa9 --- /dev/null +++ b/release/horovod_tests/app_config_master.yaml @@ -0,0 +1,20 @@ +base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: {} +debian_packages: + - curl + +python: + pip_packages: + - pytest + - awscli + conda_packages: [] + +post_build_cmds: + - pip uninstall -y numpy ray || true + - sudo rm -rf /home/ray/anaconda3/lib/python3.7/site-packages/numpy + - pip3 install numpy || true + - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - pip3 install 'ray[tune]' + - pip3 install torch torchvision + - HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip3 install -U git+https://github.com/horovod/horovod.git + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/horovod_tests/base_driver_reqs.txt b/release/horovod_tests/base_driver_reqs.txt new file mode 100644 index 000000000000..057dffa317b2 --- /dev/null +++ b/release/horovod_tests/base_driver_reqs.txt @@ -0,0 +1,8 @@ +# Make sure the driver versions are the same as cluster versions. +# The cluster uses ray-ml Docker image. +# ray-ml Docker image installs dependencies from ray/python/requirements/ml/ directory. +# We constrain on these requirements file so that the same versions are installed. +-c ../../python/requirements/ml/requirements_dl.txt + +torch +torchvision \ No newline at end of file diff --git a/release/horovod_tests/compute_tpl.yaml b/release/horovod_tests/compute_tpl.yaml index 1d6d1686a0a7..3a5b4428d90b 100644 --- a/release/horovod_tests/compute_tpl.yaml +++ b/release/horovod_tests/compute_tpl.yaml @@ -10,8 +10,8 @@ head_node_type: worker_node_types: - name: worker_node instance_type: g3.8xlarge - min_workers: 3 max_workers: 3 + min_workers: 3 use_spot: false aws: diff --git a/release/horovod_tests/compute_tpl_autoscaling.yaml b/release/horovod_tests/compute_tpl_autoscaling.yaml new file mode 100644 index 000000000000..7f156c8756ab --- /dev/null +++ b/release/horovod_tests/compute_tpl_autoscaling.yaml @@ -0,0 +1,24 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +max_workers: 3 + +head_node_type: + name: head_node + instance_type: g3.8xlarge + +worker_node_types: + - name: worker_node + instance_type: g3.8xlarge + max_workers: 3 + min_workers: 0 + use_spot: false + +aws: + TagSpecifications: + - ResourceType: "instance" + Tags: + - Key: anyscale-user + Value: '{{env["ANYSCALE_USER"]}}' + - Key: anyscale-expiration + Value: '{{env["EXPIRATION_1D"]}}' diff --git a/release/horovod_tests/driver_requirements.txt b/release/horovod_tests/driver_requirements.txt new file mode 100644 index 000000000000..8ce2c8a59a29 --- /dev/null +++ b/release/horovod_tests/driver_requirements.txt @@ -0,0 +1,3 @@ +-r ./base_driver_reqs.txt + +horovod diff --git a/release/horovod_tests/driver_requirements_master.txt b/release/horovod_tests/driver_requirements_master.txt new file mode 100644 index 000000000000..5fc9bfa194a5 --- /dev/null +++ b/release/horovod_tests/driver_requirements_master.txt @@ -0,0 +1,4 @@ +-r ./base_driver_reqs.txt + +# Horovod master. +git+https://github.com/horovod/horovod.git \ No newline at end of file diff --git a/release/horovod_tests/horovod_tests.yaml b/release/horovod_tests/horovod_tests.yaml index 0ddcd5b7bf12..9d3815d8315a 100644 --- a/release/horovod_tests/horovod_tests.yaml +++ b/release/horovod_tests/horovod_tests.yaml @@ -1,6 +1,6 @@ - name: horovod_test cluster: - app_config: app_config.yaml + app_config: app_config_master.yaml compute_template: compute_tpl.yaml run: @@ -12,3 +12,25 @@ smoke_test: run: timeout: 1800 + +- name: horovod_user_test_latest + cluster: + app_config: app_config.yaml + compute_template: compute_tpl_autoscaling.yaml + + run: + use_connect: True + autosuspend_mins: 10 + timeout: 1200 + script: python workloads/horovod_user_test.py + +- name: horovod_user_test_master + cluster: + app_config: app_config_master.yaml + compute_template: compute_tpl_autoscaling.yaml + + run: + use_connect: True + autosuspend_mins: 10 + timeout: 1200 + script: python workloads/horovod_user_test.py diff --git a/release/horovod_tests/workloads/horovod_test.py b/release/horovod_tests/workloads/horovod_tune_test.py similarity index 100% rename from release/horovod_tests/workloads/horovod_test.py rename to release/horovod_tests/workloads/horovod_tune_test.py diff --git a/release/horovod_tests/workloads/horovod_user_test.py b/release/horovod_tests/workloads/horovod_user_test.py new file mode 100644 index 000000000000..f1b53e350df2 --- /dev/null +++ b/release/horovod_tests/workloads/horovod_user_test.py @@ -0,0 +1,33 @@ +import json +import os +import time + +import ray +from ray.util.horovod.horovod_example import main + +if __name__ == "__main__": + start = time.time() + + addr = os.environ.get("RAY_ADDRESS") + job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test") + if addr is not None and addr.startswith("anyscale://"): + ray.init(address=addr, job_name=job_name) + else: + ray.init(address="auto") + + main( + num_workers=6, + use_gpu=True, + placement_group_timeout_s=900, + kwargs={"num_epochs": 20}) + + taken = time.time() - start + result = { + "time_taken": taken, + } + test_output_json = os.environ.get("TEST_OUTPUT_JSON", + "/tmp/horovod_user_test.json") + with open(test_output_json, "wt") as f: + json.dump(result, f) + + print("Test Successful!") diff --git a/release/ray_lightning_tests/app_config.yaml b/release/ray_lightning_tests/app_config.yaml new file mode 100644 index 000000000000..e3935fe41be0 --- /dev/null +++ b/release/ray_lightning_tests/app_config.yaml @@ -0,0 +1,20 @@ +base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: + PL_TORCH_DISTRIBUTED_BACKEND: gloo + +debian_packages: + - curl + +python: + pip_packages: + # TODO(amogkam): Remove the tblib, torch, and torchvision installs once we use nightly image. + - tblib + - torch==1.9.0 + - torchvision==0.10.0 + - ray-lightning + conda_packages: [] + +post_build_cmds: + - pip uninstall -y ray || true + - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/ray_lightning_tests/app_config_master.yaml b/release/ray_lightning_tests/app_config_master.yaml new file mode 100644 index 000000000000..d99cb56e0fa6 --- /dev/null +++ b/release/ray_lightning_tests/app_config_master.yaml @@ -0,0 +1,20 @@ +base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: + PL_TORCH_DISTRIBUTED_BACKEND: gloo + +debian_packages: + - curl + +python: + pip_packages: + # TODO(amogkam): Remove the tblib, torch, and torchvision installs once we use nightly image. + - tblib + - torch==1.9.0 + - torchvision==0.10.0 + - git+https://github.com/ray-project/ray_lightning#ray_lightning + conda_packages: [] + +post_build_cmds: + - pip uninstall -y ray || true + - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/ray_lightning_tests/compute_tpl.yaml b/release/ray_lightning_tests/compute_tpl.yaml new file mode 100644 index 000000000000..7809c13e7761 --- /dev/null +++ b/release/ray_lightning_tests/compute_tpl.yaml @@ -0,0 +1,24 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +max_workers: 3 + +head_node_type: + name: head_node + instance_type: g3.8xlarge + +worker_node_types: + - name: worker_node + instance_type: g3.8xlarge + min_workers: 0 + max_workers: 2 + use_spot: false + +aws: + TagSpecifications: + - ResourceType: "instance" + Tags: + - Key: anyscale-user + Value: '{{env["ANYSCALE_USER"]}}' + - Key: anyscale-expiration + Value: '{{env["EXPIRATION_1D"]}}' diff --git a/release/ray_lightning_tests/driver_requirements.txt b/release/ray_lightning_tests/driver_requirements.txt new file mode 100644 index 000000000000..ba19e088e192 --- /dev/null +++ b/release/ray_lightning_tests/driver_requirements.txt @@ -0,0 +1,9 @@ +# Make sure the driver versions are the same as cluster versions. +# The cluster uses ray-ml Docker image. +# ray-ml Docker image installs dependencies from ray/python/requirements/ml/ directory. +# We constrain on these requirements file so that the same versions are installed. +-c ../../python/requirements/ml/requirements_dl.txt + +torch +torchvision +pytorch-lightning \ No newline at end of file diff --git a/release/ray_lightning_tests/ray_lightning_tests.yaml b/release/ray_lightning_tests/ray_lightning_tests.yaml new file mode 100644 index 000000000000..5eab5a9605d7 --- /dev/null +++ b/release/ray_lightning_tests/ray_lightning_tests.yaml @@ -0,0 +1,22 @@ +- name: ray_lightning_user_test_latest + cluster: + app_config: app_config.yaml + compute_template: compute_tpl.yaml + + run: + use_connect: True + autosuspend_mins: 10 + timeout: 1200 + script: python workloads/ray_lightning_user_test.py + + +- name: ray_lightning_user_test_master + cluster: + app_config: app_config_master.yaml + compute_template: compute_tpl.yaml + + run: + use_connect: True + autosuspend_mins: 10 + timeout: 1200 + script: python workloads/ray_lightning_user_test.py \ No newline at end of file diff --git a/release/ray_lightning_tests/workloads/ray_lightning_user_test.py b/release/ray_lightning_tests/workloads/ray_lightning_user_test.py new file mode 100644 index 000000000000..211e4cd96209 --- /dev/null +++ b/release/ray_lightning_tests/workloads/ray_lightning_user_test.py @@ -0,0 +1,29 @@ +import json +import os +import time + +import ray +from ray.util.ray_lightning.simple_example import main + +if __name__ == "__main__": + start = time.time() + + addr = os.environ.get("RAY_ADDRESS") + job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test") + if addr is not None and addr.startswith("anyscale://"): + ray.init(address=addr, job_name=job_name) + else: + ray.init(address="auto") + + main(num_workers=6, use_gpu=True, max_steps=50) + + taken = time.time() - start + result = { + "time_taken": taken, + } + test_output_json = os.environ.get("TEST_OUTPUT_JSON", + "/tmp/ray_lightning_user_test.json") + with open(test_output_json, "wt") as f: + json.dump(result, f) + + print("Test Successful!") diff --git a/release/release_logs/1.8.0/benchmarks/many_actors.json b/release/release_logs/1.8.0/benchmarks/many_actors.json new file mode 100644 index 000000000000..1b9e2b210fae --- /dev/null +++ b/release/release_logs/1.8.0/benchmarks/many_actors.json @@ -0,0 +1,10 @@ +{ + "actors_per_second":502.27667887403527, + "num_actors":10000, + "time":19.909345626831055, + "success":"1", + "_runtime":34.931002140045166, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_NtMW8qrs2wfGbb1DSfhkXpa7", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/benchmarks/many_nodes.json b/release/release_logs/1.8.0/benchmarks/many_nodes.json new file mode 100644 index 000000000000..4272a3c2aadb --- /dev/null +++ b/release/release_logs/1.8.0/benchmarks/many_nodes.json @@ -0,0 +1,10 @@ +{ + "tasks_per_second":3.1463954810569446, + "num_tasks":1000, + "time":617.8240008354187, + "success":"1", + "_runtime":627.3966097831726, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_M5RJTBWv4HPcVW4LJBpQ3QbU", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/benchmarks/many_pgs.json b/release/release_logs/1.8.0/benchmarks/many_pgs.json new file mode 100644 index 000000000000..4c834fb150d7 --- /dev/null +++ b/release/release_logs/1.8.0/benchmarks/many_pgs.json @@ -0,0 +1,10 @@ +{ + "pgs_per_second":18.38968677640112, + "num_pgs":1000, + "time":54.378305196762085, + "success":"1", + "_runtime":70.17751049995422, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_NTMUAfRBeFAnGucHNvxk9SCe", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/benchmarks/many_tasks.json b/release/release_logs/1.8.0/benchmarks/many_tasks.json new file mode 100644 index 000000000000..70c9e1af8a41 --- /dev/null +++ b/release/release_logs/1.8.0/benchmarks/many_tasks.json @@ -0,0 +1,10 @@ +{ + "tasks_per_second":27.380515768078205, + "num_tasks":10000, + "time":665.2232151031494, + "success":"1", + "_runtime":676.2212898731232, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_DRNyaBrjPz92eS7XGH6s86J2", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/microbenchmark.json b/release/release_logs/1.8.0/microbenchmark.json new file mode 100644 index 000000000000..1d8c2cc53d17 --- /dev/null +++ b/release/release_logs/1.8.0/microbenchmark.json @@ -0,0 +1,134 @@ +{ + "single_client_get_calls":[ + 30940.134096399626, + 317.0832533971391 + ], + "single_client_put_calls":[ + 48943.191804559654, + 219.49873322500986 + ], + "multi_client_put_calls":[ + 196309.52527610504, + 3737.9627941064145 + ], + "single_client_get_calls_Plasma_Store":[ + 6846.037019609748, + 22.8467171817285 + ], + "single_client_put_calls_Plasma_Store":[ + 6460.924069006876, + 80.8868946635643 + ], + "multi_client_put_calls_Plasma_Store":[ + 10041.514280594329, + 156.65023333195015 + ], + "single_client_put_gigabytes":[ + 19.409071238092803, + 5.346465780702707 + ], + "single_client_tasks_and_get_batch":[ + 14.009166496750346, + 0.21491106195168685 + ], + "multi_client_put_gigabytes":[ + 35.23946199953088, + 0.7596648393204519 + ], + "single_client_get_object_containing_10k_refs":[ + 13.444237954966656, + 0.23494256884483813 + ], + "single_client_tasks_sync":[ + 1635.8932418252052, + 27.31879448903416 + ], + "single_client_tasks_async":[ + 13818.97086807901, + 296.5632852961327 + ], + "multi_client_tasks_async":[ + 39833.383674172335, + 2733.4624633691615 + ], + "1_1_actor_calls_sync":[ + 2642.40420304774, + 41.91740557372103 + ], + "1_1_actor_calls_async":[ + 6926.955725014122, + 41.74140409894753 + ], + "1_1_actor_calls_concurrent":[ + 5676.325754099567, + 151.98507185863667 + ], + "1_n_actor_calls_async":[ + 14243.648501790683, + 308.2488257960866 + ], + "n_n_actor_calls_async":[ + 44238.1477273468, + 3957.0334984514498 + ], + "n_n_actor_calls_with_arg_async":[ + 3338.0491560395135, + 20.518610734699735 + ], + "1_1_async_actor_calls_sync":[ + 1960.6729229123905, + 19.399824989325484 + ], + "1_1_async_actor_calls_async":[ + 4015.4271724213463, + 325.935225727618 + ], + "1_1_async_actor_calls_with_args_async":[ + 2844.8051074858054, + 170.12415992293433 + ], + "1_n_async_actor_calls_async":[ + 15039.391685457813, + 680.5919632322198 + ], + "n_n_async_actor_calls_async":[ + 39514.27962589182, + 1501.8795564889163 + ], + "client__get_calls":[ + 1975.254953927033, + 8.735788531735636 + ], + "client__put_calls":[ + 1098.0277773219968, + 7.155561544388648 + ], + "client__put_gigabytes":[ + 0.13727111000071365, + 0.0026377327384535312 + ], + "client__tasks_and_put_batch":[ + 69321.64862718538, + 1767.6000938242169 + ], + "client__1_1_actor_calls_sync":[ + 564.4912609733829, + 30.465268060749132 + ], + "client__1_1_actor_calls_async":[ + 943.8060173510352, + 17.7532101628873 + ], + "client__1_1_actor_calls_concurrent":[ + 949.6015148062681, + 12.69568974417014 + ], + "client__tasks_and_get_batch":[ + 0.8703009855123071, + 0.013280271782555824 + ], + "_runtime":533.6895747184753, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_DFQpftV7rrj7THGByj5N9ELj", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/scalability/object_store.json b/release/release_logs/1.8.0/scalability/object_store.json new file mode 100644 index 000000000000..4f4847311731 --- /dev/null +++ b/release/release_logs/1.8.0/scalability/object_store.json @@ -0,0 +1,10 @@ +{ + "broadcast_time":1478.5116119949998, + "object_size":1073741824, + "num_nodes":50, + "success":"1", + "_runtime":1500.23917222023, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_4HtGPbfxkKyjpLYxs4Xc7Quf", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/scalability/single_node.json b/release/release_logs/1.8.0/scalability/single_node.json new file mode 100644 index 000000000000..5f3368320acf --- /dev/null +++ b/release/release_logs/1.8.0/scalability/single_node.json @@ -0,0 +1,17 @@ +{ + "args_time":17.633971685999995, + "num_args":10000, + "returns_time":5.793101844999967, + "num_returns":3000, + "get_time":28.41037361399998, + "num_get_args":10000, + "queued_time":155.58971768599997, + "num_queued":1000000, + "large_object_time":289.177471706, + "large_object_size":107374182400, + "success":"1", + "_runtime":547.9792876243591, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_BiK1hJqWQaKY1kPtzWZtrk7z", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/stress_tests/stress_test_dead_actors.json b/release/release_logs/1.8.0/stress_tests/stress_test_dead_actors.json new file mode 100644 index 000000000000..7b0d4fe4ecc8 --- /dev/null +++ b/release/release_logs/1.8.0/stress_tests/stress_test_dead_actors.json @@ -0,0 +1,11 @@ +{ + "success":1, + "total_time":131.1956226825714, + "avg_iteration_time":1.3119536685943602, + "max_iteration_time":3.662248373031616, + "min_iteration_time":0.08778810501098633, + "_runtime":4927.166862249374, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_bEKXDxxZwwg3p4cYEiWA82gR", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/stress_tests/stress_test_many_tasks.json b/release/release_logs/1.8.0/stress_tests/stress_test_many_tasks.json new file mode 100644 index 000000000000..3a6c9c74e9b1 --- /dev/null +++ b/release/release_logs/1.8.0/stress_tests/stress_test_many_tasks.json @@ -0,0 +1,19 @@ +{ + "success":1, + "stage_0_time":5.756206750869751, + "stage_1_time":190.08577489852905, + "stage_1_avg_iteration_time":19.008567762374877, + "stage_1_max_iteration_time":19.4687020778656, + "stage_1_min_iteration_time":18.608890295028687, + "stage_2_time":246.9729506969452, + "stage_2_avg_iteration_time":49.39428896903992, + "stage_2_max_iteration_time":50.23058724403381, + "stage_2_min_iteration_time":47.09399747848511, + "stage_3_creation_time":0.05593752861022949, + "stage_3_time":1843.905479669571, + "stage_4_spread":3.2320969134286446, + "_runtime":4446.973560810089, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_wv1Ch4n2WCKNDLzUJwySyYJ2", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/stress_tests/stress_test_placement_group.json b/release/release_logs/1.8.0/stress_tests/stress_test_placement_group.json new file mode 100644 index 000000000000..6b22a4231ff3 --- /dev/null +++ b/release/release_logs/1.8.0/stress_tests/stress_test_placement_group.json @@ -0,0 +1,9 @@ +{ + "success":1, + "avg_pg_create_time_ms":0.9178227477476738, + "avg_pg_remove_time_ms":3.5015487627617348, + "_runtime":381.0450084209442, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_hMCMNudYEmmqv987qYbadjrQ", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/train_tests/app_config.yaml b/release/train_tests/app_config.yaml new file mode 100644 index 000000000000..446b53847dfa --- /dev/null +++ b/release/train_tests/app_config.yaml @@ -0,0 +1,13 @@ +base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: { } +debian_packages: + - curl + +python: + pip_packages: [ ] + conda_packages: [ ] + +post_build_cmds: + - pip3 uninstall -y ray || true + - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/train_tests/compute_tpl.yaml b/release/train_tests/compute_tpl.yaml new file mode 100644 index 000000000000..221bb8f66548 --- /dev/null +++ b/release/train_tests/compute_tpl.yaml @@ -0,0 +1,15 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +max_workers: 2 + +head_node_type: + name: head_node + instance_type: g3.8xlarge + +worker_node_types: + - name: worker_node + instance_type: g3.8xlarge + min_workers: 0 + max_workers: 2 + use_spot: false diff --git a/release/train_tests/driver_requirements.txt b/release/train_tests/driver_requirements.txt new file mode 100644 index 000000000000..b779be3a31dd --- /dev/null +++ b/release/train_tests/driver_requirements.txt @@ -0,0 +1,8 @@ +# Make sure the driver versions are the same as cluster versions. +# The cluster uses ray-ml Docker image. +# ray-ml Docker image installs dependencies from ray/python/requirements/ml/ directory. +# We constrain on these requirements file so that the same versions are installed. +-c ../../python/requirements/ml/requirements_dl.txt + +torch +tensorflow \ No newline at end of file diff --git a/release/train_tests/train_tests.yaml b/release/train_tests/train_tests.yaml new file mode 100644 index 000000000000..c19493f85d56 --- /dev/null +++ b/release/train_tests/train_tests.yaml @@ -0,0 +1,17 @@ +- name: train_tensorflow_mnist_test + cluster: + app_config: app_config.yaml + compute_template: compute_tpl.yaml + + run: + timeout: 36000 + script: python workloads/train_tensorflow_mnist_test.py + +- name: train_torch_linear_test + cluster: + app_config: app_config.yaml + compute_template: compute_tpl.yaml + + run: + timeout: 36000 + script: python workloads/train_torch_linear_test.py diff --git a/release/train_tests/workloads/train_tensorflow_mnist_test.py b/release/train_tests/workloads/train_tensorflow_mnist_test.py new file mode 100644 index 000000000000..376979d93c3a --- /dev/null +++ b/release/train_tests/workloads/train_tensorflow_mnist_test.py @@ -0,0 +1,31 @@ +import json +import os +import time + +import ray +from ray.train.examples.tensorflow_mnist_example import train_tensorflow_mnist + +if __name__ == "__main__": + start = time.time() + + addr = os.environ.get("RAY_ADDRESS") + job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test") + + if addr is not None and addr.startswith("anyscale://"): + ray.init(address=addr, job_name=job_name) + else: + ray.init(address="auto") + + train_tensorflow_mnist(num_workers=6, use_gpu=True, epochs=20) + + taken = time.time() - start + result = { + "time_taken": taken, + } + test_output_json = os.environ.get("TEST_OUTPUT_JSON", + "/tmp/train_torc_linear_test.json") + + with open(test_output_json, "wt") as f: + json.dump(result, f) + + print("Test Successful!") diff --git a/release/train_tests/workloads/train_torch_linear_test.py b/release/train_tests/workloads/train_torch_linear_test.py new file mode 100644 index 000000000000..fe013a8ef971 --- /dev/null +++ b/release/train_tests/workloads/train_torch_linear_test.py @@ -0,0 +1,30 @@ +import json +import os +import time + +import ray + +from ray.train.examples.train_linear_example import train_linear + +if __name__ == "__main__": + start = time.time() + + addr = os.environ.get("RAY_ADDRESS") + job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test") + + if addr is not None and addr.startswith("anyscale://"): + ray.init(address=addr, job_name=job_name) + else: + ray.init(address="auto") + + results = train_linear(num_workers=6, use_gpu=True, epochs=20) + + taken = time.time() - start + result = {"time_taken": taken} + test_output_json = os.environ.get("TEST_OUTPUT_JSON", + "/tmp/train_torc_linear_test.json") + + with open(test_output_json, "wt") as f: + json.dump(result, f) + + print("Test Successful!") diff --git a/rllib/agents/a3c/a3c_tf_policy.py b/rllib/agents/a3c/a3c_tf_policy.py index 78ea8cf8e28d..0dd87a9b233f 100644 --- a/rllib/agents/a3c/a3c_tf_policy.py +++ b/rllib/agents/a3c/a3c_tf_policy.py @@ -14,9 +14,9 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule, \ EntropyCoeffSchedule from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import explained_variance +from ray.rllib.utils.tf_utils import explained_variance from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ PolicyID, LocalOptimizer, ModelGradients diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index 10508d073f63..557d7eb53fc4 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -13,7 +13,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule, \ EntropyCoeffSchedule -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import apply_grad_clipping, sequence_mask from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ diff --git a/rllib/agents/ars/ars.py b/rllib/agents/ars/ars.py index ef83ea6ad247..bf55e38fa8af 100644 --- a/rllib/agents/ars/ars.py +++ b/rllib/agents/ars/ars.py @@ -16,7 +16,8 @@ from ray.rllib.agents.es.es_tf_policy import rollout from ray.rllib.env.env_context import EnvContext from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.annotations import Deprecated, override +from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.torch_ops import set_torch_seed from ray.rllib.utils import FilterManager diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index 561a98075fc7..3d48fc712c99 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -14,6 +14,7 @@ import psutil if TYPE_CHECKING: + from ray.rllib.agents.trainer import Trainer from ray.rllib.evaluation import RolloutWorker @@ -51,8 +52,6 @@ def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv, state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. - env_index: Obsoleted: The ID of the environment, which the - episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -73,19 +72,17 @@ def on_episode_step(self, """Runs on each episode step. Args: - worker (RolloutWorker): Reference to the current rollout worker. - base_env (BaseEnv): BaseEnv running the episode. The underlying + worker: Reference to the current rollout worker. + base_env: BaseEnv running the episode. The underlying sub environment objects can be retrieved by calling `base_env.get_sub_environments()`. - policies (Optional[Dict[PolicyID, Policy]]): Mapping of policy id + policies: Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy". - episode (Episode): Episode object which contains episode + episode: Episode object which contains episode state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. - env_index (EnvID): Obsoleted: The ID of the environment, which the - episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -101,19 +98,17 @@ def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, """Runs when an episode is done. Args: - worker (RolloutWorker): Reference to the current rollout worker. - base_env (BaseEnv): BaseEnv running the episode. The underlying + worker: Reference to the current rollout worker. + base_env: BaseEnv running the episode. The underlying sub environment objects can be retrieved by calling `base_env.get_sub_environments()`. - policies (Dict[PolicyID, Policy]): Mapping of policy id to policy + policies: Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy". - episode (Episode): Episode object which contains episode + episode: Episode object which contains episode state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. - env_index (EnvID): Obsoleted: The ID of the environment, which the - episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -136,16 +131,16 @@ def on_postprocess_trajectory( settings. Args: - worker (RolloutWorker): Reference to the current rollout worker. - episode (Episode): Episode object. - agent_id (str): Id of the current agent. - policy_id (str): Id of the current policy for the agent. - policies (dict): Mapping of policy id to policy objects. In single + worker: Reference to the current rollout worker. + episode: Episode object. + agent_id: Id of the current agent. + policy_id: Id of the current policy for the agent. + policies: Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy". - postprocessed_batch (SampleBatch): The postprocessed sample batch + postprocessed_batch: The postprocessed sample batch for this agent. You can mutate this object to apply your own trajectory postprocessing. - original_batches (dict): Mapping of agents to their unpostprocessed + original_batches: Mapping of agents to their unpostprocessed trajectory data. You should not mutate this object. kwargs: Forward compatibility placeholder. """ @@ -164,8 +159,8 @@ def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, """Called at the end of RolloutWorker.sample(). Args: - worker (RolloutWorker): Reference to the current rollout worker. - samples (SampleBatch): Batch to be returned. You can mutate this + worker: Reference to the current rollout worker. + samples: Batch to be returned. You can mutate this object to modify the samples generated. kwargs: Forward compatibility placeholder. """ @@ -184,21 +179,22 @@ def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch, `pad_batch_to_sequences_of_same_size`. Args: - policy (Policy): Reference to the current Policy object. - train_batch (SampleBatch): SampleBatch to be trained on. You can + policy: Reference to the current Policy object. + train_batch: SampleBatch to be trained on. You can mutate this object to modify the samples generated. - result (dict): A results dict to add custom metrics to. + result: A results dict to add custom metrics to. kwargs: Forward compatibility placeholder. """ pass - def on_train_result(self, *, trainer, result: dict, **kwargs) -> None: + def on_train_result(self, *, trainer: "Trainer", result: dict, + **kwargs) -> None: """Called at the end of Trainable.train(). Args: - trainer (Trainer): Current trainer instance. - result (dict): Dict of results returned from trainer.train() call. + trainer: Current trainer instance. + result: Dict of results returned from trainer.train() call. You can mutate this object to add additional metrics. kwargs: Forward compatibility placeholder. """ diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index d3c295feba94..be185f6d634d 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -26,7 +26,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import get_variable, try_import_tf from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable +from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ LocalOptimizer, ModelGradients from ray.util.debug import log_once diff --git a/rllib/agents/dqn/dqn_tf_policy.py b/rllib/agents/dqn/dqn_tf_policy.py index 88e3e30e06fa..d24ee4477392 100644 --- a/rllib/agents/dqn/dqn_tf_policy.py +++ b/rllib/agents/dqn/dqn_tf_policy.py @@ -20,8 +20,8 @@ from ray.rllib.utils.exploration import ParameterNoise from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.tf_ops import (huber_loss, make_tf_callable, - minimize_and_clip, reduce_mean_ignore_inf) +from ray.rllib.utils.tf_utils import ( + huber_loss, make_tf_callable, minimize_and_clip, reduce_mean_ignore_inf) from ray.rllib.utils.typing import (ModelGradients, TensorType, TrainerConfigDict) diff --git a/rllib/agents/dqn/learner_thread.py b/rllib/agents/dqn/learner_thread.py index 93bed4b18de5..6e7b1ebae348 100644 --- a/rllib/agents/dqn/learner_thread.py +++ b/rllib/agents/dqn/learner_thread.py @@ -3,8 +3,8 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder +from ray.rllib.utils.metrics.window_stat import WindowStat from ray.rllib.utils.timer import TimerStat -from ray.rllib.utils.window_stat import WindowStat LEARNER_QUEUE_MAX_SIZE = 16 diff --git a/rllib/agents/dqn/r2d2_tf_policy.py b/rllib/agents/dqn/r2d2_tf_policy.py index d34c35a44976..2f922a7dfd03 100644 --- a/rllib/agents/dqn/r2d2_tf_policy.py +++ b/rllib/agents/dqn/r2d2_tf_policy.py @@ -17,7 +17,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import LearningRateSchedule from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import huber_loss +from ray.rllib.utils.tf_utils import huber_loss from ray.rllib.utils.typing import ModelInputDict, TensorType, \ TrainerConfigDict diff --git a/rllib/agents/dqn/simple_q_tf_policy.py b/rllib/agents/dqn/simple_q_tf_policy.py index 13e62bca1fd9..49674c5e752e 100644 --- a/rllib/agents/dqn/simple_q_tf_policy.py +++ b/rllib/agents/dqn/simple_q_tf_policy.py @@ -18,7 +18,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable +from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable from ray.rllib.utils.typing import TensorType, TrainerConfigDict tf1, tf, tfv = try_import_tf() diff --git a/rllib/agents/es/es.py b/rllib/agents/es/es.py index 796076b01d9d..3535bf245926 100644 --- a/rllib/agents/es/es.py +++ b/rllib/agents/es/es.py @@ -14,7 +14,8 @@ from ray.rllib.env.env_context import EnvContext from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils import FilterManager -from ray.rllib.utils.annotations import Deprecated, override +from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.torch_ops import set_torch_seed logger = logging.getLogger(__name__) diff --git a/rllib/agents/impala/vtrace_tf_policy.py b/rllib/agents/impala/vtrace_tf_policy.py index 5a786a4da8e9..4e9594588b66 100644 --- a/rllib/agents/impala/vtrace_tf_policy.py +++ b/rllib/agents/impala/vtrace_tf_policy.py @@ -15,7 +15,7 @@ EntropyCoeffSchedule from ray.rllib.utils import force_list from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import explained_variance +from ray.rllib.utils.tf_utils import explained_variance tf1, tf, tfv = try_import_tf() diff --git a/rllib/agents/marwil/marwil_tf_policy.py b/rllib/agents/marwil/marwil_tf_policy.py index 9a386671792e..2c0f25913236 100644 --- a/rllib/agents/marwil/marwil_tf_policy.py +++ b/rllib/agents/marwil/marwil_tf_policy.py @@ -9,7 +9,7 @@ Postprocessing from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils.framework import try_import_tf, get_variable -from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable +from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable from ray.rllib.policy.policy import Policy from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ PolicyID diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index c579d21c3123..98af35d13a01 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -28,7 +28,7 @@ from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable +from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable from ray.rllib.utils.typing import AgentID, TensorType, TrainerConfigDict tf1, tf, tfv = try_import_tf() diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index 76610567abcb..6fc5ac27d88e 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -17,10 +17,10 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule, \ EntropyCoeffSchedule from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils.annotations import Deprecated -from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning +from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE, \ + deprecation_warning from ray.rllib.utils.framework import try_import_tf, get_variable -from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable +from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \ TensorType, TrainerConfigDict diff --git a/rllib/agents/registry.py b/rllib/agents/registry.py index 6139afbb17f0..f5cfe0b41748 100644 --- a/rllib/agents/registry.py +++ b/rllib/agents/registry.py @@ -3,7 +3,7 @@ import traceback from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated def _import_a2c(): diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index c6b5be01b6bf..97dbad921b1d 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -27,7 +27,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import get_variable, try_import_tf from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.utils.tf_ops import huber_loss +from ray.rllib.utils.tf_utils import huber_loss from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \ TensorType, TrainerConfigDict diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 375a3d1ac888..323bc01b3e6d 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -28,10 +28,11 @@ from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils import deep_update, FilterManager, merge_dicts -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override, \ +from ray.rllib.utils.annotations import DeveloperAPI, override, \ PublicAPI from ray.rllib.utils.debug import update_global_seed_if_necessary -from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE +from ray.rllib.utils.deprecation import Deprecated, deprecation_warning, \ + DEPRECATED_VALUE from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.from_config import from_config @@ -529,19 +530,34 @@ def with_common_config( @PublicAPI class Trainer(Trainable): - """A trainer coordinates the optimization of one or more RL policies. - - All RLlib trainers extend this base class, e.g., the A3CTrainer implements - the A3C algorithm for single and multi-agent training. - - Trainer objects retain internal model state between calls to train(), so - you should create a new trainer instance for each training session. - - Attributes: - env_creator (func): Function that creates a new training env. - config (obj): Algorithm-specific configuration data. - logdir (str): Directory in which training outputs should be placed. + """An RLlib algorithm responsible for optimizing one or more Policies. + + Trainers contain a WorkerSet under `self.workers`. A WorkerSet is + normally composed of a single local worker + (self.workers.local_worker()), used to compute and apply learning updates, + and optionally one or more remote workers (self.workers.remote_workers()), + used to generate environment samples in parallel. + + Each worker (remotes or local) contains a PolicyMap, which itself + may contain either one policy for single-agent training or one or more + policies for multi-agent training. Policies are synchronized + automatically from time to time using ray.remote calls. The exact + synchronization logic depends on the specific algorithm (Trainer) used, + but this usually happens from local worker to all remote workers and + after each training update. + + You can write your own Trainer sub-classes by using the + rllib.agents.trainer_template.py::build_trainer() utility function. + This allows you to provide a custom `execution_plan`. You can find the + different built-in algorithms' execution plans in their respective main + py files, e.g. rllib.agents.dqn.dqn.py or rllib.agents.impala.impala.py. + + The most important API methods a Trainer exposes are `train()`, + `evaluate()`, `save()` and `restore()`. Trainer objects retain internal + model state between calls to train(), so you should create a new + Trainer instance for each training session. """ + # Whether to allow unknown top-level config keys. _allow_unknown_configs = False @@ -562,15 +578,18 @@ class Trainer(Trainable): @PublicAPI def __init__(self, config: TrainerConfigDict = None, - env: str = None, + env: Union[str, EnvType, None] = None, logger_creator: Callable[[], Logger] = None): - """Initialize an RLLib trainer. + """Initializes a Trainer instance. Args: - config (dict): Algorithm-specific configuration data. - env (str): Name of the environment to use. Note that this can also - be specified as the `env` key in config. - logger_creator (func): Function that creates a ray.tune.Logger + config: Algorithm-specific configuration dict. + env: Name of the environment to use (e.g. a gym-registered str), + a full class path (e.g. + "ray.rllib.examples.env.random_env.RandomEnv"), or an Env + class directly. Note that this arg can also be specified via + the "env" key in `config`. + logger_creator: Callable that creates a ray.tune.Logger object. If unspecified, a default logger is created. """ @@ -623,151 +642,6 @@ def default_logger_creator(config): super().__init__(config, logger_creator) - @classmethod - @override(Trainable) - def default_resource_request( - cls, config: PartialTrainerConfigDict) -> \ - Union[Resources, PlacementGroupFactory]: - cf = dict(cls._default_config, **config) - - eval_config = cf["evaluation_config"] - - # TODO(ekl): add custom resources here once tune supports them - # Return PlacementGroupFactory containing all needed resources - # (already properly defined as device bundles). - return PlacementGroupFactory( - bundles=[{ - # Driver. - "CPU": cf["num_cpus_for_driver"], - "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], - }] + [ - { - # RolloutWorkers. - "CPU": cf["num_cpus_per_worker"], - "GPU": cf["num_gpus_per_worker"], - } for _ in range(cf["num_workers"]) - ] + ([ - { - # Evaluation workers. - # Note: The local eval worker is located on the driver CPU. - "CPU": eval_config.get("num_cpus_per_worker", - cf["num_cpus_per_worker"]), - "GPU": eval_config.get("num_gpus_per_worker", - cf["num_gpus_per_worker"]), - } for _ in range(cf["evaluation_num_workers"]) - ] if cf["evaluation_interval"] else []), - strategy=config.get("placement_strategy", "PACK")) - - @override(Trainable) - @PublicAPI - def train(self) -> ResultDict: - """Overrides super.train to synchronize global vars.""" - - result = None - for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): - try: - result = Trainable.train(self) - except RayError as e: - if self.config["ignore_worker_failures"]: - logger.exception( - "Error in train call, attempting to recover") - self._try_recover() - else: - logger.info( - "Worker crashed during call to train(). To attempt to " - "continue training without the failed worker, set " - "`'ignore_worker_failures': True`.") - raise e - except Exception as e: - time.sleep(0.5) # allow logs messages to propagate - raise e - else: - break - if result is None: - raise RuntimeError("Failed to recover from worker crash") - - if hasattr(self, "workers") and isinstance(self.workers, WorkerSet): - self._sync_filters_if_needed(self.workers) - - return result - - def _sync_filters_if_needed(self, workers: WorkerSet): - if self.config.get("observation_filter", "NoFilter") != "NoFilter": - FilterManager.synchronize( - workers.local_worker().filters, - workers.remote_workers(), - update_remote=self.config["synchronize_filters"]) - logger.debug("synchronized filters: {}".format( - workers.local_worker().filters)) - - @override(Trainable) - def log_result(self, result: ResultDict): - self.callbacks.on_train_result(trainer=self, result=result) - # log after the callback is invoked, so that the user has a chance - # to mutate the result - Trainable.log_result(self, result) - - @DeveloperAPI - def _create_local_replay_buffer_if_necessary(self, config): - """Create a LocalReplayBuffer instance if necessary. - - Args: - config (dict): Algorithm-specific configuration data. - - Returns: - LocalReplayBuffer instance based on trainer config. - None, if local replay buffer is not needed. - """ - # These are the agents that utilizes a local replay buffer. - if ("replay_buffer_config" not in config - or not config["replay_buffer_config"]): - # Does not need a replay buffer. - return None - - replay_buffer_config = config["replay_buffer_config"] - if ("type" not in replay_buffer_config - or replay_buffer_config["type"] != "LocalReplayBuffer"): - # DistributedReplayBuffer coming soon. - return None - - capacity = config.get("buffer_size", DEPRECATED_VALUE) - if capacity != DEPRECATED_VALUE: - # Print a deprecation warning. - deprecation_warning( - old="config['buffer_size']", - new="config['replay_buffer_config']['capacity']", - error=False) - else: - # Get capacity out of replay_buffer_config. - capacity = replay_buffer_config["capacity"] - - if config.get("prioritized_replay"): - prio_args = { - "prioritized_replay_alpha": config["prioritized_replay_alpha"], - "prioritized_replay_beta": config["prioritized_replay_beta"], - "prioritized_replay_eps": config["prioritized_replay_eps"], - } - else: - prio_args = {} - - return LocalReplayBuffer( - num_shards=1, - learning_starts=config["learning_starts"], - capacity=capacity, - replay_batch_size=config["train_batch_size"], - replay_mode=config["multiagent"]["replay_mode"], - replay_sequence_length=config.get("replay_sequence_length", 1), - replay_burn_in=config.get("burn_in", 0), - replay_zero_init_states=config.get("zero_init_states", True), - **prio_args) - - @DeveloperAPI - def _kwargs_for_execution_plan(self): - kwargs = {} - if self.local_replay_buffer: - kwargs["local_replay_buffer"] = self.local_replay_buffer - return kwargs - @override(Trainable) def setup(self, config: PartialTrainerConfigDict): env = self._env_id @@ -839,6 +713,8 @@ def env_creator_from_classpath(env_context): self.local_replay_buffer = ( self._create_local_replay_buffer_if_necessary(self.config)) + # Make the call to self._init. Sub-classes should override this + # method to implement custom initialization logic. self._init(self.config, self.env_creator) # Evaluation setup. @@ -875,69 +751,53 @@ def env_creator_from_classpath(env_context): config=evaluation_config, num_workers=self.config["evaluation_num_workers"]) - @override(Trainable) - def cleanup(self): - if hasattr(self, "workers"): - self.workers.stop() - if hasattr(self, "optimizer") and self.optimizer: - self.optimizer.stop() - - @override(Trainable) - def save_checkpoint(self, checkpoint_dir: str) -> str: - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) - - return checkpoint_path - - @override(Trainable) - def load_checkpoint(self, checkpoint_path: str): - extra_data = pickle.load(open(checkpoint_path, "rb")) - self.__setstate__(extra_data) - @DeveloperAPI - def _make_workers( - self, *, env_creator: Callable[[EnvContext], EnvType], - validate_env: Optional[Callable[[EnvType, EnvContext], None]], - policy_class: Type[Policy], config: TrainerConfigDict, - num_workers: int) -> WorkerSet: - """Default factory method for a WorkerSet running under this Trainer. + def _init(self, config: TrainerConfigDict, + env_creator: Callable[[EnvContext], EnvType]) -> None: + """Subclasses should override this for custom initialization. - Override this method by passing a custom `make_workers` into - `build_trainer`. + In the case of Trainer, this is called from inside `self.setup()`. Args: - env_creator (callable): A function that return and Env given an env - config. - validate_env (Optional[Callable[[EnvType, EnvContext], None]]): - Optional callable to validate the generated environment (only - on worker=0). - policy (Type[Policy]): The Policy class to use for creating the - policies of the workers. - config (TrainerConfigDict): The Trainer's config. - num_workers (int): Number of remote rollout workers to create. - 0 for local only. - - Returns: - WorkerSet: The created WorkerSet. + config: Algorithm-specific configuration dict. + env_creator: A callable taking an EnvContext as only arg and + returning an environment (of any type: e.g. gym.Env, RLlib + BaseEnv, MultiAgentEnv, etc..). """ - return WorkerSet( - env_creator=env_creator, - validate_env=validate_env, - policy_class=policy_class, - trainer_config=config, - num_workers=num_workers, - logdir=self.logdir) - - @DeveloperAPI - def _init(self, config: TrainerConfigDict, - env_creator: Callable[[EnvContext], EnvType]): - """Subclasses should override this for custom initialization.""" raise NotImplementedError - @Deprecated(new="Trainer.evaluate", error=False) - def _evaluate(self) -> dict: - return self.evaluate() + @override(Trainable) + @PublicAPI + def train(self) -> ResultDict: + """Overrides super.train to synchronize global vars.""" + + result = None + for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): + try: + result = Trainable.train(self) + except RayError as e: + if self.config["ignore_worker_failures"]: + logger.exception( + "Error in train call, attempting to recover") + self._try_recover() + else: + logger.info( + "Worker crashed during call to train(). To attempt to " + "continue training without the failed worker, set " + "`'ignore_worker_failures': True`.") + raise e + except Exception as e: + time.sleep(0.5) # allow logs messages to propagate + raise e + else: + break + if result is None: + raise RuntimeError("Failed to recover from worker crash") + + if hasattr(self, "workers") and isinstance(self.workers, WorkerSet): + self._sync_filters_if_needed(self.workers) + + return result @PublicAPI def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None @@ -948,10 +808,10 @@ def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None merging evaluation_config with the normal trainer config. Args: - episodes_left_fn (Optional[Callable[[int], int]]): An optional - callable taking the already run num episodes as only arg - and returning the number of episodes left to run. It's used - to find out whether evaluation should continue. + episodes_left_fn: An optional callable taking the already run + num episodes as only arg and returning the number of + episodes left to run. It's used to find out whether + evaluation should continue. """ # In case we are evaluating (in a thread) parallel to training, # we may have to re-enable eager mode here (gets disabled in the @@ -963,8 +823,8 @@ def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None # Call the `_before_evaluate` hook. self._before_evaluate() + # Sync weights to the evaluation WorkerSet. if self.evaluation_workers is not None: - # Sync weights to the evaluation WorkerSet. self._sync_weights_to_workers(worker_set=self.evaluation_workers) self._sync_filters_if_needed(self.evaluation_workers) @@ -1053,25 +913,6 @@ def episodes_left_fn(num_episodes_done): self.evaluation_workers.remote_workers()) return {"evaluation": metrics} - @DeveloperAPI - def _before_evaluate(self): - """Pre-evaluation callback.""" - pass - - @DeveloperAPI - def _sync_weights_to_workers( - self, - *, - worker_set: Optional[WorkerSet] = None, - workers: Optional[List[RolloutWorker]] = None, - ) -> None: - """Sync "main" weights to given WorkerSet or list of workers.""" - assert worker_set is not None - # Broadcast the new policy weights to all evaluation workers. - logger.info("Synchronizing weights to workers.") - weights = ray.put(self.workers.local_worker().save()) - worker_set.foreach_worker(lambda w: w.restore(ray.get(weights))) - @PublicAPI def compute_single_action( self, @@ -1223,10 +1064,6 @@ def compute_single_action( else: return action - @Deprecated(new="compute_single_action", error=False) - def compute_action(self, *args, **kwargs): - return self.compute_single_action(*args, **kwargs) - @PublicAPI def compute_actions( self, @@ -1253,7 +1090,7 @@ def compute_actions( self.get_policy(policy_id) and call compute_actions() on it directly. Args: - observation: observation from the environment. + observation: Observation from the environment. state: RNN hidden state, if any. If state is not None, then all of compute_single_action(...) is returned (computed action, rnn state(s), logits dictionary). @@ -1284,7 +1121,7 @@ def compute_actions( Returns: any: The computed action if full_fetch=False, or tuple: The full output of policy.compute_actions() if - full_fetch=True or we have an RNN-based Policy. + full_fetch=True or we have an RNN-based Policy. """ if normalize_actions is not None: deprecation_warning( @@ -1359,31 +1196,21 @@ def compute_actions( else: return actions - @property - def _name(self) -> str: - """Subclasses should override this to declare their name.""" - raise NotImplementedError - - @property - def _default_config(self) -> TrainerConfigDict: - """Subclasses should override this to declare their default config.""" - raise NotImplementedError - @PublicAPI def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy: """Return policy for the specified id, or None. Args: - policy_id (PolicyID): ID of the policy to return. + policy_id: ID of the policy to return. """ return self.workers.local_worker().get_policy(policy_id) @PublicAPI - def get_weights(self, policies: List[PolicyID] = None) -> dict: + def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict: """Return a dictionary of policy ids to weights. Args: - policies (list): Optional list of policies to return weights for, + policies: Optional list of policies to return weights for, or None for all policies. """ return self.workers.local_worker().get_weights(policies) @@ -1393,7 +1220,7 @@ def set_weights(self, weights: Dict[PolicyID, dict]): """Set policy weights by policy id. Args: - weights (dict): Map of policy ids to weights to set. + weights: Map of policy ids to weights to set. """ self.workers.local_worker().set_weights(weights) @@ -1502,35 +1329,38 @@ def fn(worker): def export_policy_model(self, export_dir: str, policy_id: PolicyID = DEFAULT_POLICY_ID, - onnx: Optional[int] = None): - """Export policy model with given policy_id to local directory. + onnx: Optional[int] = None) -> None: + """Exports policy model with given policy_id to a local directory. Args: - export_dir (string): Writable local directory. - policy_id (string): Optional policy id to export. - onnx (int): If given, will export model in ONNX format. The + export_dir: Writable local directory. + policy_id: Optional policy id to export. + onnx: If given, will export model in ONNX format. The value of this parameter set the ONNX OpSet version to use. + If None, the output format will be DL framework specific. Example: >>> trainer = MyTrainer() >>> for _ in range(10): >>> trainer.train() - >>> trainer.export_policy_model("/tmp/export_dir") + >>> trainer.export_policy_model("/tmp/dir") + >>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1) """ - self.workers.local_worker().export_policy_model( - export_dir, policy_id, onnx) + self.get_policy(policy_id).export_model(export_dir, onnx) @DeveloperAPI - def export_policy_checkpoint(self, - export_dir: str, - filename_prefix: str = "model", - policy_id: PolicyID = DEFAULT_POLICY_ID): - """Export tensorflow policy model checkpoint to local directory. + def export_policy_checkpoint( + self, + export_dir: str, + filename_prefix: str = "model", + policy_id: PolicyID = DEFAULT_POLICY_ID, + ) -> None: + """Exports policy model checkpoint to a local directory. Args: - export_dir (string): Writable local directory. - filename_prefix (string): file name prefix of checkpoint files. - policy_id (string): Optional policy id to export. + export_dir: Writable local directory. + filename_prefix: file name prefix of checkpoint files. + policy_id: Optional policy id to export. Example: >>> trainer = MyTrainer() @@ -1538,18 +1368,20 @@ def export_policy_checkpoint(self, >>> trainer.train() >>> trainer.export_policy_checkpoint("/tmp/export_dir") """ - self.workers.local_worker().export_policy_checkpoint( - export_dir, filename_prefix, policy_id) + self.get_policy(policy_id).export_checkpoint(export_dir, + filename_prefix) @DeveloperAPI - def import_policy_model_from_h5(self, - import_file: str, - policy_id: PolicyID = DEFAULT_POLICY_ID): + def import_policy_model_from_h5( + self, + import_file: str, + policy_id: PolicyID = DEFAULT_POLICY_ID, + ) -> None: """Imports a policy's model with given policy_id from a local h5 file. Args: - import_file (str): The h5 file to import from. - policy_id (string): Optional policy id to import into. + import_file: The h5 file to import from. + policy_id: Optional policy id to import into. Example: >>> trainer = MyTrainer() @@ -1557,8 +1389,9 @@ def import_policy_model_from_h5(self, >>> for _ in range(10): >>> trainer.train() """ - self.workers.local_worker().import_policy_model_from_h5( - import_file, policy_id) + self.get_policy(policy_id).import_model_from_h5(import_file) + # Sync new weights to remote workers. + self._sync_weights_to_workers() @DeveloperAPI def collect_metrics(self, @@ -1572,6 +1405,156 @@ def collect_metrics(self, min_history=self.config["metrics_smoothing_episodes"], selected_workers=selected_workers) + @override(Trainable) + def save_checkpoint(self, checkpoint_dir: str) -> str: + checkpoint_path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(self.iteration)) + pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) + + return checkpoint_path + + @override(Trainable) + def load_checkpoint(self, checkpoint_path: str) -> None: + extra_data = pickle.load(open(checkpoint_path, "rb")) + self.__setstate__(extra_data) + + @override(Trainable) + def log_result(self, result: ResultDict) -> None: + # Log after the callback is invoked, so that the user has a chance + # to mutate the result. + self.callbacks.on_train_result(trainer=self, result=result) + # Then log according to Trainable's logging logic. + Trainable.log_result(self, result) + + @override(Trainable) + def cleanup(self) -> None: + # Stop all workers. + if hasattr(self, "workers"): + self.workers.stop() + # Stop all optimizers. + if hasattr(self, "optimizer") and self.optimizer: + self.optimizer.stop() + + @classmethod + @override(Trainable) + def default_resource_request( + cls, config: PartialTrainerConfigDict) -> \ + Union[Resources, PlacementGroupFactory]: + + # Default logic for RLlib algorithms (Trainers): + # Create one bundle per individual worker (local or remote). + # Use `num_cpus_for_driver` and `num_gpus` for the local worker and + # `num_cpus_per_worker` and `num_gpus_per_worker` for the remote + # workers to determine their CPU/GPU resource needs. + + # Convenience config handles. + cf = dict(cls._default_config, **config) + eval_cf = cf["evaluation_config"] + + # TODO(ekl): add custom resources here once tune supports them + # Return PlacementGroupFactory containing all needed resources + # (already properly defined as device bundles). + return PlacementGroupFactory( + bundles=[{ + # Local worker. + "CPU": cf["num_cpus_for_driver"], + "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], + }] + [ + { + # RolloutWorkers. + "CPU": cf["num_cpus_per_worker"], + "GPU": cf["num_gpus_per_worker"], + } for _ in range(cf["num_workers"]) + ] + ([ + { + # Evaluation workers. + # Note: The local eval worker is located on the driver CPU. + "CPU": eval_cf.get("num_cpus_per_worker", + cf["num_cpus_per_worker"]), + "GPU": eval_cf.get("num_gpus_per_worker", + cf["num_gpus_per_worker"]), + } for _ in range(cf["evaluation_num_workers"]) + ] if cf["evaluation_interval"] else []), + strategy=config.get("placement_strategy", "PACK")) + + @DeveloperAPI + def _before_evaluate(self): + """Pre-evaluation callback.""" + pass + + @DeveloperAPI + def _make_workers( + self, + *, + env_creator: Callable[[EnvContext], EnvType], + validate_env: Optional[Callable[[EnvType, EnvContext], None]], + policy_class: Type[Policy], + config: TrainerConfigDict, + num_workers: int, + ) -> WorkerSet: + """Default factory method for a WorkerSet running under this Trainer. + + Override this method by passing a custom `make_workers` into + `build_trainer`. + + Args: + env_creator: A function that return and Env given an env + config. + validate_env: Optional callable to validate the generated + environment. The env to be checked is the one returned from + the env creator, which may be a (single, not-yet-vectorized) + gym.Env or your custom RLlib env type (e.g. MultiAgentEnv, + VectorEnv, BaseEnv, etc..). + policy_class: The Policy class to use for creating the policies + of the workers. + config: The Trainer's config. + num_workers: Number of remote rollout workers to create. + 0 for local only. + + Returns: + The created WorkerSet. + """ + return WorkerSet( + env_creator=env_creator, + validate_env=validate_env, + policy_class=policy_class, + trainer_config=config, + num_workers=num_workers, + logdir=self.logdir) + + def _sync_filters_if_needed(self, workers: WorkerSet): + if self.config.get("observation_filter", "NoFilter") != "NoFilter": + FilterManager.synchronize( + workers.local_worker().filters, + workers.remote_workers(), + update_remote=self.config["synchronize_filters"]) + logger.debug("synchronized filters: {}".format( + workers.local_worker().filters)) + + @DeveloperAPI + def _sync_weights_to_workers( + self, + *, + worker_set: Optional[WorkerSet] = None, + workers: Optional[List[RolloutWorker]] = None, + ) -> None: + """Sync "main" weights to given WorkerSet or list of workers.""" + assert worker_set is not None + # Broadcast the new policy weights to all evaluation workers. + logger.info("Synchronizing weights to workers.") + weights = ray.put(self.workers.local_worker().save()) + worker_set.foreach_worker(lambda w: w.restore(ray.get(weights))) + + @property + def _name(self) -> str: + """Subclasses should override this to declare their name.""" + raise NotImplementedError + + @property + def _default_config(self) -> TrainerConfigDict: + """Subclasses should override this to declare their default config.""" + raise NotImplementedError + @classmethod @override(Trainable) def resource_help(cls, config: TrainerConfigDict) -> str: @@ -1909,6 +1892,69 @@ def with_updates(**overrides) -> Type["Trainer"]: "that were generated via the `ray.rllib.agents.trainer_template." "build_trainer()` function!") + @DeveloperAPI + def _create_local_replay_buffer_if_necessary( + self, + config: PartialTrainerConfigDict) -> Optional[LocalReplayBuffer]: + """Create a LocalReplayBuffer instance if necessary. + + Args: + config: Algorithm-specific configuration data. + + Returns: + LocalReplayBuffer instance based on trainer config. + None, if local replay buffer is not needed. + """ + # These are the agents that utilizes a local replay buffer. + if ("replay_buffer_config" not in config + or not config["replay_buffer_config"]): + # Does not need a replay buffer. + return None + + replay_buffer_config = config["replay_buffer_config"] + if ("type" not in replay_buffer_config + or replay_buffer_config["type"] != "LocalReplayBuffer"): + # DistributedReplayBuffer coming soon. + return None + + capacity = config.get("buffer_size", DEPRECATED_VALUE) + if capacity != DEPRECATED_VALUE: + # Print a deprecation warning. + deprecation_warning( + old="config['buffer_size']", + new="config['replay_buffer_config']['capacity']", + error=False) + else: + # Get capacity out of replay_buffer_config. + capacity = replay_buffer_config["capacity"] + + if config.get("prioritized_replay"): + prio_args = { + "prioritized_replay_alpha": config["prioritized_replay_alpha"], + "prioritized_replay_beta": config["prioritized_replay_beta"], + "prioritized_replay_eps": config["prioritized_replay_eps"], + } + else: + prio_args = {} + + return LocalReplayBuffer( + num_shards=1, + learning_starts=config["learning_starts"], + capacity=capacity, + replay_batch_size=config["train_batch_size"], + replay_mode=config["multiagent"]["replay_mode"], + replay_sequence_length=config.get("replay_sequence_length", 1), + replay_burn_in=config.get("burn_in", 0), + replay_zero_init_states=config.get("zero_init_states", True), + **prio_args) + + @DeveloperAPI + def _kwargs_for_execution_plan(self): + kwargs = {} + if self.local_replay_buffer: + kwargs["local_replay_buffer"] = self.local_replay_buffer + return kwargs + def _register_if_needed(self, env_object: Union[str, EnvType, None], config) -> Optional[str]: if isinstance(env_object, str): @@ -1939,5 +1985,13 @@ def _is_multi_agent(self): "You can specify a custom env as either a class " "(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").") + @Deprecated(new="Trainer.evaluate", error=False) + def _evaluate(self) -> dict: + return self.evaluate() + + @Deprecated(new="compute_single_action", error=False) + def compute_action(self, *args, **kwargs): + return self.compute_single_action(*args, **kwargs) + def __repr__(self): return self._name diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index b3a8ff71c29c..450d91cefbbf 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -75,58 +75,50 @@ def build_trainer( allow_unknown_subkeys: Optional[List[str]] = None, override_all_subkeys_if_type_changes: Optional[List[str]] = None, ) -> Type[Trainer]: - """Helper function for defining a custom trainer. + """Helper function for defining a custom Trainer class. Functions will be run in this order to initialize the trainer: - 1. Config setup: validate_config, get_policy - 2. Worker setup: before_init, execution_plan - 3. Post setup: after_init + 1. Config setup: validate_config, get_policy. + 2. Worker setup: before_init, execution_plan. + 3. Post setup: after_init. Args: - name (str): name of the trainer (e.g., "PPO") - default_config (Optional[TrainerConfigDict]): The default config dict - of the algorithm, otherwise uses the Trainer default config. - validate_config (Optional[Callable[[TrainerConfigDict], None]]): - Optional callable that takes the config to check for correctness. - It may mutate the config as needed. - default_policy (Optional[Type[Policy]]): The default Policy class to - use if `get_policy_class` returns None. - get_policy_class (Optional[Callable[ - TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable - that takes a config and returns the policy class or None. If None - is returned, will use `default_policy` (which must be provided - then). - validate_env (Optional[Callable[[EnvType, EnvContext], None]]): - Optional callable to validate the generated environment (only - on worker=0). - before_init (Optional[Callable[[Trainer], None]]): Optional callable to - run before anything is constructed inside Trainer (Workers with - Policies, execution plan, etc..). Takes the Trainer instance as - argument. - after_init (Optional[Callable[[Trainer], None]]): Optional callable to - run at the end of trainer init (after all Workers and the exec. - plan have been constructed). Takes the Trainer instance as - argument. - before_evaluate_fn (Optional[Callable[[Trainer], None]]): Callback to - run before evaluation. This takes the trainer instance as argument. - mixins (list): list of any class mixins for the returned trainer class. + name: name of the trainer (e.g., "PPO") + default_config: The default config dict of the algorithm, + otherwise uses the Trainer default config. + validate_config: Optional callable that takes the config to check + for correctness. It may mutate the config as needed. + default_policy: The default Policy class to use if `get_policy_class` + returns None. + get_policy_class: Optional callable that takes a config and returns + the policy class or None. If None is returned, will use + `default_policy` (which must be provided then). + validate_env: Optional callable to validate the generated environment + (only on worker=0). + before_init: Optional callable to run before anything is constructed + inside Trainer (Workers with Policies, execution plan, etc..). + Takes the Trainer instance as argument. + after_init: Optional callable to run at the end of trainer init + (after all Workers and the exec. plan have been constructed). + Takes the Trainer instance as argument. + before_evaluate_fn: Callback to run before evaluation. This takes + the trainer instance as argument. + mixins: List of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class. - execution_plan (Optional[Callable[[WorkerSet, TrainerConfigDict], - Iterable[ResultDict]]]): Optional callable that sets up the + execution_plan: Optional callable that sets up the distributed execution workflow. - allow_unknown_configs (bool): Whether to allow unknown top-level config - keys. - allow_unknown_subkeys (Optional[List[str]]): List of top-level keys + allow_unknown_configs: Whether to allow unknown top-level config keys. + allow_unknown_subkeys: List of top-level keys with value=dict, for which new sub-keys are allowed to be added to the value dict. Appends to Trainer class defaults. - override_all_subkeys_if_type_changes (Optional[List[str]]): List of top - level keys with value=dict, for which we always override the entire - value (dict), iff the "type" key in that value dict changes. - Appends to Trainer class defaults. + override_all_subkeys_if_type_changes: List of top level keys with + value=dict, for which we always override the entire value (dict), + iff the "type" key in that value dict changes. Appends to Trainer + class defaults. Returns: - Type[Trainer]: A Trainer sub-class configured by the specified args. + A Trainer sub-class configured by the specified args. """ original_kwargs = locals().copy() diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 2ee532096d82..360005c3c808 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -39,7 +39,7 @@ from ray.rllib.utils.filter import get_filter, Filter from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.sgd import do_minibatch_sgd -from ray.rllib.utils.tf_ops import get_gpu_devices as get_tf_gpu_devices +from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices from ray.rllib.utils.tf_run_builder import TFRunBuilder from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \ ModelConfigDict, ModelGradients, ModelWeights, \ @@ -664,11 +664,12 @@ def make_sub_env(vector_index): "will discard all sampler outputs and keep only metrics.") sample_async = True elif method == "is": - ise = ImportanceSamplingEstimator.create(self.io_context) + ise = ImportanceSamplingEstimator.\ + create_from_io_context(self.io_context) self.reward_estimators.append(ise) elif method == "wis": - wise = WeightedImportanceSamplingEstimator.create( - self.io_context) + wise = WeightedImportanceSamplingEstimator.\ + create_from_io_context(self.io_context) self.reward_estimators.append(wise) else: raise ValueError( diff --git a/rllib/evaluation/sample_batch_builder.py b/rllib/evaluation/sample_batch_builder.py index 7afeb14816f2..898970c308a7 100644 --- a/rllib/evaluation/sample_batch_builder.py +++ b/rllib/evaluation/sample_batch_builder.py @@ -7,7 +7,8 @@ from ray.rllib.evaluation.episode import Episode from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.debug import summarize from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.typing import PolicyID, AgentID diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 09fdb3b968de..60b968404138 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -38,7 +38,7 @@ EntropyCoeffSchedule as TorchEntropyCoeffSchedule from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_learning_achieved -from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable +from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable from ray.rllib.utils.torch_ops import convert_to_torch_tensor tf1, tf, tfv = try_import_tf() diff --git a/rllib/examples/env/multi_agent.py b/rllib/examples/env/multi_agent.py index 4e052e70ecb5..a1d19eea1bc5 100644 --- a/rllib/examples/env/multi_agent.py +++ b/rllib/examples/env/multi_agent.py @@ -4,7 +4,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv, make_multi_agent from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2 from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated @Deprecated( diff --git a/rllib/examples/models/trajectory_view_utilizing_models.py b/rllib/examples/models/trajectory_view_utilizing_models.py index 0fd4e22cb145..c38ba2a6f7e2 100644 --- a/rllib/examples/models/trajectory_view_utilizing_models.py +++ b/rllib/examples/models/trajectory_view_utilizing_models.py @@ -3,7 +3,7 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.tf_ops import one_hot +from ray.rllib.utils.tf_utils import one_hot from ray.rllib.utils.torch_ops import one_hot as torch_one_hot tf1, tf, tfv = try_import_tf() diff --git a/rllib/execution/learner_thread.py b/rllib/execution/learner_thread.py index d8c6f93c146b..3fb8a3195eda 100644 --- a/rllib/execution/learner_thread.py +++ b/rllib/execution/learner_thread.py @@ -8,8 +8,8 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \ LEARNER_INFO, LEARNER_STATS_KEY +from ray.rllib.utils.metrics.window_stat import WindowStat from ray.rllib.utils.timer import TimerStat -from ray.rllib.utils.window_stat import WindowStat from ray.util.iter import _NextValueNotReady tf1, tf, tfv = try_import_tf() diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index 3a2e461cd38b..3c07272760d0 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -17,10 +17,10 @@ from ray.rllib.utils.annotations import DeveloperAPI, override from ray.util.iter import ParallelIteratorWorker from ray.util.debug import log_once -from ray.rllib.utils.annotations import Deprecated -from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning +from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE, \ + deprecation_warning from ray.rllib.utils.timer import TimerStat -from ray.rllib.utils.window_stat import WindowStat +from ray.rllib.utils.metrics.window_stat import WindowStat from ray.rllib.utils.typing import SampleBatchType # Constant that represents all policies in lockstep replay mode. diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 180b08e2c665..048ca478c2eb 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -18,8 +18,8 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \ TorchDeterministic, TorchDiagGaussian, \ TorchMultiActionDistribution, TorchMultiCategorical -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, PublicAPI -from ray.rllib.utils.deprecation import DEPRECATED_VALUE, \ +from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI +from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE, \ deprecation_warning from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_tf, try_import_torch diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index db234dc4247e..971fc952c593 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -10,7 +10,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils import NullContextManager -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, PublicAPI +from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ TensorType from ray.rllib.utils.spaces.repeated import Repeated diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py index 05e28e989eaf..5d0fe6aafb68 100644 --- a/rllib/models/tf/attention_net.py +++ b/rllib/models/tf/attention_net.py @@ -22,7 +22,7 @@ from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import one_hot +from ray.rllib.utils.tf_utils import one_hot from ray.rllib.utils.typing import ModelConfigDict, TensorType, List tf1, tf, tfv = try_import_tf() diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py index c7323c41cab9..8b4ffa801f4b 100644 --- a/rllib/models/tf/complex_input_net.py +++ b/rllib/models/tf/complex_input_net.py @@ -11,7 +11,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.spaces.space_utils import flatten_space -from ray.rllib.utils.tf_ops import one_hot +from ray.rllib.utils.tf_utils import one_hot tf1, tf, tfv = try_import_tf() diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index 862763304e63..aa68808bbbbf 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -11,7 +11,7 @@ from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import one_hot +from ray.rllib.utils.tf_utils import one_hot from ray.rllib.utils.typing import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py index 43a75b281003..0487bcf19de2 100644 --- a/rllib/models/torch/torch_action_dist.py +++ b/rllib/models/torch/torch_action_dist.py @@ -11,7 +11,6 @@ from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \ MAX_LOG_NN_OUTPUT from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space -from ray.rllib.utils.torch_ops import atanh from ray.rllib.utils.typing import TensorType, List, Union, \ Tuple, ModelConfigDict @@ -300,7 +299,7 @@ def _unsquash(self, values: TensorType) -> TensorType: # Stabilize input to atanh. save_normed_values = torch.clamp(normed_values, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER) - unsquashed = atanh(save_normed_values) + unsquashed = torch.atanh(save_normed_values) return unsquashed @staticmethod diff --git a/rllib/offline/d4rl_reader.py b/rllib/offline/d4rl_reader.py index d191d65c61f3..dae9cccb019a 100644 --- a/rllib/offline/d4rl_reader.py +++ b/rllib/offline/d4rl_reader.py @@ -17,11 +17,11 @@ class D4RLReader(InputReader): @PublicAPI def __init__(self, inputs: str, ioctx: IOContext = None): - """Initialize a D4RLReader. + """Initializes a D4RLReader instance. Args: - inputs (str): String corresponding to D4RL environment name - ioctx (IOContext): Current IO context object. + inputs: String corresponding to the D4RL environment name. + ioctx: Current IO context object. """ import d4rl self.env = gym.make(inputs) diff --git a/rllib/offline/input_reader.py b/rllib/offline/input_reader.py index 3b05e4772402..12ac65474027 100644 --- a/rllib/offline/input_reader.py +++ b/rllib/offline/input_reader.py @@ -16,15 +16,16 @@ @PublicAPI class InputReader(metaclass=ABCMeta): - """Input object for loading experiences in policy evaluation.""" + """API for collecting and returning experiences during policy evaluation. + """ @abstractmethod @PublicAPI - def next(self): - """Returns the next batch of experiences read. + def next(self) -> SampleBatchType: + """Returns the next batch of read experiences. Returns: - Union[SampleBatch, MultiAgentBatch]: The experience read. + The experience read (SampleBatch or MultiAgentBatch). """ raise NotImplementedError @@ -40,7 +41,7 @@ def tf_input_ops(self, queue_size: int = 1) -> Dict[str, TensorType]: reader repeatedly to feed the TensorFlow queue. Args: - queue_size (int): Max elements to allow in the TF queue. + queue_size: Max elements to allow in the TF queue. Example: >>> class MyModel(rllib.model.Model): @@ -56,7 +57,7 @@ def tf_input_ops(self, queue_size: int = 1) -> Dict[str, TensorType]: You can find a runnable version of this in examples/custom_loss.py. Returns: - dict of Tensors, one for each column of the read SampleBatch. + Dict of Tensors, one for each column of the read SampleBatch. """ if hasattr(self, "_queue_runner"): diff --git a/rllib/offline/io_context.py b/rllib/offline/io_context.py index f13103b7f295..c74db614c4f3 100644 --- a/rllib/offline/io_context.py +++ b/rllib/offline/io_context.py @@ -1,37 +1,53 @@ import os +from typing import Any, Optional, TYPE_CHECKING from ray.rllib.utils.annotations import PublicAPI -from typing import Any +from ray.rllib.utils.typing import TrainerConfigDict + +if TYPE_CHECKING: + from ray.rllib.evaluation.sampler import SamplerInput @PublicAPI class IOContext: - """Attributes to pass to input / output class constructors. - - RLlib auto-sets these attributes when constructing input / output classes. + """Class containing attributes to pass to input/output class constructors. - Attributes: - log_dir (str): Default logging directory. - config (dict): Configuration of the agent. - worker_index (int): When there are multiple workers created, this - uniquely identifies the current worker. - worker (RolloutWorker): RolloutWorker object reference. - input_config (dict): The input configuration for custom input. + RLlib auto-sets these attributes when constructing input/output classes, + such as InputReaders and OutputWriters. """ @PublicAPI def __init__(self, - log_dir: str = None, - config: dict = None, + log_dir: Optional[str] = None, + config: Optional[TrainerConfigDict] = None, worker_index: int = 0, - worker: Any = None): + worker: Optional[Any] = None): + """Initializes a IOContext object. + + Args: + log_dir: The logging directory to read from/write to. + config: The Trainer's main config dict. + worker_index (int): When there are multiple workers created, this + uniquely identifies the current worker. 0 for the local + worker, >0 for any of the remote workers. + worker (RolloutWorker): The RolloutWorker object reference. + """ self.log_dir = log_dir or os.getcwd() self.config = config or {} self.worker_index = worker_index self.worker = worker @PublicAPI - def default_sampler_input(self) -> Any: + def default_sampler_input(self) -> Optional["SamplerInput"]: + """Returns the RolloutWorker's SamplerInput object, if any. + + Returns None if the RolloutWorker has no SamplerInput. Note that local + workers in case there are also one or more remote workers by default + do not create a SamplerInput object. + + Returns: + The RolloutWorkers' SamplerInput object or None if none exists. + """ return self.worker.sampler @PublicAPI diff --git a/rllib/offline/is_estimator.py b/rllib/offline/is_estimator.py index 119eb2e1c97f..242c5f291fa8 100644 --- a/rllib/offline/is_estimator.py +++ b/rllib/offline/is_estimator.py @@ -14,7 +14,7 @@ def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: self.check_can_estimate_for(batch) rewards, old_prob = batch["rewards"], batch["action_prob"] - new_prob = self.action_prob(batch) + new_prob = self.action_log_likelihood(batch) # calculate importance ratios p = [] diff --git a/rllib/offline/json_reader.py b/rllib/offline/json_reader.py index 2177ea27ef98..01da73d87489 100644 --- a/rllib/offline/json_reader.py +++ b/rllib/offline/json_reader.py @@ -5,7 +5,7 @@ from pathlib import Path import random import re -from typing import List, Optional +from typing import List, Optional, Union from urllib.parse import urlparse import zipfile @@ -32,17 +32,20 @@ class JsonReader(InputReader): """Reader object that loads experiences from JSON file chunks. - The input files will be read from in an random order.""" + The input files will be read from in random order. + """ @PublicAPI - def __init__(self, inputs: List[str], ioctx: IOContext = None): - """Initialize a JsonReader. + def __init__(self, + inputs: Union[str, List[str]], + ioctx: Optional[IOContext] = None): + """Initializes a JsonReader instance. Args: - inputs (str|list): Either a glob expression for files, e.g., - "/tmp/**/*.json", or a list of single file paths or URIs, e.g., + inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`, + or a list of single file paths or URIs, e.g., ["s3://bucket/file.json", "s3://bucket/file2.json"]. - ioctx (IOContext): Current IO context object. + ioctx: Current IO context object or None. """ self.ioctx = ioctx or IOContext() @@ -72,8 +75,8 @@ def __init__(self, inputs: List[str], ioctx: IOContext = None): self.files = [] for i in inputs: self.files.extend(glob.glob(i)) - elif type(inputs) is list: - self.files = inputs + elif isinstance(inputs, (list, tuple)): + self.files = list(inputs) else: raise ValueError( "type of inputs must be list or str, not {}".format(inputs)) @@ -98,6 +101,26 @@ def next(self) -> SampleBatchType: return self._postprocess_if_needed(batch) + def read_all_files(self) -> SampleBatchType: + """Reads through all files and yields one SampleBatchType per line. + + When reaching the end of the last file, will start from the beginning + again. + + Yields: + One SampleBatch or MultiAgentBatch per line in all input files. + """ + for path in self.files: + file = self._try_open_file(path) + while True: + line = file.readline() + if not line: + break + batch = self._try_parse(line) + if batch is None: + break + yield batch + def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType: if not self.ioctx.config.get("postprocess_inputs"): @@ -182,18 +205,6 @@ def _try_parse(self, line: str) -> Optional[SampleBatchType]: self.ioctx.worker.policy_map[pid].action_space_struct) return batch - def read_all_files(self): - for path in self.files: - file = self._try_open_file(path) - while True: - line = file.readline() - if not line: - break - batch = self._try_parse(line) - if batch is None: - break - yield batch - def _next_line(self) -> str: if not self.cur_file: self.cur_file = self._next_file() diff --git a/rllib/offline/json_writer.py b/rllib/offline/json_writer.py index d3c849684e49..77777872dc55 100644 --- a/rllib/offline/json_writer.py +++ b/rllib/offline/json_writer.py @@ -34,15 +34,14 @@ def __init__(self, ioctx: IOContext = None, max_file_size: int = 64 * 1024 * 1024, compress_columns: List[str] = frozenset(["obs", "new_obs"])): - """Initialize a JsonWriter. + """Initializes a JsonWriter instance. Args: - path (str): a path/URI of the output directory to save files in. - ioctx (IOContext): current IO context object. - max_file_size (int): max size of single files before rolling over. - compress_columns (list): list of sample batch columns to compress. + path: a path/URI of the output directory to save files in. + ioctx: current IO context object. + max_file_size: max size of single files before rolling over. + compress_columns: list of sample batch columns to compress. """ - self.ioctx = ioctx or IOContext() self.max_file_size = max_file_size self.compress_columns = compress_columns diff --git a/rllib/offline/off_policy_estimator.py b/rllib/offline/off_policy_estimator.py index 871ad67e0436..e04f8238682f 100644 --- a/rllib/offline/off_policy_estimator.py +++ b/rllib/offline/off_policy_estimator.py @@ -7,6 +7,7 @@ from ray.rllib.policy import Policy from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.offline.io_context import IOContext +from ray.rllib.utils.annotations import Deprecated from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.typing import TensorType, SampleBatchType from typing import List @@ -23,19 +24,30 @@ class OffPolicyEstimator: @DeveloperAPI def __init__(self, policy: Policy, gamma: float): - """Creates an off-policy estimator. + """Initializes an OffPolicyEstimator instance. Args: - policy (Policy): Policy to evaluate. - gamma (float): Discount of the MDP. + policy: Policy to evaluate. + gamma: Discount factor of the environment. """ self.policy = policy self.gamma = gamma self.new_estimates = [] @classmethod - def create(cls, ioctx: IOContext) -> "OffPolicyEstimator": - """Create an off-policy estimator from a IOContext.""" + def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator": + """Creates an off-policy estimator from an IOContext object. + + Extracts Policy and gamma (discount factor) information from the + IOContext. + + Args: + ioctx: The IOContext object to create the OffPolicyEstimator + from. + + Returns: + The OffPolicyEstimator object created from the IOContext object. + """ gamma = ioctx.worker.policy_config["gamma"] # Grab a reference to the current model keys = list(ioctx.worker.policy_map.keys()) @@ -47,18 +59,36 @@ def create(cls, ioctx: IOContext) -> "OffPolicyEstimator": return cls(policy, gamma) @DeveloperAPI - def estimate(self, batch: SampleBatchType): - """Returns an estimate for the given batch of experiences. + def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: + """Returns an off policy estimate for the given batch of experiences. + + The batch will at most only contain data from one episode, + but it may also only be a fragment of an episode. - The batch will only contain data from one episode, but it may only be - a fragment of an episode. + Args: + batch: The batch to calculate the off policy estimate (OPE) on. + + Returns: + The off-policy estimates (OPE) calculated on the given batch. """ raise NotImplementedError @DeveloperAPI - def action_prob(self, batch: SampleBatchType) -> np.ndarray: - """Returns the probs for the batch actions for the current policy.""" + def action_log_likelihood(self, batch: SampleBatchType) -> TensorType: + """Returns log likelihoods for actions in given batch for policy. + + Computes likelihoods by passing the observations through the current + policy's `compute_log_likelihoods()` method. + + Args: + batch: The SampleBatch or MultiAgentBatch to calculate action + log likelihoods from. This batch/batches must contain OBS + and ACTIONS keys. + Returns: + The log likelihoods of the actions in the batch, given the + observations and the policy. + """ num_state_inputs = 0 for k in batch.keys(): if k.startswith("state_in_"): @@ -66,7 +96,7 @@ def action_prob(self, batch: SampleBatchType) -> np.ndarray: state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] log_likelihoods: TensorType = self.policy.compute_log_likelihoods( actions=batch[SampleBatch.ACTIONS], - obs_batch=batch[SampleBatch.CUR_OBS], + obs_batch=batch[SampleBatch.OBS], state_batches=[batch[k] for k in state_keys], prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS), @@ -76,12 +106,29 @@ def action_prob(self, batch: SampleBatchType) -> np.ndarray: return np.exp(log_likelihoods) @DeveloperAPI - def process(self, batch: SampleBatchType): + def process(self, batch: SampleBatchType) -> None: + """Computes off policy estimates (OPE) on batch and stores results. + + Thus-far collected results can be retrieved then by calling + `self.get_metrics` (which flushes the internal results storage). + + Args: + batch: The batch to process (call `self.estimate()` on) and + store results (OPEs) for. + """ self.new_estimates.append(self.estimate(batch)) @DeveloperAPI - def check_can_estimate_for(self, batch: SampleBatchType): - """Returns whether we can support OPE for this batch.""" + def check_can_estimate_for(self, batch: SampleBatchType) -> None: + """Checks if we support off policy estimation (OPE) on given batch. + + Args: + batch: The batch to check. + + Raises: + ValueError: In case `action_prob` key is not in batch OR batch + is a MultiAgentBatch. + """ if isinstance(batch, MultiAgentBatch): raise ValueError( @@ -98,11 +145,19 @@ def check_can_estimate_for(self, batch: SampleBatchType): @DeveloperAPI def get_metrics(self) -> List[OffPolicyEstimate]: - """Return a list of new episode metric estimates since the last call. + """Returns list of new episode metric estimates since the last call. Returns: - list of OffPolicyEstimate objects. + List of OffPolicyEstimate objects. """ out = self.new_estimates self.new_estimates = [] return out + + @Deprecated(new="OffPolicyEstimator.create_from_io_context", error=False) + def create(self, *args, **kwargs): + return self.create_from_io_context(*args, **kwargs) + + @Deprecated(new="OffPolicyEstimator.action_log_likelihood", error=False) + def action_prob(self, *args, **kwargs): + return self.action_log_likelihood(*args, **kwargs) diff --git a/rllib/offline/output_writer.py b/rllib/offline/output_writer.py index 8d168dfb451c..2389c3d741b6 100644 --- a/rllib/offline/output_writer.py +++ b/rllib/offline/output_writer.py @@ -1,15 +1,14 @@ -from ray.rllib.utils.annotations import override -from ray.rllib.utils.annotations import PublicAPI +from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.typing import SampleBatchType @PublicAPI class OutputWriter: - """Writer object for saving experiences from policy evaluation.""" + """Writer API for saving experiences from policy evaluation.""" @PublicAPI def write(self, sample_batch: SampleBatchType): - """Save a batch of experiences. + """Saves a batch of experiences. Args: sample_batch: SampleBatch or MultiAgentBatch to save. @@ -22,4 +21,5 @@ class NoopOutput(OutputWriter): @override(OutputWriter) def write(self, sample_batch: SampleBatchType): + # Do nothing. pass diff --git a/rllib/offline/shuffled_input.py b/rllib/offline/shuffled_input.py index 24522c87aa2d..a7c261018594 100644 --- a/rllib/offline/shuffled_input.py +++ b/rllib/offline/shuffled_input.py @@ -18,11 +18,11 @@ class ShuffledInput(InputReader): @DeveloperAPI def __init__(self, child: InputReader, n: int = 0): - """Initialize a MixedInput. + """Initializes a ShuffledInput instance. Args: - child (InputReader): child input reader to shuffle. - n (int): if positive, shuffle input over this many batches. + child: child input reader to shuffle. + n: If positive, shuffle input over this many batches. """ self.n = n self.child = child diff --git a/rllib/offline/wis_estimator.py b/rllib/offline/wis_estimator.py index 74eb342a440e..00bbf3145dd8 100644 --- a/rllib/offline/wis_estimator.py +++ b/rllib/offline/wis_estimator.py @@ -1,54 +1,54 @@ -from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ - OffPolicyEstimate -from ray.rllib.policy import Policy -from ray.rllib.utils.annotations import override -from ray.rllib.utils.typing import SampleBatchType - - -class WeightedImportanceSamplingEstimator(OffPolicyEstimator): - """The weighted step-wise IS estimator. - - Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf""" - - def __init__(self, policy: Policy, gamma: float): - super().__init__(policy, gamma) - self.filter_values = [] - self.filter_counts = [] - - @override(OffPolicyEstimator) - def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: - self.check_can_estimate_for(batch) - - rewards, old_prob = batch["rewards"], batch["action_prob"] - new_prob = self.action_prob(batch) - - # calculate importance ratios - p = [] - for t in range(batch.count): - if t == 0: - pt_prev = 1.0 - else: - pt_prev = p[t - 1] - p.append(pt_prev * new_prob[t] / old_prob[t]) - for t, v in enumerate(p): - if t >= len(self.filter_values): - self.filter_values.append(v) - self.filter_counts.append(1.0) - else: - self.filter_values[t] += v - self.filter_counts[t] += 1.0 - - # calculate stepwise weighted IS estimate - V_prev, V_step_WIS = 0.0, 0.0 - for t in range(batch.count): - V_prev += rewards[t] * self.gamma**t - w_t = self.filter_values[t] / self.filter_counts[t] - V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t - - estimation = OffPolicyEstimate( - "wis", { - "V_prev": V_prev, - "V_step_WIS": V_step_WIS, - "V_gain_est": V_step_WIS / max(1e-8, V_prev), - }) - return estimation +from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ + OffPolicyEstimate +from ray.rllib.policy import Policy +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import SampleBatchType + + +class WeightedImportanceSamplingEstimator(OffPolicyEstimator): + """The weighted step-wise IS estimator. + + Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf""" + + def __init__(self, policy: Policy, gamma: float): + super().__init__(policy, gamma) + self.filter_values = [] + self.filter_counts = [] + + @override(OffPolicyEstimator) + def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: + self.check_can_estimate_for(batch) + + rewards, old_prob = batch["rewards"], batch["action_prob"] + new_prob = self.action_log_likelihood(batch) + + # calculate importance ratios + p = [] + for t in range(batch.count): + if t == 0: + pt_prev = 1.0 + else: + pt_prev = p[t - 1] + p.append(pt_prev * new_prob[t] / old_prob[t]) + for t, v in enumerate(p): + if t >= len(self.filter_values): + self.filter_values.append(v) + self.filter_counts.append(1.0) + else: + self.filter_values[t] += v + self.filter_counts[t] += 1.0 + + # calculate stepwise weighted IS estimate + V_prev, V_step_WIS = 0.0, 0.0 + for t in range(batch.count): + V_prev += rewards[t] * self.gamma**t + w_t = self.filter_values[t] / self.filter_counts[t] + V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t + + estimation = OffPolicyEstimate( + "wis", { + "V_prev": V_prev, + "V_step_WIS": V_step_WIS, + "V_gain_est": V_step_WIS / max(1e-8, V_prev), + }) + return estimation diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index acb272f3bfa1..3f5ac3044a2b 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -17,7 +17,7 @@ from ray.rllib.utils.debug import summarize from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import get_placeholder +from ray.rllib.utils.tf_utils import get_placeholder from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \ TensorType, TrainerConfigDict diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 75563663c357..7248bd9feb91 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -20,7 +20,7 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.spaces.space_utils import normalize_action -from ray.rllib.utils.tf_ops import get_gpu_devices +from ray.rllib.utils.tf_utils import get_gpu_devices from ray.rllib.utils.threading import with_lock from ray.rllib.utils.typing import LocalOptimizer, TensorType @@ -724,7 +724,7 @@ def get_session(self): def get_placeholder(self, ph): raise ValueError( "get_placeholder() is not allowed in eager mode. Try using " - "rllib.utils.tf_ops.make_tf_callable() to write " + "rllib.utils.tf_utils.make_tf_callable() to write " "functions that work in both graph and eager mode.") def loss_initialized(self): diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 7e67624b7fe5..1fd9b70e8d93 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -10,7 +10,8 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import from_config diff --git a/rllib/policy/policy_map.py b/rllib/policy/policy_map.py index abdf591c6efe..ad7e940f59d3 100644 --- a/rllib/policy/policy_map.py +++ b/rllib/policy/policy_map.py @@ -8,7 +8,7 @@ from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import get_tf_eager_cls_if_necessary +from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary from ray.rllib.utils.threading import with_lock from ray.rllib.utils.typing import PartialTrainerConfigDict, \ PolicyID, TrainerConfigDict diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 389278a1a432..4dbe3436cc0c 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -6,12 +6,12 @@ from typing import Dict, Iterator, List, Optional, Set, Union from ray.util import log_once -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, \ +from ray.rllib.utils.annotations import DeveloperAPI, \ PublicAPI from ray.rllib.utils.compression import pack, unpack, is_compressed -from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.deprecation import Deprecated, deprecation_warning from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.memory import concat_aligned +from ray.rllib.utils.numpy import concat_aligned from ray.rllib.utils.typing import PolicyID, TensorType tf1, tf, tfv = try_import_tf() diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index c4f67f8e17ef..7f3b3fd81b29 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -15,14 +15,14 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils import force_list -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override +from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.debug import summarize -from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.deprecation import Deprecated, deprecation_warning from ray.rllib.utils.framework import try_import_tf, get_variable from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.spaces.space_utils import normalize_action -from ray.rllib.utils.tf_ops import get_gpu_devices +from ray.rllib.utils.tf_utils import get_gpu_devices from ray.rllib.utils.tf_run_builder import TFRunBuilder from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \ TensorType, TrainerConfigDict diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 54df8d487fbc..7ea18771b44d 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -22,11 +22,11 @@ from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.threading import with_lock -from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \ - convert_to_torch_tensor +from ray.rllib.utils.torch_ops import convert_to_torch_tensor from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \ TensorStructType, TrainerConfigDict @@ -671,7 +671,7 @@ def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]: state = super().get_state() state["_optimizer_variables"] = [] for i, o in enumerate(self._optimizers): - optim_state_dict = convert_to_non_torch_type(o.state_dict()) + optim_state_dict = convert_to_numpy(o.state_dict()) state["_optimizer_variables"].append(optim_state_dict) # Add exploration state. state["_exploration_state"] = \ @@ -940,7 +940,7 @@ def _compute_action_helper(self, input_dict, state_batches, seq_lens, # Update our global timestep by the batch size. self.global_timestep += len(input_dict[SampleBatch.CUR_OBS]) - return convert_to_non_torch_type((actions, state_out, extra_fetches)) + return convert_to_numpy((actions, state_out, extra_fetches)) def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None): # TODO: (sven): Keep for a while to ensure backward compatibility. diff --git a/rllib/policy/torch_policy_template.py b/rllib/policy/torch_policy_template.py index ee7a2d8abc8e..2a72e1224e96 100644 --- a/rllib/policy/torch_policy_template.py +++ b/rllib/policy/torch_policy_template.py @@ -7,7 +7,7 @@ from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import ModelGradients, TensorType, \ TrainerConfigDict diff --git a/rllib/utils/annotations.py b/rllib/utils/annotations.py index de815b5ba311..2df9ef2eb7d2 100644 --- a/rllib/utils/annotations.py +++ b/rllib/utils/annotations.py @@ -1,7 +1,4 @@ -import inspect - -from ray.util import log_once -from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.deprecation import Deprecated def override(cls): @@ -53,55 +50,20 @@ def DeveloperAPI(obj): return obj -def Deprecated(old=None, *, new=None, help=None, error): - """Annotation for documenting a (soon-to-be) deprecated method. +def ExperimentalAPI(obj): + """Annotation for documenting experimental APIs. + + Experimental APIs are classes and methods that are in development and may + change at any time in their development process. You should not expect + these APIs to be stable until their tag is changed to `DeveloperAPI` or + `PublicAPI`. - Methods tagged with this decorator should produce a - `ray.rllib.utils.deprecation.deprecation_warning(old=..., error=False)` - to not break existing code at this point. - In a next major release, this warning can then be made an error - (error=True), which means at this point that the method is already - no longer supported but will still inform the user about the - deprecation event. - In a further major release, the method should be erased. + Subclasses that inherit from a ``@ExperimentalAPI`` base class can be + assumed experimental as well. """ - def _inner(obj): - # A deprecated class. - if inspect.isclass(obj): - # Patch the class' init method to raise the warning/error. - obj_init = obj.__init__ - - def patched_init(*args, **kwargs): - if log_once(old or obj.__name__): - deprecation_warning( - old=old or obj.__name__, - new=new, - help=help, - error=error, - ) - return obj_init(*args, **kwargs) - - obj.__init__ = patched_init - # Return the patched class (with the warning/error when - # instantiated). - return obj - - # A deprecated class method or function. - # Patch with the warning/error at the beginning. - def _ctor(*args, **kwargs): - if log_once(old or obj.__name__): - deprecation_warning( - old=old or obj.__name__, - new=new, - help=help, - error=error, - ) - # Call the deprecated method/function. - return obj(*args, **kwargs) - - # Return the patched class method/function. - return _ctor - - # Return the prepared decorator. - return _inner + return obj + + +# Backward compatibility. +Deprecated = Deprecated diff --git a/rllib/utils/debug.py b/rllib/utils/debug.py index e6f769f0f04e..90d475cdf9e1 100644 --- a/rllib/utils/debug.py +++ b/rllib/utils/debug.py @@ -2,7 +2,7 @@ import os import pprint import random -from typing import Mapping, Optional +from typing import Any, Mapping, Optional from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.utils.framework import try_import_tf, try_import_torch @@ -10,11 +10,17 @@ _printer = pprint.PrettyPrinter(indent=2, width=60) -def summarize(obj): +def summarize(obj: Any) -> Any: """Return a pretty-formatted string for an object. This has special handling for pretty-formatting of commonly used data types in RLlib, such as SampleBatch, numpy arrays, etc. + + Args: + obj: The object to format. + + Returns: + The summarized object. """ return _printer.pformat(_summarize(obj)) @@ -76,8 +82,8 @@ def update_global_seed_if_necessary(framework: Optional[str] = None, This is useful for debugging and testing. Args: - framework (Optional[str]): The framework specifier (may be None). - seed (Optional[int]): An optional int seed. If None, will not do + framework: The framework specifier (may be None). + seed: An optional int seed. If None, will not do anything. """ if seed is None: diff --git a/rllib/utils/deprecation.py b/rllib/utils/deprecation.py index 8cda88eaf9d8..ec4559f74aa2 100644 --- a/rllib/utils/deprecation.py +++ b/rllib/utils/deprecation.py @@ -1,6 +1,9 @@ +import inspect import logging from typing import Optional, Union +from ray.util import log_once + logger = logging.getLogger(__name__) # A constant to use for any configuration that should be deprecated @@ -23,8 +26,12 @@ def deprecation_warning( help (Optional[str]): An optional help text to tell the user, what to do instead of using `old`. error (Optional[Union[bool, Exception]]): Whether or which exception to - throw. If True, throw ValueError. If False, just warn. - If Exception, throw that Exception. + raise. If True, raise ValueError. If False, just warn. + If error is-a subclass of Exception, raise that Exception. + + Raises: + ValueError: If `error=True`. + Exception: Of type `error`, iff error is-a Exception subclass. """ msg = "`{}` has been deprecated.{}".format( old, (" Use `{}` instead.".format(new) if new else f" {help}" @@ -37,3 +44,57 @@ def deprecation_warning( else: logger.warning("DeprecationWarning: " + msg + " This will raise an error in the future!") + + +def Deprecated(old=None, *, new=None, help=None, error): + """Annotation for documenting a (soon-to-be) deprecated method. + + Methods tagged with this decorator should produce a + `ray.rllib.utils.deprecation.deprecation_warning(old=..., error=False)` + to not break existing code at this point. + In a next major release, this warning can then be made an error + (error=True), which means at this point that the method is already + no longer supported but will still inform the user about the + deprecation event. + In a further major release, the method should be erased. + """ + + def _inner(obj): + # A deprecated class. + if inspect.isclass(obj): + # Patch the class' init method to raise the warning/error. + obj_init = obj.__init__ + + def patched_init(*args, **kwargs): + if log_once(old or obj.__name__): + deprecation_warning( + old=old or obj.__name__, + new=new, + help=help, + error=error, + ) + return obj_init(*args, **kwargs) + + obj.__init__ = patched_init + # Return the patched class (with the warning/error when + # instantiated). + return obj + + # A deprecated class method or function. + # Patch with the warning/error at the beginning. + def _ctor(*args, **kwargs): + if log_once(old or obj.__name__): + deprecation_warning( + old=old or obj.__name__, + new=new, + help=help, + error=error, + ) + # Call the deprecated method/function. + return obj(*args, **kwargs) + + # Return the patched class method/function. + return _ctor + + # Return the prepared decorator. + return _inner diff --git a/rllib/utils/exploration/curiosity.py b/rllib/utils/exploration/curiosity.py index 9cad2d25443f..7b7c586aa294 100644 --- a/rllib/utils/exploration/curiosity.py +++ b/rllib/utils/exploration/curiosity.py @@ -17,7 +17,7 @@ from ray.rllib.utils.framework import try_import_tf, \ try_import_torch from ray.rllib.utils.from_config import from_config -from ray.rllib.utils.tf_ops import get_placeholder, one_hot as tf_one_hot +from ray.rllib.utils.tf_utils import get_placeholder, one_hot as tf_one_hot from ray.rllib.utils.torch_ops import one_hot from ray.rllib.utils.typing import FromConfigSpec, ModelConfigDict, TensorType diff --git a/rllib/utils/exploration/exploration.py b/rllib/utils/exploration/exploration.py index b6eb32b005db..8942e03a10a2 100644 --- a/rllib/utils/exploration/exploration.py +++ b/rllib/utils/exploration/exploration.py @@ -5,7 +5,8 @@ from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_torch, TensorType from ray.rllib.utils.typing import LocalOptimizer, TrainerConfigDict diff --git a/rllib/utils/exploration/gaussian_noise.py b/rllib/utils/exploration/gaussian_noise.py index 3c1972d1e5fd..e49a2ba12960 100644 --- a/rllib/utils/exploration/gaussian_noise.py +++ b/rllib/utils/exploration/gaussian_noise.py @@ -12,7 +12,7 @@ from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.schedules import Schedule from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule -from ray.rllib.utils.tf_ops import zero_logps_from_actions +from ray.rllib.utils.tf_utils import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py index ba7582903cf5..d8f9f8b39715 100644 --- a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py +++ b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py @@ -8,7 +8,7 @@ get_variable, TensorType from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.schedules import Schedule -from ray.rllib.utils.tf_ops import zero_logps_from_actions +from ray.rllib.utils.tf_utils import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/utils/exploration/random.py b/rllib/utils/exploration/random.py index d1d6c4d0ad98..a375e1e9deb6 100644 --- a/rllib/utils/exploration/random.py +++ b/rllib/utils/exploration/random.py @@ -12,7 +12,7 @@ TensorType from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space -from ray.rllib.utils.tf_ops import zero_logps_from_actions +from ray.rllib.utils.tf_utils import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/utils/exploration/stochastic_sampling.py b/rllib/utils/exploration/stochastic_sampling.py index 593233625de1..4c9f53645832 100644 --- a/rllib/utils/exploration/stochastic_sampling.py +++ b/rllib/utils/exploration/stochastic_sampling.py @@ -10,7 +10,7 @@ from ray.rllib.utils.exploration.random import Random from ray.rllib.utils.framework import get_variable, try_import_tf, \ try_import_torch, TensorType -from ray.rllib.utils.tf_ops import zero_logps_from_actions +from ray.rllib.utils.tf_utils import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index 8057434d67e4..0670f832882b 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -4,20 +4,20 @@ import sys from typing import Any, Optional -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.typing import TensorShape, TensorType logger = logging.getLogger(__name__) -def try_import_jax(error=False): +def try_import_jax(error: bool = False): """Tries importing JAX and FLAX and returns both modules (or Nones). Args: - error (bool): Whether to raise an error if JAX/FLAX cannot be imported. + error: Whether to raise an error if JAX/FLAX cannot be imported. Returns: - Tuple: The jax- and the flax modules. + Tuple containing the jax- and the flax modules. Raises: ImportError: If error=True and JAX is not installed. @@ -39,18 +39,17 @@ def try_import_jax(error=False): return jax, flax -def try_import_tf(error=False): +def try_import_tf(error: bool = False): """Tries importing tf and returns the module (or None). Args: - error (bool): Whether to raise an error if tf cannot be imported. + error: Whether to raise an error if tf cannot be imported. Returns: - Tuple: - - tf1.x module (either from tf2.x.compat.v1 OR as tf1.x). - - tf module (resulting from `import tensorflow`). - Either tf1.x or 2.x. - - The actually installed tf version as int: 1 or 2. + Tuple containing + 1) tf1.x module (either from tf2.x.compat.v1 OR as tf1.x). + 2) tf module (resulting from `import tensorflow`). Either tf1.x or + 2.x. 3) The actually installed tf version as int: 1 or 2. Raises: ImportError: If error=True and tf is not installed. @@ -119,11 +118,11 @@ def decorator(func): return decorator -def try_import_tfp(error=False): +def try_import_tfp(error: bool = False): """Tries importing tfp and returns the module (or None). Args: - error (bool): Whether to raise an error if tfp cannot be imported. + error: Whether to raise an error if tfp cannot be imported. Returns: The tfp module. @@ -159,14 +158,14 @@ def __init__(self, *a, **kw): raise ImportError("Could not import `torch`.") -def try_import_torch(error=False): +def try_import_torch(error: bool = False): """Tries importing torch and returns the module (or None). Args: - error (bool): Whether to raise an error if torch cannot be imported. + error: Whether to raise an error if torch cannot be imported. Returns: - tuple: torch AND torch.nn modules. + Tuple consisting of the torch- AND torch.nn modules. Raises: ImportError: If error=True and PyTorch is not installed. @@ -201,7 +200,8 @@ def get_variable(value: Any, device: Optional[str] = None, shape: Optional[TensorShape] = None, dtype: Optional[TensorType] = None) -> Any: - """ + """Creates a tf variable, a torch tensor, or a python primitive. + Args: value: The initial value to use. In the non-tf case, this will be returned as is. In the tf case, this could be a tf-Initializer @@ -223,7 +223,7 @@ def get_variable(value: Any, Returns: A framework-specific variable (tf.Variable, torch.tensor, or - python primitive). + python primitive). """ if framework in ["tf2", "tf", "tfe"]: import tensorflow as tf @@ -258,8 +258,8 @@ def get_variable(value: Any, @Deprecated( - old="rllib/models/utils.py::get_activation_fn", - new="rllib/utils/framework.py::get_activation_fn", + old="rllib/utils/framework.py::get_activation_fn", + new="rllib/models/utils.py::get_activation_fn", error=False) def get_activation_fn(name: Optional[str] = None, framework: str = "tf"): """Returns a framework specific activation function, given a name string. diff --git a/rllib/utils/memory.py b/rllib/utils/memory.py index c2989a407ce8..48248c602e5b 100644 --- a/rllib/utils/memory.py +++ b/rllib/utils/memory.py @@ -1,73 +1,8 @@ -import numpy as np - - -def aligned_array(size, dtype, align=64): - """Returns an array of a given size that is 64-byte aligned. - - The returned array can be efficiently copied into GPU memory by TensorFlow. - """ - - n = size * dtype.itemsize - empty = np.empty(n + (align - 1), dtype=np.uint8) - data_align = empty.ctypes.data % align - offset = 0 if data_align == 0 else (align - data_align) - if n == 0: - # stop np from optimising out empty slice reference - output = empty[offset:offset + 1][0:0].view(dtype) - else: - output = empty[offset:offset + n].view(dtype) - - assert len(output) == size, len(output) - assert output.ctypes.data % align == 0, output.ctypes.data - return output - - -def concat_aligned(items, time_major=None): - """Concatenate arrays, ensuring the output is 64-byte aligned. - - We only align float arrays; other arrays are concatenated as normal. - - This should be used instead of np.concatenate() to improve performance - when the output array is likely to be fed into TensorFlow. - - Args: - items (List(np.ndarray)): The list of items to concatenate and align. - time_major (bool): Whether the data in items is time-major, in which - case, we will concatenate along axis=1. - - Returns: - np.ndarray: The concat'd and aligned array. - """ - - if len(items) == 0: - return [] - elif len(items) == 1: - # we assume the input is aligned. In any case, it doesn't help - # performance to force align it since that incurs a needless copy. - return items[0] - elif (isinstance(items[0], np.ndarray) - and items[0].dtype in [np.float32, np.float64, np.uint8]): - dtype = items[0].dtype - flat = aligned_array(sum(s.size for s in items), dtype) - if time_major is not None: - if time_major is True: - batch_dim = sum(s.shape[1] for s in items) - new_shape = ( - items[0].shape[0], - batch_dim, - ) + items[0].shape[2:] - else: - batch_dim = sum(s.shape[0] for s in items) - new_shape = ( - batch_dim, - items[0].shape[1], - ) + items[0].shape[2:] - else: - batch_dim = sum(s.shape[0] for s in items) - new_shape = (batch_dim, ) + items[0].shape[1:] - output = flat.reshape(new_shape) - assert output.ctypes.data % 64 == 0, output.ctypes.data - np.concatenate(items, out=output, axis=1 if time_major else 0) - return output - else: - return np.concatenate(items, axis=1 if time_major else 0) +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.numpy import aligned_array, concat_aligned # noqa + +deprecation_warning( + old="ray.rllib.utils.memory.[...]", + new="ray.rllib.utils.numpy.[...]", + error=False, +) diff --git a/rllib/utils/metrics/window_stat.py b/rllib/utils/metrics/window_stat.py new file mode 100644 index 000000000000..9aa0d9f301df --- /dev/null +++ b/rllib/utils/metrics/window_stat.py @@ -0,0 +1,28 @@ +import numpy as np + + +class WindowStat: + def __init__(self, name, n): + self.name = name + self.items = [None] * n + self.idx = 0 + self.count = 0 + + def push(self, obj): + self.items[self.idx] = obj + self.idx += 1 + self.count += 1 + self.idx %= len(self.items) + + def stats(self): + if not self.count: + _quantiles = [] + else: + _quantiles = np.nanpercentile(self.items[:self.count], + [0, 10, 50, 90, 100]).tolist() + return { + self.name + "_count": int(self.count), + self.name + "_mean": float(np.nanmean(self.items[:self.count])), + self.name + "_std": float(np.nanstd(self.items[:self.count])), + self.name + "_quantiles": _quantiles, + } diff --git a/rllib/utils/multi_agent.py b/rllib/utils/multi_agent.py index 50d5227c54e7..82bf6b2089a1 100644 --- a/rllib/utils/multi_agent.py +++ b/rllib/utils/multi_agent.py @@ -11,12 +11,11 @@ def check_multi_agent(config: PartialTrainerConfigDict) -> \ """Checks, whether a (partial) config defines a multi-agent setup. Args: - config (PartialTrainerConfigDict): The user/Trainer/Policy config - to check for multi-agent. + config: The user/Trainer/Policy config to check for multi-agent. Returns: - The resulting (all fixed) multi-agent policy dict and whether we - have a multi-agent setup or not. + Tuple consisting of the resulting (all fixed) multi-agent policy + dict and bool indicating whether we have a multi-agent setup or not. """ multiagent_config = config["multiagent"] policies = multiagent_config.get("policies") diff --git a/rllib/utils/numpy.py b/rllib/utils/numpy.py index 2a77db1a61fc..91f56f259125 100644 --- a/rllib/utils/numpy.py +++ b/rllib/utils/numpy.py @@ -1,8 +1,9 @@ import numpy as np import tree # pip install dm_tree +from typing import List, Optional from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.typing import TensorType, Union +from ray.rllib.utils.typing import TensorType, TensorStructType, Union tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -17,146 +18,132 @@ MAX_LOG_NN_OUTPUT = 2 -def huber_loss(x, delta=1.0): - """Reference: https://en.wikipedia.org/wiki/Huber_loss""" - return np.where( - np.abs(x) < delta, - np.power(x, 2.0) * 0.5, delta * (np.abs(x) - 0.5 * delta)) - - -def l2_loss(x): - """Computes half the L2 norm of a tensor (w/o the sqrt): sum(x**2) / 2 - - Args: - x (np.ndarray): The input tensor. +def aligned_array(size: int, dtype, align: int = 64) -> np.ndarray: + """Returns an array of a given size that is 64-byte aligned. - Returns: - The l2-loss output according to the above formula given `x`. - """ - return np.sum(np.square(x)) / 2.0 - - -def sigmoid(x, derivative=False): - """ - Returns the sigmoid function applied to x. - Alternatively, can return the derivative or the sigmoid function. + The returned array can be efficiently copied into GPU memory by TensorFlow. Args: - x (np.ndarray): The input to the sigmoid function. - derivative (bool): Whether to return the derivative or not. - Default: False. + size: The size (total number of items) of the array. For example, + array([[0.0, 1.0], [2.0, 3.0]]) would have size=4. + dtype: The numpy dtype of the array. + align: The alignment to use. Returns: - np.ndarray: The sigmoid function (or its derivative) applied to x. + A np.ndarray with the given specifications. """ - if derivative: - return x * (1 - x) + n = size * dtype.itemsize + empty = np.empty(n + (align - 1), dtype=np.uint8) + data_align = empty.ctypes.data % align + offset = 0 if data_align == 0 else (align - data_align) + if n == 0: + # stop np from optimising out empty slice reference + output = empty[offset:offset + 1][0:0].view(dtype) else: - return 1 / (1 + np.exp(-x)) + output = empty[offset:offset + n].view(dtype) + assert len(output) == size, len(output) + assert output.ctypes.data % align == 0, output.ctypes.data + return output -def softmax(x, axis=-1): - """ - Returns the softmax values for x as: - S(xi) = e^xi / SUMj(e^xj), where j goes over all elements in x. - Args: - x (np.ndarray): The input to the softmax function. - axis (int): The axis along which to softmax. +def concat_aligned(items: List[np.ndarray], + time_major: Optional[bool] = None) -> np.ndarray: + """Concatenate arrays, ensuring the output is 64-byte aligned. - Returns: - np.ndarray: The softmax over x. - """ - # x_exp = np.maximum(np.exp(x), SMALL_NUMBER) - x_exp = np.exp(x) - # return x_exp / - # np.maximum(np.sum(x_exp, axis, keepdims=True), SMALL_NUMBER) - return np.maximum(x_exp / np.sum(x_exp, axis, keepdims=True), SMALL_NUMBER) + We only align float arrays; other arrays are concatenated as normal. - -def relu(x, alpha=0.0): - """ - Implementation of the leaky ReLU function: - y = x * alpha if x < 0 else x + This should be used instead of np.concatenate() to improve performance + when the output array is likely to be fed into TensorFlow. Args: - x (np.ndarray): The input values. - alpha (float): A scaling ("leak") factor to use for negative x. + items: The list of items to concatenate and align. + time_major: Whether the data in items is time-major, in which + case, we will concatenate along axis=1. Returns: - np.ndarray: The leaky ReLU output for x. + The concat'd and aligned array. """ - return np.maximum(x, x * alpha, x) + if len(items) == 0: + return [] + elif len(items) == 1: + # we assume the input is aligned. In any case, it doesn't help + # performance to force align it since that incurs a needless copy. + return items[0] + elif (isinstance(items[0], np.ndarray) + and items[0].dtype in [np.float32, np.float64, np.uint8]): + dtype = items[0].dtype + flat = aligned_array(sum(s.size for s in items), dtype) + if time_major is not None: + if time_major is True: + batch_dim = sum(s.shape[1] for s in items) + new_shape = ( + items[0].shape[0], + batch_dim, + ) + items[0].shape[2:] + else: + batch_dim = sum(s.shape[0] for s in items) + new_shape = ( + batch_dim, + items[0].shape[1], + ) + items[0].shape[2:] + else: + batch_dim = sum(s.shape[0] for s in items) + new_shape = (batch_dim, ) + items[0].shape[1:] + output = flat.reshape(new_shape) + assert output.ctypes.data % 64 == 0, output.ctypes.data + np.concatenate(items, out=output, axis=1 if time_major else 0) + return output + else: + return np.concatenate(items, axis=1 if time_major else 0) -def one_hot(x: Union[TensorType, int], - depth: int = 0, - on_value: int = 1.0, - off_value: float = 0.0): - """ - One-hot utility function for numpy. - Thanks to qianyizhang: - https://gist.github.com/qianyizhang/07ee1c15cad08afb03f5de69349efc30. + +def convert_to_numpy(x: TensorStructType, reduce_floats: bool = False): + """Converts values in `stats` to non-Tensor numpy or python types. Args: - x (TensorType): The input to be one-hot encoded. - depth (int): The max. number to be one-hot encoded (size of last rank). - on_value (float): The value to use for on. Default: 1.0. - off_value (float): The value to use for off. Default: 0.0. + x: Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all torch/tf tensors + being converted to numpy types. + reduce_floats: Whether to reduce all float64 data into float32 + automatically. Returns: - np.ndarray: The one-hot encoded equivalent of the input array. + A new struct with the same structure as `x`, but with all + values converted to numpy arrays (on CPU). """ - # Handle simple ints properly. - if isinstance(x, int): - x = np.array(x, dtype=np.int32) - # Handle torch arrays properly. - elif torch and isinstance(x, torch.Tensor): - x = x.numpy() - - # Handle bool arrays correctly. - if x.dtype == np.bool_: - x = x.astype(np.int) - depth = 2 - - # If depth is not given, try to infer it from the values in the array. - if depth == 0: - depth = np.max(x) + 1 - assert np.max(x) < depth, \ - "ERROR: The max. index of `x` ({}) is larger than depth ({})!".\ - format(np.max(x), depth) - shape = x.shape + # The mapping function used to numpyize torch/tf Tensors (and move them + # to the CPU beforehand). + def mapping(item): + if torch and isinstance(item, torch.Tensor): + ret = item.cpu().item() if len(item.size()) == 0 else \ + item.detach().cpu().numpy() + elif tf and isinstance(item, (tf.Tensor, tf.Variable)): + assert tf.executing_eagerly() + ret = item.numpy() + else: + ret = item + if reduce_floats and isinstance(ret, np.ndarray) and \ + ret.dtype == np.float64: + ret = ret.astype(np.float32) + return ret - # Python 2.7 compatibility, (*shape, depth) is not allowed. - shape_list = list(shape[:]) - shape_list.append(depth) - out = np.ones(shape_list) * off_value - indices = [] - for i in range(x.ndim): - tiles = [1] * x.ndim - s = [1] * x.ndim - s[i] = -1 - r = np.arange(shape[i]).reshape(s) - if i > 0: - tiles[i - 1] = shape[i - 1] - r = np.tile(r, tiles) - indices.append(r) - indices.append(x) - out[tuple(indices)] = on_value - return out + return tree.map_structure(mapping, x) -def fc(x, weights, biases=None, framework=None): - """ - Calculates the outputs of a fully-connected (dense) layer given - weights/biases and an input. +def fc(x: np.ndarray, + weights: np.ndarray, + biases: Optional[np.ndarray] = None, + framework: Optional[str] = None) -> np.ndarray: + """Calculates FC (dense) layer outputs given weights/biases and input. Args: - x (np.ndarray): The input to the dense layer. - weights (np.ndarray): The weights matrix. - biases (Optional[np.ndarray]): The biases vector. All 0s if None. - framework (Optional[str]): An optional framework hint (to figure out, + x: The input to the dense layer. + weights: The weights matrix. + biases: The biases vector. All 0s if None. + framework: An optional framework hint (to figure out, e.g. whether to transpose torch weight matrices). Returns: @@ -184,36 +171,48 @@ def map_(data, transpose=False): return np.matmul(x, weights) + (0.0 if biases is None else biases) -def lstm(x, - weights, - biases=None, - initial_internal_states=None, - time_major=False, - forget_bias=1.0): - """ - Calculates the outputs of an LSTM layer given weights/biases, - internal_states, and input. +def huber_loss(x: np.ndarray, delta: float = 1.0) -> np.ndarray: + """Reference: https://en.wikipedia.org/wiki/Huber_loss.""" + return np.where( + np.abs(x) < delta, + np.power(x, 2.0) * 0.5, delta * (np.abs(x) - 0.5 * delta)) + + +def l2_loss(x: np.ndarray) -> np.ndarray: + """Computes half the L2 norm of a tensor (w/o the sqrt): sum(x**2) / 2. Args: - x (np.ndarray): The inputs to the LSTM layer including time-rank - (0th if time-major, else 1st) and the batch-rank - (1st if time-major, else 0th). + x: The input tensor. - weights (np.ndarray): The weights matrix. - biases (Optional[np.ndarray]): The biases vector. All 0s if None. + Returns: + The l2-loss output according to the above formula given `x`. + """ + return np.sum(np.square(x)) / 2.0 - initial_internal_states (Optional[np.ndarray]): The initial internal - states to pass into the layer. All 0s if None. - time_major (bool): Whether to use time-major or not. Default: False. +def lstm(x, + weights: np.ndarray, + biases: Optional[np.ndarray] = None, + initial_internal_states: Optional[np.ndarray] = None, + time_major: bool = False, + forget_bias: float = 1.0): + """Calculates LSTM layer output given weights/biases, states, and input. - forget_bias (float): Gets added to first sigmoid (forget gate) output. + Args: + x: The inputs to the LSTM layer including time-rank + (0th if time-major, else 1st) and the batch-rank + (1st if time-major, else 0th). + weights: The weights matrix. + biases: The biases vector. All 0s if None. + initial_internal_states: The initial internal + states to pass into the layer. All 0s if None. + time_major: Whether to use time-major or not. Default: False. + forget_bias: Gets added to first sigmoid (forget gate) output. Default: 1.0. Returns: - Tuple: - - The LSTM layer's output. - - Tuple: Last (c-state, h-state). + Tuple consisting of 1) The LSTM layer's output and + 2) Tuple: Last (c-state, h-state). """ sequence_length = x.shape[0 if time_major else 1] batch_size = x.shape[1 if time_major else 0] @@ -259,36 +258,113 @@ def lstm(x, return unrolled_outputs, (c_states, h_states) -# TODO: (sven) this will replace `TorchPolicy._convert_to_non_torch_tensor()`. -def convert_to_numpy(x, reduce_floats=False): - """Converts values in `stats` to non-Tensor numpy or python types. +def one_hot(x: Union[TensorType, int], + depth: int = 0, + on_value: int = 1.0, + off_value: float = 0.0) -> np.ndarray: + """One-hot utility function for numpy. + + Thanks to qianyizhang: + https://gist.github.com/qianyizhang/07ee1c15cad08afb03f5de69349efc30. Args: - stats (any): Any (possibly nested) struct, the values in which will be - converted and returned as a new struct with all torch/tf tensors - being converted to numpy types. - reduce_floats (bool): Whether to reduce all float64 data into float32 - automatically. + x: The input to be one-hot encoded. + depth: The max. number to be one-hot encoded (size of last rank). + on_value: The value to use for on. Default: 1.0. + off_value: The value to use for off. Default: 0.0. Returns: - Any: A new struct with the same structure as `stats`, but with all - values converted to numpy arrays (on CPU). + The one-hot encoded equivalent of the input array. """ - # The mapping function used to numpyize torch/tf Tensors (and move them - # to the CPU beforehand). - def mapping(item): - if torch and isinstance(item, torch.Tensor): - ret = item.cpu().item() if len(item.size()) == 0 else \ - item.detach().cpu().numpy() - elif tf and isinstance(item, (tf.Tensor, tf.Variable)): - assert tf.executing_eagerly() - ret = item.numpy() - else: - ret = item - if reduce_floats and isinstance(ret, np.ndarray) and \ - ret.dtype == np.float64: - ret = ret.astype(np.float32) - return ret + # Handle simple ints properly. + if isinstance(x, int): + x = np.array(x, dtype=np.int32) + # Handle torch arrays properly. + elif torch and isinstance(x, torch.Tensor): + x = x.numpy() - return tree.map_structure(mapping, x) + # Handle bool arrays correctly. + if x.dtype == np.bool_: + x = x.astype(np.int) + depth = 2 + + # If depth is not given, try to infer it from the values in the array. + if depth == 0: + depth = np.max(x) + 1 + assert np.max(x) < depth, \ + "ERROR: The max. index of `x` ({}) is larger than depth ({})!".\ + format(np.max(x), depth) + shape = x.shape + + # Python 2.7 compatibility, (*shape, depth) is not allowed. + shape_list = list(shape[:]) + shape_list.append(depth) + out = np.ones(shape_list) * off_value + indices = [] + for i in range(x.ndim): + tiles = [1] * x.ndim + s = [1] * x.ndim + s[i] = -1 + r = np.arange(shape[i]).reshape(s) + if i > 0: + tiles[i - 1] = shape[i - 1] + r = np.tile(r, tiles) + indices.append(r) + indices.append(x) + out[tuple(indices)] = on_value + return out + + +def relu(x: np.ndarray, alpha: float = 0.0) -> np.ndarray: + """Implementation of the leaky ReLU function. + + y = x * alpha if x < 0 else x + + Args: + x: The input values. + alpha: A scaling ("leak") factor to use for negative x. + + Returns: + The leaky ReLU output for x. + """ + return np.maximum(x, x * alpha, x) + + +def sigmoid(x: np.ndarray, derivative: bool = False) -> np.ndarray: + """ + Returns the sigmoid function applied to x. + Alternatively, can return the derivative or the sigmoid function. + + Args: + x: The input to the sigmoid function. + derivative: Whether to return the derivative or not. + Default: False. + + Returns: + The sigmoid function (or its derivative) applied to x. + """ + if derivative: + return x * (1 - x) + else: + return 1 / (1 + np.exp(-x)) + + +def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray: + """Returns the softmax values for x. + + The exact formula used is: + S(xi) = e^xi / SUMj(e^xj), where j goes over all elements in x. + + Args: + x: The input to the softmax function. + axis: The axis along which to softmax. + + Returns: + The softmax over x. + """ + # x_exp = np.maximum(np.exp(x), SMALL_NUMBER) + x_exp = np.exp(x) + # return x_exp / + # np.maximum(np.sum(x_exp, axis, keepdims=True), SMALL_NUMBER) + return np.maximum(x_exp / np.sum(x_exp, axis, keepdims=True), SMALL_NUMBER) diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index a616fd9b4112..99baf070cd51 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -279,11 +279,11 @@ def check_compute_single_action(trainer, """Tests different combinations of args for trainer.compute_single_action. Args: - trainer (Trainer): The Trainer object to test. - include_state (bool): Whether to include the initial state of the - Policy's Model in the `compute_single_action` call. - include_prev_action_reward (bool): Whether to include the prev-action - and -reward in the `compute_single_action` call. + trainer: The Trainer object to test. + include_state: Whether to include the initial state of the Policy's + Model in the `compute_single_action` call. + include_prev_action_reward: Whether to include the prev-action and + -reward in the `compute_single_action` call. Raises: ValueError: If anything unexpected happens. diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index 1b577be7ef72..bdef1e5a4710 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -1,305 +1,8 @@ -import gym -from gym.spaces import Discrete, MultiDiscrete -import numpy as np -import tree # pip install dm_tree - -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space -from ray.rllib.utils.typing import TensorStructType, TensorType - -tf1, tf, tfv = try_import_tf() - - -def convert_to_non_tf_type(stats): - """Converts values in `stats` to non-Tensor numpy or python types. - - Args: - stats (any): Any (possibly nested) struct, the values in which will be - converted and returned as a new struct with all tf (eager) tensors - being converted to numpy types. - - Returns: - Any: A new struct with the same structure as `stats`, but with all - values converted to non-tf Tensor types. - """ - - # The mapping function used to numpyize torch Tensors. - def mapping(item): - if isinstance(item, (tf.Tensor, tf.Variable)): - return item.numpy() - else: - return item - - return tree.map_structure(mapping, stats) - - -def explained_variance(y, pred): - _, y_var = tf.nn.moments(y, axes=[0]) - _, diff_var = tf.nn.moments(y - pred, axes=[0]) - return tf.maximum(-1.0, 1 - (diff_var / y_var)) - - -def get_gpu_devices(): - """Returns a list of GPU device names, e.g. ["/gpu:0", "/gpu:1"]. - - Supports both tf1.x and tf2.x. - """ - if tfv == 1: - from tensorflow.python.client import device_lib - devices = device_lib.list_local_devices() - else: - try: - devices = tf.config.list_physical_devices() - except Exception: - devices = tf.config.experimental.list_physical_devices() - - # Expect "GPU", but also stuff like: "XLA_GPU". - return [d.name for d in devices if "GPU" in d.device_type] - - -def get_placeholder(*, - space=None, - value=None, - name=None, - time_axis=False, - flatten=True): - from ray.rllib.models.catalog import ModelCatalog - - if space is not None: - if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): - if flatten: - return ModelCatalog.get_action_placeholder(space, None) - else: - return tree.map_structure_with_path( - lambda path, component: get_placeholder( - space=component, - name=name + "." + ".".join([str(p) for p in path]), - ), - get_base_struct_from_space(space), - ) - return tf1.placeholder( - shape=(None, ) + ((None, ) if time_axis else ()) + space.shape, - dtype=tf.float32 if space.dtype == np.float64 else space.dtype, - name=name, - ) - else: - assert value is not None - shape = value.shape[1:] - return tf1.placeholder( - shape=(None, ) + ((None, ) - if time_axis else ()) + (shape if isinstance( - shape, tuple) else tuple(shape.as_list())), - dtype=tf.float32 if value.dtype == np.float64 else value.dtype, - name=name, - ) - - -def get_tf_eager_cls_if_necessary(orig_cls, config): - cls = orig_cls - framework = config.get("framework", "tf") - if framework in ["tf2", "tf", "tfe"]: - if not tf1: - raise ImportError("Could not import tensorflow!") - if framework in ["tf2", "tfe"]: - assert tf1.executing_eagerly() - - from ray.rllib.policy.tf_policy import TFPolicy - - # Create eager-class. - if hasattr(orig_cls, "as_eager"): - cls = orig_cls.as_eager() - if config.get("eager_tracing"): - cls = cls.with_tracing() - # Could be some other type of policy. - elif not issubclass(orig_cls, TFPolicy): - pass - else: - raise ValueError("This policy does not support eager " - "execution: {}".format(orig_cls)) - return cls - - -def huber_loss(x, delta=1.0): - """Reference: https://en.wikipedia.org/wiki/Huber_loss""" - return tf.where( - tf.abs(x) < delta, - tf.math.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta)) - - -def zero_logps_from_actions(actions: TensorStructType) -> TensorType: - """Helper function useful for returning dummy logp's (0) for some actions. - - Args: - actions (TensorStructType): The input actions. This can be any struct - of complex action components or a simple tensor of different - dimensions, e.g. [B], [B, 2], or {"a": [B, 4, 5], "b": [B]}. - - Returns: - TensorType: A 1D tensor of 0.0 (dummy logp's) matching the batch - dim of `actions` (shape=[B]). - """ - # Need to flatten `actions` in case we have a complex action space. - # Take the 0th component to extract the batch dim. - action_component = tree.flatten(actions)[0] - logp_ = tf.zeros_like(action_component, dtype=tf.float32) - # Logp's should be single values (but with the same batch dim as - # `deterministic_actions` or `stochastic_actions`). In case - # actions are just [B], zeros_like works just fine here, but if - # actions are [B, ...], we have to reduce logp back to just [B]. - while len(logp_.shape) > 1: - logp_ = logp_[:, 0] - return logp_ - - -def one_hot(x, space): - if isinstance(space, Discrete): - return tf.one_hot(x, space.n, dtype=tf.float32) - elif isinstance(space, MultiDiscrete): - return tf.concat( - [ - tf.one_hot(x[:, i], n, dtype=tf.float32) - for i, n in enumerate(space.nvec) - ], - axis=-1) - else: - raise ValueError("Unsupported space for `one_hot`: {}".format(space)) - - -def reduce_mean_ignore_inf(x, axis): - """Same as tf.reduce_mean() but ignores -inf values.""" - mask = tf.not_equal(x, tf.float32.min) - x_zeroed = tf.where(mask, x, tf.zeros_like(x)) - return (tf.reduce_sum(x_zeroed, axis) / tf.reduce_sum( - tf.cast(mask, tf.float32), axis)) - - -def minimize_and_clip(optimizer, objective, var_list, clip_val=10.0): - """Minimized `objective` using `optimizer` w.r.t. variables in - `var_list` while ensure the norm of the gradients for each - variable is clipped to `clip_val` - """ - # Accidentally passing values < 0.0 will break all gradients. - assert clip_val is None or clip_val > 0.0, clip_val - - if tf.executing_eagerly(): - tape = optimizer.tape - grads_and_vars = list( - zip(list(tape.gradient(objective, var_list)), var_list)) - else: - grads_and_vars = optimizer.compute_gradients( - objective, var_list=var_list) - - return [(tf.clip_by_norm(g, clip_val) if clip_val is not None else g, v) - for (g, v) in grads_and_vars if g is not None] - - -def make_tf_callable(session_or_none, dynamic_shape=False): - """Returns a function that can be executed in either graph or eager mode. - - The function must take only positional args. - - If eager is enabled, this will act as just a function. Otherwise, it - will build a function that executes a session run with placeholders - internally. - - Args: - session_or_none (tf.Session): tf.Session if in graph mode, else None. - dynamic_shape (bool): True if the placeholders should have a dynamic - batch dimension. Otherwise they will be fixed shape. - - Returns: - a Python function that can be called in either mode. - """ - - if tf.executing_eagerly(): - assert session_or_none is None - else: - assert session_or_none is not None - - def make_wrapper(fn): - # Static-graph mode: Create placeholders and make a session call each - # time the wrapped function is called. Returns the output of this - # session call. - if session_or_none is not None: - args_placeholders = [] - kwargs_placeholders = {} - - symbolic_out = [None] - - def call(*args, **kwargs): - args_flat = [] - for a in args: - if type(a) is list: - args_flat.extend(a) - else: - args_flat.append(a) - args = args_flat - - # We have not built any placeholders yet: Do this once here, - # then reuse the same placeholders each time we call this - # function again. - if symbolic_out[0] is None: - with session_or_none.graph.as_default(): - - def _create_placeholders(path, value): - if dynamic_shape: - if len(value.shape) > 0: - shape = (None, ) + value.shape[1:] - else: - shape = () - else: - shape = value.shape - return tf1.placeholder( - dtype=value.dtype, - shape=shape, - name=".".join([str(p) for p in path]), - ) - - placeholders = tree.map_structure_with_path( - _create_placeholders, args) - for ph in tree.flatten(placeholders): - args_placeholders.append(ph) - - placeholders = tree.map_structure_with_path( - _create_placeholders, kwargs) - for k, ph in placeholders.items(): - kwargs_placeholders[k] = ph - - symbolic_out[0] = fn(*args_placeholders, - **kwargs_placeholders) - feed_dict = dict(zip(args_placeholders, tree.flatten(args))) - tree.map_structure(lambda ph, v: feed_dict.__setitem__(ph, v), - kwargs_placeholders, kwargs) - ret = session_or_none.run(symbolic_out[0], feed_dict) - return ret - - return call - # Eager mode (call function as is). - else: - return fn - - return make_wrapper - - -def scope_vars(scope, trainable_only=False): - """ - Get variables inside a scope - The scope can be specified as a string - - Parameters - ---------- - scope: str or VariableScope - scope in which the variables reside. - trainable_only: bool - whether or not to return only the variables that were marked as - trainable. - - Returns - ------- - vars: [tf.Variable] - list of variables in `scope`. - """ - return tf1.get_collection( - tf1.GraphKeys.TRAINABLE_VARIABLES - if trainable_only else tf1.GraphKeys.VARIABLES, - scope=scope if isinstance(scope, str) else scope.name) +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.tf_utils import * # noqa + +deprecation_warning( + old="ray.rllib.utils.tf_ops.[...]", + new="ray.rllib.utils.tf_utils.[...]", + error=False, +) diff --git a/rllib/utils/tf_utils.py b/rllib/utils/tf_utils.py new file mode 100644 index 000000000000..7af39988b910 --- /dev/null +++ b/rllib/utils/tf_utils.py @@ -0,0 +1,426 @@ +import gym +from gym.spaces import Discrete, MultiDiscrete +import numpy as np +import tree # pip install dm_tree +from typing import Any, Callable, List, Optional, Type, TYPE_CHECKING, Union + +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \ + PartialTrainerConfigDict, TensorStructType, TensorType + +if TYPE_CHECKING: + from ray.rllib.policy.tf_policy import TFPolicy + +tf1, tf, tfv = try_import_tf() + + +@Deprecated(new="ray.rllib.utils.numpy.convert_to_numpy()", error=True) +def convert_to_non_tf_type(x: TensorStructType) -> TensorStructType: + """Converts values in `stats` to non-Tensor numpy or python types. + + Args: + x: Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all tf (eager) tensors + being converted to numpy types. + + Returns: + A new struct with the same structure as `x`, but with all + values converted to non-tf Tensor types. + """ + + # The mapping function used to numpyize torch Tensors. + def mapping(item): + if isinstance(item, (tf.Tensor, tf.Variable)): + return item.numpy() + else: + return item + + return tree.map_structure(mapping, x) + + +def explained_variance(y: TensorType, pred: TensorType) -> TensorType: + """Computes the explained variance for a pair of labels and predictions. + + The formula used is: + max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2)) + + Args: + y: The labels. + pred: The predictions. + + Returns: + The explained variance given a pair of labels and predictions. + """ + _, y_var = tf.nn.moments(y, axes=[0]) + _, diff_var = tf.nn.moments(y - pred, axes=[0]) + return tf.maximum(-1.0, 1 - (diff_var / y_var)) + + +def get_gpu_devices() -> List[str]: + """Returns a list of GPU device names, e.g. ["/gpu:0", "/gpu:1"]. + + Supports both tf1.x and tf2.x. + + Returns: + List of GPU device names (str). + """ + if tfv == 1: + from tensorflow.python.client import device_lib + devices = device_lib.list_local_devices() + else: + try: + devices = tf.config.list_physical_devices() + except Exception: + devices = tf.config.experimental.list_physical_devices() + + # Expect "GPU", but also stuff like: "XLA_GPU". + return [d.name for d in devices if "GPU" in d.device_type] + + +def get_placeholder(*, + space: Optional[gym.Space] = None, + value: Optional[Any] = None, + name: Optional[str] = None, + time_axis: bool = False, + flatten: bool = True) -> "tf1.placeholder": + """Returns a tf1.placeholder object given optional hints, such as a space. + + Note that the returned placeholder will always have a leading batch + dimension (None). + + Args: + space: An optional gym.Space to hint the shape and dtype of the + placeholder. + value: An optional value to hint the shape and dtype of the + placeholder. + name: An optional name for the placeholder. + time_axis: Whether the placeholder should also receive a time + dimension (None). + flatten: Whether to flatten the given space into a plain Box space + and then create the placeholder from the resulting space. + + Returns: + The tf1 placeholder. + """ + from ray.rllib.models.catalog import ModelCatalog + + if space is not None: + if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): + if flatten: + return ModelCatalog.get_action_placeholder(space, None) + else: + return tree.map_structure_with_path( + lambda path, component: get_placeholder( + space=component, + name=name + "." + ".".join([str(p) for p in path]), + ), + get_base_struct_from_space(space), + ) + return tf1.placeholder( + shape=(None, ) + ((None, ) if time_axis else ()) + space.shape, + dtype=tf.float32 if space.dtype == np.float64 else space.dtype, + name=name, + ) + else: + assert value is not None + shape = value.shape[1:] + return tf1.placeholder( + shape=(None, ) + ((None, ) + if time_axis else ()) + (shape if isinstance( + shape, tuple) else tuple(shape.as_list())), + dtype=tf.float32 if value.dtype == np.float64 else value.dtype, + name=name, + ) + + +def get_tf_eager_cls_if_necessary( + orig_cls: Type["TFPolicy"], + config: PartialTrainerConfigDict) -> Type["TFPolicy"]: + """Returns the corresponding tf-eager class for a given TFPolicy class. + + Args: + orig_cls: The original TFPolicy class to get the corresponding tf-eager + class for. + config: The Trainer config dict. + + Returns: + The tf eager policy class corresponding to the given TFPolicy class. + """ + cls = orig_cls + framework = config.get("framework", "tf") + if framework in ["tf2", "tf", "tfe"]: + if not tf1: + raise ImportError("Could not import tensorflow!") + if framework in ["tf2", "tfe"]: + assert tf1.executing_eagerly() + + from ray.rllib.policy.tf_policy import TFPolicy + + # Create eager-class. + if hasattr(orig_cls, "as_eager"): + cls = orig_cls.as_eager() + if config.get("eager_tracing"): + cls = cls.with_tracing() + # Could be some other type of policy. + elif not issubclass(orig_cls, TFPolicy): + pass + else: + raise ValueError("This policy does not support eager " + "execution: {}".format(orig_cls)) + return cls + + +def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType: + """Computes the huber loss for a given term and delta parameter. + + Reference: https://en.wikipedia.org/wiki/Huber_loss + Note that the factor of 0.5 is implicitly included in the calculation. + + Formula: + L = 0.5 * x^2 for small abs x (delta threshold) + L = delta * (abs(x) - 0.5*delta) for larger abs x (delta threshold) + + Args: + x: The input term, e.g. a TD error. + delta: The delta parmameter in the above formula. + + Returns: + The Huber loss resulting from `x` and `delta`. + """ + return tf.where( + tf.abs(x) < delta, # for small x -> apply the Huber correction + tf.math.square(x) * 0.5, + delta * (tf.abs(x) - 0.5 * delta), + ) + + +def make_tf_callable(session_or_none: Optional["tf1.Session"], + dynamic_shape: bool = False) -> Callable: + """Returns a function that can be executed in either graph or eager mode. + + The function must take only positional args. + + If eager is enabled, this will act as just a function. Otherwise, it + will build a function that executes a session run with placeholders + internally. + + Args: + session_or_none: tf.Session if in graph mode, else None. + dynamic_shape: True if the placeholders should have a dynamic + batch dimension. Otherwise they will be fixed shape. + + Returns: + A function that can be called in either eager or static-graph mode. + """ + + if tf.executing_eagerly(): + assert session_or_none is None + else: + assert session_or_none is not None + + def make_wrapper(fn): + # Static-graph mode: Create placeholders and make a session call each + # time the wrapped function is called. Returns the output of this + # session call. + if session_or_none is not None: + args_placeholders = [] + kwargs_placeholders = {} + + symbolic_out = [None] + + def call(*args, **kwargs): + args_flat = [] + for a in args: + if type(a) is list: + args_flat.extend(a) + else: + args_flat.append(a) + args = args_flat + + # We have not built any placeholders yet: Do this once here, + # then reuse the same placeholders each time we call this + # function again. + if symbolic_out[0] is None: + with session_or_none.graph.as_default(): + + def _create_placeholders(path, value): + if dynamic_shape: + if len(value.shape) > 0: + shape = (None, ) + value.shape[1:] + else: + shape = () + else: + shape = value.shape + return tf1.placeholder( + dtype=value.dtype, + shape=shape, + name=".".join([str(p) for p in path]), + ) + + placeholders = tree.map_structure_with_path( + _create_placeholders, args) + for ph in tree.flatten(placeholders): + args_placeholders.append(ph) + + placeholders = tree.map_structure_with_path( + _create_placeholders, kwargs) + for k, ph in placeholders.items(): + kwargs_placeholders[k] = ph + + symbolic_out[0] = fn(*args_placeholders, + **kwargs_placeholders) + feed_dict = dict(zip(args_placeholders, tree.flatten(args))) + tree.map_structure(lambda ph, v: feed_dict.__setitem__(ph, v), + kwargs_placeholders, kwargs) + ret = session_or_none.run(symbolic_out[0], feed_dict) + return ret + + return call + # Eager mode (call function as is). + else: + return fn + + return make_wrapper + + +def minimize_and_clip( + optimizer: LocalOptimizer, + objective: TensorType, + var_list: List["tf.Variable"], + clip_val: float = 10.0, +) -> ModelGradients: + """Computes, then clips gradients using objective, optimizer and var list. + + Ensures the norm of the gradients for each variable is clipped to + `clip_val`. + + Args: + optimizer: Either a shim optimizer (tf eager) containing a + tf.GradientTape under `self.tape` or a tf1 local optimizer + object. + objective: The loss tensor to calculate gradients on. + var_list: The list of tf.Variables to compute gradients over. + clip_val: The global norm clip value. Will clip around -clip_val and + +clip_val. + + Returns: + The resulting model gradients (list or tuples of grads + vars) + corresponding to the input `var_list`. + """ + # Accidentally passing values < 0.0 will break all gradients. + assert clip_val is None or clip_val > 0.0, clip_val + + if tf.executing_eagerly(): + tape = optimizer.tape + grads_and_vars = list( + zip(list(tape.gradient(objective, var_list)), var_list)) + else: + grads_and_vars = optimizer.compute_gradients( + objective, var_list=var_list) + + return [(tf.clip_by_norm(g, clip_val) if clip_val is not None else g, v) + for (g, v) in grads_and_vars if g is not None] + + +def one_hot(x: TensorType, space: gym.Space) -> TensorType: + """Returns a one-hot tensor, given and int tensor and a space. + + Handles the MultiDiscrete case as well. + + Args: + x: The input tensor. + space: The space to use for generating the one-hot tensor. + + Returns: + The resulting one-hot tensor. + + Raises: + ValueError: If the given space is not a discrete one. + + Examples: + >>> x = tf.Variable([0, 3], dtype=tf.int32) # batch-dim=2 + >>> # Discrete space with 4 (one-hot) slots per batch item. + >>> s = gym.spaces.Discrete(4) + >>> one_hot(x, s) + + + >>> x = tf.Variable([[0, 1, 2, 3]], dtype=tf.int32) # batch-dim=1 + >>> # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots + >>> # per batch item. + >>> s = gym.spaces.MultiDiscrete([5, 4, 4, 7]) + >>> one_hot(x, s) + + """ + if isinstance(space, Discrete): + return tf.one_hot(x, space.n, dtype=tf.float32) + elif isinstance(space, MultiDiscrete): + return tf.concat( + [ + tf.one_hot(x[:, i], n, dtype=tf.float32) + for i, n in enumerate(space.nvec) + ], + axis=-1) + else: + raise ValueError("Unsupported space for `one_hot`: {}".format(space)) + + +def reduce_mean_ignore_inf(x: TensorType, + axis: Optional[int] = None) -> TensorType: + """Same as tf.reduce_mean() but ignores -inf values. + + Args: + x: The input tensor to reduce mean over. + axis: The axis over which to reduce. None for all axes. + + Returns: + The mean reduced inputs, ignoring inf values. + """ + mask = tf.not_equal(x, tf.float32.min) + x_zeroed = tf.where(mask, x, tf.zeros_like(x)) + return (tf.math.reduce_sum(x_zeroed, axis) / tf.math.reduce_sum( + tf.cast(mask, tf.float32), axis)) + + +def scope_vars(scope: Union[str, "tf1.VariableScope"], + trainable_only: bool = False) -> List["tf.Variable"]: + """Get variables inside a given scope. + + Args: + scope: Scope in which the variables reside. + trainable_only: Whether or not to return only the variables that were + marked as trainable. + + Returns: + The list of variables in the given `scope`. + """ + return tf1.get_collection( + tf1.GraphKeys.TRAINABLE_VARIABLES + if trainable_only else tf1.GraphKeys.VARIABLES, + scope=scope if isinstance(scope, str) else scope.name) + + +def zero_logps_from_actions(actions: TensorStructType) -> TensorType: + """Helper function useful for returning dummy logp's (0) for some actions. + + Args: + actions: The input actions. This can be any struct + of complex action components or a simple tensor of different + dimensions, e.g. [B], [B, 2], or {"a": [B, 4, 5], "b": [B]}. + + Returns: + A 1D tensor of 0.0 (dummy logp's) matching the batch + dim of `actions` (shape=[B]). + """ + # Need to flatten `actions` in case we have a complex action space. + # Take the 0th component to extract the batch dim. + action_component = tree.flatten(actions)[0] + logp_ = tf.zeros_like(action_component, dtype=tf.float32) + # Logp's should be single values (but with the same batch dim as + # `deterministic_actions` or `stochastic_actions`). In case + # actions are just [B], zeros_like works just fine here, but if + # actions are [B, ...], we have to reduce logp back to just [B]. + while len(logp_.shape) > 1: + logp_ = logp_[:, 0] + return logp_ diff --git a/rllib/utils/threading.py b/rllib/utils/threading.py index a75f1d65c306..f6a3f7b4fa67 100644 --- a/rllib/utils/threading.py +++ b/rllib/utils/threading.py @@ -1,7 +1,7 @@ from typing import Callable -def with_lock(func: Callable): +def with_lock(func: Callable) -> Callable: """Use as decorator (@withlock) around object methods that need locking. Note: The object must have a self._lock = threading.Lock() property. @@ -9,10 +9,10 @@ def with_lock(func: Callable): object can be called asynchronously). Args: - func (Callable): The function to decorate/wrap. + func: The function to decorate/wrap. Returns: - Callable: The wrapped (object-level locked) function. + The wrapped (object-level locked) function. """ def wrapper(self, *a, **k): diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index 90ccc64aad12..8bc9eebbd116 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -1,255 +1,8 @@ -from gym.spaces import Discrete, MultiDiscrete -import numpy as np -import os -import tree # pip install dm_tree -import warnings - -from ray.rllib.models.repeated_values import RepeatedValues -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.numpy import SMALL_NUMBER - -torch, nn = try_import_torch() - -# Limit values suitable for use as close to a -inf logit. These are useful -# since -inf / inf cause NaNs during backprop. -FLOAT_MIN = -3.4e38 -FLOAT_MAX = 3.4e38 - - -def apply_grad_clipping(policy, optimizer, loss): - """Applies gradient clipping to already computed grads inside `optimizer`. - - Args: - policy (TorchPolicy): The TorchPolicy, which calculated `loss`. - optimizer (torch.optim.Optimizer): A local torch optimizer object. - loss (torch.Tensor): The torch loss tensor. - """ - info = {} - if policy.config["grad_clip"]: - for param_group in optimizer.param_groups: - # Make sure we only pass params with grad != None into torch - # clip_grad_norm_. Would fail otherwise. - params = list( - filter(lambda p: p.grad is not None, param_group["params"])) - if params: - grad_gnorm = nn.utils.clip_grad_norm_( - params, policy.config["grad_clip"]) - if isinstance(grad_gnorm, torch.Tensor): - grad_gnorm = grad_gnorm.cpu().numpy() - info["grad_gnorm"] = grad_gnorm - return info - - -def atanh(x): - return 0.5 * torch.log( - (1 + x).clamp(min=SMALL_NUMBER) / (1 - x).clamp(min=SMALL_NUMBER)) - - -def concat_multi_gpu_td_errors(policy): - td_error = torch.cat( - [ - t.tower_stats.get("td_error", torch.tensor([0.0])).to( - policy.device) for t in policy.model_gpu_towers - ], - dim=0) - policy.td_error = td_error - return { - "td_error": td_error, - "mean_td_error": torch.mean(td_error), - } - - -def convert_to_non_torch_type(stats): - """Converts values in `stats` to non-Tensor numpy or python types. - - Args: - stats (any): Any (possibly nested) struct, the values in which will be - converted and returned as a new struct with all torch tensors - being converted to numpy types. - - Returns: - Any: A new struct with the same structure as `stats`, but with all - values converted to non-torch Tensor types. - """ - - # The mapping function used to numpyize torch Tensors. - def mapping(item): - if isinstance(item, torch.Tensor): - return item.cpu().item() if len(item.size()) == 0 else \ - item.detach().cpu().numpy() - else: - return item - - return tree.map_structure(mapping, stats) - - -def convert_to_torch_tensor(x, device=None): - """Converts any struct to torch.Tensors. - - x (any): Any (possibly nested) struct, the values in which will be - converted and returned as a new struct with all leaves converted - to torch tensors. - - Returns: - Any: A new struct with the same structure as `stats`, but with all - values converted to torch Tensor types. - """ - - def mapping(item): - # Already torch tensor -> make sure it's on right device. - if torch.is_tensor(item): - return item if device is None else item.to(device) - # Special handling of "Repeated" values. - elif isinstance(item, RepeatedValues): - return RepeatedValues( - tree.map_structure(mapping, item.values), item.lengths, - item.max_len) - # Numpy arrays. - if isinstance(item, np.ndarray): - # np.object_ type (e.g. info dicts in train batch): leave as-is. - if item.dtype == np.object_: - return item - # Non-writable numpy-arrays will cause PyTorch warning. - elif item.flags.writeable is False: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - tensor = torch.from_numpy(item) - # Already numpy: Wrap as torch tensor. - else: - tensor = torch.from_numpy(item) - # Everything else: Convert to numpy, then wrap as torch tensor. - else: - tensor = torch.from_numpy(np.asarray(item)) - # Floatify all float64 tensors. - if tensor.dtype == torch.double: - tensor = tensor.float() - return tensor if device is None else tensor.to(device) - - return tree.map_structure(mapping, x) - - -def explained_variance(y, pred): - y_var = torch.var(y, dim=[0]) - diff_var = torch.var(y - pred, dim=[0]) - min_ = torch.tensor([-1.0]).to(pred.device) - return torch.max(min_, 1 - (diff_var / y_var))[0] - - -def global_norm(tensors): - """Returns the global L2 norm over a list of tensors. - - output = sqrt(SUM(t ** 2 for t in tensors)), - where SUM reduces over all tensors and over all elements in tensors. - - Args: - tensors (List[torch.Tensor]): The list of tensors to calculate the - global norm over. - """ - # List of single tensors' L2 norms: SQRT(SUM(xi^2)) over all xi in tensor. - single_l2s = [ - torch.pow(torch.sum(torch.pow(t, 2.0)), 0.5) for t in tensors - ] - # Compute global norm from all single tensors' L2 norms. - return torch.pow(sum(torch.pow(l2, 2.0) for l2 in single_l2s), 0.5) - - -def huber_loss(x, delta=1.0): - """Reference: https://en.wikipedia.org/wiki/Huber_loss""" - return torch.where( - torch.abs(x) < delta, - torch.pow(x, 2.0) * 0.5, delta * (torch.abs(x) - 0.5 * delta)) - - -def l2_loss(x): - """Computes half the L2 norm of a tensor without the sqrt. - - output = sum(x ** 2) / 2 - """ - return torch.sum(torch.pow(x, 2.0)) / 2.0 - - -def minimize_and_clip(optimizer, clip_val=10): - """Clips gradients found in `optimizer.param_groups` to given value. - - Ensures the norm of the gradients for each variable is clipped to - `clip_val` - """ - for param_group in optimizer.param_groups: - for p in param_group["params"]: - if p.grad is not None: - torch.nn.utils.clip_grad_norm_(p.grad, clip_val) - - -def one_hot(x, space): - if isinstance(space, Discrete): - return nn.functional.one_hot(x.long(), space.n) - elif isinstance(space, MultiDiscrete): - return torch.cat( - [ - nn.functional.one_hot(x[:, i].long(), n) - for i, n in enumerate(space.nvec) - ], - dim=-1) - else: - raise ValueError("Unsupported space for `one_hot`: {}".format(space)) - - -def reduce_mean_ignore_inf(x, axis): - """Same as torch.mean() but ignores -inf values.""" - mask = torch.ne(x, float("-inf")) - x_zeroed = torch.where(mask, x, torch.zeros_like(x)) - return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis) - - -def sequence_mask(lengths, maxlen=None, dtype=None, time_major=False): - """Offers same behavior as tf.sequence_mask for torch. - - Thanks to Dimitris Papatheodorou - (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/ - 39036). - """ - if maxlen is None: - maxlen = int(lengths.max()) - - mask = ~(torch.ones( - (len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths) - if not time_major: - mask = mask.t() - mask.type(dtype or torch.bool) - - return mask - - -def set_torch_seed(seed): - if seed is not None and torch: - torch.manual_seed(seed) - # See https://github.com/pytorch/pytorch/issues/47672. - cuda_version = torch.version.cuda - if cuda_version is not None and float(torch.version.cuda) >= 10.2: - os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8" - else: - # Not all Operations support this. - torch.use_deterministic_algorithms(True) - # This is only for Convolution no problem. - torch.backends.cudnn.deterministic = True - - -def softmax_cross_entropy_with_logits(logits, labels): - """Same behavior as tf.nn.softmax_cross_entropy_with_logits. - - Args: - x (TensorType): - - Returns: - - """ - return torch.sum(-labels * nn.functional.log_softmax(logits, -1), -1) - - -class Swish(nn.Module): - def __init__(self): - super().__init__() - self._beta = nn.Parameter(torch.tensor(1.0)) - - def forward(self, input_tensor): - return input_tensor * torch.sigmoid(self._beta * input_tensor) +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.torch_utils import * # noqa + +deprecation_warning( + old="ray.rllib.utils.torch_ops.[...]", + new="ray.rllib.utils.torch_utils.[...]", + error=False, +) diff --git a/rllib/utils/torch_utils.py b/rllib/utils/torch_utils.py new file mode 100644 index 000000000000..19ea8ded7cfc --- /dev/null +++ b/rllib/utils/torch_utils.py @@ -0,0 +1,395 @@ +import gym +from gym.spaces import Discrete, MultiDiscrete +import numpy as np +import os +import tree # pip install dm_tree +from typing import Dict, List, Optional, TYPE_CHECKING +import warnings + +from ray.rllib.models.repeated_values import RepeatedValues +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import SMALL_NUMBER +from ray.rllib.utils.typing import LocalOptimizer, TensorType, TensorStructType + +if TYPE_CHECKING: + from ray.rllib.policy.torch_policy import TorchPolicy + +torch, nn = try_import_torch() + +# Limit values suitable for use as close to a -inf logit. These are useful +# since -inf / inf cause NaNs during backprop. +FLOAT_MIN = -3.4e38 +FLOAT_MAX = 3.4e38 + + +def apply_grad_clipping(policy: "TorchPolicy", optimizer: LocalOptimizer, + loss: TensorType) -> Dict[str, TensorType]: + """Applies gradient clipping to already computed grads inside `optimizer`. + + Args: + policy: The TorchPolicy, which calculated `loss`. + optimizer: A local torch optimizer object. + loss: The torch loss tensor. + + Returns: + An info dict containing the "grad_norm" key and the resulting clipped + gradients. + """ + info = {} + if policy.config["grad_clip"]: + for param_group in optimizer.param_groups: + # Make sure we only pass params with grad != None into torch + # clip_grad_norm_. Would fail otherwise. + params = list( + filter(lambda p: p.grad is not None, param_group["params"])) + if params: + grad_gnorm = nn.utils.clip_grad_norm_( + params, policy.config["grad_clip"]) + if isinstance(grad_gnorm, torch.Tensor): + grad_gnorm = grad_gnorm.cpu().numpy() + info["grad_gnorm"] = grad_gnorm + return info + + +@Deprecated( + old="ray.rllib.utils.torch_utils.atanh", + new="torch.math.atanh", + error=False) +def atanh(x: TensorType) -> TensorType: + """Atanh function for PyTorch.""" + return 0.5 * torch.log( + (1 + x).clamp(min=SMALL_NUMBER) / (1 - x).clamp(min=SMALL_NUMBER)) + + +def concat_multi_gpu_td_errors(policy: "TorchPolicy") -> Dict[str, TensorType]: + """Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy. + + TD-errors are extracted from the TorchPolicy via its tower_stats property. + + Args: + policy: The TorchPolicy to extract the TD-error values from. + + Returns: + A dict mapping strings "td_error" and "mean_td_error" to the + corresponding concatenated and mean-reduced values. + """ + td_error = torch.cat( + [ + t.tower_stats.get("td_error", torch.tensor([0.0])).to( + policy.device) for t in policy.model_gpu_towers + ], + dim=0) + policy.td_error = td_error + return { + "td_error": td_error, + "mean_td_error": torch.mean(td_error), + } + + +@Deprecated(new="ray/rllib/utils/numpy.py::convert_to_numpy", error=False) +def convert_to_non_torch_type(stats: TensorStructType) -> TensorStructType: + """Converts values in `stats` to non-Tensor numpy or python types. + + Args: + stats (any): Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all torch tensors + being converted to numpy types. + + Returns: + Any: A new struct with the same structure as `stats`, but with all + values converted to non-torch Tensor types. + """ + + # The mapping function used to numpyize torch Tensors. + def mapping(item): + if isinstance(item, torch.Tensor): + return item.cpu().item() if len(item.size()) == 0 else \ + item.detach().cpu().numpy() + else: + return item + + return tree.map_structure(mapping, stats) + + +def convert_to_torch_tensor(x: TensorStructType, device: Optional[str] = None): + """Converts any struct to torch.Tensors. + + x (any): Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all leaves converted + to torch tensors. + + Returns: + Any: A new struct with the same structure as `stats`, but with all + values converted to torch Tensor types. + """ + + def mapping(item): + # Already torch tensor -> make sure it's on right device. + if torch.is_tensor(item): + return item if device is None else item.to(device) + # Special handling of "Repeated" values. + elif isinstance(item, RepeatedValues): + return RepeatedValues( + tree.map_structure(mapping, item.values), item.lengths, + item.max_len) + # Numpy arrays. + if isinstance(item, np.ndarray): + # np.object_ type (e.g. info dicts in train batch): leave as-is. + if item.dtype == np.object_: + return item + # Non-writable numpy-arrays will cause PyTorch warning. + elif item.flags.writeable is False: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + tensor = torch.from_numpy(item) + # Already numpy: Wrap as torch tensor. + else: + tensor = torch.from_numpy(item) + # Everything else: Convert to numpy, then wrap as torch tensor. + else: + tensor = torch.from_numpy(np.asarray(item)) + # Floatify all float64 tensors. + if tensor.dtype == torch.double: + tensor = tensor.float() + return tensor if device is None else tensor.to(device) + + return tree.map_structure(mapping, x) + + +def explained_variance(y: TensorType, pred: TensorType) -> TensorType: + """Computes the explained variance for a pair of labels and predictions. + + The formula used is: + max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2)) + + Args: + y: The labels. + pred: The predictions. + + Returns: + The explained variance given a pair of labels and predictions. + """ + y_var = torch.var(y, dim=[0]) + diff_var = torch.var(y - pred, dim=[0]) + min_ = torch.tensor([-1.0]).to(pred.device) + return torch.max(min_, 1 - (diff_var / y_var))[0] + + +def global_norm(tensors: List[TensorType]) -> TensorType: + """Returns the global L2 norm over a list of tensors. + + output = sqrt(SUM(t ** 2 for t in tensors)), + where SUM reduces over all tensors and over all elements in tensors. + + Args: + tensors: The list of tensors to calculate the global norm over. + + Returns: + The global L2 norm over the given tensor list. + """ + # List of single tensors' L2 norms: SQRT(SUM(xi^2)) over all xi in tensor. + single_l2s = [ + torch.pow(torch.sum(torch.pow(t, 2.0)), 0.5) for t in tensors + ] + # Compute global norm from all single tensors' L2 norms. + return torch.pow(sum(torch.pow(l2, 2.0) for l2 in single_l2s), 0.5) + + +def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType: + """Computes the huber loss for a given term and delta parameter. + + Reference: https://en.wikipedia.org/wiki/Huber_loss + Note that the factor of 0.5 is implicitly included in the calculation. + + Formula: + L = 0.5 * x^2 for small abs x (delta threshold) + L = delta * (abs(x) - 0.5*delta) for larger abs x (delta threshold) + + Args: + x: The input term, e.g. a TD error. + delta: The delta parmameter in the above formula. + + Returns: + The Huber loss resulting from `x` and `delta`. + """ + return torch.where( + torch.abs(x) < delta, + torch.pow(x, 2.0) * 0.5, delta * (torch.abs(x) - 0.5 * delta)) + + +def l2_loss(x: TensorType) -> TensorType: + """Computes half the L2 norm over a tensor's values without the sqrt. + + output = 0.5 * sum(x ** 2) + + Args: + x: The input tensor. + + Returns: + 0.5 times the L2 norm over the given tensor's values (w/o sqrt). + """ + return 0.5 * torch.sum(torch.pow(x, 2.0)) + + +def minimize_and_clip(optimizer: "torch.optim.Optimizer", + clip_val: float = 10.0) -> None: + """Clips grads found in `optimizer.param_groups` to given value in place. + + Ensures the norm of the gradients for each variable is clipped to + `clip_val`. + + Args: + optimizer: The torch.optim.Optimizer to get the variables from. + clip_val: The global norm clip value. Will clip around -clip_val and + +clip_val. + """ + # Loop through optimizer's variables and norm per variable. + for param_group in optimizer.param_groups: + for p in param_group["params"]: + if p.grad is not None: + torch.nn.utils.clip_grad_norm_(p.grad, clip_val) + + +def one_hot(x: TensorType, space: gym.Space) -> TensorType: + """Returns a one-hot tensor, given and int tensor and a space. + + Handles the MultiDiscrete case as well. + + Args: + x: The input tensor. + space: The space to use for generating the one-hot tensor. + + Returns: + The resulting one-hot tensor. + + Raises: + ValueError: If the given space is not a discrete one. + + Examples: + >>> x = torch.IntTensor([0, 3]) # batch-dim=2 + >>> # Discrete space with 4 (one-hot) slots per batch item. + >>> s = gym.spaces.Discrete(4) + >>> one_hot(x, s) + tensor([[1, 0, 0, 0], [0, 0, 0, 1]]) + + >>> x = torch.IntTensor([[0, 1, 2, 3]]) # batch-dim=1 + >>> # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots + >>> # per batch item. + >>> s = gym.spaces.MultiDiscrete([5, 4, 4, 7]) + >>> one_hot(x, s) + tensor([[1, 0, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, 0, 0, 0]]) + """ + if isinstance(space, Discrete): + return nn.functional.one_hot(x.long(), space.n) + elif isinstance(space, MultiDiscrete): + return torch.cat( + [ + nn.functional.one_hot(x[:, i].long(), n) + for i, n in enumerate(space.nvec) + ], + dim=-1) + else: + raise ValueError("Unsupported space for `one_hot`: {}".format(space)) + + +def reduce_mean_ignore_inf(x: TensorType, + axis: Optional[int] = None) -> TensorType: + """Same as torch.mean() but ignores -inf values. + + Args: + x: The input tensor to reduce mean over. + axis: The axis over which to reduce. None for all axes. + + Returns: + The mean reduced inputs, ignoring inf values. + """ + mask = torch.ne(x, float("-inf")) + x_zeroed = torch.where(mask, x, torch.zeros_like(x)) + return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis) + + +def sequence_mask( + lengths: TensorType, + maxlen: Optional[int] = None, + dtype=None, + time_major: bool = False, +) -> TensorType: + """Offers same behavior as tf.sequence_mask for torch. + + Thanks to Dimitris Papatheodorou + (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/ + 39036). + + Args: + lengths: The tensor of individual lengths to mask by. + maxlen: The maximum length to use for the time axis. If None, use + the max of `lengths`. + dtype: The torch dtype to use for the resulting mask. + time_major: Whether to return the mask as [B, T] (False; default) or + as [T, B] (True). + + Returns: + The sequence mask resulting from the given input and parameters. + """ + # If maxlen not given, use the longest lengths in the `lengths` tensor. + if maxlen is None: + maxlen = int(lengths.max()) + + mask = ~(torch.ones( + (len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths) + # Time major transformation. + if not time_major: + mask = mask.t() + + # By default, set the mask to be boolean. + mask.type(dtype or torch.bool) + + return mask + + +def set_torch_seed(seed: Optional[int] = None) -> None: + """Sets the torch random seed to the given value. + + Args: + seed: The seed to use or None for no seeding. + """ + if seed is not None and torch: + torch.manual_seed(seed) + # See https://github.com/pytorch/pytorch/issues/47672. + cuda_version = torch.version.cuda + if cuda_version is not None and float(torch.version.cuda) >= 10.2: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8" + else: + # Not all Operations support this. + torch.use_deterministic_algorithms(True) + # This is only for Convolution no problem. + torch.backends.cudnn.deterministic = True + + +def softmax_cross_entropy_with_logits( + logits: TensorType, + labels: TensorType, +) -> TensorType: + """Same behavior as tf.nn.softmax_cross_entropy_with_logits. + + Args: + x: The input predictions. + labels: The labels corresponding to `x`. + + Returns: + The resulting softmax cross-entropy given predictions and labels. + """ + return torch.sum(-labels * nn.functional.log_softmax(logits, -1), -1) + + +class Swish(nn.Module): + def __init__(self): + super().__init__() + self._beta = nn.Parameter(torch.tensor(1.0)) + + def forward(self, input_tensor): + return input_tensor * torch.sigmoid(self._beta * input_tensor) diff --git a/rllib/utils/window_stat.py b/rllib/utils/window_stat.py index 9aa0d9f301df..873d803914af 100644 --- a/rllib/utils/window_stat.py +++ b/rllib/utils/window_stat.py @@ -1,28 +1,9 @@ -import numpy as np - - -class WindowStat: - def __init__(self, name, n): - self.name = name - self.items = [None] * n - self.idx = 0 - self.count = 0 - - def push(self, obj): - self.items[self.idx] = obj - self.idx += 1 - self.count += 1 - self.idx %= len(self.items) - - def stats(self): - if not self.count: - _quantiles = [] - else: - _quantiles = np.nanpercentile(self.items[:self.count], - [0, 10, 50, 90, 100]).tolist() - return { - self.name + "_count": int(self.count), - self.name + "_mean": float(np.nanmean(self.items[:self.count])), - self.name + "_std": float(np.nanstd(self.items[:self.count])), - self.name + "_quantiles": _quantiles, - } +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.metrics.window_stat import WindowStat + +deprecation_warning( + old="ray.rllib.utils.window_stat.WindowStat", + new="ray.rllib.utils.metrics.window_stat.WindowStat", + error=False, +) +WindowStat = WindowStat diff --git a/src/mock/ray/core_worker/task_manager.h b/src/mock/ray/core_worker/task_manager.h index effea598da9d..cfc54a8b6a28 100644 --- a/src/mock/ray/core_worker/task_manager.h +++ b/src/mock/ray/core_worker/task_manager.h @@ -24,7 +24,7 @@ class MockTaskFinisherInterface : public TaskFinisherInterface { const rpc::Address &actor_addr), (override)); MOCK_METHOD(bool, PendingTaskFailed, - (const TaskID &task_id, rpc::ErrorType error_type, Status *status, + (const TaskID &task_id, rpc::ErrorType error_type, const Status *status, const std::shared_ptr &creation_task_exception, bool immediately_mark_object_fail), (override)); diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 5d4b27e45055..12c753257b65 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -350,7 +350,7 @@ bool TaskManager::RetryTaskIfPossible(const TaskID &task_id) { } bool TaskManager::PendingTaskFailed( - const TaskID &task_id, rpc::ErrorType error_type, Status *status, + const TaskID &task_id, rpc::ErrorType error_type, const Status *status, const std::shared_ptr &creation_task_exception, bool immediately_mark_object_fail) { // Note that this might be the __ray_terminate__ task, so we don't log diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index c59d307fd973..f77a2c0957b6 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -34,7 +34,7 @@ class TaskFinisherInterface { virtual bool RetryTaskIfPossible(const TaskID &task_id) = 0; virtual bool PendingTaskFailed( - const TaskID &task_id, rpc::ErrorType error_type, Status *status, + const TaskID &task_id, rpc::ErrorType error_type, const Status *status, const std::shared_ptr &creation_task_exception = nullptr, bool immediately_mark_object_fail = true) = 0; @@ -146,7 +146,7 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// result object as failed. /// \return Whether the task will be retried or not. bool PendingTaskFailed( - const TaskID &task_id, rpc::ErrorType error_type, Status *status = nullptr, + const TaskID &task_id, rpc::ErrorType error_type, const Status *status = nullptr, const std::shared_ptr &creation_task_exception = nullptr, bool immediately_mark_object_fail = true) override; diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index ea5a15b48dc6..69d8d46d9bd7 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -41,7 +41,7 @@ TaskSpecification CreateActorTaskHelper(ActorID actor_id, WorkerID caller_worker int64_t counter, TaskID caller_id = TaskID::Nil()) { TaskSpecification task; - task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary()); task.GetMutableMessage().set_caller_id(caller_id.Binary()); task.GetMutableMessage().set_type(TaskType::ACTOR_TASK); task.GetMutableMessage().mutable_caller_address()->set_worker_id( @@ -137,7 +137,7 @@ TEST_F(DirectActorSubmitterTest, TestSubmitTask) { ASSERT_TRUE(submitter_.SubmitTask(task).ok()); ASSERT_EQ(worker_client_->callbacks.size(), 2); - EXPECT_CALL(*task_finisher_, CompletePendingTask(TaskID::Nil(), _, _)) + EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _, _)) .Times(worker_client_->callbacks.size()); EXPECT_CALL(*task_finisher_, PendingTaskFailed(_, _, _, _, _)).Times(0); while (!worker_client_->callbacks.empty()) { @@ -277,10 +277,10 @@ TEST_F(DirectActorSubmitterTest, TestActorDead) { } EXPECT_CALL(*task_finisher_, PendingTaskFailed(_, _, _, _, _)).Times(0); - submitter_.DisconnectActor(actor_id, 0, /*dead=*/false); + submitter_.DisconnectActor(actor_id, 1, /*dead=*/false); // Actor marked as dead. All queued tasks should get failed. EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)).Times(1); - submitter_.DisconnectActor(actor_id, 1, /*dead=*/true); + submitter_.DisconnectActor(actor_id, 2, /*dead=*/true); } TEST_F(DirectActorSubmitterTest, TestActorRestartNoRetry) { @@ -303,14 +303,16 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartNoRetry) { ASSERT_TRUE(submitter_.SubmitTask(task2).ok()); ASSERT_TRUE(submitter_.SubmitTask(task3).ok()); - EXPECT_CALL(*task_finisher_, CompletePendingTask(task1.TaskId(), _, _)).Times(2); - EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)).Times(2); + EXPECT_CALL(*task_finisher_, CompletePendingTask(task1.TaskId(), _, _)).Times(1); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)).Times(1); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task3.TaskId(), _, _, _, _)).Times(1); + EXPECT_CALL(*task_finisher_, CompletePendingTask(task4.TaskId(), _, _)).Times(1); // First task finishes. Second task fails. ASSERT_TRUE(worker_client_->ReplyPushTask(Status::OK())); ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); // Simulate the actor failing. - submitter_.DisconnectActor(actor_id, 0, /*dead=*/false); + submitter_.DisconnectActor(actor_id, 1, /*dead=*/false); // Third task fails after the actor is disconnected. It should not get // retried. ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); @@ -346,17 +348,20 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartRetry) { ASSERT_TRUE(submitter_.SubmitTask(task3).ok()); // All tasks will eventually finish. - EXPECT_CALL(*task_finisher_, CompletePendingTask(task1.TaskId(), _, _)).Times(4); + EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _, _)).Times(4); // Tasks 2 and 3 will be retried. EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)) - .Times(2) + .Times(1) + .WillRepeatedly(Return(true)); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task3.TaskId(), _, _, _, _)) + .Times(1) .WillRepeatedly(Return(true)); // First task finishes. Second task fails. ASSERT_TRUE(worker_client_->ReplyPushTask(Status::OK())); ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); // Simulate the actor failing. - submitter_.DisconnectActor(actor_id, 0, /*dead=*/false); + submitter_.DisconnectActor(actor_id, 1, /*dead=*/false); // Third task fails after the actor is disconnected. ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); @@ -395,7 +400,7 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartOutOfOrderRetry) { ASSERT_TRUE(submitter_.SubmitTask(task2).ok()); ASSERT_TRUE(submitter_.SubmitTask(task3).ok()); // All tasks will eventually finish. - EXPECT_CALL(*task_finisher_, CompletePendingTask(task1.TaskId(), _, _)).Times(3); + EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _, _)).Times(3); // Tasks 2 will be retried EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)) @@ -406,7 +411,7 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartOutOfOrderRetry) { ASSERT_TRUE(worker_client_->ReplyPushTask(Status::OK(), /*index=*/1)); // Simulate the actor failing. ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""), /*index=*/0)); - submitter_.DisconnectActor(actor_id, 0, /*dead=*/false); + submitter_.DisconnectActor(actor_id, 1, /*dead=*/false); // Actor gets restarted. addr.set_port(1); @@ -493,6 +498,47 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartOutOfOrderGcs) { ASSERT_TRUE(submitter_.SubmitTask(task).ok()); } +TEST_F(DirectActorSubmitterTest, TestActorRestartFailInflightTasks) { + rpc::Address addr; + auto worker_id = WorkerID::FromRandom(); + addr.set_worker_id(worker_id.Binary()); + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + submitter_.AddActorQueueIfNotExists(actor_id); + addr.set_port(0); + submitter_.ConnectActor(actor_id, addr, 0); + ASSERT_EQ(worker_client_->callbacks.size(), 0); + ASSERT_EQ(num_clients_connected_, 1); + + // Create 3 tasks for the actor. + auto task1 = CreateActorTaskHelper(actor_id, worker_id, 0); + auto task2 = CreateActorTaskHelper(actor_id, worker_id, 1); + auto task3 = CreateActorTaskHelper(actor_id, worker_id, 1); + // Submit a task. + ASSERT_TRUE(submitter_.SubmitTask(task1).ok()); + EXPECT_CALL(*task_finisher_, CompletePendingTask(task1.TaskId(), _, _)).Times(1); + ASSERT_TRUE(worker_client_->ReplyPushTask(Status::OK())); + + // Submit 2 tasks. + ASSERT_TRUE(submitter_.SubmitTask(task2).ok()); + ASSERT_TRUE(submitter_.SubmitTask(task3).ok()); + // Actor failed, but the task replies are delayed (or in some scenarios, lost). + // We should still be able to fail the inflight tasks. + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)).Times(1); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task3.TaskId(), _, _, _, _)).Times(1); + submitter_.DisconnectActor(actor_id, 1, /*dead=*/false); + + // The task replies are now received. Since the tasks are already failed, they will not + // be marked as failed or finished again. + EXPECT_CALL(*task_finisher_, CompletePendingTask(task2.TaskId(), _, _)).Times(0); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)).Times(0); + EXPECT_CALL(*task_finisher_, CompletePendingTask(task3.TaskId(), _, _)).Times(0); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task3.TaskId(), _, _, _, _)).Times(0); + // Task 2 replied with OK. + ASSERT_TRUE(worker_client_->ReplyPushTask(Status::OK())); + // Task 3 replied with error. + ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); +} + class MockDependencyWaiter : public DependencyWaiter { public: MOCK_METHOD2(Wait, void(const std::vector &dependencies, diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index b631b1d37217..2db2ab426fc5 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -110,7 +110,7 @@ class MockTaskFinisher : public TaskFinisherInterface { } bool PendingTaskFailed( - const TaskID &task_id, rpc::ErrorType error_type, Status *status, + const TaskID &task_id, rpc::ErrorType error_type, const Status *status, const std::shared_ptr &creation_task_exception = nullptr, bool immediately_mark_object_fail = true) override { num_tasks_failed++; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index e04421f080d9..d085884729b8 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -134,121 +134,161 @@ void CoreWorkerDirectActorTaskSubmitter::DisconnectRpcClient(ClientQueue &queue) queue.pending_force_kill.reset(); } +void CoreWorkerDirectActorTaskSubmitter::FailInflightTasks( + const std::unordered_map> + &inflight_task_callbacks) { + // NOTE(kfstorm): We invoke the callbacks with a bad status to act like there's a + // network issue. We don't call `task_finisher_.PendingTaskFailed` directly because + // there's much more work to do in the callback. + auto status = Status::IOError("Fail all inflight tasks due to actor state change."); + rpc::PushTaskReply reply; + for (const auto &entry : inflight_task_callbacks) { + entry.second(status, reply); + } +} + void CoreWorkerDirectActorTaskSubmitter::ConnectActor(const ActorID &actor_id, const rpc::Address &address, int64_t num_restarts) { RAY_LOG(DEBUG) << "Connecting to actor " << actor_id << " at worker " << WorkerID::FromBinary(address.worker_id()); - absl::MutexLock lock(&mu_); - auto queue = client_queues_.find(actor_id); - RAY_CHECK(queue != client_queues_.end()); - if (num_restarts < queue->second.num_restarts) { - // This message is about an old version of the actor and the actor has - // already restarted since then. Skip the connection. - RAY_LOG(INFO) << "Skip actor connection that has already been restarted, actor_id=" - << actor_id; - return; - } + std::unordered_map> + inflight_task_callbacks; - if (queue->second.rpc_client && - queue->second.rpc_client->Addr().ip_address() == address.ip_address() && - queue->second.rpc_client->Addr().port() == address.port()) { - RAY_LOG(DEBUG) << "Skip actor that has already been connected, actor_id=" << actor_id; - return; - } + { + absl::MutexLock lock(&mu_); - if (queue->second.state == rpc::ActorTableData::DEAD) { - // This message is about an old version of the actor and the actor has - // already died since then. Skip the connection. - return; - } + auto queue = client_queues_.find(actor_id); + RAY_CHECK(queue != client_queues_.end()); + if (num_restarts < queue->second.num_restarts) { + // This message is about an old version of the actor and the actor has + // already restarted since then. Skip the connection. + RAY_LOG(INFO) << "Skip actor connection that has already been restarted, actor_id=" + << actor_id; + return; + } - queue->second.num_restarts = num_restarts; - if (queue->second.rpc_client) { - // Clear the client to the old version of the actor. - DisconnectRpcClient(queue->second); + if (queue->second.rpc_client && + queue->second.rpc_client->Addr().ip_address() == address.ip_address() && + queue->second.rpc_client->Addr().port() == address.port()) { + RAY_LOG(DEBUG) << "Skip actor that has already been connected, actor_id=" + << actor_id; + return; + } + + if (queue->second.state == rpc::ActorTableData::DEAD) { + // This message is about an old version of the actor and the actor has + // already died since then. Skip the connection. + return; + } + + queue->second.num_restarts = num_restarts; + if (queue->second.rpc_client) { + // Clear the client to the old version of the actor. + DisconnectRpcClient(queue->second); + inflight_task_callbacks = std::move(queue->second.inflight_task_callbacks); + queue->second.inflight_task_callbacks.clear(); + } + + queue->second.state = rpc::ActorTableData::ALIVE; + // Update the mapping so new RPCs go out with the right intended worker id. + queue->second.worker_id = address.worker_id(); + // Create a new connection to the actor. + queue->second.rpc_client = core_worker_client_pool_.GetOrConnect(address); + // This assumes that all replies from the previous incarnation + // of the actor have been received. This assumption should be OK + // because we fail all inflight tasks in `DisconnectRpcClient`. + RAY_LOG(DEBUG) << "Resetting caller starts at for actor " << actor_id << " from " + << queue->second.caller_starts_at << " to " + << queue->second.next_task_reply_position; + queue->second.caller_starts_at = queue->second.next_task_reply_position; + + RAY_LOG(INFO) << "Connecting to actor " << actor_id << " at worker " + << WorkerID::FromBinary(address.worker_id()); + ResendOutOfOrderTasks(actor_id); + SendPendingTasks(actor_id); } - queue->second.state = rpc::ActorTableData::ALIVE; - // Update the mapping so new RPCs go out with the right intended worker id. - queue->second.worker_id = address.worker_id(); - // Create a new connection to the actor. - queue->second.rpc_client = core_worker_client_pool_.GetOrConnect(address); - // TODO(swang): This assumes that all replies from the previous incarnation - // of the actor have been received. Fix this by setting an epoch for each - // actor task, so we can ignore completed tasks from old epochs. - RAY_LOG(DEBUG) << "Resetting caller starts at for actor " << actor_id << " from " - << queue->second.caller_starts_at << " to " - << queue->second.next_task_reply_position; - queue->second.caller_starts_at = queue->second.next_task_reply_position; - - RAY_LOG(INFO) << "Connecting to actor " << actor_id << " at worker " - << WorkerID::FromBinary(address.worker_id()); - ResendOutOfOrderTasks(actor_id); - SendPendingTasks(actor_id); + // NOTE(kfstorm): We need to make sure the lock is released before invoking callbacks. + FailInflightTasks(inflight_task_callbacks); } void CoreWorkerDirectActorTaskSubmitter::DisconnectActor( const ActorID &actor_id, int64_t num_restarts, bool dead, const std::shared_ptr &creation_task_exception) { RAY_LOG(DEBUG) << "Disconnecting from actor " << actor_id; - absl::MutexLock lock(&mu_); - auto queue = client_queues_.find(actor_id); - RAY_CHECK(queue != client_queues_.end()); - if (num_restarts <= queue->second.num_restarts && !dead) { - // This message is about an old version of the actor that has already been - // restarted successfully. Skip the message handling. - RAY_LOG(INFO) << "Skip actor disconnection that has already been restarted, actor_id=" - << actor_id; - return; - } - // The actor failed, so erase the client for now. Either the actor is - // permanently dead or the new client will be inserted once the actor is - // restarted. - DisconnectRpcClient(queue->second); - - if (dead) { - queue->second.state = rpc::ActorTableData::DEAD; - queue->second.creation_task_exception = creation_task_exception; - // If there are pending requests, treat the pending tasks as failed. - RAY_LOG(INFO) << "Failing pending tasks for actor " << actor_id - << " because the actor is already dead."; - auto &requests = queue->second.requests; - auto head = requests.begin(); - - auto status = Status::IOError("cancelling all pending tasks of dead actor"); - while (head != requests.end()) { - const auto &task_spec = head->second.first; - task_finisher_.MarkTaskCanceled(task_spec.TaskId()); - // No need to increment the number of completed tasks since the actor is - // dead. - RAY_UNUSED(!task_finisher_.PendingTaskFailed(task_spec.TaskId(), - rpc::ErrorType::ACTOR_DIED, &status, - creation_task_exception)); - head = requests.erase(head); + std::unordered_map> + inflight_task_callbacks; + + { + absl::MutexLock lock(&mu_); + auto queue = client_queues_.find(actor_id); + RAY_CHECK(queue != client_queues_.end()); + if (!dead) { + RAY_CHECK(num_restarts > 0); + } + if (num_restarts <= queue->second.num_restarts && !dead) { + // This message is about an old version of the actor that has already been + // restarted successfully. Skip the message handling. + RAY_LOG(INFO) + << "Skip actor disconnection that has already been restarted, actor_id=" + << actor_id; + return; } - auto &wait_for_death_info_tasks = queue->second.wait_for_death_info_tasks; + // The actor failed, so erase the client for now. Either the actor is + // permanently dead or the new client will be inserted once the actor is + // restarted. + DisconnectRpcClient(queue->second); + inflight_task_callbacks = std::move(queue->second.inflight_task_callbacks); + queue->second.inflight_task_callbacks.clear(); + + if (dead) { + queue->second.state = rpc::ActorTableData::DEAD; + queue->second.creation_task_exception = creation_task_exception; + // If there are pending requests, treat the pending tasks as failed. + RAY_LOG(INFO) << "Failing pending tasks for actor " << actor_id + << " because the actor is already dead."; + auto &requests = queue->second.requests; + auto head = requests.begin(); + + auto status = Status::IOError("cancelling all pending tasks of dead actor"); + while (head != requests.end()) { + const auto &task_spec = head->second.first; + task_finisher_.MarkTaskCanceled(task_spec.TaskId()); + // No need to increment the number of completed tasks since the actor is + // dead. + RAY_UNUSED(!task_finisher_.PendingTaskFailed(task_spec.TaskId(), + rpc::ErrorType::ACTOR_DIED, &status, + creation_task_exception)); + head = requests.erase(head); + } - RAY_LOG(INFO) << "Failing tasks waiting for death info, size=" - << wait_for_death_info_tasks.size() << ", actor_id=" << actor_id; - for (auto &net_err_task : wait_for_death_info_tasks) { - RAY_UNUSED(task_finisher_.MarkPendingTaskFailed( - net_err_task.second, rpc::ErrorType::ACTOR_DIED, creation_task_exception)); - } + auto &wait_for_death_info_tasks = queue->second.wait_for_death_info_tasks; - // No need to clean up tasks that have been sent and are waiting for - // replies. They will be treated as failed once the connection dies. - // We retain the sequencing information so that we can properly fail - // any tasks submitted after the actor death. - } else if (queue->second.state != rpc::ActorTableData::DEAD) { - // Only update the actor's state if it is not permanently dead. The actor - // will eventually get restarted or marked as permanently dead. - queue->second.state = rpc::ActorTableData::RESTARTING; - queue->second.num_restarts = num_restarts; + RAY_LOG(INFO) << "Failing tasks waiting for death info, size=" + << wait_for_death_info_tasks.size() << ", actor_id=" << actor_id; + for (auto &net_err_task : wait_for_death_info_tasks) { + RAY_UNUSED(task_finisher_.MarkPendingTaskFailed( + net_err_task.second, rpc::ErrorType::ACTOR_DIED, creation_task_exception)); + } + + // No need to clean up tasks that have been sent and are waiting for + // replies. They will be treated as failed once the connection dies. + // We retain the sequencing information so that we can properly fail + // any tasks submitted after the actor death. + } else if (queue->second.state != rpc::ActorTableData::DEAD) { + // Only update the actor's state if it is not permanently dead. The actor + // will eventually get restarted or marked as permanently dead. + queue->second.state = rpc::ActorTableData::RESTARTING; + queue->second.num_restarts = num_restarts; + } } + + // NOTE(kfstorm): We need to make sure the lock is released before invoking callbacks. + FailInflightTasks(inflight_task_callbacks); } void CoreWorkerDirectActorTaskSubmitter::CheckTimeoutTasks() { @@ -319,7 +359,7 @@ void CoreWorkerDirectActorTaskSubmitter::ResendOutOfOrderTasks(const ActorID &ac client_queue.out_of_order_completed_tasks.clear(); } -void CoreWorkerDirectActorTaskSubmitter::PushActorTask(const ClientQueue &queue, +void CoreWorkerDirectActorTaskSubmitter::PushActorTask(ClientQueue &queue, const TaskSpecification &task_spec, bool skip_queue) { auto request = std::make_unique(); @@ -349,10 +389,9 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(const ClientQueue &queue, } rpc::Address addr(queue.rpc_client->Addr()); - queue.rpc_client->PushActorTask( - std::move(request), skip_queue, + rpc::ClientCallback reply_callback = [this, addr, task_id, actor_id, actor_counter, task_spec, task_skipped]( - Status status, const rpc::PushTaskReply &reply) { + const Status &status, const rpc::PushTaskReply &reply) { bool increment_completed_tasks = true; if (task_skipped) { @@ -420,7 +459,30 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(const ClientQueue &queue, << " and size of out_of_order_tasks set is " << queue.out_of_order_completed_tasks.size(); } - }); + }; + + queue.inflight_task_callbacks.emplace(task_id, std::move(reply_callback)); + rpc::ClientCallback wrapped_callback = + [this, task_id, actor_id](const Status &status, const rpc::PushTaskReply &reply) { + rpc::ClientCallback reply_callback; + { + absl::MutexLock lock(&mu_); + auto it = client_queues_.find(actor_id); + RAY_CHECK(it != client_queues_.end()); + auto &queue = it->second; + auto callback_it = queue.inflight_task_callbacks.find(task_id); + if (callback_it == queue.inflight_task_callbacks.end()) { + RAY_LOG(DEBUG) << "The task " << task_id + << " has already been marked as failed. Ingore the reply."; + return; + } + reply_callback = std::move(callback_it->second); + queue.inflight_task_callbacks.erase(callback_it); + } + reply_callback(status, reply); + }; + + queue.rpc_client->PushActorTask(std::move(request), skip_queue, wrapped_callback); } bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) const { diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 42a048ac2f16..162e9dd52b49 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -223,6 +223,11 @@ class CoreWorkerDirectActorTaskSubmitter /// A force-kill request that should be sent to the actor once an RPC /// client to the actor is available. absl::optional pending_force_kill; + + /// Stores all callbacks of inflight tasks. Note that this doesn't include tasks + /// without replies. + std::unordered_map> + inflight_task_callbacks; }; /// Push a task to a remote actor via the given client. @@ -234,7 +239,7 @@ class CoreWorkerDirectActorTaskSubmitter /// \param[in] skip_queue Whether to skip the task queue. This will send the /// task for execution immediately. /// \return Void. - void PushActorTask(const ClientQueue &queue, const TaskSpecification &task_spec, + void PushActorTask(ClientQueue &queue, const TaskSpecification &task_spec, bool skip_queue) EXCLUSIVE_LOCKS_REQUIRED(mu_); /// Send all pending tasks for an actor. @@ -253,6 +258,11 @@ class CoreWorkerDirectActorTaskSubmitter /// Disconnect the RPC client for an actor. void DisconnectRpcClient(ClientQueue &queue) EXCLUSIVE_LOCKS_REQUIRED(mu_); + /// Fail all in-flight tasks. + void FailInflightTasks( + const std::unordered_map> + &inflight_task_callbacks) LOCKS_EXCLUDED(mu_); + /// Whether the specified actor is alive. /// /// \param[in] actor_id The actor ID. diff --git a/src/ray/gcs/gcs_server/gcs_resource_manager.cc b/src/ray/gcs/gcs_server/gcs_resource_manager.cc index 9e36db6efafa..4266c635bf98 100644 --- a/src/ray/gcs/gcs_server/gcs_resource_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_resource_manager.cc @@ -58,7 +58,7 @@ void GcsResourceManager::HandleUpdateResources( const rpc::UpdateResourcesRequest &request, rpc::UpdateResourcesReply *reply, rpc::SendReplyCallback send_reply_callback) { NodeID node_id = NodeID::FromBinary(request.node_id()); - RAY_LOG(INFO) << "Updating resources, node id = " << node_id; + RAY_LOG(DEBUG) << "Updating resources, node id = " << node_id; auto changed_resources = std::make_shared>(); for (const auto &entry : request.resources()) { changed_resources->emplace(entry.first, entry.second.resource_capacity()); diff --git a/src/ray/protobuf/serve.proto b/src/ray/protobuf/serve.proto index c35223b9fbb8..eea2445b8988 100644 --- a/src/ray/protobuf/serve.proto +++ b/src/ray/protobuf/serve.proto @@ -51,40 +51,40 @@ message AutoscalingConfig { double upscale_delay_s = 8; } -// Configuration options for a backend, to be set by the user. -message BackendConfig { - // The number of processes to start up that will handle requests to this backend. +// Configuration options for a deployment, to be set by the user. +message DeploymentConfig { + // The number of processes to start up that will handle requests to this deployment. // Defaults to 1. int32 num_replicas = 1; - // The maximum number of queries that will be sent to a replica of this backend without - // receiving a response. Defaults to 100. + // The maximum number of queries that will be sent to a replica of this deployment + // without receiving a response. Defaults to 100. int32 max_concurrent_queries = 2; - // Arguments to pass to the reconfigure method of the backend. The reconfigure method is - // called if user_config is not None. + // Arguments to pass to the reconfigure method of the deployment. The reconfigure method + // is called if user_config is not None. bytes user_config = 3; - // Duration that backend workers will wait until there is no more work to be done before - // shutting down. Defaults to 2s. + // Duration that deployment replicas will wait until there is no more work to be done + // before shutting down. Defaults to 2s. double graceful_shutdown_wait_loop_s = 4; // Controller waits for this duration to forcefully kill the replica for shutdown. // Defaults to 20s. double graceful_shutdown_timeout_s = 5; - // Is the construction of backend is cross language? + // Is the construction of deployment is cross language? bool is_cross_language = 6; - // The backend's programming language. - BackendLanguage backend_language = 7; + // The deployment's programming language. + DeploymentLanguage deployment_language = 7; - // The backend's autoscaling configuration. + // The deployment's autoscaling configuration. AutoscalingConfig autoscaling_config = 8; } // Backend language. -enum BackendLanguage { +enum DeploymentLanguage { PYTHON = 0; JAVA = 1; } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 260f259657de..f184bf6f38fc 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -891,6 +891,9 @@ void NodeManager::ResourceCreateUpdated(const NodeID &node_id, RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from node id " << node_id << " with created or updated resources: " << createUpdatedResources.ToString() << ". Updating resource map."; + if (node_id == self_node_id_) { + return; + } // Update local_available_resources_ and SchedulingResources for (const auto &resource_pair : createUpdatedResources.GetResourceMap()) { @@ -900,11 +903,7 @@ void NodeManager::ResourceCreateUpdated(const NodeID &node_id, new_resource_capacity); } RAY_LOG(DEBUG) << "[ResourceCreateUpdated] Updated cluster_resource_map."; - - if (node_id == self_node_id_) { - // The resource update is on the local node, check if we can reschedule tasks. - cluster_task_manager_->ScheduleAndDispatchTasks(); - } + cluster_task_manager_->ScheduleAndDispatchTasks(); } void NodeManager::ResourceDeleted(const NodeID &node_id, @@ -1474,39 +1473,44 @@ void NodeManager::HandleUpdateResourceUsage( rpc::SendReplyCallback send_reply_callback) { rpc::ResourceUsageBroadcastData resource_usage_batch; resource_usage_batch.ParseFromString(request.serialized_resource_usage_batch()); - - if (resource_usage_batch.seq_no() != next_resource_seq_no_) { + // When next_resource_seq_no_ == 0 it means it just started. + // TODO: Fetch a snapshot from gcs for lightweight resource broadcasting + if (next_resource_seq_no_ != 0 && + resource_usage_batch.seq_no() != next_resource_seq_no_) { + // TODO (Alex): Ideally we would be really robust, and potentially eagerly + // pull a full resource "snapshot" from gcs to make sure our state doesn't + // diverge from GCS. RAY_LOG(WARNING) << "Raylet may have missed a resource broadcast. This either means that GCS has " "restarted, the network is heavily congested and is dropping, reordering, or " "duplicating packets. Expected seq#: " << next_resource_seq_no_ << ", but got: " << resource_usage_batch.seq_no() << "."; - // TODO (Alex): Ideally we would be really robust, and potentially eagerly - // pull a full resource "snapshot" from gcs to make sure our state doesn't - // diverge from GCS. + if (resource_usage_batch.seq_no() < next_resource_seq_no_) { + RAY_LOG(WARNING) << "Discard the the resource update since local version is newer"; + return; + } } next_resource_seq_no_ = resource_usage_batch.seq_no() + 1; for (const auto &resource_change_or_data : resource_usage_batch.batch()) { if (resource_change_or_data.has_data()) { const auto &resource_usage = resource_change_or_data.data(); - const NodeID &node_id = NodeID::FromBinary(resource_usage.node_id()); - if (node_id == self_node_id_) { - // Skip messages from self. - continue; + auto node_id = NodeID::FromBinary(resource_usage.node_id()); + // Skip messages from self. + if (node_id != self_node_id_) { + UpdateResourceUsage(node_id, resource_usage); } - UpdateResourceUsage(node_id, resource_usage); } else if (resource_change_or_data.has_change()) { const auto &resource_notification = resource_change_or_data.change(); - auto id = NodeID::FromBinary(resource_notification.node_id()); + auto node_id = NodeID::FromBinary(resource_notification.node_id()); if (resource_notification.updated_resources_size() != 0) { ResourceSet resource_set( MapFromProtobuf(resource_notification.updated_resources())); - ResourceCreateUpdated(id, resource_set); + ResourceCreateUpdated(node_id, resource_set); } if (resource_notification.deleted_resources_size() != 0) { - ResourceDeleted(id, + ResourceDeleted(node_id, VectorFromProtobuf(resource_notification.deleted_resources())); } }