diff --git a/build.sbt b/build.sbt index 371b3f61ec..8bde7fa3b1 100644 --- a/build.sbt +++ b/build.sbt @@ -63,6 +63,12 @@ lazy val `quill-sql` = lazy val `quill-sql-jvm` = `quill-sql`.jvm lazy val `quill-sql-js` = `quill-sql`.js +lazy val `quill-test` = + (project in file("quill-test")) + .settings(commonSettings: _*) + .settings(mimaSettings: _*) + .dependsOn(`quill-sql-jvm` % "compile->compile;test->test") + lazy val `quill-jdbc` = (project in file("quill-jdbc")) .settings(commonSettings: _*) diff --git a/build/setup.sh b/build/setup.sh index e3d0d3b0a3..b219607da6 100755 --- a/build/setup.sh +++ b/build/setup.sh @@ -38,27 +38,4 @@ done echo -e "\nPostgres ready" psql -h postgres -U postgres -c "CREATE DATABASE quill_test" -psql -h postgres -U postgres -d quill_test -a -f quill-sql/src/test/sql/postgres-schema.sql - -echo "Waiting for Cassandra" -until nc -z cassandra 9042 -do - printf "." - sleep 1 -done -echo -e "\nCassandra ready" - -echo "CREATE KEYSPACE quill_test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1};" > /tmp/create-keyspace.cql -cqlsh cassandra -f /tmp/create-keyspace.cql -cqlsh cassandra -k quill_test -f quill-cassandra/src/test/cql/cassandra-schema.cql - -echo "Waiting for Sql Server" -until sqlcmd -S sqlserver -U SA -P 'QuillRocks!' -Q "SELECT 1" &> /dev/null -do - printf "." - sleep 1 -done -echo -e "\nSql Server ready" - -sqlcmd -S sqlserver -U SA -P "QuillRocks!" -Q "CREATE DATABASE quill_test" -sqlcmd -S sqlserver -U SA -P "QuillRocks!" -d quill_test -i quill-sql/src/test/sql/sqlserver-schema.sql +psql -h postgres -U postgres -d quill_test -a -f quill-sql/src/test/sql/postgres-schema.sql \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index beae195194..2c022d95bb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -52,9 +52,9 @@ services: links: - postgres:postgres - mysql:mysql - - cassandra:cassandra - - orientdb:orientdb - - sqlserver:sqlserver + #- cassandra:cassandra + #- orientdb:orientdb + #- sqlserver:sqlserver volumes: - ./:/app command: @@ -73,9 +73,9 @@ services: links: - postgres:postgres - mysql:mysql - - cassandra:cassandra - - orientdb:orientdb - - sqlserver:sqlserver + #- cassandra:cassandra + #- orientdb:orientdb + #- sqlserver:sqlserver volumes: - ./:/app - ~/.ivy2:/root/.ivy2 diff --git a/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala b/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala index 561eafe72c..d0145c6c1c 100644 --- a/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala +++ b/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala @@ -7,6 +7,7 @@ import io.getquill.idiom.Statement import io.getquill.idiom.StatementInterpolator._ import io.getquill.norm.Normalize import io.getquill.util.Interleave +import io.getquill.util.Messages.fail object MirrorIdiom extends MirrorIdiom @@ -22,24 +23,25 @@ class MirrorIdiom extends Idiom { } implicit def astTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[Ast] = Tokenizer[Ast] { - case ast: Query => ast.token - case ast: Function => ast.token - case ast: Value => ast.token - case ast: Operation => ast.token - case ast: Action => ast.token - case ast: Ident => ast.token - case ast: Property => ast.token - case ast: Infix => ast.token - case ast: OptionOperation => ast.token - case ast: TraversableOperation => ast.token - case ast: Dynamic => ast.token - case ast: If => ast.token - case ast: Block => ast.token - case ast: Val => ast.token - case ast: Ordering => ast.token - case ast: QuotedReference => ast.ast.token - case ast: Lift => ast.token - case ast: Assignment => ast.token + case ast: Query => ast.token + case ast: Function => ast.token + case ast: Value => ast.token + case ast: Operation => ast.token + case ast: Action => ast.token + case ast: Ident => ast.token + case ast: Property => ast.token + case ast: Infix => ast.token + case ast: OptionOperation => ast.token + case ast: TraversableOperation => ast.token + case ast: Dynamic => ast.token + case ast: If => ast.token + case ast: Block => ast.token + case ast: Val => ast.token + case ast: Ordering => ast.token + case ast: QuotedReference => ast.ast.token + case ast: Lift => ast.token + case ast: Assignment => ast.token + case ast @ (Excluded | Existing) => fail(s"$ast is not supported at this place") } implicit def ifTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[If] = Tokenizer[If] { @@ -184,6 +186,37 @@ class MirrorIdiom extends Idiom { case Delete(query) => stmt"${query.token}.delete" case Returning(query, alias, body) => stmt"${query.token}.returning((${alias.token}) => ${body.token})" case Foreach(query, alias, body) => stmt"${query.token}.foreach((${alias.token}) => ${body.token})" + case c: Conflict => stmt"${c.token}" + } + + implicit def conflictTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[Conflict] = { + + def targetProps(l: List[Property]) = l.map(p => Transform(p) { + case Ident(_) => Ident("_") + }) + + implicit val conflictTargetTokenizer = Tokenizer[Conflict.Target] { + case Conflict.NoTarget => stmt"" + case Conflict.StringValue(s) => stmt"""("${s.token}")""" + case Conflict.Properties(props) => stmt"(${targetProps(props).token})" + } + + def updateValue(v: Ast) = Transform(v) { + case Excluded => Ident("e") + case Existing => Ident("t") + } + + val updateAssignsTokenizer = Tokenizer[Assignment] { + case Assignment(i, p, v) => + stmt"(${i.token}, e) => ${p.token} -> ${scopedTokenizer(updateValue(v))}" + } + + Tokenizer[Conflict] { + case Conflict(i, t, Conflict.Update(assign)) => + stmt"${i.token}.onConflictUpdate${t.token}(${assign.map(updateAssignsTokenizer.token).mkStmt()})" + case Conflict(i, t, Conflict.Ignore) => + stmt"${i.token}.onConflictIgnore${t.token}" + } } implicit def assignmentTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[Assignment] = Tokenizer[Assignment] { diff --git a/quill-core/src/main/scala/io/getquill/ast/Ast.scala b/quill-core/src/main/scala/io/getquill/ast/Ast.scala index 1141a3f1da..bf82c424c3 100644 --- a/quill-core/src/main/scala/io/getquill/ast/Ast.scala +++ b/quill-core/src/main/scala/io/getquill/ast/Ast.scala @@ -89,6 +89,8 @@ case class If(condition: Ast, `then`: Ast, `else`: Ast) extends Ast case class Assignment(alias: Ident, property: Ast, value: Ast) extends Ast +case object Excluded extends Ast +case object Existing extends Ast //************************************************************ sealed trait Operation extends Ast @@ -126,6 +128,17 @@ case class Returning(action: Ast, alias: Ident, property: Ast) extends Action case class Foreach(query: Ast, alias: Ident, body: Ast) extends Action +case class Conflict(insert: Ast, target: Conflict.Target, action: Conflict.Action) extends Action +object Conflict { + trait Target + case object NoTarget extends Target + case class StringValue(value: String) extends Target + case class Properties(props: List[Property]) extends Target + + trait Action + case object Ignore extends Action + case class Update(assignments: List[Assignment]) extends Action +} //************************************************************ case class Dynamic(tree: Any) extends Ast diff --git a/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala b/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala index 966075bb9f..477a1f9a83 100644 --- a/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala +++ b/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala @@ -14,15 +14,13 @@ trait StatefulTransformer[T] { case e: Ident => (e, this) case e: OptionOperation => apply(e) case e: TraversableOperation => apply(e) + case e: Property => apply(e) + case Existing | Excluded => (e, this) case Function(a, b) => val (bt, btt) = apply(b) (Function(a, bt), btt) - case Property(a, b) => - val (at, att) = apply(a) - (Property(at, b), att) - case Infix(a, b) => val (bt, btt) = apply(b)(_.apply) (Infix(a, bt), btt) @@ -168,6 +166,13 @@ trait StatefulTransformer[T] { (Assignment(a, bt, ct), ctt) } + def apply(e: Property): (Property, StatefulTransformer[T]) = + e match { + case Property(a, b) => + val (at, att) = apply(a) + (Property(at, b), att) + } + def apply(e: Operation): (Operation, StatefulTransformer[T]) = e match { case UnaryOperation(o, a) => @@ -217,6 +222,27 @@ trait StatefulTransformer[T] { val (at, att) = apply(a) val (ct, ctt) = att.apply(c) (Foreach(at, b, ct), ctt) + case Conflict(a, b, c) => + val (at, att) = apply(a) + val (bt, btt) = att.apply(b) + val (ct, ctt) = btt.apply(c) + (Conflict(at, bt, ct), ctt) + } + + def apply(e: Conflict.Target): (Conflict.Target, StatefulTransformer[T]) = + e match { + case Conflict.NoTarget | Conflict.StringValue(_) => (e, this) + case Conflict.Properties(a) => + val (at, att) = apply(a)(_.apply) + (Conflict.Properties(at), att) + } + + def apply(e: Conflict.Action): (Conflict.Action, StatefulTransformer[T]) = + e match { + case Conflict.Ignore => (e, this) + case Conflict.Update(a) => + val (at, att) = apply(a)(_.apply) + (Conflict.Update(at), att) } def apply[U, R](list: List[U])(f: StatefulTransformer[T] => U => (R, StatefulTransformer[T])) = diff --git a/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala b/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala index 08f8506060..d418468082 100644 --- a/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala +++ b/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala @@ -11,7 +11,7 @@ trait StatelessTransformer { case e: Assignment => apply(e) case Function(params, body) => Function(params, apply(body)) case e: Ident => e - case Property(a, name) => Property(apply(a), name) + case e: Property => apply(e) case Infix(a, b) => Infix(a, b.map(apply)) case e: OptionOperation => apply(e) case e: TraversableOperation => apply(e) @@ -22,6 +22,7 @@ trait StatelessTransformer { case Block(statements) => Block(statements.map(apply)) case Val(name, body) => Val(name, apply(body)) case o: Ordering => o + case Excluded | Existing => e } def apply(o: OptionOperation): OptionOperation = @@ -69,6 +70,11 @@ trait StatelessTransformer { case Assignment(a, b, c) => Assignment(a, apply(b), apply(c)) } + def apply(e: Property): Property = + e match { + case Property(a, name) => Property(apply(a), name) + } + def apply(e: Operation): Operation = e match { case UnaryOperation(o, a) => UnaryOperation(o, apply(a)) @@ -94,6 +100,19 @@ trait StatelessTransformer { case Delete(query) => Delete(apply(query)) case Returning(query, alias, property) => Returning(apply(query), alias, apply(property)) case Foreach(query, alias, body) => Foreach(apply(query), alias, apply(body)) + case Conflict(query, target, action) => Conflict(apply(query), apply(target), apply(action)) + } + + def apply(e: Conflict.Target): Conflict.Target = + e match { + case Conflict.NoTarget | Conflict.StringValue(_) => e + case Conflict.Properties(props) => Conflict.Properties(props.map(apply)) + } + + def apply(e: Conflict.Action): Conflict.Action = + e match { + case Conflict.Ignore => e + case Conflict.Update(assigns) => Conflict.Update(assigns.map(apply)) } } diff --git a/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala b/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala index 14ea7dd5a5..36402322da 100644 --- a/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala +++ b/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala @@ -91,6 +91,41 @@ private[dsl] trait QueryDsl { sealed trait Insert[E] extends Action[E] { @compileTimeOnly(NonQuotedException.message) def returning[R](f: E => R): ActionReturning[E, R] = NonQuotedException() + + @compileTimeOnly(NonQuotedException.message) + def onConflictIgnore: Insert[E] = NonQuotedException() + + @compileTimeOnly(NonQuotedException.message) + def onConflictIgnore(target: String): Insert[E] = NonQuotedException() + + @compileTimeOnly(NonQuotedException.message) + def onConflictIgnore(target: E => Any, targets: (E => Any)*): Insert[E] = NonQuotedException() + + @compileTimeOnly(NonQuotedException.message) + def onConflictUpdate(assign: ((E, E) => (Any, Any)), assigns: ((E, E) => (Any, Any))*): Insert[E] = NonQuotedException() + + @compileTimeOnly(NonQuotedException.message) + def onConflictUpdate(target: String)(assign: ((E, E) => (Any, Any)), assigns: ((E, E) => (Any, Any))*): Insert[E] = NonQuotedException() + + /** + * Generates an atomic INSERT or UPDATE (upsert) action if supported. + * + * @param targets - conflict target + * @param assigns - update statement, declared as function: (table, excluded) => (assign, result). + * `table` - is used to extract column for update assignment and reference existing row + * `excluded` - aliases excluded table, e.g. row proposed for insertion. + * `assign` - left hand side of assignment. Should be accessed from `table` argument + * `result` - right hand side of assignment. + * + * Example usage: + * ``` + * insert.onConflictUpdate(_.id)((t, e) => t.col -> (e.col + t.col)) + * ``` + * If insert statement violates conflict target then the column `col` of row will be updated with sum of + * existing value and and proposed `col` in insert. + */ + @compileTimeOnly(NonQuotedException.message) + def onConflictUpdate(target: E => Any, targets: (E => Any)*)(assign: ((E, E) => (Any, Any)), assigns: ((E, E) => (Any, Any))*): Insert[E] = NonQuotedException() } sealed trait ActionReturning[E, Output] extends Action[E] diff --git a/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala b/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala index 84f72914a6..060d24bdcf 100644 --- a/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala +++ b/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala @@ -6,14 +6,18 @@ object NormalizeReturning { def apply(e: Action): Action = { e match { - case Returning(Insert(query, assignments), alias, body) => - Returning(Insert(query, filterReturnedColumn(assignments, body)), alias, body) - case Returning(Update(query, assignments), alias, body) => - Returning(Update(query, filterReturnedColumn(assignments, body)), alias, body) - case e => e + case Returning(a: Action, alias, body) => Returning(apply(a, body), alias, body) + case _ => e } } + private def apply(e: Action, body: Ast): Action = e match { + case Insert(query, assignments) => Insert(query, filterReturnedColumn(assignments, body)) + case Update(query, assignments) => Update(query, filterReturnedColumn(assignments, body)) + case Conflict(a: Action, target, act) => Conflict(apply(a, body), target, act) + case _ => e + } + private def filterReturnedColumn(assignments: List[Assignment], column: Ast): List[Assignment] = assignments.flatMap(filterReturnedColumn(_, column)) diff --git a/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala b/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala index c0f5fb995a..a2a21f99a5 100644 --- a/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala +++ b/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala @@ -29,6 +29,10 @@ object RenameProperties extends StatelessTransformer { val bodyr = BetaReduction(body, replace: _*) (Returning(action, alias, bodyr), schema) } + case Conflict(a: Action, target, act) => + applySchema(a) match { + case (action, schema) => (Conflict(action, target, act), schema) + } case q => (q, Tuple(List.empty)) } diff --git a/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala b/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala index 9ea186859b..a84eefb1f6 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala @@ -49,6 +49,8 @@ case class FreeVariables(state: State) super.apply(other) } + override def apply(e: Conflict.Target): (Conflict.Target, StatefulTransformer[State]) = (e, this) + override def apply(query: Query): (Query, StatefulTransformer[State]) = query match { case q @ Filter(a, b, c) => (q, free(a, b, c)) diff --git a/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala b/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala index 06d336532d..eb9df36f8d 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala @@ -21,9 +21,9 @@ trait Liftables { case ast: Assignment => assignmentLiftable(ast) case ast: OptionOperation => optionOperationLiftable(ast) case ast: TraversableOperation => traversableOperationLiftable(ast) + case ast: Property => propertyLiftable(ast) case Val(name, body) => q"$pack.Val($name, $body)" case Block(statements) => q"$pack.Block($statements)" - case Property(a, b) => q"$pack.Property($a, $b)" case Function(a, b) => q"$pack.Function($a, $b)" case FunctionApply(a, b) => q"$pack.FunctionApply($a, $b)" case BinaryOperation(a, b, c) => q"$pack.BinaryOperation($a, $b, $c)" @@ -33,6 +33,8 @@ trait Liftables { case Dynamic(tree: Tree) if (tree.tpe <:< c.weakTypeOf[CoreDsl#Quoted[Any]]) => q"$tree.ast" case Dynamic(tree: Tree) => q"$pack.Constant($tree)" case QuotedReference(tree: Tree, ast) => q"$ast" + case Excluded => q"$pack.Excluded" + case Existing => q"$pack.Existing" } implicit val optionOperationLiftable: Liftable[OptionOperation] = Liftable[OptionOperation] { @@ -113,6 +115,10 @@ trait Liftables { case PropertyAlias(a, b) => q"$pack.PropertyAlias($a, $b)" } + implicit val propertyLiftable: Liftable[Property] = Liftable[Property] { + case Property(a, b) => q"$pack.Property($a, $b)" + } + implicit val orderingLiftable: Liftable[Ordering] = Liftable[Ordering] { case TupleOrdering(elems) => q"$pack.TupleOrdering($elems)" case Asc => q"$pack.Asc" @@ -136,6 +142,18 @@ trait Liftables { case Delete(a) => q"$pack.Delete($a)" case Returning(a, b, c) => q"$pack.Returning($a, $b, $c)" case Foreach(a, b, c) => q"$pack.Foreach($a, $b, $c)" + case Conflict(a, b, c) => q"$pack.Conflict($a, $b, $c)" + } + + implicit val conflictTargetLiftable: Liftable[Conflict.Target] = Liftable[Conflict.Target] { + case Conflict.NoTarget => q"$pack.Conflict.NoTarget" + case Conflict.StringValue(a) => q"$pack.Conflict.StringValue.apply($a)" + case Conflict.Properties(a) => q"$pack.Conflict.Properties.apply($a)" + } + + implicit val conflictActionLiftable: Liftable[Conflict.Action] = Liftable[Conflict.Action] { + case Conflict.Ignore => q"$pack.Conflict.Ignore" + case Conflict.Update(a) => q"$pack.Conflict.Update.apply($a)" } implicit val assignmentLiftable: Liftable[Assignment] = Liftable[Assignment] { diff --git a/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala b/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala index b00b5ffff7..2fde870bfa 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala @@ -41,6 +41,7 @@ trait Parsing { case `quotedAstParser`(value) => value case `functionParser`(value) => value case `actionParser`(value) => value + case `conflictParser`(value) => value case `infixParser`(value) => value case `orderingParser`(value) => value case `operationParser`(value) => value @@ -317,7 +318,7 @@ trait Parsing { ListContains(astParser(col), astParser(body)) } - val propertyParser: Parser[Ast] = Parser[Ast] { + val propertyParser: Parser[Property] = Parser[Property] { case q"$e.get" if is[Option[Any]](e) => c.fail("Option.get is not supported since it's an unsafe operation. Use `forall` or `exists` instead.") case q"$e.$property" => Property(astParser(e), property.decodedName.toString) @@ -535,10 +536,51 @@ trait Parsing { checkTypes(prop, value) Assignment(i1, astParser(prop), astParser(value)) + case tr @ q"((${ identParser(i1) }, ${ identParser(i2) }) => $pack.Predef.ArrowAssoc[$t]($prop).$arrow[$v]($value))" => + println(tr) + println(s"$i1, $i2, $prop, $value") + checkTypes(prop, value) + + val valueAst = Transform(astParser(value)) { + case `i1` => Existing + case `i2` => Excluded + } + Assignment(i1, astParser(prop), valueAst) // Unused, it's here only to make eclipse's presentation compiler happy case astParser(ast) => Assignment(Ident("unused"), Ident("unused"), Constant("unused")) } + /*private def excludedParser(tpe: Type): Parser[Excluded] = Parser[Excluded] { + case q"$pack.excluded[$t].$prop" => Excluded(Property(Ident("excluded"), prop.decodedName.toString)) + }*/ + + val conflictParser: Parser[Ast] = Parser[Ast] { + case q"$query.onConflictIgnore" => + Conflict(astParser(query), Conflict.NoTarget, Conflict.Ignore) + case q"$query.onConflictIgnore(${ target: String })" => + Conflict(astParser(query), Conflict.StringValue(target), Conflict.Ignore) + case q"$query.onConflictIgnore(..$targets)" => + Conflict(astParser(query), parseConflictProps(targets), Conflict.Ignore) + + case q"$query.onConflictUpdate(..$assigns)" => + Conflict(astParser(query), Conflict.NoTarget, parseConflictAssigns(assigns)) + case q"$query.onConflictUpdate(${ target: String })(..$assigns)" => + Conflict(astParser(query), Conflict.StringValue(target), parseConflictAssigns(assigns)) + case q"$query.onConflictUpdate(..$targets)(..$assigns)" => + Conflict(astParser(query), parseConflictProps(targets), parseConflictAssigns(assigns)) + } + + private def parseConflictProps(targets: List[Tree]) = Conflict.Properties { + targets.map { + case q"($e) => $prop" => propertyParser(prop) + case tree => c.fail(s"Tree '$tree' can't be parsed as conflict target") + } + } + + private def parseConflictAssigns(targets: List[Tree]) = Conflict.Update { + targets.map(assignmentParser(_)) + } + private def checkTypes(lhs: Tree, rhs: Tree): Unit = { def unquoted(tree: Tree) = is[CoreDsl#Quoted[Any]](tree) match { diff --git a/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala b/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala index 1db009462e..c3d0af59b3 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala @@ -18,7 +18,7 @@ trait Unliftables { case orderingUnliftable(ast) => ast case optionOperationUnliftable(ast) => ast case traversableOperationUnliftable(ast) => ast - case q"$pack.Property.apply(${ a: Ast }, ${ b: String })" => Property(a, b) + case propertyUnliftable(ast) => ast case q"$pack.Function.apply(${ a: List[Ident] }, ${ b: Ast })" => Function(a, b) case q"$pack.FunctionApply.apply(${ a: Ast }, ${ b: List[Ast] })" => FunctionApply(a, b) case q"$pack.BinaryOperation.apply(${ a: Ast }, ${ b: BinaryOperator }, ${ c: Ast })" => BinaryOperation(a, b, c) @@ -26,6 +26,8 @@ trait Unliftables { case q"$pack.Aggregation.apply(${ a: AggregationOperator }, ${ b: Ast })" => Aggregation(a, b) case q"$pack.Infix.apply(${ a: List[String] }, ${ b: List[Ast] })" => Infix(a, b) case q"$pack.If.apply(${ a: Ast }, ${ b: Ast }, ${ c: Ast })" => If(a, b, c) + case q"$pack.Excluded" => Excluded + case q"$pack.Existing" => Existing case q"$tree.ast" => Dynamic(tree) } @@ -123,6 +125,10 @@ trait Unliftables { case q"$pack.PropertyAlias.apply(${ a: List[String] }, ${ b: String })" => PropertyAlias(a, b) } + implicit val propertyUnliftable: Unliftable[Property] = Unliftable[Property] { + case q"$pack.Property.apply(${ a: Ast }, ${ b: String })" => Property(a, b) + } + implicit def optionUnliftable[T](implicit u: Unliftable[T]): Unliftable[Option[T]] = Unliftable[Option[T]] { case q"scala.None" => None case q"scala.Some.apply[$t]($v)" => Some(u.unapply(v).get) @@ -136,11 +142,23 @@ trait Unliftables { } implicit val actionUnliftable: Unliftable[Action] = Unliftable[Action] { - case q"$pack.Update.apply(${ a: Ast }, ${ b: List[Assignment] })" => Update(a, b) - case q"$pack.Insert.apply(${ a: Ast }, ${ b: List[Assignment] })" => Insert(a, b) - case q"$pack.Delete.apply(${ a: Ast })" => Delete(a) + case q"$pack.Update.apply(${ a: Ast }, ${ b: List[Assignment] })" => Update(a, b) + case q"$pack.Insert.apply(${ a: Ast }, ${ b: List[Assignment] })" => Insert(a, b) + case q"$pack.Delete.apply(${ a: Ast })" => Delete(a) case q"$pack.Returning.apply(${ a: Ast }, ${ b: Ident }, ${ c: Ast })" => Returning(a, b, c) - case q"$pack.Foreach.apply(${ a: Ast }, ${ b: Ident }, ${ c: Ast })" => Foreach(a, b, c) + case q"$pack.Foreach.apply(${ a: Ast }, ${ b: Ident }, ${ c: Ast })" => Foreach(a, b, c) + case q"$pack.Conflict.apply(${ a: Ast }, ${ b: Conflict.Target }, ${ c: Conflict.Action })" => Conflict(a, b, c) + } + + implicit val conflictTargetUnliftable: Unliftable[Conflict.Target] = Unliftable[Conflict.Target] { + case q"$pack.Conflict.NoTarget" => Conflict.NoTarget + case q"$pack.Conflict.StringValue.apply(${ a: String })" => Conflict.StringValue(a) + case q"$pack.Conflict.Properties.apply(${ a: List[Property] })" => Conflict.Properties(a) + } + + implicit val conflictActionUnliftable: Unliftable[Conflict.Action] = Unliftable[Conflict.Action] { + case q"$pack.Conflict.Ignore" => Conflict.Ignore + case q"$pack.Conflict.Update.apply(${ a: List[Assignment] })" => Conflict.Update(a) } implicit val assignmentUnliftable: Unliftable[Assignment] = Unliftable[Assignment] { diff --git a/quill-core/src/test/scala/io/getquill/ast/StatefulTransformerSpec.scala b/quill-core/src/test/scala/io/getquill/ast/StatefulTransformerSpec.scala index 8fcc77118f..76aebd4528 100644 --- a/quill-core/src/test/scala/io/getquill/ast/StatefulTransformerSpec.scala +++ b/quill-core/src/test/scala/io/getquill/ast/StatefulTransformerSpec.scala @@ -216,6 +216,78 @@ class StatefulTransformerSpec extends Spec { att.state mustEqual List(Ident("a")) } } + "conflict" in { + val ast: Ast = Conflict(Insert(Ident("a"), Nil), Conflict.NoTarget, Conflict.Ignore) + Subject(Nil, Ident("a") -> Ident("a'"))(ast) match { + case (at, att) => + at mustEqual Conflict(Insert(Ident("a'"), Nil), Conflict.NoTarget, Conflict.Ignore) + att.state mustEqual List(Ident("a")) + } + } + } + + "conflict.target" - { + "no" in { + val target: Conflict.Target = Conflict.NoTarget + Subject(Nil)(target) match { + case (at, att) => + at mustEqual target + att.state mustEqual Nil + } + } + "string" in { + val target: Conflict.Target = Conflict.StringValue("a") + Subject(Nil)(target) match { + case (at, att) => + at mustEqual target + att.state mustEqual Nil + } + } + "properties" in { + val target: Conflict.Target = Conflict.Properties(List(Property(Ident("a"), "b"))) + Subject(Nil, Ident("a") -> Ident("a'"))(target) match { + case (at, att) => + at mustEqual Conflict.Properties(List(Property(Ident("a'"), "b"))) + att.state mustEqual List(Ident("a")) + } + } + } + + "conflict.action" - { + "ignore" in { + val action: Conflict.Action = Conflict.Ignore + Subject(Nil)(action) match { + case (at, att) => + at mustEqual action + att.state mustEqual Nil + } + } + "update" in { + val action: Conflict.Action = Conflict.Update(List(Assignment(Ident("a"), Ident("b"), Ident("c")))) + Subject(Nil, Ident("a") -> Ident("a'"), Ident("b") -> Ident("b'"), Ident("c") -> Ident("c'"))(action) match { + case (at, att) => + at mustEqual Conflict.Update(List(Assignment(Ident("a"), Ident("b'"), Ident("c'")))) + att.state mustEqual List(Ident("b"), Ident("c")) + } + } + } + + "excluded" in { + val ast: Ast = Excluded + Subject(Nil)(ast) match { + case (at, att) => + at mustEqual ast + att.state mustEqual Nil + } + } + + "existing" in { + val ast: Ast = Existing + Subject(Nil)(ast) match { + case (at, att) => + at mustEqual ast + att.state mustEqual Nil + } } "function" in { diff --git a/quill-core/src/test/scala/io/getquill/ast/StatelessTransformerSpec.scala b/quill-core/src/test/scala/io/getquill/ast/StatelessTransformerSpec.scala index 866003db29..7c1b65e46a 100644 --- a/quill-core/src/test/scala/io/getquill/ast/StatelessTransformerSpec.scala +++ b/quill-core/src/test/scala/io/getquill/ast/StatelessTransformerSpec.scala @@ -148,6 +148,49 @@ class StatelessTransformerSpec extends Spec { Subject(Ident("a") -> Ident("a'"))(ast) mustEqual Delete(Ident("a'")) } + "conflict" in { + val ast: Ast = Conflict(Insert(Ident("a"), Nil), Conflict.NoTarget, Conflict.Ignore) + Subject(Ident("a") -> Ident("a'"))(ast) mustEqual + Conflict(Insert(Ident("a'"), Nil), Conflict.NoTarget, Conflict.Ignore) + } + } + + "conflict.target" - { + "no" in { + val target: Conflict.Target = Conflict.NoTarget + Subject()(target) mustEqual target + } + "string" in { + val target: Conflict.Target = Conflict.StringValue("a") + Subject()(target) mustEqual target + } + "properties" in { + val target: Conflict.Target = Conflict.Properties(List(Property(Ident("a"), "b"))) + Subject(Ident("a") -> Ident("a'"))(target) mustEqual + Conflict.Properties(List(Property(Ident("a'"), "b"))) + } + } + + "conflict.action" - { + "ignore" in { + val action: Conflict.Action = Conflict.Ignore + Subject()(action) mustEqual action + } + "update" in { + val action: Conflict.Action = Conflict.Update(List(Assignment(Ident("a"), Ident("b"), Ident("c")))) + Subject(Ident("a") -> Ident("a'"), Ident("b") -> Ident("b'"), Ident("c") -> Ident("c'"))(action) mustEqual + Conflict.Update(List(Assignment(Ident("a"), Ident("b'"), Ident("c'")))) + } + } + + "excluded" in { + val ast: Ast = Excluded + Subject()(ast) mustEqual ast + } + + "existing" in { + val ast: Ast = Existing + Subject()(ast) mustEqual ast } "function" in { diff --git a/quill-core/src/test/scala/io/getquill/context/mirror/MirrorIdiomSpec.scala b/quill-core/src/test/scala/io/getquill/context/mirror/MirrorIdiomSpec.scala index b899b584af..20510c8312 100644 --- a/quill-core/src/test/scala/io/getquill/context/mirror/MirrorIdiomSpec.scala +++ b/quill-core/src/test/scala/io/getquill/context/mirror/MirrorIdiomSpec.scala @@ -447,6 +447,37 @@ class MirrorIdiomSpec extends Spec { stmt"${(q.ast: Ast).token}" mustEqual stmt"""querySchema("TestEntity").delete""" } + + "conflict" - { + val i = quote { + query[TestEntity].insert(t => t.s -> "a") + } + val t = stmt"""querySchema("TestEntity").insert(t => t.s -> "a")""" + "onConflictIgnore" in { + stmt"${(i.onConflictIgnore.ast: Ast).token}" mustEqual + stmt"$t.onConflictIgnore" + } + "onConflictIgnore(str)" in { + stmt"${(i.onConflictIgnore("str").ast: Ast).token}" mustEqual + stmt"""$t.onConflictIgnore("str")""" + } + "onConflictIgnore(targets*)" in { + stmt"${(i.onConflictIgnore(_.i, _.s).ast: Ast).token}" mustEqual + stmt"$t.onConflictIgnore(_.i, _.s)" + } + "onConflictUpdate(assigns*)" in { + stmt"${(i.onConflictUpdate((t, e) => t.s -> e.s, (t, e) => t.i -> (t.i + 1)).ast: Ast).token}" mustEqual + stmt"$t.onConflictUpdate((t, e) => t.s -> e.s, (t, e) => t.i -> (t.i + 1))" + } + "onConflictUpdate(str)(assigns*)" in { + stmt"${(i.onConflictUpdate("str")((t, e) => t.s -> e.s).ast: Ast).token}" mustEqual + stmt"""$t.onConflictUpdate("str")((t, e) => t.s -> e.s)""" + } + "onConflictUpdate(targets*)(assigns*)" in { + stmt"${(i.onConflictUpdate(_.i)((t, e) => t.s -> e.s).ast: Ast).token}" mustEqual + stmt"$t.onConflictUpdate(_.i)((t, e) => t.s -> e.s)" + } + } } "shows infix" - { diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/JdbcUpsertSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/JdbcUpsertSpec.scala new file mode 100644 index 0000000000..bd4cba79cb --- /dev/null +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/JdbcUpsertSpec.scala @@ -0,0 +1,21 @@ +package io.getquill.context.jdbc.postgres + +import io.getquill.context.sql.idiom.ConflictSpec +import org.scalatest.BeforeAndAfterEach + +class JdbcUpsertSpec extends ConflictSpec with BeforeAndAfterEach { + val ctx = testContext + import ctx._ + + "Ex1" in { + ctx.run(`Ex1 query`) mustBe 1 + } + + private def prepareE1() = ctx.run(ins) + private def prepareE2() = ctx.run(ins2) + override protected def afterEach(): Unit = { + ctx.run(del1) + ctx.run(del2) + () + } +} diff --git a/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala b/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala index f1b94f566f..2b310f9ddf 100644 --- a/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala +++ b/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala @@ -1,21 +1,11 @@ package io.getquill -import io.getquill.idiom.StatementInterpolator._ -import io.getquill.ast.Asc -import io.getquill.ast.AscNullsFirst -import io.getquill.ast.AscNullsLast -import io.getquill.ast.BinaryOperation -import io.getquill.ast.Desc -import io.getquill.ast.DescNullsFirst -import io.getquill.ast.DescNullsLast -import io.getquill.ast.Operation -import io.getquill.ast.StringOperator -import io.getquill.context.sql.idiom.SqlIdiom +import io.getquill.ast.{ Ast, _ } import io.getquill.context.sql.OrderByCriteria -import io.getquill.context.sql.idiom.QuestionMarkBindVariables +import io.getquill.context.sql.idiom.{ NoConcatSupport, QuestionMarkBindVariables, SqlIdiom } +import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom.{ Statement, Token } -import io.getquill.ast.Ast -import io.getquill.context.sql.idiom.NoConcatSupport +import io.getquill.util.Messages.fail trait MySQLDialect extends SqlIdiom @@ -29,6 +19,36 @@ trait MySQLDialect override def defaultAutoGeneratedToken(field: Token) = stmt"($field) VALUES (DEFAULT)" + override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Ast] = + Tokenizer[Ast] { + case c: Conflict => c.token + case ast => super.astTokenizer.token(ast) + } + + implicit def conflictTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Conflict] = { + import Conflict._ + + def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = + Tokenizer[Conflict] { + case Conflict(i, NoTarget, Update(a)) => + stmt"${i.token} ON DUPLICATE KEY UPDATE ${a.token}" + case Conflict(i, Properties(p), Ignore) => + stmt"${i.token} ON DUPLICATE KEY UPDATE ${p.map(_.token).map(t => stmt"$t=$t").token}" + case Conflict(i, NoTarget, Ignore) => + fail("maybe insert ignore here?") + case _ => + fail("This upsert construct is not supported in MySQL. Please refer documentation for details.") + } + + val customAstTokenizer = + Tokenizer.withFallback[Ast](MySQLDialect.this.astTokenizer(_, strategy)) { + case Property(Excluded, name) => stmt"VALUES(${strategy.column(name).token})" + case Property(_, name) => strategy.column(name).token + } + + tokenizer(customAstTokenizer) + } + override implicit def operationTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Operation] = Tokenizer[Operation] { case BinaryOperation(a, StringOperator.`+`, b) => stmt"CONCAT(${a.token}, ${b.token})" diff --git a/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala b/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala index 11952a9b87..bce9ac7c23 100644 --- a/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala +++ b/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala @@ -3,9 +3,10 @@ package io.getquill import java.util.concurrent.atomic.AtomicInteger import io.getquill.ast._ -import io.getquill.context.sql.idiom.{ QuestionMarkBindVariables, SqlIdiom } +import io.getquill.context.sql.idiom.{ ConcatSupport, QuestionMarkBindVariables, SqlIdiom } import io.getquill.idiom.StatementInterpolator._ -import io.getquill.context.sql.idiom.ConcatSupport +import io.getquill.idiom.Token +import io.getquill.util.Messages.fail trait PostgresDialect extends SqlIdiom @@ -15,9 +16,61 @@ trait PostgresDialect override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Ast] = Tokenizer[Ast] { case ListContains(ast, body) => stmt"${body.token} = ANY(${ast.token})" + case c: Conflict => conflictTokenizer.token(c) case ast => super.astTokenizer.token(ast) } + implicit def conflictTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Conflict] = { + + val customEntityTokenizer = Tokenizer[Entity] { + case Entity(name, _) => stmt"${strategy.table(name).token} AS t" + } + + val customAstTokenizer = + Tokenizer.withFallback[Ast](PostgresDialect.this.astTokenizer(_, strategy)) { + case Excluded => stmt"EXCLUDED" + case Existing => stmt"t" + case a: Action => super.actionTokenizer(astTokenizer, customEntityTokenizer, strategy).token(a) + } + + implicit val conflictTargetPropsTokenizer = + Tokenizer[Conflict.Properties] { + case Conflict.Properties(props) => stmt"(${props.map(n => strategy.column(n.name)).mkStmt(",")})" + } + + def doUpdateStmt(i: Token, t: Token, a: Token) = stmt"$i ON CONFLICT $t DO UPDATE SET $a" + + def doNothingStmt(i: Ast, t: Token) = stmt"${i.token} ON CONFLICT $t DO NOTHING" + + def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = + Tokenizer[Conflict] { + + case Conflict(_, Conflict.NoTarget, _: Conflict.Update) => + fail("'DO UPDATE' statement requires explicit conflict target") + + case Conflict(i, Conflict.StringValue(s), Conflict.Update(a)) => doUpdateStmt(i.token, s.token, a.token) + case Conflict(i, p: Conflict.Properties, Conflict.Update(a)) => doUpdateStmt(i.token, p.token, a.token) + + case Conflict(i, Conflict.NoTarget, Conflict.Ignore) => stmt"${i.token} ON CONFLICT DO NOTHING" + case Conflict(i, Conflict.StringValue(s), Conflict.Ignore) => doNothingStmt(i, s.token) + case Conflict(i, p: Conflict.Properties, Conflict.Ignore) => doNothingStmt(i, p.token) + } + + tokenizer(customAstTokenizer) + } + + private implicit def conflictTargetTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Conflict.Target] = + Tokenizer[Conflict.Target] { + case Conflict.StringValue(v) => v.token + case Conflict.Properties(props) => stmt"(${props.map(n => strategy.column(n.name)).mkStmt(",")})" + } + + private implicit def conflictActionTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Conflict.Action] = + Tokenizer[Conflict.Action] { + case Conflict.Ignore => stmt"DO NOTHING" + case Conflict.Update(assigns) => stmt"DO UPDATE SET ${assigns.token}" + } + override implicit def operationTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Operation] = Tokenizer[Operation] { case UnaryOperation(StringOperator.`toLong`, ast) => stmt"${scopedTokenizer(ast)}::bigint" diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala b/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala index 67c17e9209..26af43c1ff 100644 --- a/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala +++ b/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala @@ -60,7 +60,7 @@ trait SqlIdiom extends Idiom { case a: OptionOperation => a.token case a @ ( _: Function | _: FunctionApply | _: Dynamic | _: OptionOperation | _: Block | - _: Val | _: Ordering | _: QuotedReference | _: TraversableOperation + _: Val | _: Ordering | _: QuotedReference | _: TraversableOperation | Excluded | Existing ) => fail(s"Malformed or unsupported construct: $a.") } @@ -317,50 +317,50 @@ trait SqlIdiom extends Idiom { stmt"${prop.token} = ${scopedTokenizer(value)}" } - implicit def actionTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Action] = { + implicit def actionTokenizer(implicit astTokenizer: Tokenizer[Ast], entityTokenizer: Tokenizer[Entity], strategy: NamingStrategy): Tokenizer[Action] = { + actionTokenizer(customAstTokenizer, entityTokenizer, strategy) + } - def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = - Tokenizer[Action] { + protected def customActionTokenizer(implicit astTokenizer: Tokenizer[Ast], entityTokenizer: Tokenizer[Entity], strategy: NamingStrategy): Tokenizer[Action] = { + Tokenizer[Action] { - case Insert(table: Entity, assignments) => - val columns = assignments.map(_.property.token) - val values = assignments.map(_.value) - stmt"INSERT INTO ${table.token} (${columns.mkStmt(",")}) VALUES (${values.map(scopedTokenizer(_)).mkStmt(", ")})" + case Insert(table: Entity, assignments) => + val columns = assignments.map(_.property.token) + val values = assignments.map(_.value) + stmt"INSERT INTO ${table.token} (${columns.mkStmt(",")}) VALUES (${values.map(scopedTokenizer(_)).mkStmt(", ")})" - case Update(table: Entity, assignments) => - stmt"UPDATE ${table.token} SET ${assignments.token}" + case Update(table: Entity, assignments) => + stmt"UPDATE ${table.token} SET ${assignments.token}" - case Update(Filter(table: Entity, x, where), assignments) => - stmt"UPDATE ${table.token} SET ${assignments.token} WHERE ${where.token}" + case Update(Filter(table: Entity, x, where), assignments) => + stmt"UPDATE ${table.token} SET ${assignments.token} WHERE ${where.token}" - case Delete(Filter(table: Entity, x, where)) => - stmt"DELETE FROM ${table.token} WHERE ${where.token}" + case Delete(Filter(table: Entity, x, where)) => + stmt"DELETE FROM ${table.token} WHERE ${where.token}" - case Delete(table: Entity) => - stmt"DELETE FROM ${table.token}" + case Delete(table: Entity) => + stmt"DELETE FROM ${table.token}" - case Returning(Insert(table: Entity, Nil), alias, prop) => - stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)}" + case Returning(Insert(table: Entity, Nil), alias, prop) => + stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)}" - case Returning(action, alias, prop) => - action.token + case Returning(action, alias, prop) => + action.token - case other => - fail(s"Action ast can't be translated to sql: '$other'") - } - - val customAstTokenizer = - Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { - case q: Query => astTokenizer.token(q) - case Property(Property(_, name), "isEmpty") => stmt"${strategy.column(name).token} IS NULL" - case Property(Property(_, name), "isDefined") => stmt"${strategy.column(name).token} IS NOT NULL" - case Property(Property(_, name), "nonEmpty") => stmt"${strategy.column(name).token} IS NOT NULL" - case Property(_, name) => strategy.column(name).token - } - - tokenizer(customAstTokenizer) + case other => + fail(s"Action ast can't be translated to sql: '$other'") + } } + protected def customAstTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) = + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + case q: Query => astTokenizer.token(q) + case Property(Property(_, name), "isEmpty") => stmt"${strategy.column(name).token} IS NULL" + case Property(Property(_, name), "isDefined") => stmt"${strategy.column(name).token} IS NOT NULL" + case Property(Property(_, name), "nonEmpty") => stmt"${strategy.column(name).token} IS NOT NULL" + case Property(_, name) => strategy.column(name).token + } + implicit def entityTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Entity] = Tokenizer[Entity] { case Entity(name, _) => strategy.table(name).token } diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/ConflictSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/ConflictSpec.scala new file mode 100644 index 0000000000..bd5773b79d --- /dev/null +++ b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/ConflictSpec.scala @@ -0,0 +1,32 @@ +package io.getquill.context.sql.idiom + +import io.getquill.{ Spec, TestEntities } +import io.getquill.context.sql.SqlContext + +abstract class ConflictSpec[T <: SqlContext[_, _] with TestEntities](protected val ctx: T) extends Spec { + import ctx._ + + val e = TestEntity("s1", 1, 1, None) + + val ins = quote(query[TestEntity].insert(lift(e))) + val del = quote(query[TestEntity].delete) + + val `no target - ignore` = quote { + ins.onConflictIgnore + } + val `string target - ignore` = quote { + ins.onConflictIgnore("string_target") + } + val `cols target - ignore` = quote { + ins.onConflictIgnore(_.i) + } + val `no target - update` = quote { + ins.onConflictUpdate((t, e) => t.l -> (t.l + e.l) / 2, _.s -> _.s) + } + def `string target - update`(s: String) = quote { + ins.onConflictUpdate("string_target")((t, _) => t.s -> lift(s)) + } + val `cols target - update` = quote { + ins.onConflictUpdate(_.i, _.s)((t, _) => t.l -> (t.l + 1)) + } +} diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala index 872ca9e62d..f0047b17fa 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala @@ -1,15 +1,10 @@ package io.getquill.context.sql.idiom -import io.getquill.Spec -import io.getquill.Literal -import io.getquill.MySQLDialect -import io.getquill.SqlMirrorContext -import io.getquill.TestEntities +import io.getquill._ import io.getquill.idiom.StringToken -class MySQLDialectSpec extends Spec { +class MySQLDialectSpec extends ConflictSpec(new SqlMirrorContext(MySQLDialect, Literal) with TestEntities) { - val ctx = new SqlMirrorContext(MySQLDialect, Literal) with TestEntities import ctx._ "workaround for offset without limit" in { @@ -86,4 +81,36 @@ class MySQLDialectSpec extends Spec { ctx.run(q).string mustEqual "INSERT INTO TestEntity4 (i) VALUES (DEFAULT)" } + + "Conflict" - { + "no target - ignore" in { + intercept[IllegalStateException] { + ctx.run(`no target - ignore`.dynamic) + } + } + "string target - ignore" in { + intercept[IllegalStateException] { + ctx.run(`string target - ignore`.dynamic) + } + } + "cols target - ignore" in { + intercept[IllegalStateException] { + ctx.run(`cols target - ignore`.dynamic) + } + } + "no target - update" in { + ctx.run(`no target - update`).string mustEqual + "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE l = ((l + VALUES(l)) / 2), s = VALUES(s)" + } + "string target - update" in { + intercept[IllegalStateException] { + ctx.run(`string target - update`("123").dynamic) + } + } + "cols target - update" in { + intercept[IllegalStateException] { + ctx.run(`cols target - update`.dynamic) + } + } + } } diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/PostgresDialectSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/PostgresDialectSpec.scala index 40d56c6c8e..af0ca1acd9 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/PostgresDialectSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/PostgresDialectSpec.scala @@ -1,35 +1,32 @@ package io.getquill.context.sql.idiom -import io.getquill.Spec import io.getquill.PostgresDialect import io.getquill.SqlMirrorContext import io.getquill.Literal import io.getquill.TestEntities -class PostgresDialectSpec extends Spec { - - val context = new SqlMirrorContext(PostgresDialect, Literal) with TestEntities - import context._ +class PostgresDialectSpec extends ConflictSpec(new SqlMirrorContext(PostgresDialect, Literal) with TestEntities) { + import ctx._ "applies explicit casts" - { "toLong" in { val q = quote { qr1.map(t => t.s.toLong) } - context.run(q).string mustEqual "SELECT t.s::bigint FROM TestEntity t" + ctx.run(q).string mustEqual "SELECT t.s::bigint FROM TestEntity t" } "toInt" in { val q = quote { qr1.map(t => t.s.toInt) } - context.run(q).string mustEqual "SELECT t.s::integer FROM TestEntity t" + ctx.run(q).string mustEqual "SELECT t.s::integer FROM TestEntity t" } } "Array Operations" - { case class ArrayOps(id: Int, numbers: Vector[Int]) "contains" in { - context.run(query[ArrayOps].filter(_.numbers.contains(10))).string mustEqual + ctx.run(query[ArrayOps].filter(_.numbers.contains(10))).string mustEqual "SELECT x1.id, x1.numbers FROM ArrayOps x1 WHERE 10 = ANY(x1.numbers)" } } @@ -44,4 +41,34 @@ class PostgresDialectSpec extends Spec { prepareForProbing("INSERT INTO tb (x1,x2,x3) VALUES (?,?,?)") mustEqual s"PREPARE p${id + 2} AS INSERT INTO tb (x1,x2,x3) VALUES ($$1,$$2,$$3)" } + + "Conflict" - { + "no target - ignore" in { + ctx.run(`no target - ignore`).string mustEqual + "INSERT INTO TestEntity AS t (s,i,l,o) VALUES (?, ?, ?, ?) ON CONFLICT DO NOTHING" + } + "string target - ignore" in { + ctx.run(`string target - ignore`).string mustEqual + "INSERT INTO TestEntity AS t (s,i,l,o) VALUES (?, ?, ?, ?) ON CONFLICT string_target DO NOTHING" + } + "cols target - ignore" in { + ctx.run(`cols target - ignore`).string mustEqual + "INSERT INTO TestEntity AS t (s,i,l,o) VALUES (?, ?, ?, ?) ON CONFLICT (i) DO NOTHING" + } + "no target - update" in { + + intercept[IllegalStateException] { + ctx.run(`no target - update`.dynamic) + } + } + "string target - update" in { + def x = "123" + ctx.run(`string target - update`(x)).string mustEqual + "INSERT INTO TestEntity AS t (s,i,l,o) VALUES (?, ?, ?, ?) ON CONFLICT string_target DO UPDATE SET s = ?" + } + "cols target - update" in { + ctx.run(`cols target - update`).string mustEqual + "INSERT INTO TestEntity AS t (s,i,l,o) VALUES (?, ?, ?, ?) ON CONFLICT (i,s) DO UPDATE SET l = (t.l + 1)" + } + } } diff --git a/quill-sql/src/test/sql/mysql-schema.sql b/quill-sql/src/test/sql/mysql-schema.sql index eb2cfb618b..5cd5c5fab2 100644 --- a/quill-sql/src/test/sql/mysql-schema.sql +++ b/quill-sql/src/test/sql/mysql-schema.sql @@ -76,16 +76,17 @@ Create TABLE BooleanEncodingTestEntity( ); CREATE TABLE TestEntity( + i INTEGER PRIMARY KEY, s VARCHAR(255), - i INTEGER, l BIGINT, o INTEGER ); CREATE TABLE TestEntity2( - s VARCHAR(255), i INTEGER, - l BIGINT + s VARCHAR(255), + l BIGINT, + PRIMARY KEY (i, s) ); CREATE TABLE TestEntity3( @@ -119,34 +120,4 @@ CREATE TABLE Address( street VARCHAR(255), zip int, otherExtraInfo VARCHAR(255) -); - -CREATE TABLE Contact( - firstName VARCHAR(255), - lastName VARCHAR(255), - age int, - addressFk int, - extraInfo VARCHAR(255) -); - -CREATE TABLE Address( - id int, - street VARCHAR(255), - zip int, - otherExtraInfo VARCHAR(255) -); - -CREATE TABLE Contact( - firstName VARCHAR(255), - lastName VARCHAR(255), - age int, - addressFk int, - extraInfo VARCHAR(255) -); - -CREATE TABLE Address( - id int, - street VARCHAR(255), - zip int, - otherExtraInfo VARCHAR(255) -); +); \ No newline at end of file diff --git a/quill-sql/src/test/sql/postgres-schema.sql b/quill-sql/src/test/sql/postgres-schema.sql index a337b5f98f..8e8f959601 100644 --- a/quill-sql/src/test/sql/postgres-schema.sql +++ b/quill-sql/src/test/sql/postgres-schema.sql @@ -59,16 +59,17 @@ CREATE TABLE EncodingUUIDTestEntity( ); CREATE TABLE TestEntity( + i INTEGER PRIMARY KEY, s VARCHAR(255), - i INTEGER, l BIGINT, o INTEGER ); CREATE TABLE TestEntity2( - s VARCHAR(255), i INTEGER, - l BIGINT + s VARCHAR(255), + l BIGINT, + PRIMARY KEY (i, s) ); CREATE TABLE TestEntity3( diff --git a/quill-test/src/main/scala/hey/Hey.scala b/quill-test/src/main/scala/hey/Hey.scala new file mode 100644 index 0000000000..4f72040408 --- /dev/null +++ b/quill-test/src/main/scala/hey/Hey.scala @@ -0,0 +1,28 @@ +import io.getquill.{ PostgresDialect, SnakeCase, SqlMirrorContext } + +object Hey { + val ctx = new SqlMirrorContext(PostgresDialect, SnakeCase) + + import ctx._ + + case class Inner(hisAge: Int) extends Embedded + + case class MyTable(myId: Int, str: String, in: Inner) + + val e = MyTable(1, "hey", Inner(1)) + val q = quote { + querySchema[MyTable]("myy_table", _.str -> "my_str") + } + + run { + q.update(lift(e)) + } + + run { + q.insert(lift(e)) + } + run { + q.insert(lift(e)).onConflictUpdate(_.str)((t, e) => t.in.hisAge -> (e.myId + 1 + t.in.hisAge)) + } + +} \ No newline at end of file