Skip to content

Commit

Permalink
[SXnVeQi9] change ApocStreamHandlerFactory to a service provider to a…
Browse files Browse the repository at this point in the history
…void deadlock on startup (neo4j-contrib#210)
  • Loading branch information
nadja-muller authored Oct 27, 2022
1 parent 3a01fe6 commit 81c777c
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 122 deletions.
9 changes: 0 additions & 9 deletions common/src/main/java/apoc/ApocExtensionFactory.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package apoc;

import apoc.util.ApocUrlStreamHandlerFactory;
import org.neo4j.annotations.service.ServiceProvider;
import org.neo4j.dbms.api.DatabaseManagementService;
import org.neo4j.graphdb.GraphDatabaseService;
Expand All @@ -19,7 +18,6 @@
import org.neo4j.scheduler.JobScheduler;
import org.neo4j.service.Services;

import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
Expand All @@ -36,13 +34,6 @@
@ServiceProvider
public class ApocExtensionFactory extends ExtensionFactory<ApocExtensionFactory.Dependencies> {

static {
try {
URL.setURLStreamHandlerFactory(new ApocUrlStreamHandlerFactory());
} catch (Error e) {
System.err.println("APOC couln't set a URLStreamHandlerFactory since some other tool already did this (e.g. tomcat). This means you cannot use s3:// or hdfs:// style URLs in APOC. This is caused by a limitation of the JVM which we cannot fix. ");
}
}
public ApocExtensionFactory() {
super(ExtensionType.DATABASE, "APOC");
}
Expand Down
12 changes: 7 additions & 5 deletions common/src/main/java/apoc/util/ApocUrlStreamHandlerFactory.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package apoc.util;

import java.net.URLStreamHandler;
import java.net.URLStreamHandlerFactory;
import java.net.spi.URLStreamHandlerProvider;

public class ApocUrlStreamHandlerFactory implements URLStreamHandlerFactory {


public class ApocUrlStreamHandlerFactory extends URLStreamHandlerProvider
{

@Override
public URLStreamHandler createURLStreamHandler(String protocol) {
FileUtils.SupportedProtocols supportedProtocol = FileUtils.SupportedProtocols.of(protocol);
return supportedProtocol == null ? null : supportedProtocol.createURLStreamHandler();
SupportedProtocols supportedProtocol = FileUtils.of(protocol);
return supportedProtocol == null ? null : FileUtils.createURLStreamHandler(supportedProtocol);
}

}
176 changes: 74 additions & 102 deletions common/src/main/java/apoc/util/FileUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,117 +40,69 @@
*/
public class FileUtils {

public enum SupportedProtocols {
http(true, null),
https(true, null),
ftp(true, null),
s3(Util.classExists("com.amazonaws.services.s3.AmazonS3"),
"apoc.util.s3.S3UrlStreamHandlerFactory"),
gs(Util.classExists("com.google.cloud.storage.Storage"),
"apoc.util.google.cloud.GCStorageURLStreamHandlerFactory"),
hdfs(Util.classExists("org.apache.hadoop.fs.FileSystem"),
"org.apache.hadoop.fs.FsUrlStreamHandlerFactory"),
file(true, null);

private final boolean enabled;

private final String urlStreamHandlerClassName;

SupportedProtocols(boolean enabled, String urlStreamHandlerClassName) {
this.enabled = enabled;
this.urlStreamHandlerClassName = urlStreamHandlerClassName;
}

public StreamConnection getStreamConnection(String urlAddress, Map<String, Object> headers, String payload) throws IOException {
switch (this) {
case s3:
return FileUtils.openS3InputStream(urlAddress);
case hdfs:
return FileUtils.openHdfsInputStream(urlAddress);
case ftp:
case http:
case https:
case gs:
return readHttpInputStream(urlAddress, headers, payload);
default:
try {
return new StreamConnection.FileStreamConnection(URI.create(urlAddress));
} catch (IllegalArgumentException iae) {
try {
return new StreamConnection.FileStreamConnection(new URL(urlAddress).getFile());
} catch (MalformedURLException mue) {
if (mue.getMessage().contains("no protocol")) {
return new StreamConnection.FileStreamConnection(urlAddress);
}
throw mue;
}
}
}
}

public OutputStream getOutputStream(String fileName, ExportConfig config) {
if (fileName == null) return null;
final CompressionAlgo compressionAlgo = CompressionAlgo.valueOf(config.getCompressionAlgo());
final OutputStream outputStream;
public static StreamConnection getStreamConnection(SupportedProtocols protocol, String urlAddress, Map<String, Object> headers, String payload) throws IOException {
switch (protocol) {
case s3:
return FileUtils.openS3InputStream(urlAddress);
case hdfs:
return FileUtils.openHdfsInputStream(urlAddress);
case ftp:
case http:
case https:
case gs:
return readHttpInputStream(urlAddress, headers, payload);
default:
try {
switch (this) {
case s3:
outputStream = S3UploadUtils.writeFile(fileName);
break;
case hdfs:
outputStream = HDFSUtils.writeFile(fileName);
break;
default:
final Path path = resolvePath(fileName);
outputStream = new FileOutputStream(path.toFile());
return new StreamConnection.FileStreamConnection(URI.create(urlAddress));
} catch (IllegalArgumentException iae) {
try {
return new StreamConnection.FileStreamConnection(new URL(urlAddress).getFile());
} catch (MalformedURLException mue) {
if (mue.getMessage().contains("no protocol")) {
return new StreamConnection.FileStreamConnection(urlAddress);
}
throw mue;
}
return new BufferedOutputStream(compressionAlgo.getOutputStream(outputStream));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

public boolean isEnabled() {
return enabled;
}

public URLStreamHandler createURLStreamHandler() {
return Optional.ofNullable(urlStreamHandlerClassName)
.map(Util::createInstanceOrNull)
.map(urlStreamHandlerFactory -> ((URLStreamHandlerFactory) urlStreamHandlerFactory).createURLStreamHandler(this.name()))
.orElse(null);
}
public static URLStreamHandler createURLStreamHandler(SupportedProtocols protocol) {
URLStreamHandler handler = Optional.ofNullable(protocol.getUrlStreamHandlerClassName())
.map(Util::createInstanceOrNull)
.map(urlStreamHandlerFactory -> ((URLStreamHandlerFactory) urlStreamHandlerFactory).createURLStreamHandler(protocol.name()))
.orElse(null);
return handler;
}

public static SupportedProtocols from(String source) {
try {
final URL url = new URL(source);
return from(url);
} catch (MalformedURLException e) {
if (!e.getMessage().contains("no protocol")) {
try {
// in case new URL(source) throw e.g. unknown protocol: hdfs, because of missing jar,
// we retrieve the related enum and throw the associated MissingDependencyException(..)
// otherwise we return unknown protocol: yyyyy
return SupportedProtocols.valueOf(new URI(source).getScheme());
} catch (Exception ignored) {}
throw new RuntimeException(e);
}
return SupportedProtocols.file;
}
public static SupportedProtocols of(String name) {
try {
return SupportedProtocols.valueOf(name);
} catch (Exception e) {
return SupportedProtocols.file;
}
}

public static SupportedProtocols from(URL url) {
return SupportedProtocols.of(url.getProtocol());
}
public static SupportedProtocols from(URL url) {
return of(url.getProtocol());
}

public static SupportedProtocols of(String name) {
try {
return SupportedProtocols.valueOf(name);
} catch (Exception e) {
return file;
public static SupportedProtocols from(String source) {
try {
final URL url = new URL(source);
return from(url);
} catch (MalformedURLException e) {
if (!e.getMessage().contains("no protocol")) {
try {
// in case new URL(source) throw e.g. unknown protocol: hdfs, because of missing jar,
// we retrieve the related enum and throw the associated MissingDependencyException(..)
// otherwise we return unknown protocol: yyyyy
return SupportedProtocols.valueOf(new URI(source).getScheme());
} catch (Exception ignored) {}
throw new RuntimeException(e);
}
return SupportedProtocols.file;
}

}

public static final String ERROR_READ_FROM_FS_NOT_ALLOWED = "Import file %s not enabled, please set " + APOC_IMPORT_FILE_ALLOW__READ__FROM__FILESYSTEM + "=true in your neo4j.conf";
Expand Down Expand Up @@ -251,7 +203,7 @@ private static boolean pathStartsWithOther(Path resolvedPath, Path basePath) thr
}

public static boolean isFile(String fileName) {
return SupportedProtocols.from(fileName) == SupportedProtocols.file;
return from(fileName) == SupportedProtocols.file;
}

public static OutputStream getOutputStream(String fileName) {
Expand All @@ -262,7 +214,27 @@ public static OutputStream getOutputStream(String fileName, ExportConfig config)
if (fileName.equals("-")) {
return null;
}
return SupportedProtocols.from(fileName).getOutputStream(fileName, config);
return getOutputStream(from(fileName), fileName, config);
}

public static OutputStream getOutputStream(SupportedProtocols protocol, String fileName, ExportConfig config) {
if (fileName == null) return null;
final CompressionAlgo compressionAlgo = CompressionAlgo.valueOf(config.getCompressionAlgo());
final OutputStream outputStream;
try {
switch ( protocol )
{
case s3 -> outputStream = S3UploadUtils.writeFile( fileName );
case hdfs -> outputStream = HDFSUtils.writeFile( fileName );
default -> {
final Path path = resolvePath( fileName );
outputStream = new FileOutputStream( path.toFile() );
}
}
return new BufferedOutputStream(compressionAlgo.getOutputStream(outputStream));
} catch (Exception e) {
throw new RuntimeException(e);
}
}

public static boolean isImportUsingNeo4jConfig() {
Expand Down
32 changes: 32 additions & 0 deletions common/src/main/java/apoc/util/SupportedProtocols.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package apoc.util;

public enum SupportedProtocols
{
http(true, null),
https(true, null),
ftp(true, null),
s3(Util.classExists("com.amazonaws.services.s3.AmazonS3"),
"apoc.util.s3.S3UrlStreamHandlerFactory"),
gs(Util.classExists("com.google.cloud.storage.Storage"),
"apoc.util.google.cloud.GCStorageURLStreamHandlerFactory"),
hdfs(Util.classExists("org.apache.hadoop.fs.FileSystem"),
"org.apache.hadoop.fs.FsUrlStreamHandlerFactory"),
file(true, null);

private final boolean enabled;

private final String urlStreamHandlerClassName;

SupportedProtocols(boolean enabled, String urlStreamHandlerClassName) {
this.enabled = enabled;
this.urlStreamHandlerClassName = urlStreamHandlerClassName;
}

public boolean isEnabled() {
return enabled;
}

String getUrlStreamHandlerClassName() {
return urlStreamHandlerClassName;
}
}
4 changes: 1 addition & 3 deletions common/src/main/java/apoc/util/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,7 @@ private static CountingInputStream getStreamCompressedFile(String urlAddress, Ma
}

private static StreamConnection getStreamConnection(String urlAddress, Map<String, Object> headers, String payload) throws IOException {
return FileUtils.SupportedProtocols
.from(urlAddress)
.getStreamConnection(urlAddress, headers, payload);
return FileUtils.getStreamConnection( FileUtils.from( urlAddress), urlAddress, headers, payload);
}

private static InputStream getFileStreamIntoCompressedFile(InputStream is, String fileName, ArchiveType archiveType) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package apoc.util.google.cloud;

import apoc.util.FileUtils;
import apoc.util.SupportedProtocols;

import java.net.URLStreamHandler;
import java.net.URLStreamHandlerFactory;
Expand All @@ -11,7 +11,7 @@ public GCStorageURLStreamHandlerFactory() {}

@Override
public URLStreamHandler createURLStreamHandler(final String protocol) {
final FileUtils.SupportedProtocols supportedProtocols = FileUtils.SupportedProtocols.valueOf(protocol);
return supportedProtocols == FileUtils.SupportedProtocols.gs ? new GCStorageURLStreamHandler() : null;
final SupportedProtocols supportedProtocols = SupportedProtocols.valueOf(protocol);
return supportedProtocols == SupportedProtocols.gs ? new GCStorageURLStreamHandler() : null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
apoc.util.ApocUrlStreamHandlerFactory

0 comments on commit 81c777c

Please sign in to comment.