From 3e6af1b814bf2c71e89d79a6ca4f88fb71608ebe Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Sun, 7 Jan 2024 17:06:30 +0100 Subject: [PATCH] [SYSTEMDS-3663] Low overhead join indexes This commit adds a few more variations to indexes to allow efficient combination and ordering of column indexes when co-coding. This is critical in cases where thousands of columns are combined, since the execution time suddenly is dominated not by combining columns but the column indexes. Closes #1979 --- .../compress/colgroup/indexes/AColIndex.java | 56 +++- .../colgroup/indexes/ColIndexFactory.java | 2 + .../colgroup/indexes/CombinedIndex.java | 246 ++++++++++++++++++ .../compress/colgroup/indexes/IColIndex.java | 80 +++++- .../compress/colgroup/indexes/RangeIndex.java | 84 ++++-- .../colgroup/indexes/TwoRangesIndex.java | 4 +- 6 files changed, 437 insertions(+), 35 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/CombinedIndex.java diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java index df4685a65d6..81a5f5b4803 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java @@ -21,6 +21,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockCSR; public abstract class AColIndex implements IColIndex { @@ -69,11 +71,55 @@ private static int hashCode(IIterate it) { @Override public boolean containsAny(IColIndex idx) { - final IIterate it = idx.iterator(); - while(it.hasNext()) - if(contains(it.next())) - return true; + if(idx instanceof TwoRangesIndex){ + TwoRangesIndex o = (TwoRangesIndex) idx; + return this.containsAny(o.idx1) || this.containsAny(o.idx2); + } + else if(idx instanceof CombinedIndex){ + CombinedIndex ci = (CombinedIndex) idx; + return containsAny(ci.l) || containsAny(ci.r); + } + else{ + final IIterate it = idx.iterator(); + while(it.hasNext()) + if(contains(it.next())) + return true; + + return false; + } + } - return false; + @Override + public void decompressToDenseFromSparse(SparseBlock sb, int vr, int off, double[] c) { + if(sb instanceof SparseBlockCSR) + decompressToDenseFromSparseCSR((SparseBlockCSR)sb, vr, off, c); + else + decompressToDenseFromSparseGeneric(sb, vr, off, c); + } + + private void decompressToDenseFromSparseGeneric(SparseBlock sb, int vr, int off, double[] c) { + if(sb.isEmpty(vr)) + return; + final int apos = sb.pos(vr); + final int alen = sb.size(vr) + apos; + final int[] aix = sb.indexes(vr); + final double[] aval = sb.values(vr); + for(int j = apos; j < alen; j++) + c[off + get(aix[j])] += aval[j]; + } + + private void decompressToDenseFromSparseCSR(SparseBlockCSR sb, int vr, int off, double[] c) { + final int apos = sb.pos(vr); + final int alen = sb.size(vr) + apos; + final int[] aix = sb.indexes(vr); + final double[] aval = sb.values(vr); + for(int j = apos; j < alen; j++) + c[off + get(aix[j])] += aval[j]; + } + + @Override + public void decompressVec(int nCol, double[] c, int off, double[] values, int rowIdx) { + for(int j = 0; j < nCol; j++) + c[off + get(j)] += values[rowIdx + j]; } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java index fd929b8a1aa..c9a45e4aeea 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java @@ -48,6 +48,8 @@ public static IColIndex read(DataInput in) throws IOException { return RangeIndex.read(in); case TWORANGE: return TwoRangesIndex.read(in); + case COMBINED: + return CombinedIndex.read(in); default: throw new DMLCompressionException("Failed reading column index of type: " + t); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/CombinedIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/CombinedIndex.java new file mode 100644 index 00000000000..f1a80a6d279 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/CombinedIndex.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup.indexes; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.sysds.runtime.compress.DMLCompressionException; + +public class CombinedIndex extends AColIndex { + protected final IColIndex l; + protected final IColIndex r; + + public CombinedIndex(IColIndex l, IColIndex r) { + this.l = l; + this.r = r; + } + + @Override + public int size() { + return l.size() + r.size(); + } + + @Override + public int get(int i) { + if(i >= l.size()) + return r.get(i - l.size()); + else + return l.get(i); + } + + @Override + public IColIndex shift(int i) { + return new CombinedIndex(l.shift(i), r.shift(i)); + } + + @Override + public void write(DataOutput out) throws IOException { + out.write(ColIndexType.COMBINED.ordinal()); + l.write(out); + r.write(out); + } + + @Override + public long getExactSizeOnDisk() { + return 1 + l.getExactSizeOnDisk() + r.getExactSizeOnDisk(); + } + + @Override + public long estimateInMemorySize() { + return 16 + 8 + 8 + l.estimateInMemorySize() + r.estimateInMemorySize(); + } + + @Override + public IIterate iterator() { + return new CombinedIterator(); + } + + @Override + public int findIndex(int i) { + final int a = l.findIndex(i); + if(a < 0) { + final int b = r.findIndex(i); + if(b < 0) + return b + a + 1; + else + return b + l.size(); + } + else + return a; + } + + @Override + public SliceResult slice(int l, int u) { + return getArrayIndex().slice(l, u); + } + + @Override + public boolean equals(IColIndex other) { + if(other == this) + return true; + else if(size() == other.size()) { + if(other instanceof CombinedIndex) { + CombinedIndex o = (CombinedIndex) other; + return o.l.equals(l) && o.r.equals(r); + } + else { + IIterate t = iterator(); + IIterate o = other.iterator(); + + while(t.hasNext()) { + if(t.next() != o.next()) + return false; + } + return true; + } + } + return false; + } + + @Override + public IColIndex combine(IColIndex other) { + final int sr = other.size(); + final int sl = size(); + final int maxCombined = Math.max(this.get(this.size() - 1), other.get(other.size() - 1)); + final int minCombined = Math.min(this.get(0), other.get(0)); + if(sr + sl == maxCombined - minCombined + 1) { + return new RangeIndex(minCombined, maxCombined + 1); + } + + final int[] ret = new int[sr + sl]; + IIterate t = iterator(); + IIterate o = other.iterator(); + int i = 0; + while(t.hasNext() && o.hasNext()) { + final int tv = t.v(); + final int ov = o.v(); + if(tv < ov) { + ret[i++] = tv; + t.next(); + } + else { + ret[i++] = ov; + o.next(); + } + } + while(t.hasNext()) + ret[i++] = t.next(); + while(o.hasNext()) + ret[i++] = o.next(); + + return ColIndexFactory.create(ret); + + } + + @Override + public boolean isContiguous() { + return false; + } + + @Override + public int[] getReorderingIndex() { + return getArrayIndex().getReorderingIndex(); + } + + @Override + public boolean isSorted() { + return true; + } + + @Override + public IColIndex sort() { + throw new DMLCompressionException("CombinedIndex is always sorted"); + } + + @Override + public boolean contains(int i) { + return l.contains(i) || r.contains(i); + } + + @Override + public double avgOfIndex() { + double lv = l.avgOfIndex() * l.size(); + double rv = r.avgOfIndex() * r.size(); + return (lv + rv) / size(); + } + + private IColIndex getArrayIndex() { + int s = size(); + int[] vals = new int[s]; + IIterate a = iterator(); + for(int i = 0; i < s; i++) { + vals[i] = a.next(); + } + return ColIndexFactory.create(vals); + } + + public static CombinedIndex read(DataInput in) throws IOException { + return new CombinedIndex(ColIndexFactory.read(in), ColIndexFactory.read(in)); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append("["); + sb.append(l); + sb.append(", "); + sb.append(r); + sb.append("]"); + return sb.toString(); + } + + protected class CombinedIterator implements IIterate { + boolean doneFirst = false; + IIterate I = l.iterator(); + + @Override + public int next() { + int v = I.next(); + if(!I.hasNext() && !doneFirst) { + doneFirst = true; + I = r.iterator(); + } + return v; + + } + + @Override + public boolean hasNext() { + return I.hasNext() || doneFirst == false; + } + + @Override + public int v() { + return I.v(); + } + + @Override + public int i() { + if(doneFirst) + return I.i() + l.size(); + else + return I.i(); + } + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java index 60c2cec4b23..8da8ad518ff 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java @@ -22,13 +22,16 @@ import java.io.DataOutput; import java.io.IOException; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.matrix.data.Pair; + /** * Class to contain column indexes for the compression column groups. */ public interface IColIndex { public static enum ColIndexType { - SINGLE, TWO, ARRAY, RANGE, TWORANGE, UNKNOWN; + SINGLE, TWO, ARRAY, RANGE, TWORANGE, COMBINED, UNKNOWN; } /** @@ -212,6 +215,76 @@ public static enum ColIndexType { */ public double avgOfIndex(); + /** + * Decompress this + */ + /** + * Decompress this column index into the dense c array. + * + * @param sb A sparse block to extract values out of and insert into c + * @param vr The row to extract from the sparse block + * @param off The offset that the row starts at in c. + * @param c The dense output to decompress into + */ + public void decompressToDenseFromSparse(SparseBlock sb, int vr, int off, double[] c); + + /** + * Decompress into c using the values provided. The offset to start into c is off and then row index is similarly the + * offset of values. nCol specify the number of values to add over. + * + * @param nCol The number of columns to copy. + * @param c The output to add into + * @param off The offset to start in c + * @param values the values to copy from + * @param rowIdx The offset to start in values + */ + public void decompressVec(int nCol, double[] c, int off, double[] values, int rowIdx); + + /** + * Indicate if the two given column indexes are in order such that the first set of indexes all are of lower value + * than the second. + * + * @param a the first column index + * @param b the second column index + * @return If the first all is lower than the second. + */ + public static boolean inOrder(IColIndex a, IColIndex b) { + return a.get(a.size() - 1) < b.get(0); + } + + public static Pair reorderingIndexes(IColIndex a, IColIndex b){ + final int[] ar = new int[a.size()]; + final int[] br = new int[b.size()]; + final IIterate ai = a.iterator(); + final IIterate bi = b.iterator(); + + int ia = 0; + int ib = 0; + int i = 0; + while(ai.hasNext() && bi.hasNext()){ + if(ai.v()< bi.v()){ + ar[ia++] = i++; + ai.next(); + } + else{ + br[ib++] = i++; + bi.next(); + } + } + + while(ai.hasNext()){ + ar[ia++] = i++; + ai.next(); + } + + while(bi.hasNext()){ + br[ib++] = i++; + bi.next(); + } + + return new Pair(ar, br); + } + /** A Class for slice results containing indexes for the slicing of dictionaries, and the resulting column index */ public static class SliceResult { /** Start index to slice inside the dictionary */ @@ -223,9 +296,10 @@ public static class SliceResult { /** * The slice result + * * @param idStart The starting index - * @param idEnd The ending index (not inclusive) - * @param ret The resulting IColIndex + * @param idEnd The ending index (not inclusive) + * @param ret The resulting IColIndex */ protected SliceResult(int idStart, int idEnd, IColIndex ret) { this.idStart = idStart; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java index bbe5aeb8a5c..17c2bed3ba1 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java @@ -24,6 +24,7 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.utils.IntArrayList; @@ -133,7 +134,7 @@ else if(l <= this.l && u >= this.u) int minU = Math.min(u, this.u); int offL = maxL - this.l; int offR = minU - this.l; - return new SliceResult(offL, offR, new RangeIndex(maxL - l, minU - l )); + return new SliceResult(offL, offR, new RangeIndex(maxL - l, minU - l)); } } @@ -147,16 +148,40 @@ public boolean equals(IColIndex other) { return other.equals(this); } + @Override + public boolean containsAny(IColIndex idx) { + if(idx instanceof RangeIndex) { + RangeIndex o = (RangeIndex) idx; + if(o.l >= u) + return false; + else if(o.u <= l) + return false; + else if(o.l <= l && o.u > l) + return true; + else if(o.l < u && o.u > u) + return true; + else + throw new NotImplementedException(idx + " " + this); + } + else + return super.containsAny(idx); + } + @Override public IColIndex combine(IColIndex other) { + final int sr = other.size(); if(other.size() == 1) { int v = other.get(0); if(v + 1 == l) return new RangeIndex(l - 1, u); else if(v == u) return new RangeIndex(l, u + 1); + else if(v < l) + return new CombinedIndex(other, this); + else + return new CombinedIndex(this, other); } - if(other instanceof RangeIndex) { + else if(other instanceof RangeIndex) { if(other.get(0) == u) return new RangeIndex(l, other.get(other.size() - 1) + 1); else if(other.get(other.size() - 1) == l - 1) @@ -166,31 +191,40 @@ else if(other.get(0) < this.get(0)) else return new TwoRangesIndex(this, (RangeIndex) other); } - - final int sr = other.size(); - final int sl = size(); - final int[] ret = new int[sr + sl]; - - int pl = 0; - int pr = 0; - int i = 0; - while(pl < sl && pr < sr) { - final int vl = get(pl); - final int vr = other.get(pr); - if(vl < vr) { - ret[i++] = vl; - pl++; - } - else { - ret[i++] = vr; - pr++; + else if(other.get(sr - 1) < l) { + return new CombinedIndex(other, this); + } + else if(other.get(0) > u) { + return new CombinedIndex(this, other); + } + else { + // final int sr = other.size(); + final int sl = size(); + final int[] ret = new int[sr + sl]; + + int pl = 0; + int pr = 0; + int i = 0; + while(pl < sl && pr < sr) { + final int vl = get(pl); + final int vr = other.get(pr); + if(vl < vr) { + ret[i++] = vl; + pl++; + } + else { + ret[i++] = vr; + pr++; + } } + while(pl < sl) + ret[i++] = get(pl++); + while(pr < sr) + ret[i++] = other.get(pr++); + return ColIndexFactory.create(ret); + } - while(pl < sl) - ret[i++] = get(pl++); - while(pr < sr) - ret[i++] = other.get(pr++); - return ColIndexFactory.create(ret); + } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java index 51634b92692..f1c27d44156 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java @@ -28,9 +28,9 @@ public class TwoRangesIndex extends AColIndex { /** The lower index range */ - private final RangeIndex idx1; + protected final RangeIndex idx1; /** The upper index range */ - private final RangeIndex idx2; + protected final RangeIndex idx2; public TwoRangesIndex(RangeIndex lower, RangeIndex higher) { this.idx1 = lower;