Skip to content

Commit

Permalink
Fix Spark 3.5.x support + support Spark 3.4.3, deprecate python 3.7+s…
Browse files Browse the repository at this point in the history
…upport python 3.10 (#411)

* support 3.4.1

* Support Spark 3.5.0

* support 3.4.3

* support 3.5.1

* build and test with python 3.10

deprecate spark support prior to 3.4

* deprecate databricks.koalas

* fix test

* Fix spark 3.5.x support

* fix lint

* Fix spark 3.5.x support

* final touch

* pin numpy < 2.0.0 and pyarrow < 15.0.0
  • Loading branch information
pang-wu authored Jun 18, 2024
1 parent 3b5e219 commit d4fd70d
Show file tree
Hide file tree
Showing 21 changed files with 112 additions and 56 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master
- name: Set up Python 3.7
- name: Set up Python 3.9
uses: actions/setup-python@0f07f7f756721ebd886c2462646a35f78a8bc4de # v1.2.4
with:
python-version: 3.7
python-version: 3.9
- name: Set up JDK 1.8
uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4
with:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ray_nightly_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest ]
python-version: [3.8, 3.9]
spark-version: [3.1.3, 3.2.4, 3.3.2, 3.4.0]
python-version: [3.8, 3.9, 3.10.14]
spark-version: [3.2.4, 3.3.2, 3.4.0, 3.5.0]

runs-on: ${{ matrix.os }}

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/raydp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest ]
python-version: [3.8, 3.9]
spark-version: [3.1.3, 3.2.4, 3.3.2, 3.4.0]
python-version: [3.8, 3.9, 3.10.14]
spark-version: [3.2.4, 3.3.2, 3.4.0, 3.5.0]

runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -82,7 +82,7 @@ jobs:
else
pip install torch
fi
pip install pyarrow==6.0.1 ray[train] pytest koalas tensorflow==2.13.1 tabulate grpcio-tools wget
pip install pyarrow==6.0.1 ray[train] pytest tensorflow==2.13.1 tabulate grpcio-tools wget
pip install "xgboost_ray[default]<=0.1.13"
pip install torchmetrics
- name: Cache Maven
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/raydp_nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master
- name: Set up Python 3.7
- name: Set up Python 3.9
uses: actions/setup-python@0f07f7f756721ebd886c2462646a35f78a8bc4de # v1.2.4
with:
python-version: 3.7
python-version: 3.9
- name: Set up JDK 1.8
uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4
with:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,19 @@

package org.apache.spark.sql.raydp


import com.intel.raydp.shims.SparkShimLoader
import io.ray.api.{ActorHandle, ObjectRef, PyActorHandle, Ray}
import io.ray.runtime.AbstractRayRuntime
import java.io.ByteArrayOutputStream
import java.util.{List, UUID}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
import java.util.function.{Function => JFunction}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import io.ray.api.{ActorHandle, ObjectRef, PyActorHandle, Ray}
import io.ray.runtime.AbstractRayRuntime
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.arrow.vector.types.pojo.Schema
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{RayDPException, SparkContext}
import org.apache.spark.deploy.raydp._
Expand Down Expand Up @@ -105,7 +103,7 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
Iterator(iter)
}

val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val arrowSchema = SparkShimLoader.getSparkShims.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"ray object store writer", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
Expand Down Expand Up @@ -217,7 +215,7 @@ object ObjectStoreWriter {
def toArrowSchema(df: DataFrame): Schema = {
val conf = df.queryExecution.sparkSession.sessionState.conf
val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
ArrowUtils.toArrowSchema(df.schema, timeZoneId)
SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId)
}

def fromSparkRDD(df: DataFrame, storageLevel: StorageLevel): Array[Array[Byte]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package com.intel.raydp.shims

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.executor.RayDPExecutorBackendFactory
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SparkSession}

sealed abstract class ShimDescriptor
Expand All @@ -36,4 +38,6 @@ trait SparkShims {
def getExecutorBackendFactory(): RayDPExecutorBackendFactory

def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext

def toArrowSchema(schema : StructType, timeZoneId : String) : Schema
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark322._
import org.apache.spark.spark322.TaskContextUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark322.SparkSqlUtils

import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.sql.types.StructType

class Spark322Shims extends SparkShims {
override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
Expand All @@ -44,4 +45,8 @@ class Spark322Shims extends SparkShims {
override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}

override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@

package org.apache.spark.sql.spark322

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

object SparkSqlUtils {
def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session))
}

def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark330._
import org.apache.spark.spark330.TaskContextUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark330.SparkSqlUtils

import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.sql.types.StructType

class Spark330Shims extends SparkShims {
override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
Expand All @@ -44,4 +45,8 @@ class Spark330Shims extends SparkShims {
override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}

override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@

package org.apache.spark.sql.spark330

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

object SparkSqlUtils {
def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
ArrowConverters.toDataFrame(rdd, schema, session)
}

def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ object SparkShimProvider {
val SPARK340_DESCRIPTOR = SparkShimDescriptor(3, 4, 0)
val SPARK341_DESCRIPTOR = SparkShimDescriptor(3, 4, 1)
val SPARK342_DESCRIPTOR = SparkShimDescriptor(3, 4, 2)
val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR")
val SPARK343_DESCRIPTOR = SparkShimDescriptor(3, 4, 3)
val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR",
s"$SPARK343_DESCRIPTOR")
val DESCRIPTOR = SPARK341_DESCRIPTOR
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark340._
import org.apache.spark.spark340.TaskContextUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark340.SparkSqlUtils

import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.sql.types.StructType

class Spark340Shims extends SparkShims {
override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
Expand All @@ -44,4 +45,8 @@ class Spark340Shims extends SparkShims {
override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}

override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.sql.spark340

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils

object SparkSqlUtils {
def toDataFrame(
Expand All @@ -36,4 +38,8 @@ object SparkSqlUtils {
}
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
}

def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}

object SparkShimProvider {
val SPARK350_DESCRIPTOR = SparkShimDescriptor(3, 5, 0)
val DESCRIPTOR_STRINGS = Seq(s"$SPARK350_DESCRIPTOR")
val SPARK351_DESCRIPTOR = SparkShimDescriptor(3, 5, 1)
val DESCRIPTOR_STRINGS = Seq(s"$SPARK350_DESCRIPTOR", s"$SPARK351_DESCRIPTOR")
val DESCRIPTOR = SPARK350_DESCRIPTOR
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import org.apache.spark.executor.spark350._
import org.apache.spark.spark350.TaskContextUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.spark350.SparkSqlUtils

import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.sql.types.StructType

class Spark350Shims extends SparkShims {
override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
Expand All @@ -44,4 +45,8 @@ class Spark350Shims extends SparkShims {
override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
TaskContextUtils.getDummyTaskContext(partitionId, env)
}

override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.sql.spark350

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils

object SparkSqlUtils {
def toDataFrame(
Expand All @@ -36,4 +38,8 @@ object SparkSqlUtils {
}
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
}

def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false)
}
}
4 changes: 2 additions & 2 deletions python/raydp/spark/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from raydp.utils import convert_to_spark

DF = Union["pyspark.sql.DataFrame", "koalas.DataFrame"]
OPTIONAL_DF = Union[Optional["pyspark.sql.DataFrame"], Optional["koalas.DataFrame"]]
DF = Union["pyspark.sql.DataFrame", "pyspark.pandas.DataFrame"]
OPTIONAL_DF = Union[Optional["pyspark.sql.DataFrame"], Optional["pyspark.pandas.DataFrame"]]


class SparkEstimatorInterface:
Expand Down
20 changes: 11 additions & 9 deletions python/raydp/tests/test_spark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import math
import sys

import databricks.koalas as ks
# https://spark.apache.org/docs/latest/api/python/migration_guide/koalas_to_pyspark.html
# import databricks.koalas as ks
import pyspark.pandas as ps
import pyspark
import pytest

Expand All @@ -27,13 +29,13 @@

def test_df_type_check(spark_session):
spark_df = spark_session.range(0, 10)
koalas_df = ks.range(0, 10)
koalas_df = ps.range(0, 10)
assert utils.df_type_check(spark_df)
assert utils.df_type_check(koalas_df)

other_df = "df"
error_msg = (f"The type: {type(other_df)} is not supported, only support " +
"pyspark.sql.DataFrame and databricks.koalas.DataFrame")
"pyspark.sql.DataFrame and pyspark.pandas.DataFrame")
with pytest.raises(Exception) as exinfo:
utils.df_type_check(other_df)
assert str(exinfo.value) == error_msg
Expand All @@ -45,15 +47,15 @@ def test_convert_to_spark(spark_session):
assert is_spark_df
assert spark_df is converted

koalas_df = ks.range(0, 10)
converted, is_spark_df = utils.convert_to_spark(koalas_df)
pandas_on_spark_df = ps.range(0, 10)
converted, is_spark_df = utils.convert_to_spark(pandas_on_spark_df)
assert not is_spark_df
assert isinstance(converted, pyspark.sql.DataFrame)
assert converted.count() == 10

other_df = "df"
error_msg = (f"The type: {type(other_df)} is not supported, only support " +
"pyspark.sql.DataFrame and databricks.koalas.DataFrame")
"pyspark.sql.DataFrame and pyspark.pandas.DataFrame")
with pytest.raises(Exception) as exinfo:
utils.df_type_check(other_df)
assert str(exinfo.value) == error_msg
Expand All @@ -64,10 +66,10 @@ def test_random_split(spark_session):
splits = utils.random_split(spark_df, [0.7, 0.3])
assert len(splits) == 2

koalas_df = ks.range(0, 10)
koalas_df = ps.range(0, 10)
splits = utils.random_split(koalas_df, [0.7, 0.3])
assert isinstance(splits[0], ks.DataFrame)
assert isinstance(splits[1], ks.DataFrame)
assert isinstance(splits[0], ps.DataFrame)
assert isinstance(splits[1], ps.DataFrame)
assert len(splits) == 2


Expand Down
Loading

0 comments on commit d4fd70d

Please sign in to comment.