-
Notifications
You must be signed in to change notification settings - Fork 903
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pay off some JNI RMM API tech debt (#12632)
This makes the java APIs for RMM more closely match the C++ APIs. Authors: - Robert (Bobby) Evans (https://github.com/revans2) Approvers: - Jason Lowe (https://github.com/jlowe) URL: #12632
- Loading branch information
Showing
14 changed files
with
1,288 additions
and
280 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
67 changes: 67 additions & 0 deletions
67
java/src/main/java/ai/rapids/cudf/RmmArenaMemoryResource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package ai.rapids.cudf; | ||
|
||
/** | ||
* A device memory resource that will pre-allocate a pool of resources and sub-allocate from this | ||
* pool to improve memory performance. This uses an algorithm to try and reduce fragmentation | ||
* much more than the RmmPoolMemoryResource does. | ||
*/ | ||
public class RmmArenaMemoryResource<C extends RmmDeviceMemoryResource> | ||
extends RmmWrappingDeviceMemoryResource<C> { | ||
private final long size; | ||
private final boolean dumpLogOnFailure; | ||
private long handle = 0; | ||
|
||
|
||
/** | ||
* Create a new arena memory resource taking ownership of the RmmDeviceMemoryResource that it is | ||
* wrapping. | ||
* @param wrapped the memory resource to use for the pool. This should not be reused. | ||
* @param size the size of the pool | ||
* @param dumpLogOnFailure if true, dump memory log when running out of memory. | ||
*/ | ||
public RmmArenaMemoryResource(C wrapped, long size, boolean dumpLogOnFailure) { | ||
super(wrapped); | ||
this.size = size; | ||
this.dumpLogOnFailure = dumpLogOnFailure; | ||
handle = Rmm.newArenaMemoryResource(wrapped.getHandle(), size, dumpLogOnFailure); | ||
} | ||
|
||
@Override | ||
public long getHandle() { | ||
return handle; | ||
} | ||
|
||
public long getSize() { | ||
return size; | ||
} | ||
|
||
@Override | ||
public void close() { | ||
if (handle != 0) { | ||
Rmm.releaseArenaMemoryResource(handle); | ||
handle = 0; | ||
} | ||
super.close(); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return Long.toHexString(getHandle()) + "/ARENA(" + wrapped + | ||
", " + size + ", " + dumpLogOnFailure + ")"; | ||
} | ||
} |
59 changes: 59 additions & 0 deletions
59
java/src/main/java/ai/rapids/cudf/RmmCudaAsyncMemoryResource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package ai.rapids.cudf; | ||
|
||
/** | ||
* A device memory resource that uses `cudaMallocAsync` and `cudaFreeAsync` for allocation and | ||
* deallocation. | ||
*/ | ||
public class RmmCudaAsyncMemoryResource implements RmmDeviceMemoryResource { | ||
private final long releaseThreshold; | ||
private final long size; | ||
private long handle = 0; | ||
|
||
/** | ||
* Create a new async memory resource | ||
* @param size the initial size of the pool | ||
* @param releaseThreshold size in bytes for when memory is released back to cuda | ||
*/ | ||
public RmmCudaAsyncMemoryResource(long size, long releaseThreshold) { | ||
this.size = size; | ||
this.releaseThreshold = releaseThreshold; | ||
handle = Rmm.newCudaAsyncMemoryResource(size, releaseThreshold); | ||
} | ||
|
||
@Override | ||
public long getHandle() { | ||
return handle; | ||
} | ||
|
||
public long getSize() { | ||
return size; | ||
} | ||
|
||
@Override | ||
public void close() { | ||
if (handle != 0) { | ||
Rmm.releaseCudaAsyncMemoryResource(handle); | ||
handle = 0; | ||
} | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return Long.toHexString(getHandle()) + "/ASYNC(" + size + ", " + releaseThreshold + ")"; | ||
} | ||
} |
44 changes: 44 additions & 0 deletions
44
java/src/main/java/ai/rapids/cudf/RmmCudaMemoryResource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package ai.rapids.cudf; | ||
|
||
/** | ||
* A device memory resource that uses `cudaMalloc` and `cudaFree` for allocation and deallocation. | ||
*/ | ||
public class RmmCudaMemoryResource implements RmmDeviceMemoryResource { | ||
private long handle = 0; | ||
|
||
public RmmCudaMemoryResource() { | ||
handle = Rmm.newCudaMemoryResource(); | ||
} | ||
@Override | ||
public long getHandle() { | ||
return handle; | ||
} | ||
|
||
@Override | ||
public void close() { | ||
if (handle != 0) { | ||
Rmm.releaseCudaMemoryResource(handle); | ||
handle = 0; | ||
} | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return Long.toHexString(getHandle()) + "/CUDA()"; | ||
} | ||
} |
31 changes: 31 additions & 0 deletions
31
java/src/main/java/ai/rapids/cudf/RmmDeviceMemoryResource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package ai.rapids.cudf; | ||
|
||
/** | ||
* A resource that allocates/deallocates device memory. This is not intended to be something that | ||
* a user will just subclass. This is intended to be a wrapper around a C++ class that RMM will | ||
* use directly. | ||
*/ | ||
public interface RmmDeviceMemoryResource extends AutoCloseable { | ||
/** | ||
* Returns a pointer to the underlying C++ class that implements rmm::mr::device_memory_resource | ||
*/ | ||
long getHandle(); | ||
|
||
// Remove the exception... | ||
void close(); | ||
} |
76 changes: 76 additions & 0 deletions
76
java/src/main/java/ai/rapids/cudf/RmmEventHandlerResourceAdaptor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package ai.rapids.cudf; | ||
|
||
import java.util.Arrays; | ||
|
||
/** | ||
* A device memory resource that will give callbacks in specific situations. | ||
*/ | ||
public class RmmEventHandlerResourceAdaptor<C extends RmmDeviceMemoryResource> | ||
extends RmmWrappingDeviceMemoryResource<C> { | ||
private long handle = 0; | ||
private final long [] allocThresholds; | ||
private final long [] deallocThresholds; | ||
private final boolean debug; | ||
|
||
/** | ||
* Create a new logging resource adaptor. | ||
* @param wrapped the memory resource to get callbacks for. This should not be reused. | ||
* @param handler the handler that will get the callbacks | ||
* @param tracker the tracking event handler | ||
* @param debug true if you want all the callbacks, else false | ||
*/ | ||
public RmmEventHandlerResourceAdaptor(C wrapped, RmmTrackingResourceAdaptor<?> tracker, | ||
RmmEventHandler handler, boolean debug) { | ||
super(wrapped); | ||
this.debug = debug; | ||
allocThresholds = sortThresholds(handler.getAllocThresholds()); | ||
deallocThresholds = sortThresholds(handler.getDeallocThresholds()); | ||
handle = Rmm.newEventHandlerResourceAdaptor(wrapped.getHandle(), tracker.getHandle(), handler, | ||
allocThresholds, deallocThresholds, debug); | ||
} | ||
|
||
private static long[] sortThresholds(long[] thresholds) { | ||
if (thresholds == null) { | ||
return null; | ||
} | ||
long[] result = Arrays.copyOf(thresholds, thresholds.length); | ||
Arrays.sort(result); | ||
return result; | ||
} | ||
|
||
@Override | ||
public long getHandle() { | ||
return handle; | ||
} | ||
|
||
@Override | ||
public void close() { | ||
if (handle != 0) { | ||
Rmm.releaseEventHandlerResourceAdaptor(handle, debug); | ||
handle = 0; | ||
} | ||
super.close(); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return Long.toHexString(getHandle()) + "/EVENT(" + wrapped + | ||
", " + debug + ", " + Arrays.toString(allocThresholds) + ", " + | ||
Arrays.toString(deallocThresholds) + ")"; | ||
} | ||
} |
59 changes: 59 additions & 0 deletions
59
java/src/main/java/ai/rapids/cudf/RmmLimitingResourceAdaptor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package ai.rapids.cudf; | ||
|
||
/** | ||
* A device memory resource that will limit the maximum amount allocated. | ||
*/ | ||
public class RmmLimitingResourceAdaptor<C extends RmmDeviceMemoryResource> | ||
extends RmmWrappingDeviceMemoryResource<C> { | ||
private final long limit; | ||
private final long alignment; | ||
private long handle = 0; | ||
|
||
/** | ||
* Create a new limiting resource adaptor. | ||
* @param wrapped the memory resource to limit. This should not be reused. | ||
* @param limit the allocation limit in bytes | ||
* @param alignment the alignment | ||
*/ | ||
public RmmLimitingResourceAdaptor(C wrapped, long limit, long alignment) { | ||
super(wrapped); | ||
this.limit = limit; | ||
this.alignment = alignment; | ||
handle = Rmm.newLimitingResourceAdaptor(wrapped.getHandle(), limit, alignment); | ||
} | ||
|
||
@Override | ||
public long getHandle() { | ||
return handle; | ||
} | ||
|
||
@Override | ||
public void close() { | ||
if (handle != 0) { | ||
Rmm.releaseLimitingResourceAdaptor(handle); | ||
handle = 0; | ||
} | ||
super.close(); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return Long.toHexString(getHandle()) + "/LIMIT(" + wrapped + | ||
", " + limit + ", " + alignment + ")"; | ||
} | ||
} |
Oops, something went wrong.