diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 61496ddfb0346..4825d12fc27b3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -410,10 +410,9 @@ object VertexRDD { def apply[VD: ClassTag]( vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD, mergeFunc: (VD, VD) => VD ): VertexRDD[VD] = { - val verticesDedup = vertices.reduceByKey((VD1, VD2) => mergeFunc(VD1, VD2)) - val vPartitioned: RDD[(VertexId, VD)] = verticesDedup.partitioner match { - case Some(p) => verticesDedup - case None => verticesDedup.copartitionWithVertices(new HashPartitioner(verticesDedup.partitions.size)) + val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { + case Some(p) => vertices + case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size)) } val routingTables = createRoutingTables(edges, vPartitioned.partitioner.get) val vertexPartitions = vPartitioned.zipPartitions(routingTables, preservesPartitioning = true) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index dca54b8a7da86..d638d578ee300 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -40,13 +40,14 @@ object ShippableVertexPartition { /** * Construct a `ShippableVertexPartition` from the given vertices with the specified routing - * table, filling in missing vertices mentioned in the routing table using `defaultVal`. + * table, filling in missing vertices mentioned in the routing table using `defaultVal`, + * and merging duplicate vertex atrribute with mergeFunc. */ def apply[VD: ClassTag]( - iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD) + iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD, mergeFunc: (VD, VD) => VD) : ShippableVertexPartition[VD] = { val fullIter = iter ++ routingTable.iterator.map(vid => (vid, defaultVal)) - val (index, values, mask) = VertexPartitionBase.initFrom(fullIter, (a: VD, b: VD) => a) + val (index, values, mask) = VertexPartitionBase.initFrom(fullIter, mergeFunc) new ShippableVertexPartition(index, values, mask, routingTable) }