-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add read/write plugins for displacement fields
- Loading branch information
Showing
2 changed files
with
359 additions
and
0 deletions.
There are no files selected for viewing
167 changes: 167 additions & 0 deletions
167
src/main/java/bigwarp/scripts/ReadDisplacementField.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
package bigwarp.scripts; | ||
|
||
import java.util.concurrent.Callable; | ||
|
||
import org.janelia.saalfeldlab.n5.DatasetAttributes; | ||
import org.janelia.saalfeldlab.n5.N5Exception; | ||
import org.janelia.saalfeldlab.n5.N5Reader; | ||
import org.janelia.saalfeldlab.n5.imglib2.N5DisplacementField; | ||
import org.janelia.saalfeldlab.n5.universe.N5Factory; | ||
import org.scijava.command.Command; | ||
import org.scijava.log.LogService; | ||
import org.scijava.plugin.Parameter; | ||
import org.scijava.plugin.Plugin; | ||
import org.scijava.ui.UIService; | ||
|
||
import net.imagej.Dataset; | ||
import net.imagej.DatasetService; | ||
import net.imagej.axis.CalibratedAxis; | ||
import net.imagej.axis.DefaultAxisType; | ||
import net.imagej.axis.DefaultLinearAxis; | ||
import net.imglib2.RandomAccessibleInterval; | ||
import net.imglib2.type.NativeType; | ||
import net.imglib2.type.numeric.RealType; | ||
import net.imglib2.type.numeric.real.DoubleType; | ||
import net.imglib2.type.numeric.real.FloatType; | ||
import net.imglib2.view.Views; | ||
import net.imglib2.type.numeric.integer.ByteType; | ||
import net.imglib2.type.numeric.integer.ShortType; | ||
|
||
@Plugin(type = Command.class, menuPath = "Plugins>Transform>Read Displacement Field") | ||
public class ReadDisplacementField implements Callable<Void>, Command { | ||
|
||
public static final String NATIVE = "NATIVE"; | ||
public static final String[] AXIS_LABELS = new String[] { "x", "y", "z" }; | ||
|
||
@Parameter | ||
private UIService ui; | ||
|
||
@Parameter | ||
private DatasetService ds; | ||
|
||
@Parameter | ||
private LogService log; | ||
|
||
@Parameter | ||
private String n5Root; | ||
|
||
@Parameter | ||
private String n5Dataset; | ||
|
||
@Parameter(label = "ResultType type", style = "listBox", choices = { "FLOAT64", "FLOAT32", "NATIVE" }) | ||
private String resultType; | ||
|
||
@Parameter(label = "Thread count", required = true, min = "1", max = "999") | ||
private int nThreads = 1; | ||
|
||
private CalibratedAxis[] axes; | ||
|
||
@Override | ||
public void run() { | ||
|
||
call(); | ||
} | ||
|
||
@Override | ||
public Void call() { | ||
|
||
process(); | ||
return null; | ||
} | ||
|
||
public <T extends RealType<T> & NativeType<T>> void process() { | ||
|
||
final RandomAccessibleInterval<T> dfieldRai = readDataAndMetadata(); | ||
final Dataset dataset = ds.create(dfieldRai); | ||
dataset.setName(n5Dataset); | ||
dataset.setAxes(axes); | ||
ui.show(dataset); | ||
} | ||
|
||
@SuppressWarnings("unchecked") | ||
private <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> readDataAndMetadata() { | ||
|
||
try (N5Reader n5 = new N5Factory().openReader(n5Root)) { | ||
createAxes(n5, n5Dataset); | ||
final RandomAccessibleInterval<T> img; | ||
if( resultType.equals(NATIVE)) | ||
img = N5DisplacementField.openRaw(n5, n5Dataset, (T)getRawZero(n5, n5Dataset)); | ||
else | ||
img = N5DisplacementField.openRaw(n5, n5Dataset, (T)getTargetType()); | ||
|
||
final int nd = img.numDimensions(); | ||
final RandomAccessibleInterval<T> imgp; | ||
if( nd == 4 ) | ||
imgp = Views.moveAxis(img, nd-2, nd-1); | ||
else if( nd == 3 ) | ||
imgp = img; | ||
else | ||
throw new N5Exception("Dataset must be 3D or 4D, but had " + nd + " dimensions"); | ||
|
||
return imgp; | ||
|
||
}catch( Exception e ) { | ||
e.printStackTrace(); | ||
} | ||
throw new N5Exception("Could not read displacement field from: " + n5Root + " " + n5Dataset); | ||
} | ||
|
||
private CalibratedAxis[] createAxes(N5Reader n5, final String n5Dataset) { | ||
final DatasetAttributes dsetAttrs = n5.getDatasetAttributes(n5Dataset); | ||
if (dsetAttrs == null) | ||
throw new N5Exception("No dataset at" + n5Dataset); | ||
|
||
final double[] spacing = n5.getAttribute(n5Dataset, N5DisplacementField.SPACING_ATTR, double[].class); | ||
final double[] offset = n5.getAttribute(n5Dataset, N5DisplacementField.OFFSET_ATTR, double[].class); | ||
|
||
final int nd = dsetAttrs.getNumDimensions(); | ||
// last axis always hold vector dimension | ||
axes = new CalibratedAxis[nd]; | ||
int j = 0; | ||
for (int i = 0; i < nd; i++) { | ||
if (i == 2) | ||
axes[i] = new DefaultLinearAxis(new DefaultAxisType("v", false), "px"); | ||
else { | ||
axes[i] = new DefaultLinearAxis(new DefaultAxisType(AXIS_LABELS[j], true), "px", spacing[j], offset[j]); | ||
j++; | ||
} | ||
} | ||
|
||
return axes; | ||
} | ||
|
||
@SuppressWarnings({ "incomplete-switch", "unchecked" }) | ||
private static <T extends RealType<T> & NativeType<T>> T getRawZero( N5Reader n5, final String n5Dataset ) { | ||
|
||
// The types enumerated here are the only allowed types for displacement fields | ||
final DatasetAttributes dsetAttrs = n5.getDatasetAttributes(n5Dataset); | ||
if( dsetAttrs == null) | ||
throw new N5Exception("No dataset at" + n5Dataset); | ||
|
||
switch( dsetAttrs.getDataType() ) { | ||
case INT8: | ||
return (T)new ByteType(); | ||
case INT16: | ||
return (T)new ShortType(); | ||
case FLOAT32: | ||
return (T)new FloatType(); | ||
case FLOAT64: | ||
return (T)new DoubleType(); | ||
} | ||
throw new N5Exception("Unexpected type: " + dsetAttrs.getDataType()); | ||
} | ||
|
||
|
||
@SuppressWarnings("unchecked") | ||
private <T extends RealType<T> & NativeType<T>> T getTargetType() { | ||
switch (resultType) { | ||
case WriteDisplacementField.FLOAT32: | ||
return (T)new FloatType(); | ||
case WriteDisplacementField.FLOAT64: | ||
return (T)new DoubleType(); | ||
} | ||
return null; | ||
} | ||
|
||
|
||
} |
192 changes: 192 additions & 0 deletions
192
src/main/java/bigwarp/scripts/WriteDisplacementField.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
package bigwarp.scripts; | ||
|
||
import java.util.Arrays; | ||
import java.util.concurrent.Callable; | ||
import java.util.stream.IntStream; | ||
|
||
import org.janelia.saalfeldlab.n5.Compression; | ||
import org.janelia.saalfeldlab.n5.N5Writer; | ||
import org.janelia.saalfeldlab.n5.ij.N5ScalePyramidExporter; | ||
import org.janelia.saalfeldlab.n5.imglib2.N5DisplacementField; | ||
import org.janelia.saalfeldlab.n5.universe.N5Factory; | ||
import org.scijava.command.Command; | ||
import org.scijava.log.LogService; | ||
import org.scijava.plugin.Parameter; | ||
import org.scijava.plugin.Plugin; | ||
import org.scijava.ui.UIService; | ||
|
||
import net.imagej.Dataset; | ||
import net.imglib2.RandomAccessibleInterval; | ||
import net.imglib2.converter.Converter; | ||
import net.imglib2.converter.Converters; | ||
import net.imglib2.realtransform.AffineGet; | ||
import net.imglib2.type.NativeType; | ||
import net.imglib2.type.numeric.IntegerType; | ||
import net.imglib2.type.numeric.RealType; | ||
import net.imglib2.type.numeric.real.DoubleType; | ||
import net.imglib2.type.numeric.real.FloatType; | ||
import net.imglib2.type.numeric.integer.ByteType; | ||
import net.imglib2.type.numeric.integer.ShortType; | ||
import net.imglib2.view.Views; | ||
|
||
@Plugin(type = Command.class, menuPath = "Plugins>Transform>Write Displacement Field") | ||
public class WriteDisplacementField implements Callable<Void>, Command { | ||
|
||
// Output types | ||
public static final String INT8 = "INT8"; | ||
public static final String INT16 = "INT16"; | ||
public static final String FLOAT32 = "FLOAT32"; | ||
public static final String FLOAT64 = "FLOAT64"; | ||
|
||
@Parameter | ||
private UIService ui; | ||
|
||
@Parameter | ||
private LogService log; | ||
|
||
@Parameter | ||
private String n5Root; | ||
|
||
@Parameter | ||
private String n5Dataset; | ||
|
||
@Parameter | ||
private Dataset dataset; | ||
|
||
@Parameter(label = "Chunk size", description = "The size of chunks. Comma separated, for example: \"64,32,16\".\n " | ||
+ | ||
"ImageJ's axis order is X,Y,C,Z,T. The chunk size must be specified in this order.\n" + | ||
"You must skip any axis whose size is 1, e.g. a 2D time-series without channels\n" + | ||
"may have a chunk size of 1024,1024,1 (X,Y,T).\n" + | ||
"You may provide fewer values than the data dimension. In that case, the size will\n" + | ||
"be expanded to necessary size with the last value, for example \"64\", will expand\n" + | ||
"to \"64,64,64\" for 3D data.") | ||
private String chunkSizeArg; | ||
|
||
@Parameter(label = "Compression", style = "listBox", choices = { | ||
N5ScalePyramidExporter.GZIP_COMPRESSION, | ||
N5ScalePyramidExporter.RAW_COMPRESSION, | ||
N5ScalePyramidExporter.LZ4_COMPRESSION, | ||
N5ScalePyramidExporter.XZ_COMPRESSION, | ||
N5ScalePyramidExporter.BLOSC_COMPRESSION, | ||
N5ScalePyramidExporter.ZSTD_COMPRESSION}) | ||
private String compressionArg = N5ScalePyramidExporter.GZIP_COMPRESSION; | ||
|
||
@Parameter(label = "Output type", style = "listBox", choices = { | ||
"FLOAT64", "FLOAT32", "INT16", "INT8" | ||
}) | ||
private String outputType; | ||
|
||
@Parameter(label = "Thread count", required = true, min = "1", max = "999") | ||
private int nThreads = 1; | ||
|
||
private int nd = -1; | ||
private int vectorDim = -1; | ||
private int vectorSize = -1; | ||
|
||
@SuppressWarnings({ "unchecked" }) | ||
public <T extends RealType<T> & NativeType<T>, S extends RealType<S> & NativeType<S>, Q extends NativeType<Q> & IntegerType<Q>> void process() { | ||
final AffineGet affine = null; | ||
final Compression compression = N5ScalePyramidExporter.getCompression(compressionArg); | ||
|
||
nd = dataset.numDimensions() - 1; | ||
final long[] spatialDims = new long[nd]; | ||
final double[] offset = new double[nd]; | ||
final double[] spacing = new double[nd]; | ||
Arrays.fill(spacing, 1.0); | ||
|
||
int j = 0; | ||
for (int i = 0; i < dataset.numDimensions(); i++) { | ||
|
||
if (dataset.axis(i).type().isSpatial()) { | ||
spatialDims[j] = dataset.dimension(i); | ||
offset[j] = dataset.axis(i).calibratedValue(0.0); | ||
spacing[j++] = dataset.averageScale(i); | ||
} else { | ||
vectorDim = i; | ||
vectorSize = (int) dataset.dimension(i); | ||
} | ||
} | ||
|
||
validateAndWarn(); | ||
|
||
final int[] chunkSizeSpatial = N5ScalePyramidExporter.parseBlockSize(chunkSizeArg, spatialDims); | ||
final int[] chunkSize = IntStream.concat( | ||
IntStream.of(vectorSize), | ||
Arrays.stream(chunkSizeSpatial)).toArray(); | ||
|
||
final RandomAccessibleInterval<T> vectorAxisFirst = (RandomAccessibleInterval<T>) Views.moveAxis( (RandomAccessibleInterval<T>)dataset, vectorDim, 0); | ||
try (N5Writer n5 = new N5Factory().openWriter(n5Root)) { | ||
|
||
if (outputType.equals(FLOAT32) || outputType.equals(FLOAT64)) { | ||
final RandomAccessibleInterval<S> converted = convertIfNecessary(vectorAxisFirst, (S)getTargetType()); | ||
N5DisplacementField.save(n5, n5Dataset, affine, converted, spacing, offset, chunkSize, compression); | ||
} | ||
else { | ||
final Q quantizedType = (Q)getTargetType(); | ||
N5DisplacementField.save(n5, n5Dataset, affine, vectorAxisFirst, spacing, offset, chunkSize, compression, quantizedType, 1e-6); | ||
} | ||
|
||
} catch (Exception e) { | ||
System.err.println("Failed to write displacement field at " + n5Root); | ||
e.printStackTrace(); | ||
} | ||
|
||
} | ||
|
||
@SuppressWarnings("unchecked") | ||
private <T extends RealType<T> & NativeType<T>> T getTargetType() { | ||
switch (outputType) { | ||
case FLOAT32: | ||
return (T)new FloatType(); | ||
case FLOAT64: | ||
return (T)new DoubleType(); | ||
case INT16: | ||
return (T)new ShortType(); | ||
case INT8: | ||
return (T)new ByteType(); | ||
} | ||
return null; | ||
} | ||
|
||
@SuppressWarnings("unchecked") | ||
private <T extends RealType<T> & NativeType<T>, S extends RealType<S> & NativeType<S>> RandomAccessibleInterval<S> convertIfNecessary( | ||
final RandomAccessibleInterval<T> dfield, final S targetType) { | ||
|
||
if (dfield.getType().getClass().equals(targetType.getClass())) | ||
return (RandomAccessibleInterval<S>) dfield; | ||
|
||
final Converter<T, S> conv = new Converter<T, S>() { | ||
|
||
@Override | ||
public void convert(T input, S output) { | ||
|
||
output.setReal(input.getRealDouble()); | ||
} | ||
}; | ||
return Converters.convertRAI(dfield, conv, targetType); | ||
} | ||
|
||
private void validateAndWarn() { | ||
|
||
if (vectorSize != nd) { | ||
ui.showDialog(String.format("Error: channel dimension size (%d) must match dimensionality (%d). Exiting.", | ||
vectorSize, nd)); | ||
return; | ||
} | ||
} | ||
|
||
@Override | ||
public void run() { | ||
|
||
call(); | ||
} | ||
|
||
@Override | ||
public Void call() { | ||
|
||
process(); | ||
return null; | ||
} | ||
|
||
} |