Skip to content

Commit

Permalink
refuse to insert into table with loaded soft copied data
Browse files Browse the repository at this point in the history
  • Loading branch information
Matagits committed Apr 11, 2024
1 parent b05266a commit 4a39403
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com._4paradigm.std.{ExprNodeVector, VectorString}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{BooleanType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType,
ShortType, StringType, StructField, StructType, TimestampType}
import org.slf4j.LoggerFactory

import java.sql.{Date, Timestamp}
import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
Expand All @@ -35,6 +36,8 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
object InsertPlan {
case class ColInfo(colDesc: ColumnDesc, field: StructField)

private val logger = LoggerFactory.getLogger(this.getClass)

def gen(ctx: PlanContext, node: PhysicalInsertNode): SparkInstance = {
val stmt = node.GetInsertStmt()
require(stmt != null, "Fail to get insert statement")
Expand All @@ -46,6 +49,13 @@ object InsertPlan {
require(tableInfo != null && tableInfo.getName.nonEmpty,
s"table $db.$table info is not existed(no table name): $tableInfo")

val hasOfflineTableInfo = tableInfo.hasOfflineTableInfo
logger.info(s"hasOfflineTableInfo: $hasOfflineTableInfo")
if (hasOfflineTableInfo) {
val symbolicPaths = tableInfo.getOfflineTableInfo.getSymbolicPathsList
require(symbolicPaths == null || symbolicPaths.isEmpty, "can't insert into table with soft copied data")
}

val colDescList = tableInfo.getColumnDescList
var oriSchema = new StructType
val colInfoMap = mutable.Map[String, ColInfo]()
Expand All @@ -66,7 +76,6 @@ object InsertPlan {

val offlineDataPath = getOfflineDataPath(ctx, db, table)
val newTableInfoBuilder = tableInfo.toBuilder
val hasOfflineTableInfo = tableInfo.hasOfflineTableInfo
if (!hasOfflineTableInfo) {
val newOfflineInfo = OfflineTableInfo
.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com._4paradigm.openmldb.batch

import com._4paradigm.openmldb.batch.api.OpenmldbSession
import com._4paradigm.openmldb.batch.utils.SparkUtil
import com._4paradigm.openmldb.proto.NS.TableInfo
import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{BooleanType, DateType, DoubleType, FloatType, IntegerType, LongType, StringType,
Expand Down Expand Up @@ -179,14 +180,59 @@ class TestInsertPlan extends SparkTestSuite {
StructField("c3", StringType, nullable = true),
StructField("c4", StringType, nullable = true)
))
// Now if a column's type is string, and insert value is null, InsertPlan can't judge whether the value is null
// itself or null string
val expectDf = sparkSession.createDataFrame(
sparkSession.sparkContext.parallelize(Seq(
Row(1, null, "a", null),
// Now if a column's type is string, and insert value is null, InsertPlan can't judge whether the value is null
// itself or null string
Row(5, null, "NuLl", "null"),
Row(null, null, null, null))),
schema)
assert(SparkUtil.approximateDfEqual(expectDf, queryResult.getSparkDf()))
}

test("Test table with loaded deep copied data") {
val table = "t6"
openmldbConnector.executeDDL(db, s"create table $table(c1 int, c2 int64, c3 double);")
openmldbConnector.refreshCatalog()
assert(openmldbConnector.getTableInfo(db, table).getName.nonEmpty)

val testFileWithHeader = "file://" + getClass.getResource("/load_data_test_src/test_with_any_header.csv")
.getPath
openmldbSession.sql(s"load data infile '$testFileWithHeader' into table $db.$table " +
s"options(format='csv', deep_copy=true);")
val loadInfo = getLatestTableInfo(db, table)
val oldFormat = loadInfo.getOfflineTableInfo.getFormat

var querySess = new OpenmldbSession(sparkSession)
var queryResult = querySess.sql(s"select * from $db.$table")
assert(queryResult.count() == 2)

val sql = s"insert into $db.$table values (1, 1, 1)"
openmldbSession.sql(sql)
val info = getLatestTableInfo(db, table)
assert(oldFormat.equals(info.getOfflineTableInfo.getFormat))

querySess = new OpenmldbSession(sparkSession)
queryResult = querySess.sql(s"select * from $db.$table")
assert(queryResult.count() == 3)
}

def getLatestTableInfo(db: String, table: String): TableInfo = {
openmldbConnector.refreshCatalog()
openmldbConnector.getTableInfo(db, table)
}

test("Test table with loaded soft copied data") {
val table = "t7"
openmldbConnector.executeDDL(db, s"create table $table(c1 int, c2 int64, c3 double);")
openmldbConnector.refreshCatalog()
assert(openmldbConnector.getTableInfo(db, table).getName.nonEmpty)

val testFileWithHeader = "file://" + getClass.getResource("/load_data_test_src/test_with_any_header.csv")
.getPath
openmldbSession.sql(s"load data infile '$testFileWithHeader' into table $db.$table " +
"options(format='csv', deep_copy=false);")
assertThrows[IllegalArgumentException](openmldbSession.sql(s"insert into $db.$table values (1, 1, 1)"))
}
}

0 comments on commit 4a39403

Please sign in to comment.