diff --git a/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/app/MLSQLMllib.scala b/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/app/MLSQLMllib.scala index 313b6a92..04be2a6b 100644 --- a/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/app/MLSQLMllib.scala +++ b/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/app/MLSQLMllib.scala @@ -5,7 +5,7 @@ import tech.mlsql.common.utils.log.Logging import tech.mlsql.dsl.CommandCollection import tech.mlsql.ets.register.ETRegister import tech.mlsql.plugins.mllib.ets._ -import tech.mlsql.plugins.mllib.ets.fe.{DataTranspose, OnehotExt, PSIExt, SQLDataSummary, SQLDescriptiveMetrics, SQLMissingValueProcess, SQLPatternDistribution, SQLUniqueIdentifier} +import tech.mlsql.plugins.mllib.ets.fe.{DataTranspose, OnehotExt, PSIExt, SQLDataSummary, SQLDataSummaryV2, SQLDescriptiveMetrics, SQLMissingValueProcess, SQLPatternDistribution, SQLUniqueIdentifier} import tech.mlsql.plugins.mllib.ets.fintech.scorecard.{SQLBinning, SQLScoreCard} import tech.mlsql.version.VersionCompatibility @@ -20,7 +20,7 @@ class MLSQLMllib extends tech.mlsql.app.App with VersionCompatibility with Loggi ETRegister.register("SampleDatasetExt", classOf[SampleDatasetExt].getName) ETRegister.register("TakeRandomSampleExt", classOf[TakeRandomSampleExt].getName) ETRegister.register("ColumnsExt", classOf[ColumnsExt].getName) - ETRegister.register("DataSummary", classOf[SQLDataSummary].getName) + ETRegister.register("DataSummary", classOf[SQLDataSummaryV2].getName) ETRegister.register("DataMissingValueProcess", classOf[SQLMissingValueProcess].getName) ETRegister.register("Binning", classOf[SQLBinning].getName) ETRegister.register("ScoreCard", classOf[SQLScoreCard].getName) diff --git a/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2.scala b/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2.scala new file mode 100644 index 00000000..6d94484e --- /dev/null +++ b/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2.scala @@ -0,0 +1,442 @@ +package tech.mlsql.plugins.mllib.ets.fe + +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession, functions => F} +import streaming.dsl.ScriptSQLExec +import streaming.dsl.auth._ +import streaming.dsl.mmlib.algs.param.BaseParams +import streaming.dsl.mmlib.algs.{CodeExampleText, Functions, MllibFunctions} +import streaming.dsl.mmlib.{Code, SQLAlg, SQLCode} +import tech.mlsql.dsl.auth.ETAuth +import tech.mlsql.common.utils.log.Logging +import tech.mlsql.dsl.auth.dsl.mmlib.ETMethod.ETMethod + +import java.util.Date +import scala.util.{Failure, Success, Try} + +class SQLDataSummaryV2(override val uid: String) extends SQLAlg with MllibFunctions with Functions with BaseParams with ETAuth with Logging { + + def this() = this(BaseParams.randomUID()) + + var round_at = 2 + + var numericCols: Array[String] = null + + def colWithFilterBlank(sc: StructField): Column = { + val col_name = sc.name + sc.dataType match { + case DoubleType => col(col_name).isNotNull && !col(col_name).isNaN + case FloatType => col(col_name).isNotNull && !col(col_name).isNaN + case StringType => col(col_name).isNotNull && col(col_name) =!= "" + case _ => col(col_name).isNotNull + } + } + + + def countColsStdDevNumber(schema: StructType, numeric_columns: Array[String]): Array[Column] = { + schema.map(sc => { + val c = sc.name + if (numeric_columns.contains(c)) { + val expr = stddev(when(colWithFilterBlank(sc), col(c))) + when(expr.isNull, lit("")).otherwise(expr).alias(c + "_standardDeviation") + } else { + lit("").alias(c + "_standardDeviation") + } + }).toArray + } + + def countColsStdErrNumber(schema: StructType, numeric_columns: Array[String]): Array[Column] = { + schema.map(sc => { + val c = sc.name + if (numeric_columns.contains(c)) { + val expr = stddev(when(colWithFilterBlank(sc), col(c))) / sqrt(sum(when(colWithFilterBlank(sc), 1).otherwise(0))) + when(expr.isNull, lit("")).otherwise(expr).alias(c + "_standardError") + } else { + lit("").alias(c + "_standardError") + } + }).toArray + } + + def isPrimaryKey(schmea: StructType, approx: Boolean): Array[Column] = { + schmea.map(sc => { + val c = sc.name + val exp1 = if (approx) { + approx_count_distinct(when(colWithFilterBlank(sc), col(sc.name))) / sum(when(colWithFilterBlank(sc), 1).otherwise(0)) + } else { + countDistinct(when(colWithFilterBlank(sc), col(sc.name))) / sum(when(colWithFilterBlank(sc), 1).otherwise(0)) + } + when(exp1 === 1, 1).otherwise(0).alias(sc.name + "_primaryKeyCandidate") + }).toArray + } + + def countUniqueValueRatio(schema: StructType, approx: Boolean): Array[Column] = { + schema.map(sc => { + // TODO: + val sum_expr = sum(when(colWithFilterBlank(sc), 1).otherwise(0)) + val divide_expr = if (approx) { + approx_count_distinct(when(colWithFilterBlank(sc), col(sc.name))) / sum_expr + } else { + countDistinct(when(colWithFilterBlank(sc), col(sc.name))) / sum_expr + } + val ratio_expr = when(sum_expr === 0, 0.0).otherwise(divide_expr) + + ratio_expr.alias(sc.name + "_uniqueValueRatio") + }).toArray + } + + def getMaxNum(schema: StructType, numeric_columns: Array[String]): Array[Column] = { + schema.map(sc => { + val c = sc.name + val max_expr = max(when(colWithFilterBlank(sc), col(c))) + when(max_expr.isNull, "").otherwise(max_expr.cast(StringType)).alias(c + "_max") + }).toArray + } + + def getMinNum(schema: StructType, numeric_columns: Array[String]): Array[Column] = { + schema.map(sc => { + val c = sc.name + val min_expr = min(when(colWithFilterBlank(sc), col(c))) + when(min_expr.isNull, "").otherwise(min_expr.cast(StringType)).alias(c + "_min") + }).toArray + } + + def roundAtSingleCol(sc: StructField, column: Column): Column = { + if (numericCols.contains(sc.name)) { + return round(column, round_at).cast(StringType) + } + column.cast(StringType) + } + + def processModeValue(modeCandidates: Array[Row], modeFormat: String): Any = { + val mode = if (modeCandidates.lengthCompare(2) >= 0) { + modeFormat match { + case ModeValueFormat.empty => "" + case ModeValueFormat.all => "[" + modeCandidates.map(_.get(0).toString).mkString(",") + "]" + case ModeValueFormat.auto => modeCandidates.head.get(0) + } + } else { + modeCandidates.head.get(0) + } + mode + } + + def isArrayString(mode: Any): Boolean = { + mode.toString.startsWith("[") && mode.toString.endsWith("]") + } + + def countNonNullValue(schema: StructType): Array[Column] = { + schema.map(sc => { + sum(when(col(sc.name).isNotNull, 1).otherwise(0)) + }).toArray + } + + def nullValueCount(schema: StructType): Array[Column] = { + schema.map(sc => { + sc.dataType match { + case DoubleType | FloatType => (sum(when(col(sc.name).isNull || col(sc.name).isNaN, 1).otherwise(0))) / (sum(lit(1))).alias(sc.name + "_nullValueRatio") + case _ => (sum(when(col(sc.name).isNull, 1).otherwise(0))) / (sum(lit(1))).alias(sc.name + "_nullValueRatio") + } + }).toArray + } + + def emptyCount(schema: StructType): Array[Column] = { + schema.map(sc => { + sum(when(col(sc.name) === "", 1).otherwise(0)) / sum(lit(1.0)).alias(sc.name + "_blankValueRatio") + }).toArray + } + + def getMaxLength(schema: StructType): Array[Column] = { + schema.map(sc => { + sc.dataType match { + case StringType => max(length(col(sc.name))).alias(sc.name + "maximumLength") + case _ => lit("").alias(sc.name + "maximumLength") + } + }).toArray + } + + def getMinLength(schema: StructType): Array[Column] = { + schema.map(sc => { + sc.dataType match { + case StringType => min(length(col(sc.name))).alias(sc.name + "minimumLength") + case _ => lit("").alias(sc.name + "minimumLength") + } + }).toArray + } + + + def getMeanValue(schema: StructType): Array[Column] = { + schema.map(sc => { + val new_col = if (numericCols.contains(sc.name)) { + val avgExp = avg(when(colWithFilterBlank(sc), col(sc.name))) + // val roundExp = round(avgExp, round_at) + when(avgExp.isNull, lit("")).otherwise(avgExp).alias(sc.name + "_mean") + } else { + lit("").alias(sc.name + "_mean") + } + new_col + }).toArray + } + + def getTypeLength(schema: StructType): Array[Column] = { + schema.map(sc => { + sc.dataType.typeName match { + case "byte" => lit(1L).alias(sc.name) + case "short" => lit(2L).alias(sc.name) + case "integer" => lit(4L).alias(sc.name) + case "long" => lit(8L).alias(sc.name) + case "float" => lit(4L).alias(sc.name) + case "double" => lit(8L).alias(sc.name) + case "string" => max(length(col(sc.name))).alias(sc.name) + case "date" => lit(8L).alias(sc.name) + case "timestamp" => lit(8L).alias(sc.name) + case "boolean" => lit(1L).alias(sc.name) + case name: String if name.contains("decimal") => first(lit(16L)).alias(sc.name) + case _ => lit("").alias(sc.name) + } + }).toArray + } + + def roundNumericCols(df: DataFrame, round_at: Integer): DataFrame = { + df.select(df.schema.map(sc => { + sc.dataType match { + case DoubleType => expr(s"cast (${sc.name} as decimal(38,2)) as ${sc.name}") + case FloatType => expr(s"cast (${sc.name} as decimal(38,2)) as ${sc.name}") + case _ => col(sc.name) + } + }): _*) + } + + def dataFormat(resRow: Array[Seq[Any]], metricsIdx: Map[Int, String], roundAt: Int): Array[Seq[Any]] = { + resRow.map(row => { + row.zipWithIndex.map(el => { + val e = el._1 + val round_at = metricsIdx.getOrElse(el._2 - 1, "") match { + case t if t.endsWith("Ratio") => roundAt + 2 + case _ => roundAt + } + var newE = e + try { + val v = e.toString.toDouble + newE = BigDecimal(v).setScale(round_at, BigDecimal.RoundingMode.HALF_UP).toDouble + } catch { + case e: Exception => logInfo(e.toString) + } + newE + }) + }) + } + + def getPercentileRows(metrics: Array[String], schema: StructType, df: DataFrame, relativeError: Double): (Array[Array[Double]], Array[String]) = { + var percentilePoints: Array[Double] = Array() + var percentileCols: Array[String] = Array() + if (metrics.contains("%25")) { + percentilePoints = percentilePoints :+ 0.25 + percentileCols = percentileCols :+ "%25" + } + if (metrics.contains("median")) { + percentilePoints = percentilePoints :+ 0.5 + percentileCols = percentileCols :+ "median" + } + if (metrics.contains("%75")) { + percentilePoints = percentilePoints :+ 0.75 + percentileCols = percentileCols :+ "%75" + } + + val cols = schema.map(sc => { + var res = lit(0.0).as(sc.name) + if (numericCols.contains(sc.name)) { + res = col(sc.name) + } + res + }).toArray + val quantileRows: Array[Array[Double]] = df.select(cols: _*).na.fill(0.0).stat.approxQuantile(df.columns, percentilePoints, relativeError) + (quantileRows, percentileCols) + } + + def processSelectedMetrics(metrics: Array[String]): Array[String] = { + val normalMetrics = "maximumLength,minimumLength,uniqueValueRatio,nullValueRatio,blankValueRatio,mean,standardDeviation,standardError,max,min,dataLength,primaryKeyCandidate".split(",") + val computedMetrics = "%25,median,%75".split(",") + val modeMetric = "mode".split(",") + var leftMetrics: Array[String] = Array() + var rightMetrics: Array[String] = Array() + var appendMetrics: Array[String] = Array() + metrics.map(m => { + m match { + case metric if normalMetrics.contains(metric) => leftMetrics = leftMetrics :+ metric + case metric if computedMetrics.contains(metric) => rightMetrics = rightMetrics :+ metric + case metric if modeMetric.contains(metric) => appendMetrics = appendMetrics :+ metric + case _ => require(false, "The selected metrics contains unkonwn calculation! " + m) + } + }) + leftMetrics ++ rightMetrics ++ appendMetrics + } + + def getModeValue(schema: StructType, df: DataFrame): Array[Any] = { + val mode = schema.toList.par.map(sc => { + val dfWithoutNa = df.select(col(sc.name)).na.drop() + val modeDF = dfWithoutNa.groupBy(col(sc.name)).count().orderBy(F.desc("count")).limit(2) + val modeList = modeDF.collect() + if (modeList.length != 0) { + modeList match { + case p if p.length < 2 => p(0).get(0) + case p if p(0).get(1) == p(1).get(1) => "" + case _ => modeList(0).get(0) + } + } else { + "" + } + }).toArray + mode + } + + def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = { + + round_at = Integer.valueOf(params.getOrElse("roundAt", "2")) + + val metrics = params.getOrElse(DataSummary.metrics, "dataLength,max,min,maximumLength,minimumLength,mean,standardDeviation,standardError,nullValueRatio,blankValueRatio,uniqueValueRatio,primaryKeyCandidate,median,mode").split(",").filter(!_.equalsIgnoreCase("")) + val relativeError = params.getOrElse("relativeError", "0.01").toDouble + val approxCountDistinct = params.getOrElse("approxCountDistinct", "false").toBoolean + val repartitionDF = df + val columns = repartitionDF.columns + + columns.map(col => { + if (col.contains(".") || col.contains("`")) { + throw new RuntimeException(s"The column name : ${col} contains special symbols, like . or `, please rename it first!! ") + } + }) + + var start_time = new Date().getTime + numericCols = repartitionDF.schema.filter(sc => { + sc.dataType.typeName match { + case datatype: String => Array("integer", "short", "double", "float", "long").contains(datatype) || datatype.contains("decimal") + case _ => false + } + }).map(sc => { + sc.name + }).toArray + val schema = repartitionDF.schema + + val default_metrics = Map( + "dataLength" -> getTypeLength(schema), + "max" -> getMaxNum(schema, numericCols), + "min" -> getMinNum(schema, numericCols), + "maximumLength" -> getMaxLength(schema), + "minimumLength" -> getMinLength(schema), + "mean" -> getMeanValue(schema), + "standardDeviation" -> countColsStdDevNumber(schema, numericCols), + "standardError" -> countColsStdErrNumber(schema, numericCols), + "nullValueRatio" -> nullValueCount(schema), + "blankValueRatio" -> emptyCount(schema), + "uniqueValueRatio" -> countUniqueValueRatio(schema, approxCountDistinct), + "primaryKeyCandidate" -> isPrimaryKey(schema, approxCountDistinct), + ) + val processedSelectedMetrics = processSelectedMetrics(metrics) + val newCols = processedSelectedMetrics.map(name => default_metrics.getOrElse(name, null)).filter(_ != null).flatMap(arr => arr).toArray + val metricsIdx = processedSelectedMetrics.zipWithIndex.map(t => { + (t._2, t._1) + }).toMap + var resDF = repartitionDF.select(newCols: _*) + logInfo(s"normal metrics plan:\n${resDF.explain(true)}") + val rows = resDF.collect() + val rowN = schema.length + val ordinaryPosRow = df.columns.map(col_name => String.valueOf(df.columns.indexOf(col_name) + 1)).toSeq + val normalMetricsRow = (ordinaryPosRow ++ rows(0).toSeq).grouped(rowN).map(_.toSeq).toArray.toSeq.transpose + var end_time = new Date().getTime + + logInfo("The elapsed time for normal metrics is : " + (end_time - start_time)) + + // Calculate Percentile + start_time = new Date().getTime + val (quantileRows, quantileCols) = getPercentileRows(processedSelectedMetrics, schema, df, relativeError) + end_time = new Date().getTime + logInfo("The elapsed time for percentile metrics is: " + (end_time - start_time)) + + var datatype_schema: Array[StructField] = null + var resRows: Array[Seq[Any]] = null + quantileCols.length match { + case 0 => + resRows = Range(0, schema.length).map(i => { + Seq(schema(i).name) ++ normalMetricsRow(i) + }).toArray + case _ => + resRows = Range(0, schema.length).map(i => { + Seq(schema(i).name) ++ normalMetricsRow(i) ++ quantileRows(i).toSeq + }).toArray + } + datatype_schema = ("ColumnName" +: "ordinaryPosition" +: processedSelectedMetrics).map(t => { + StructField(t, StringType) + }) + + start_time = new Date().getTime + // Calculate Mode + if (processedSelectedMetrics.contains("mode")) { + val modeRows = getModeValue(schema, df) + resRows = Range(0, schema.length).map(i => { + resRows(i) :+ modeRows(i) + }).toArray + end_time = new Date().getTime + logInfo("The elapsed time for mode metric is: " + (end_time - start_time)) + } + + + resRows = dataFormat(resRows, metricsIdx, round_at) + val resAfterTransformed = resRows.map(row => { + val t = row.map(e => String.valueOf(e)) + t + }) + val spark = df.sparkSession + spark.createDataFrame(spark.sparkContext.parallelize(resAfterTransformed.map(Row.fromSeq(_)), 1), StructType(datatype_schema)) + } + + override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = { + } + + override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = ??? + + override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = + train(df, path, params) + + override def codeExample: Code = Code(SQLCode, CodeExampleText.jsonStr + + """ + | + |set abc=''' + |{"name": "elena", "age": 57, "phone": 15552231521, "income": 433000, "label": 0} + |{"name": "candy", "age": 67, "phone": 15552231521, "income": 1200, "label": 0} + |{"name": "bob", "age": 57, "phone": 15252211521, "income": 89000, "label": 0} + |{"name": "candy", "age": 25, "phone": 15552211522, "income": 36000, "label": 1} + |{"name": "candy", "age": 31, "phone": 15552211521, "income": 300000, "label": 1} + |{"name": "finn", "age": 23, "phone": 15552211521, "income": 238000, "label": 1} + |'''; + | + |load jsonStr.`abc` as table1; + |select age, income from table1 as table2; + |run table2 as DataSummary.`` as summaryTable; + |; + """.stripMargin) + + + override def auth(etMethod: ETMethod, path: String, params: Map[String, String]): List[TableAuthResult] = { + val vtable = MLSQLTable( + Option(DB_DEFAULT.MLSQL_SYSTEM.toString), + Option("__fe_data_summary_operator__"), + OperateType.SELECT, + Option("select"), + TableType.SYSTEM) + + val context = ScriptSQLExec.contextGetOrForTest() + context.execListener.getTableAuth match { + case Some(tableAuth) => + tableAuth.auth(List(vtable)) + case None => + List(TableAuthResult(granted = true, "")) + } + } +} + +object ModeValueFormat { + val all = "all" + val empty = "empty" + val auto = "auto" +} \ No newline at end of file diff --git a/mlsql-mllib/src/test/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2Test.scala b/mlsql-mllib/src/test/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2Test.scala new file mode 100644 index 00000000..6ff6a5de --- /dev/null +++ b/mlsql-mllib/src/test/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2Test.scala @@ -0,0 +1,96 @@ +package tech.mlsql.plugins.mllib.ets.fe + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.{col, explode, struct, desc} +import org.apache.spark.streaming.SparkOperationUtil +import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers} +import streaming.core.strategy.platform.SparkRuntime +import tech.mlsql.test.BasicMLSQLConfig + +import java.sql.Timestamp +import java.time.LocalDateTime +import java.util.{Date, UUID} + +/** + * + * @Author; Andie Huang + * @Date: 2022/6/27 19:07 + * + */ +class SQLDataSummaryV2Test extends FlatSpec with SparkOperationUtil with Matchers with BasicMLSQLConfig with BeforeAndAfterAll { + def startParams = Array( + "-streaming.master", "local[*]", + "-streaming.name", "unit-test", + "-streaming.rest", "false", + "-streaming.platform", "spark", + "-streaming.enableHiveSupport", "false", + "-streaming.hive.javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=metastore_db/${UUID.randomUUID().toString};create=true", + "-streaming.spark.service", "false", + "-streaming.unittest", "true", + "-spark.sql.shuffle.partitions", "12", + "-spark.default.parallelism", "12", + "-spark.executor.memoryOverheadFactor", "0.2", + "-spark.dirver.maxResultSize", "2g" + ) + + "DataSummary" should "Summarize the Dataset" in { + withBatchContext(setupBatchContext(startParams)) { runtime: SparkRuntime => + implicit val spark: SparkSession = runtime.sparkSession + val et = new SQLDataSummaryV2() + val sseq1 = Seq( + ("elena", 57, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))), + ("abe", 50, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))), + ("AA", 10, "432000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))), + ("cc", 40, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))), + ("", 30, "434000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))), + ("bb", 21, "533000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))) + ) + val seq_df1 = spark.createDataFrame(sseq1).toDF("name", "age", "income", "date") + val res1DF = et.train(seq_df1, "", Map("atRound" -> "2", "metrics" -> "dataLength,max,min,maximumLength,minimumLength,mean,standardDeviation,standardError,nullValueRatio,blankValueRatio,uniqueValueRatio,primaryKeyCandidate,median,mode")) + res1DF.show() + assert(res1DF.collect()(0).mkString(",") === "name,1.0,5.0,elena,AA,5.0,0.0,,,,0.0,0.1667,1.0,1.0,0.0,") + assert(res1DF.collect()(1).mkString(",") === "age,2.0,4.0,57.0,10.0,,,34.67,17.77,7.2556,0.0,0.0,1.0,1.0,30.0,") + assert(res1DF.collect()(2).mkString(",") === "income,3.0,6.0,533000.0,432000.0,6.0,6.0,,,,0.0,0.0,0.67,0.0,0.0,433000.0") + assert(res1DF.collect()(3).mkString(",") === "date,4.0,8.0,2021-03-08 18:00:00,2021-03-08 18:00:00,,,,,,0.0,0.0,0.17,0.0,0.0,2021-03-08 18:00:00.0") + val sseq = Seq( + ("elena", 57, 57, 110L, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 110F, true, null, null, BigDecimal.valueOf(12), 1.123D), + ("abe", 57, 50, 120L, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 120F, true, null, null, BigDecimal.valueOf(2), 1.123D), + ("AA", 57, 10, 130L, "432000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 130F, true, null, null, BigDecimal.valueOf(2), 2.224D), + ("cc", 0, 40, 100L, "", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), Float.NaN, true, null, null, BigDecimal.valueOf(2), 2D), + ("", -1, 30, 150L, "434000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 150F, true, null, null, BigDecimal.valueOf(2), 3.375D), + ("bb", 57, 21, 160L, "533000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), Float.NaN, false, null, null, BigDecimal.valueOf(2), 3.375D) + ) + val seq_df = spark.createDataFrame(sseq).toDF("name", "favoriteNumber", "age", "mock_col1", "income", "date", "mock_col2", "alived", "extra", "extra1", "extra2", "extra3") + val res2DF = et.train(seq_df, "", Map("atRound" -> "2", "metrics" -> "dataLength,max,min,maximumLength,minimumLength,mean,standardDeviation,standardError,nullValueRatio,blankValueRatio,uniqueValueRatio,primaryKeyCandidate,median,mode")) + res2DF.show() + assert(res2DF.collect()(0).mkString(",") === "name,1.0,5.0,elena,AA,5.0,0.0,,,,0.0,0.1667,1.0,1.0,0.0,") + assert(res2DF.collect()(1).mkString(",") === "favoriteNumber,2.0,4.0,57.0,-1.0,,,37.83,29.69,12.1228,0.0,0.0,0.5,0.0,57.0,57.0") + assert(res2DF.collect()(2).mkString(",") === "age,3.0,4.0,57.0,10.0,,,34.67,17.77,7.2556,0.0,0.0,1.0,1.0,30.0,") + assert(res2DF.collect()(3).mkString(",") === "mock_col1,4.0,8.0,160.0,100.0,,,128.33,23.17,9.4575,0.0,0.0,1.0,1.0,120.0,") + assert(res2DF.collect()(4).mkString(",") === "income,5.0,6.0,533000.0,432000.0,6.0,0.0,,,,0.0,0.1667,0.8,0.0,0.0,433000.0") + assert(res2DF.collect()(5).mkString(",") === "date,6.0,8.0,2021-03-08 18:00:00,2021-03-08 18:00:00,,,,,,0.0,0.0,0.17,0.0,0.0,2021-03-08 18:00:00.0") + assert(res2DF.collect()(6).mkString(",") === "mock_col2,7.0,4.0,150.0,110.0,,,127.5,17.08,8.5391,0.3333,0.0,1.0,1.0,110.0,") + assert(res2DF.collect()(7).mkString(",") === "alived,8.0,1.0,true,false,,,,,,0.0,0.0,0.33,0.0,0.0,true") + assert(res2DF.collect()(8).mkString(",") === "extra,9.0,,,,,,,,,1.0,0.0,0.0,0.0,0.0,") + assert(res2DF.collect()(9).mkString(",") === "extra1,10.0,,,,,,,,,1.0,0.0,0.0,0.0,0.0,") + assert(res2DF.collect()(10).mkString(",") === "extra2,11.0,16.0,12.0,2.0,,,3.67,4.08,1.6667,0.0,0.0,0.33,0.0,2.0,2.0") + assert(res2DF.collect()(11).mkString(",") === "extra3,12.0,8.0,3.38,1.12,,,2.2,1.01,0.4132,0.0,0.0,0.67,0.0,2.0,") + val sseq2 = Seq( + (null, null), + (null, null) + ) + val seq_df2 = spark.createDataFrame(sseq2).toDF("col1", "col2") + val res3DF = et.train(seq_df2, "", Map("atRound" -> "2", "metrics" -> "dataLength,max,min,maximumLength,minimumLength,mean,standardDeviation,standardError,nullValueRatio,blankValueRatio,uniqueValueRatio,primaryKeyCandidate,median,mode")) + res3DF.show() + assert(res3DF.collect()(0).mkString(",") === "col1,1.0,,,,,,,,,1.0,0.0,0.0,0.0,0.0,") + assert(res3DF.collect()(1).mkString(",") === "col2,2.0,,,,,,,,,1.0,0.0,0.0,0.0,0.0,") + // val paquetDF1 = spark.sqlContext.read.format("parquet").load("/Users/yonghui.huang/Data/benchmarkZL1") + // val paquetDF2 = paquetDF1.sample(true, 1) + // println(paquetDF2.count()) + // val df1 = et.train(paquetDF2, "", Map("atRound" -> "2", "relativeError" -> "0.01")) + // df1.show() + // val df2 = et.train(paquetDF2, "", Map("atRound" -> "2", "approxCountDistinct" -> "true")) + // df2.show() + } + } +} \ No newline at end of file