diff --git a/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala b/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala index b8c17706..852ce3c4 100644 --- a/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala +++ b/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala @@ -6,6 +6,7 @@ import com.redislabs.provider.redis.util.PipelineUtils._ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import scala.collection.JavaConversions.mapAsJavaMap +import scala.collection.JavaConverters.mapAsJavaMapConverter /** * RedisContext extends sparkContext's functionality with redis functions @@ -303,6 +304,17 @@ class RedisContext(@transient val sc: SparkContext) extends Serializable { kvs.foreachPartition(partition => setZset(zsetName, partition, ttl, redisConfig, readWriteConfig)) } + /** + * @param kvs Write RDD of (zset name, zset member -> score) + * @param ttl time to live + */ + def toRedisZSETs(kvs: RDD[(String, Map[String, String])], ttl: Int = 0) + (implicit + redisConfig: RedisConfig = RedisConfig.fromSparkConf(sc.getConf), + readWriteConfig: ReadWriteConfig = ReadWriteConfig.fromSparkConf(sc.getConf)) { + kvs.foreachPartition(partition => setZset(partition, ttl, redisConfig, readWriteConfig)) + } + /** * @param vs RDD of values * @param setName target set's name which hold all the vs @@ -503,6 +515,33 @@ object RedisContext extends Serializable { conn.close() } + /** + * @param zsets zsetName: map of member -> score to be saved in the target host + * @param ttl time to live + */ + def setZset(zsets: Iterator[(String, Map[String, String])], + ttl: Int, + redisConfig: RedisConfig, + readWriteConfig: ReadWriteConfig) { + implicit val rwConf: ReadWriteConfig = readWriteConfig + + zsets + .map { case (key, memberScores) => + (redisConfig.getHost(key), (key, memberScores)) + } + .toArray + .groupBy(_._1) + .foreach { case (node, arr) => + withConnection(node.endpoint.connect()) { conn => + foreachWithPipeline(conn, arr) { (pipeline, a) => + val (key, memberScores) = a._2 + pipeline.zadd(key, memberScores.mapValues((v) => Double.box(v.toDouble)).asJava) + if (ttl > 0) pipeline.expire(key, ttl.toLong) + } + } + } + } + /** * @param setName * @param arr values which should be saved in the target host diff --git a/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala index 17102052..cf2d7b2b 100644 --- a/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala +++ b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala @@ -73,6 +73,20 @@ trait RedisRddExtraSuite extends SparkRedisSuite with Keys with Matchers { verifyHash("hash2", map2) } + test("toRedisZETs") { + val map1 = Map("k1" -> "3.14", "k2" -> "2.71") + val map2 = Map("k3" -> "10", "k4" -> "12", "k5" -> "8", "k6" -> "2") + val zsets = Seq( + ("zset1", map1), + ("zset2", map2) + ) + val rdd = sc.parallelize(zsets) + sc.toRedisZSETs(rdd) + + verifyZSET("zset1", map1) + verifyZSET("zset2", map2) + } + test("connection fails with incorrect user/pass") { assertThrows[JedisConnectionException] { new RedisConfig(RedisEndpoint( @@ -112,4 +126,9 @@ trait RedisRddExtraSuite extends SparkRedisSuite with Keys with Matchers { } } + def verifyZSET(zset: String, vals: Map[String, String]): Unit = { + val zsetWithScore = sc.fromRedisZSetWithScore(zset).sortByKey().collect + zsetWithScore should be(vals.mapValues((v) => v.toDouble).toArray.sortBy(_._1)) + } + }