Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Use folk java process to avoid jvm consume GPU memory #2882

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 105 additions & 3 deletions api/src/main/java/ai/djl/util/cuda/CudaUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.management.MemoryUsage;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.regex.Pattern;

Expand All @@ -33,6 +37,8 @@ public final class CudaUtils {

private static final CudaLibrary LIB = loadLibrary();

private static String[] gpuInfo;

private CudaUtils() {}

/**
Expand All @@ -49,7 +55,15 @@ public static boolean hasCuda() {
*
* @return the number of GPUs available in the system
*/
@SuppressWarnings("PMD.NonThreadSafeSingleton")
public static int getGpuCount() {
if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
if (gpuInfo == null) {
gpuInfo = execute(-1); // NOPMD
}
return Integer.parseInt(gpuInfo[0]);
}

if (LIB == null) {
return 0;
}
Expand Down Expand Up @@ -79,7 +93,19 @@ public static int getGpuCount() {
*
* @return the version of CUDA runtime
*/
@SuppressWarnings("PMD.NonThreadSafeSingleton")
public static int getCudaVersion() {
if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
if (gpuInfo == null) {
gpuInfo = execute(-1);
}
int version = Integer.parseInt(gpuInfo[1]);
if (version == -1) {
throw new IllegalArgumentException("No cuda device found.");
}
return version;
}

if (LIB == null) {
throw new IllegalStateException("No cuda library is loaded.");
}
Expand All @@ -95,9 +121,6 @@ public static int getCudaVersion() {
* @return the version string of CUDA runtime
*/
public static String getCudaVersionString() {
if (LIB == null) {
throw new IllegalStateException("No cuda library is loaded.");
}
int version = getCudaVersion();
int major = version / 1000;
int minor = (version / 10) % 10;
Expand All @@ -111,6 +134,14 @@ public static String getCudaVersionString() {
* @return the CUDA compute capability
*/
public static String getComputeCapability(int device) {
if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
String[] ret = execute(device);
if (ret.length != 3) {
throw new IllegalArgumentException(ret[0]);
}
return ret[0];
}

if (LIB == null) {
throw new IllegalStateException("No cuda library is loaded.");
}
Expand All @@ -137,6 +168,16 @@ public static MemoryUsage getGpuMemory(Device device) {
throw new IllegalArgumentException("Only GPU device is allowed.");
}

if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
String[] ret = execute(device.getDeviceId());
if (ret.length != 3) {
throw new IllegalArgumentException(ret[0]);
}
long total = Long.parseLong(ret[1]);
long used = Long.parseLong(ret[2]);
return new MemoryUsage(-1, used, used, total);
}

if (LIB == null) {
throw new IllegalStateException("No GPU device detected.");
}
Expand All @@ -155,8 +196,42 @@ public static MemoryUsage getGpuMemory(Device device) {
return new MemoryUsage(-1, committed, committed, total[0]);
}

/**
* The main entrypoint to get CUDA information with command line.
*
* @param args the command line arguments.
*/
@SuppressWarnings("PMD.SystemPrintln")
public static void main(String[] args) {
int gpuCount = getGpuCount();
if (args.length == 0) {
if (gpuCount <= 0) {
System.out.println("0,-1");
return;
}
int cudaVersion = getCudaVersion();
System.out.println(gpuCount + "," + cudaVersion);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use println?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to read the console output from java side

return;
}
try {
int deviceId = Integer.parseInt(args[0]);
if (deviceId < 0 || deviceId >= gpuCount) {
System.out.println("Invalid device: " + deviceId);
return;
}
MemoryUsage mem = getGpuMemory(Device.gpu(deviceId));
String cc = getComputeCapability(deviceId);
System.out.println(cc + ',' + mem.getMax() + ',' + mem.getUsed());
} catch (NumberFormatException e) {
System.out.println("Invalid device: " + args[0]);
}
}

private static CudaLibrary loadLibrary() {
try {
if (Boolean.getBoolean("ai.djl.util.cuda.folk")) {
return null;
}
if (System.getProperty("os.name").startsWith("Win")) {
String path = Utils.getenv("PATH");
if (path == null) {
Expand Down Expand Up @@ -199,6 +274,33 @@ private static CudaLibrary loadLibrary() {
}
}

private static String[] execute(int deviceId) {
try {
String javaHome = System.getProperty("java.home");
String classPath = System.getProperty("java.class.path");
String os = System.getProperty("os.name");
List<String> cmd = new ArrayList<>(4);
if (os.startsWith("Win")) {
cmd.add(javaHome + "\\bin\\java.exe");
} else {
cmd.add(javaHome + "/bin/java");
}
cmd.add("-cp");
cmd.add(classPath);
cmd.add("ai.djl.util.cuda.CudaUtils");
if (deviceId >= 0) {
cmd.add(String.valueOf(deviceId));
}
Process ps = new ProcessBuilder(cmd).redirectErrorStream(true).start();
try (InputStream is = ps.getInputStream()) {
String line = Utils.toString(is).trim();
return line.split(",");
}
} catch (IOException e) {
throw new IllegalArgumentException("Failed get GPU information", e);
}
}

private static void checkCall(int ret) {
if (LIB == null) {
throw new IllegalStateException("No cuda library is loaded.");
Expand Down
9 changes: 6 additions & 3 deletions api/src/test/java/ai/djl/util/SecurityManagerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ public void checkPermission(Permission perm) {
}
};
System.setSecurityManager(sm);

Assert.assertFalse(CudaUtils.hasCuda());
Assert.assertEquals(CudaUtils.getGpuCount(), 0);
try {
Assert.assertFalse(CudaUtils.hasCuda());
Assert.assertEquals(CudaUtils.getGpuCount(), 0);
} finally {
System.setSecurityManager(null);
}
}
}
21 changes: 15 additions & 6 deletions api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import org.testng.annotations.Test;

import java.lang.management.MemoryUsage;
import java.util.Arrays;
import java.util.List;

public class CudaUtilsTest {

Expand All @@ -30,23 +28,34 @@ public class CudaUtilsTest {
@Test
public void testCudaUtils() {
if (!CudaUtils.hasCuda()) {
Assert.assertThrows(CudaUtils::getCudaVersionString);
Assert.assertThrows(() -> CudaUtils.getComputeCapability(0));
Assert.assertThrows(() -> CudaUtils.getGpuMemory(Device.gpu()));
return;
}
// Possible to have CUDA and not have a GPU.
if (CudaUtils.getGpuCount() == 0) {
return;
}

int cudaVersion = CudaUtils.getCudaVersion();
String cudaVersion = CudaUtils.getCudaVersionString();
String smVersion = CudaUtils.getComputeCapability(0);
MemoryUsage memoryUsage = CudaUtils.getGpuMemory(Device.gpu());

logger.info("CUDA runtime version: {}, sm: {}", cudaVersion, smVersion);
logger.info("Memory usage: {}", memoryUsage);

Assert.assertTrue(cudaVersion >= 9020, "cuda 9.2+ required.");
Assert.assertNotNull(cudaVersion);
Assert.assertNotNull(smVersion);
}

List<String> supportedSm = Arrays.asList("37", "52", "60", "61", "70", "75");
Assert.assertTrue(supportedSm.contains(smVersion), "Unsupported cuda sm: " + smVersion);
@Test
public void testCudaUtilsWithFolk() {
System.setProperty("ai.djl.util.cuda.folk", "true");
try {
testCudaUtils();
} finally {
System.clearProperty("ai.djl.util.cuda.folk");
}
}
}
Loading