diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala index ad85376cec8ac..f70715fca6eea 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -21,7 +21,7 @@ package org.apache.spark.graphx * Represents an edge along with its neighboring vertices and allows sending messages along the * edge. Used in [[Graph#aggregateMessages]]. */ -trait EdgeContext[VD, ED, A] { +abstract class EdgeContext[VD, ED, A] { /** The vertex id of the edge's source vertex. */ def srcId: VertexId /** The vertex id of the edge's destination vertex. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index c0c7ca19d3b76..e0ba9403ba75b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -185,6 +185,33 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab def mapEdges[ED2: ClassTag](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]) : Graph[VD, ED2] + /** + * Transforms each edge attribute using the map function, passing it the adjacent vertex + * attributes as well. If adjacent vertex values are not required, + * consider using `mapEdges` instead. + * + * @note This does not change the structure of the + * graph or modify the values of this graph. As a consequence + * the underlying index structures can be reused. + * + * @param map the function from an edge object to a new edge value. + * + * @tparam ED2 the new edge data type + * + * @example This function might be used to initialize edge + * attributes based on the attributes associated with each vertex. + * {{{ + * val rawGraph: Graph[Int, Int] = someLoadFunction() + * val graph = rawGraph.mapTriplets[Int]( edge => + * edge.src.data - edge.dst.data) + * }}} + * + */ + def mapTriplets[ED2: ClassTag]( + map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = { + mapTriplets((pid, iter) => iter.map(map), TripletFields.All) + } + /** * Transforms each edge attribute using the map function, passing it the adjacent vertex * attributes as well. If adjacent vertex values are not required, @@ -211,7 +238,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab */ def mapTriplets[ED2: ClassTag]( map: EdgeTriplet[VD, ED] => ED2, - tripletFields: TripletFields = TripletFields.All): Graph[VD, ED2] = { + tripletFields: TripletFields): Graph[VD, ED2] = { mapTriplets((pid, iter) => iter.map(map), tripletFields) } @@ -305,13 +332,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * be commutative and associative and is used to combine the output * of the map phase * - * @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to - * consider when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on - * edges with destination in the active set. If the direction is `Out`, - * `mapFunc` will only be run on edges originating from vertices in the active set. If the - * direction is `Either`, `mapFunc` will be run on edges with *either* vertex in the active set - * . If the direction is `Both`, `mapFunc` will be run on edges with *both* vertices in the - * active set. The active set must have the same index as the graph's vertices. + * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if + * desired. This is done by specifying a set of "active" vertices and an edge direction. The + * `sendMsg` function will then run only on edges connected to active vertices by edges in the + * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with + * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges + * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be + * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` + * will be run on edges with *both* vertices in the active set. The active set must have the + * same index as the graph's vertices. * * @example We can use this function to compute the in-degree of each * vertex @@ -349,15 +378,6 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * combiner should be commutative and associative. * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the * `sendMsg` function. If not all fields are needed, specifying this can improve performance. - * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if - * desired. This is done by specifying a set of "active" vertices and an edge direction. The - * `sendMsg` function will then run on only edges connected to active vertices by edges in the - * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with - * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges - * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be - * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` - * will be run on edges with *both* vertices in the active set. The active set must have the - * same index as the graph's vertices. * * @example We can use this function to compute the in-degree of each * vertex @@ -377,8 +397,43 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab def aggregateMessages[A: ClassTag]( sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, - tripletFields: TripletFields = TripletFields.All, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) + tripletFields: TripletFields = TripletFields.All) + : VertexRDD[A] = { + aggregateMessagesWithActiveSet(sendMsg, mergeMsg, tripletFields, None) + } + + /** + * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied + * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be + * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages + * destined to the same vertex. + * + * This variant can take an active set to restrict the computation and is intended for internal + * use only. + * + * @tparam A the type of message to be sent to each vertex + * + * @param sendMsg runs on each edge, sending messages to neighboring vertices using the + * [[EdgeContext]]. + * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This + * combiner should be commutative and associative. + * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the + * `sendMsg` function. If not all fields are needed, specifying this can improve performance. + * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if + * desired. This is done by specifying a set of "active" vertices and an edge direction. The + * `sendMsg` function will then run on only edges connected to active vertices by edges in the + * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with + * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges + * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be + * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` + * will be run on edges with *both* vertices in the active set. The active set must have the + * same index as the graph's vertices. + */ + private[graphx] def aggregateMessagesWithActiveSet[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]) : VertexRDD[A] /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java new file mode 100644 index 0000000000000..34df4b7ee7a06 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx; + +import java.io.Serializable; + +/** + * Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the + * system to populate only those fields for efficiency. + */ +public class TripletFields implements Serializable { + public final boolean useSrc; + public final boolean useDst; + public final boolean useEdge; + + public TripletFields() { + this(true, true, true); + } + + public TripletFields(boolean useSrc, boolean useDst, boolean useEdge) { + this.useSrc = useSrc; + this.useDst = useDst; + this.useEdge = useEdge; + } + + public static final TripletFields None = new TripletFields(false, false, false); + public static final TripletFields EdgeOnly = new TripletFields(false, false, true); + public static final TripletFields SrcOnly = new TripletFields(true, false, false); + public static final TripletFields DstOnly = new TripletFields(false, true, false); + public static final TripletFields SrcDstOnly = new TripletFields(true, true, false); + public static final TripletFields SrcAndEdge = new TripletFields(true, false, true); + public static final TripletFields Src = SrcAndEdge; + public static final TripletFields DstAndEdge = new TripletFields(false, true, true); + public static final TripletFields Dst = DstAndEdge; + public static final TripletFields All = new TripletFields(true, true, true); +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala deleted file mode 100644 index e92e2763a0c06..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx - -/** - * Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the - * system to populate only those fields for efficiency. - */ -class TripletFields private ( - val useSrc: Boolean, - val useDst: Boolean, - val useEdge: Boolean) - extends Serializable { - private def this() = this(true, true, true) -} - -/** - * Exposes all possible [[TripletFields]] objects. - */ -object TripletFields { - final val None = new TripletFields(useSrc = false, useDst = false, useEdge = false) - final val EdgeOnly = new TripletFields(useSrc = false, useDst = false, useEdge = true) - final val SrcOnly = new TripletFields(useSrc = true, useDst = false, useEdge = false) - final val DstOnly = new TripletFields(useSrc = false, useDst = true, useEdge = false) - final val SrcDstOnly = new TripletFields(useSrc = true, useDst = true, useEdge = false) - final val SrcAndEdge = new TripletFields(useSrc = true, useDst = false, useEdge = true) - final val Src = SrcAndEdge - final val DstAndEdge = new TripletFields(useSrc = false, useDst = true, useEdge = true) - final val Dst = DstAndEdge - final val All = new TripletFields(useSrc = true, useDst = true, useEdge = true) - - /** Returns the appropriate [[TripletFields]] object. */ - private[graphx] def apply(useSrc: Boolean, useDst: Boolean, useEdge: Boolean) = - (useSrc, useDst, useEdge) match { - case (false, false, false) => TripletFields.None - case (false, false, true) => EdgeOnly - case (true, false, false) => SrcOnly - case (false, true, false) => DstOnly - case (true, true, false) => SrcDstOnly - case (true, false, true) => SrcAndEdge - case (false, true, true) => DstAndEdge - case (true, true, true) => All - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 86ee3923151b4..78d8ac24b5271 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -24,9 +24,17 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.util.collection.BitSet /** - * A collection of edges stored in columnar format, along with any vertex attributes referenced. The - * edges are stored in 3 large columnar arrays (src, dst, attribute). The arrays are clustered by - * src. There is an optional active vertex set for filtering computation on the edges. + * A collection of edges, along with referenced vertex attributes and an optional active vertex set + * for filtering computation on the edges. + * + * The edges are stored in columnar format in `localSrcIds`, `localDstIds`, and `data`. All + * referenced global vertex ids are mapped to a compact set of local vertex ids according to the + * `global2local` map. Each local vertex id is a valid index into `vertexAttrs`, which stores the + * corresponding vertex attribute, and `local2global`, which stores the reverse mapping to global + * vertex id. The global vertex ids that are active are optionally stored in `activeSet`. + * + * The edges are clustered by source vertex id, and the mapping from global vertex id to the index + * of the corresponding edge cluster is stored in `index`. * * @tparam ED the edge attribute type * @tparam VD the vertex attribute type @@ -46,15 +54,17 @@ import org.apache.spark.util.collection.BitSet private[graphx] class EdgePartition[ @specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag]( - val localSrcIds: Array[Int] = null, - val localDstIds: Array[Int] = null, - val data: Array[ED] = null, - val index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null, - val global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null, - val local2global: Array[VertexId] = null, - val vertexAttrs: Array[VD] = null, - val activeSet: Option[VertexSet] = None - ) extends Serializable { + localSrcIds: Array[Int], + localDstIds: Array[Int], + data: Array[ED], + index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + local2global: Array[VertexId], + vertexAttrs: Array[VD], + activeSet: Option[VertexSet]) + extends Serializable { + + private def this() = this(null, null, null, null, null, null, null, null) /** Return a new `EdgePartition` with the specified edge data. */ def withData[ED2: ClassTag](data: Array[ED2]): EdgePartition[ED2, VD] = { @@ -85,16 +95,18 @@ class EdgePartition[ } /** Return a new `EdgePartition` without any locally cached vertex attributes. */ - def clearVertices[VD2: ClassTag](): EdgePartition[ED, VD2] = { + def withoutVertexAttributes[VD2: ClassTag](): EdgePartition[ED, VD2] = { val newVertexAttrs = new Array[VD2](vertexAttrs.length) new EdgePartition( localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs, activeSet) } - private def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos)) + @inline private def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos)) - private def dstIds(pos: Int): VertexId = local2global(localDstIds(pos)) + @inline private def dstIds(pos: Int): VertexId = local2global(localDstIds(pos)) + + @inline private def attrs(pos: Int): ED = data(pos) /** Look up vid in activeSet, throwing an exception if it is None. */ def isActive(vid: VertexId): Boolean = { @@ -285,7 +297,7 @@ class EdgePartition[ if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) { // ... run `f` on the matching edge builder.add(srcId, dstId, localSrcIds(i), localDstIds(i), - f(srcId, dstId, this.data(i), other.data(j))) + f(srcId, dstId, this.data(i), other.attrs(j))) } } i += 1 @@ -332,27 +344,53 @@ class EdgePartition[ * It is safe to keep references to the objects from this iterator. */ def tripletIterator( - includeSrc: Boolean = true, includeDst: Boolean = true): Iterator[EdgeTriplet[VD, ED]] = { - new EdgeTripletIterator(this, includeSrc, includeDst) + includeSrc: Boolean = true, includeDst: Boolean = true) + : Iterator[EdgeTriplet[VD, ED]] = new Iterator[EdgeTriplet[VD, ED]] { + private[this] var pos = 0 + + override def hasNext: Boolean = pos < EdgePartition.this.size + + override def next() = { + val triplet = new EdgeTriplet[VD, ED] + val localSrcId = localSrcIds(pos) + val localDstId = localDstIds(pos) + triplet.srcId = local2global(localSrcId) + triplet.dstId = local2global(localDstId) + if (includeSrc) { + triplet.srcAttr = vertexAttrs(localSrcId) + } + if (includeDst) { + triplet.dstAttr = vertexAttrs(localDstId) + } + triplet.attr = data(pos) + pos += 1 + triplet + } } /** * Send messages along edges and aggregate them at the receiving vertices. Implemented by scanning - * all edges sequentially and filtering them with `idPred`. + * all edges sequentially. * * @param sendMsg generates messages to neighboring vertices of an edge * @param mergeMsg the combiner applied to messages destined to the same vertex - * @param sendMsgUsesSrcAttr whether or not `mapFunc` uses the edge's source vertex attribute - * @param sendMsgUsesDstAttr whether or not `mapFunc` uses the edge's destination vertex attribute - * @param idPred a predicate to filter edges based on their source and destination vertex ids + * @param tripletFields which triplet fields `sendMsg` uses + * @param srcMustBeActive if true, edges will only be considered if their source vertex is in the + * active set + * @param dstMustBeActive if true, edges will only be considered if their destination vertex is in + * the active set + * @param maySatisfyEither if true, only one vertex need be in the active set for an edge to be + * considered * * @return iterator aggregated messages keyed by the receiving vertex id */ - def aggregateMessages[A: ClassTag]( + def aggregateMessagesEdgeScan[A: ClassTag]( sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, tripletFields: TripletFields, - idPred: (VertexId, VertexId) => Boolean): Iterator[(VertexId, A)] = { + srcMustBeActive: Boolean, + dstMustBeActive: Boolean, + maySatisfyEither: Boolean): Iterator[(VertexId, A)] = { val aggregates = new Array[A](vertexAttrs.length) val bitset = new BitSet(vertexAttrs.length) @@ -363,14 +401,14 @@ class EdgePartition[ val srcId = local2global(localSrcId) val localDstId = localDstIds(i) val dstId = local2global(localDstId) - if (idPred(srcId, dstId)) { - ctx.localSrcId = localSrcId - ctx.localDstId = localDstId - ctx.srcId = srcId - ctx.dstId = dstId - ctx.attr = data(i) - if (tripletFields.useSrc) { ctx.srcAttr = vertexAttrs(localSrcId) } - if (tripletFields.useDst) { ctx.dstAttr = vertexAttrs(localDstId) } + val srcIsActive = !srcMustBeActive || isActive(srcId) + val dstIsActive = !dstMustBeActive || isActive(dstId) + val edgeIsActive = + if (maySatisfyEither) srcIsActive || dstIsActive else srcIsActive && dstIsActive + if (edgeIsActive) { + val srcAttr = if (tripletFields.useSrc) vertexAttrs(localSrcId) else null.asInstanceOf[VD] + val dstAttr = if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD] + ctx.set(srcId, dstId, localSrcId, localDstId, srcAttr, dstAttr, data(i)) sendMsg(ctx) } i += 1 @@ -381,22 +419,27 @@ class EdgePartition[ /** * Send messages along edges and aggregate them at the receiving vertices. Implemented by - * filtering the source vertex index with `srcIdPred`, then scanning edge clusters and filtering - * with `dstIdPred`. Both `srcIdPred` and `dstIdPred` must match for an edge to run. + * filtering the source vertex index, then scanning each edge cluster. * * @param sendMsg generates messages to neighboring vertices of an edge * @param mergeMsg the combiner applied to messages destined to the same vertex - * @param srcIdPred a predicate to filter edges based on their source vertex id - * @param dstIdPred a predicate to filter edges based on their destination vertex id + * @param tripletFields which triplet fields `sendMsg` uses + * @param srcMustBeActive if true, edges will only be considered if their source vertex is in the + * active set + * @param dstMustBeActive if true, edges will only be considered if their destination vertex is in + * the active set + * @param maySatisfyEither if true, only one vertex need be in the active set for an edge to be + * considered * * @return iterator aggregated messages keyed by the receiving vertex id */ - def aggregateMessagesWithIndex[A: ClassTag]( + def aggregateMessagesIndexScan[A: ClassTag]( sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, tripletFields: TripletFields, - srcIdPred: VertexId => Boolean, - dstIdPred: VertexId => Boolean): Iterator[(VertexId, A)] = { + srcMustBeActive: Boolean, + dstMustBeActive: Boolean, + maySatisfyEither: Boolean): Iterator[(VertexId, A)] = { val aggregates = new Array[A](vertexAttrs.length) val bitset = new BitSet(vertexAttrs.length) @@ -405,19 +448,22 @@ class EdgePartition[ val clusterSrcId = cluster._1 val clusterPos = cluster._2 val clusterLocalSrcId = localSrcIds(clusterPos) - if (srcIdPred(clusterSrcId)) { + val srcIsActive = !srcMustBeActive || isActive(clusterSrcId) + if (srcIsActive || maySatisfyEither) { var pos = clusterPos - ctx.srcId = clusterSrcId - ctx.localSrcId = clusterLocalSrcId - if (tripletFields.useSrc) { ctx.srcAttr = vertexAttrs(clusterLocalSrcId) } + val srcAttr = + if (tripletFields.useSrc) vertexAttrs(clusterLocalSrcId) else null.asInstanceOf[VD] + ctx.setSrcOnly(clusterSrcId, clusterLocalSrcId, srcAttr) while (pos < size && localSrcIds(pos) == clusterLocalSrcId) { val localDstId = localDstIds(pos) val dstId = local2global(localDstId) - if (dstIdPred(dstId)) { - ctx.dstId = dstId - ctx.localDstId = localDstId - ctx.attr = data(pos) - if (tripletFields.useDst) { ctx.dstAttr = vertexAttrs(localDstId) } + val dstIsActive = !dstMustBeActive || isActive(dstId) + val edgeIsActive = + if (maySatisfyEither) srcIsActive || dstIsActive else srcIsActive && dstIsActive + if (edgeIsActive) { + val dstAttr = + if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD] + ctx.setRest(dstId, localDstId, dstAttr, data(pos)) sendMsg(ctx) } pos += 1 @@ -435,23 +481,55 @@ private class AggregatingEdgeContext[VD, ED, A]( bitset: BitSet) extends EdgeContext[VD, ED, A] { - var srcId: VertexId = _ - var dstId: VertexId = _ - var srcAttr: VD = _ - var dstAttr: VD = _ - var attr: ED = _ + private[this] var _srcId: VertexId = _ + private[this] var _dstId: VertexId = _ + private[this] var _localSrcId: Int = _ + private[this] var _localDstId: Int = _ + private[this] var _srcAttr: VD = _ + private[this] var _dstAttr: VD = _ + private[this] var _attr: ED = _ + + def set( + srcId: VertexId, dstId: VertexId, + localSrcId: Int, localDstId: Int, + srcAttr: VD, dstAttr: VD, + attr: ED) { + _srcId = srcId + _dstId = dstId + _localSrcId = localSrcId + _localDstId = localDstId + _srcAttr = srcAttr + _dstAttr = dstAttr + _attr = attr + } + + def setSrcOnly(srcId: VertexId, localSrcId: Int, srcAttr: VD) { + _srcId = srcId + _localSrcId = localSrcId + _srcAttr = srcAttr + } + + def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED) { + _dstId = dstId + _localDstId = localDstId + _dstAttr = dstAttr + _attr = attr + } - var localSrcId: Int = _ - var localDstId: Int = _ + override def srcId = _srcId + override def dstId = _dstId + override def srcAttr = _srcAttr + override def dstAttr = _dstAttr + override def attr = _attr override def sendToSrc(msg: A) { - send(localSrcId, msg) + send(_localSrcId, msg) } override def sendToDst(msg: A) { - send(localDstId, msg) + send(_localDstId, msg) } - private def send(localId: Int, msg: A) { + @inline private def send(localId: Int, msg: A) { if (bitset.get(localId)) { aggregates(localId) = mergeMsg(aggregates(localId), msg) } else { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 95a9dca3d16e7..b0cb0fe47d461 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -29,7 +29,7 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap private[graphx] class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( size: Int = 64) { - var edges = new PrimitiveVector[Edge[ED]](size) + private[this] val edges = new PrimitiveVector[Edge[ED]](size) /** Add a new edge to the partition. */ def add(src: VertexId, dst: VertexId, d: ED) { @@ -71,7 +71,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla vertexAttrs = new Array[VD](currLocalId + 1) } new EdgePartition( - localSrcIds, localDstIds, data, index, global2local, local2global.trim().array, vertexAttrs) + localSrcIds, localDstIds, data, index, global2local, local2global.trim().array, vertexAttrs, + None) } } @@ -87,7 +88,7 @@ class ExistingEdgePartitionBuilder[ vertexAttrs: Array[VD], activeSet: Option[VertexSet], size: Int = 64) { - var edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size) + private[this] val edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size) /** Add a new edge to the partition. */ def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala deleted file mode 100644 index a8f829ed20a34..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.reflect.ClassTag - -import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap - -/** - * The Iterator type returned when constructing edge triplets. This could be an anonymous class in - * EdgePartition.tripletIterator, but we name it here explicitly so it is easier to debug / profile. - */ -private[impl] -class EdgeTripletIterator[VD: ClassTag, ED: ClassTag]( - val edgePartition: EdgePartition[ED, VD], - val includeSrc: Boolean, - val includeDst: Boolean) - extends Iterator[EdgeTriplet[VD, ED]] { - - // Current position in the array. - private var pos = 0 - - override def hasNext: Boolean = pos < edgePartition.size - - override def next() = { - val triplet = new EdgeTriplet[VD, ED] - val localSrcId = edgePartition.localSrcIds(pos) - val localDstId = edgePartition.localDstIds(pos) - triplet.srcId = edgePartition.local2global(localSrcId) - triplet.dstId = edgePartition.local2global(localDstId) - if (includeSrc) { - triplet.srcAttr = edgePartition.vertexAttrs(localSrcId) - } - if (includeDst) { - triplet.dstAttr = edgePartition.vertexAttrs(localDstId) - } - triplet.attr = edgePartition.data(pos) - pos += 1 - triplet - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index bcbb22b9100dc..a1fe158b7b490 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -186,16 +186,16 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") - val tripletFields = TripletFields(mapUsesSrcAttr, mapUsesDstAttr, useEdge = true) + val tripletFields = new TripletFields(mapUsesSrcAttr, mapUsesDstAttr, true) - aggregateMessages(sendMsg, reduceFunc, tripletFields, activeSetOpt) + aggregateMessagesWithActiveSet(sendMsg, reduceFunc, tripletFields, activeSetOpt) } - override def aggregateMessages[A: ClassTag]( + override def aggregateMessagesWithActiveSet[A: ClassTag]( sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, tripletFields: TripletFields, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = { + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { vertices.cache() // For each vertex, replicate its attribute only to partitions where it is @@ -217,33 +217,31 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( activeDirectionOpt match { case Some(EdgeDirection.Both) => if (activeFraction < 0.8) { - edgePartition.aggregateMessagesWithIndex(sendMsg, mergeMsg, tripletFields, - srcId => edgePartition.isActive(srcId), - dstId => edgePartition.isActive(dstId)) + edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields, + true, true, false) } else { - edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, - (srcId, dstId) => edgePartition.isActive(srcId) && edgePartition.isActive(dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + true, true, false) } case Some(EdgeDirection.Either) => // TODO: Because we only have a clustered index on the source vertex ID, we can't filter // the index here. Instead we have to scan all edges and then do the filter. - edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, - (srcId, dstId) => edgePartition.isActive(srcId) || edgePartition.isActive(dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + true, true, true) case Some(EdgeDirection.Out) => if (activeFraction < 0.8) { - edgePartition.aggregateMessagesWithIndex(sendMsg, mergeMsg, tripletFields, - srcId => edgePartition.isActive(srcId), - dstId => true) + edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields, + true, false, false) } else { - edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, - (srcId, dstId) => edgePartition.isActive(srcId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + true, false, false) } case Some(EdgeDirection.In) => - edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, - (srcId, dstId) => edgePartition.isActive(dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + false, true, false) case _ => // None - edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, - (srcId, dstId) => true) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + false, false, false) } }).setName("GraphImpl.aggregateMessages - preAgg") @@ -327,7 +325,7 @@ object GraphImpl { vertices: VertexRDD[VD], edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = { // Convert the vertex partitions in edges to the correct type - val newEdges = edges.mapEdgePartitions((pid, part) => part.clearVertices[VD]) + val newEdges = edges.mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) GraphImpl.fromExistingRDDs(vertices, newEdges) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index c7a59990ce8e7..515f3a9cd02eb 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -102,6 +102,16 @@ class EdgePartitionSuite extends FunSuite { assert(ep.numActives == Some(2)) } + test("tripletIterator") { + val builder = new EdgePartitionBuilder[Int, Int] + builder.add(1, 2, 0) + builder.add(1, 3, 0) + builder.add(1, 4, 0) + val ep = builder.toEdgePartition + val result = ep.tripletIterator().toList.map(et => (et.srcId, et.dstId)) + assert(result === Seq((1, 2), (1, 3), (1, 4))) + } + test("serialization") { val aList = List((0, 1, 1), (1, 0, 2), (1, 2, 3), (5, 4, 4), (5, 5, 5)) val a: EdgePartition[Int, Int] = makeEdgePartition(aList) @@ -113,7 +123,6 @@ class EdgePartitionSuite extends FunSuite { for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a)) assert(aSer.tripletIterator().toList === a.tripletIterator().toList) - assert(aSer.index != null) } } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala deleted file mode 100644 index 49b2704390fea..0000000000000 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.reflect.ClassTag -import scala.util.Random - -import org.scalatest.FunSuite - -import org.apache.spark.graphx._ - -class EdgeTripletIteratorSuite extends FunSuite { - test("iterator.toList") { - val builder = new EdgePartitionBuilder[Int, Int] - builder.add(1, 2, 0) - builder.add(1, 3, 0) - builder.add(1, 4, 0) - val iter = new EdgeTripletIterator[Int, Int](builder.toEdgePartition, true, true) - val result = iter.toList.map(et => (et.srcId, et.dstId)) - assert(result === Seq((1, 2), (1, 3), (1, 4))) - } -}