From 7607e77c5e66b67467239a2b983bd012113ea81c Mon Sep 17 00:00:00 2001 From: mentegy Date: Thu, 12 Apr 2018 21:41:59 +0300 Subject: [PATCH] Add support of upsert (insert or update, on conflict) for Postgres and MySQL --- README.md | 69 ++++++++++++++++- .../async/mysql/MysqlAsyncContextSpec.scala | 5 ++ .../async/mysql/OnConflictAsyncSpec.scala | 38 ++++++++++ .../async/postgres/OnConflictAsyncSpec.scala | 37 +++++++++ .../postgres/PostgresAsyncContextSpec.scala | 5 ++ .../getquill/context/cassandra/CqlIdiom.scala | 2 +- .../main/scala/io/getquill/MirrorIdiom.scala | 35 +++++++++ .../src/main/scala/io/getquill/ast/Ast.scala | 14 ++++ .../io/getquill/ast/StatefulTransformer.scala | 35 ++++++++- .../getquill/ast/StatelessTransformer.scala | 22 +++++- .../main/scala/io/getquill/dsl/QueryDsl.scala | 29 +++++++ .../io/getquill/norm/NormalizeReturning.scala | 14 ++-- .../io/getquill/norm/RenameProperties.scala | 4 + .../io/getquill/quotation/FreeVariables.scala | 2 + .../io/getquill/quotation/Liftables.scala | 29 +++++-- .../scala/io/getquill/quotation/Parsing.scala | 32 +++++++- .../io/getquill/quotation/Unliftables.scala | 27 +++++-- .../ast/StatefulTransformerSpec.scala | 64 ++++++++++++++++ .../ast/StatelessTransformerSpec.scala | 39 ++++++++++ .../context/mirror/MirrorIdiomSpec.scala | 23 ++++++ .../mysql/FinagleMysqlContextSpec.scala | 5 ++ .../finagle/mysql/OnConflictFinagleSpec.scala | 38 ++++++++++ .../postgres/FinaglePostgresContextSpec.scala | 5 ++ .../postgres/OnConflictFinagleSpec.scala | 38 ++++++++++ .../jdbc/mysql/OnConflictJdbcSpec.scala | 35 +++++++++ .../jdbc/postgres/OnConflictJdbcSpec.scala | 35 +++++++++ .../context/orientdb/OrientDBIdiom.scala | 2 +- .../main/scala/io/getquill/MySQLDialect.scala | 60 +++++++++++---- .../scala/io/getquill/PostgresDialect.scala | 48 +++++++++++- .../getquill/context/sql/idiom/SqlIdiom.scala | 75 ++++++++++--------- .../getquill/context/sql/OnConflictSpec.scala | 65 ++++++++++++++++ .../context/sql/idiom/MySQLDialectSpec.scala | 29 +++++-- .../context/sql/idiom/OnConflictSpec.scala | 27 +++++++ .../sql/idiom/PostgresDialectSpec.scala | 33 ++++++-- quill-sql/src/test/sql/mysql-schema.sql | 2 +- quill-sql/src/test/sql/postgres-schema.sql | 2 +- 36 files changed, 931 insertions(+), 93 deletions(-) create mode 100644 quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/OnConflictAsyncSpec.scala create mode 100644 quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/OnConflictAsyncSpec.scala create mode 100644 quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/OnConflictFinagleSpec.scala create mode 100644 quill-finagle-postgres/src/test/scala/io/getquill/context/finagle/postgres/OnConflictFinagleSpec.scala create mode 100644 quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/OnConflictJdbcSpec.scala create mode 100644 quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/OnConflictJdbcSpec.scala create mode 100644 quill-sql/src/test/scala/io/getquill/context/sql/OnConflictSpec.scala create mode 100644 quill-sql/src/test/scala/io/getquill/context/sql/idiom/OnConflictSpec.scala diff --git a/README.md b/README.md index 01b01b6660..bcf8fbf7f5 100644 --- a/README.md +++ b/README.md @@ -703,7 +703,74 @@ val a = quote { ctx.run(a) // DELETE FROM Person WHERE name = '' ``` - + +### insert or update (upsert, conflict) + +Upsert is only supported by Postgres and MySQL + +#### Postgres +Ignore conflict +```scala +val a = quote { + query[Product].insert(_.id -> 1, _.sku -> 10).onConflictIgnore +} + +// INSERT INTO Product AS t (id,sku) VALUES (1, 10) ON CONFLICT DO NOTHING +``` + +Ignore conflict by explicitly setting conflict target +```scala +val a = quote { + query[Product].insert(_.id -> 1, _.sku -> 10).onConflictIgnore(_.id) +} + +// INSERT INTO Product AS t (id,sku) VALUES (1, 10) ON CONFLICT (id) DO NOTHING +``` + +Resolve conflict by updating existing row if needed. In `onConflictUpdate(target)((t, e) => assignment)`: `target` refers to +conflict target, `t` - to existing row and `e` - to excluded, e.g. row proposed for insert. +```scala +val a = quote { + query[Product] + .insert(_.id -> 1, _.sku -> 10) + .onConflictUpdate(_.id)((t, e) => t.sku -> (t.sku + e.sku)) +} + +// INSERT INTO Product AS t (id,sku) VALUES (1, 10) ON CONFLICT (id) DO UPDATE SET sku = (t.sku + EXCLUDED.sku) +``` + +#### MySQL + +Ignore any conflict, e.g. `insert ignore` +```scala +val a = quote { + query[Product].insert(_.id -> 1, _.sku -> 10).onConflictIgnore +} + +// INSERT IGNORE INTO Product (id,sku) VALUES (1, 10) +``` + +Ignore duplicate key conflict by explicitly setting it +```scala +val a = quote { + query[Product].insert(_.id -> 1, _.sku -> 10).onConflictIgnore(_.id) +} + +// INSERT INTO Product (id,sku) VALUES (1, 10) ON DUPLICATE KEY UPDATE id=id +``` + +Resolve duplicate key by updating existing row if needed. In `onConflictUpdate((t, e) => assignment)`: `t` refers to +existing row and `e` - to values, e.g. values proposed for insert. +```scala +val a = quote { + query[Product] + .insert(_.id -> 1, _.sku -> 10) + .onConflictUpdate((t, e) => t.sku -> (t.sku + e.sku)) +} + +// INSERT INTO Product (id,sku) VALUES (1, 10) ON DUPLICATE KEY UPDATE sku = (sku + VALUES(sku)) +``` + ## IO Monad Quill provides an IO monad that allows the user to express multiple computations and execute them separately. This mechanism is also known as a free monad, which provides a way of expressing computations as referentially-transparent values and isolates the unsafe IO operations into a single operation. For instance: diff --git a/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/MysqlAsyncContextSpec.scala b/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/MysqlAsyncContextSpec.scala index d596227d75..9bb3962169 100644 --- a/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/MysqlAsyncContextSpec.scala +++ b/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/MysqlAsyncContextSpec.scala @@ -45,4 +45,9 @@ class MysqlAsyncContextSpec extends Spec { } ctx.close } + + override protected def beforeAll(): Unit = { + await(testContext.run(qr1.delete)) + () + } } diff --git a/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/OnConflictAsyncSpec.scala b/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/OnConflictAsyncSpec.scala new file mode 100644 index 0000000000..d6f0b0bb9e --- /dev/null +++ b/quill-async-mysql/src/test/scala/io/getquill/context/async/mysql/OnConflictAsyncSpec.scala @@ -0,0 +1,38 @@ +package io.getquill.context.async.mysql + +import io.getquill.context.sql.OnConflictSpec + +import scala.concurrent.ExecutionContext.Implicits.global + +class OnConflictAsyncSpec extends OnConflictSpec { + val ctx = testContext + import ctx._ + + override protected def beforeAll(): Unit = { + await(ctx.run(qr1.delete)) + () + } + + "INSERT IGNORE" in { + import `onConflictIgnore`._ + await(ctx.run(testQuery1)) mustEqual res1 + await(ctx.run(testQuery2)) mustEqual res2 + await(ctx.run(testQuery3)) mustEqual res3 + } + + "ON DUPLICATE KEY UPDATE i=i " in { + import `onConflictIgnore(_.i)`._ + await(ctx.run(testQuery1)) mustEqual res1 + await(ctx.run(testQuery2)) mustEqual res2 + await(ctx.run(testQuery3)) mustEqual res3 + } + + "ON DUPLICATE KEY UPDATE ..." in { + import `onConflictUpdate((t, e) => ...)`._ + await(ctx.run(testQuery(e1))) mustEqual res1 + await(ctx.run(testQuery(e2))) mustEqual res2 + 1 + await(ctx.run(testQuery(e3))) mustEqual res3 + 1 + await(ctx.run(testQuery4)) mustEqual res4 + } +} + diff --git a/quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/OnConflictAsyncSpec.scala b/quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/OnConflictAsyncSpec.scala new file mode 100644 index 0000000000..e75cd8e8e5 --- /dev/null +++ b/quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/OnConflictAsyncSpec.scala @@ -0,0 +1,37 @@ +package io.getquill.context.async.postgres + +import io.getquill.context.sql.OnConflictSpec + +import scala.concurrent.ExecutionContext.Implicits.global + +class OnConflictAsyncSpec extends OnConflictSpec { + val ctx = testContext + import ctx._ + + override protected def beforeAll(): Unit = { + await(ctx.run(qr1.delete)) + () + } + + "ON CONFLICT DO NOTHING" in { + import `onConflictIgnore`._ + await(ctx.run(testQuery1)) mustEqual res1 + await(ctx.run(testQuery2)) mustEqual res2 + await(ctx.run(testQuery3)) mustEqual res3 + } + + "ON CONFLICT (i) DO NOTHING" in { + import `onConflictIgnore(_.i)`._ + await(ctx.run(testQuery1)) mustEqual res1 + await(ctx.run(testQuery2)) mustEqual res2 + await(ctx.run(testQuery3)) mustEqual res3 + } + + "ON CONFLICT (i) DO UPDATE ..." in { + import `onConflictUpdate(_.i)((t, e) => ...)`._ + await(ctx.run(testQuery(e1))) mustEqual res1 + await(ctx.run(testQuery(e2))) mustEqual res2 + await(ctx.run(testQuery(e3))) mustEqual res3 + await(ctx.run(testQuery4)) mustEqual res4 + } +} diff --git a/quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/PostgresAsyncContextSpec.scala b/quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/PostgresAsyncContextSpec.scala index ee3fa550ac..19c3f84e0b 100644 --- a/quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/PostgresAsyncContextSpec.scala +++ b/quill-async-postgres/src/test/scala/io/getquill/context/async/postgres/PostgresAsyncContextSpec.scala @@ -45,4 +45,9 @@ class PostgresAsyncContextSpec extends Spec { } ctx.close } + + override protected def beforeAll(): Unit = { + await(testContext.run(qr1.delete)) + () + } } diff --git a/quill-cassandra/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala b/quill-cassandra/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala index 3cf4147cbb..bf4f3fdaf7 100644 --- a/quill-cassandra/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala +++ b/quill-cassandra/src/main/scala/io/getquill/context/cassandra/CqlIdiom.scala @@ -42,7 +42,7 @@ trait CqlIdiom extends Idiom { case a: TraversableOperation => a.token case a @ ( _: Function | _: FunctionApply | _: Dynamic | _: OptionOperation | _: Block | - _: Val | _: Ordering | _: QuotedReference | _: If + _: Val | _: Ordering | _: QuotedReference | _: If | _: OnConflict.Excluded | _: OnConflict.Existing ) => fail(s"Invalid cql: '$a'") } diff --git a/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala b/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala index 501db682f1..88618f6070 100644 --- a/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala +++ b/quill-core/src/main/scala/io/getquill/MirrorIdiom.scala @@ -40,6 +40,8 @@ class MirrorIdiom extends Idiom { case ast: QuotedReference => ast.ast.token case ast: Lift => ast.token case ast: Assignment => ast.token + case ast: OnConflict.Excluded => ast.token + case ast: OnConflict.Existing => ast.token } implicit def ifTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[If] = Tokenizer[If] { @@ -181,12 +183,45 @@ class MirrorIdiom extends Idiom { case e => stmt"${e.name.token}" } + implicit val excludedTokenizer: Tokenizer[OnConflict.Excluded] = Tokenizer[OnConflict.Excluded] { + case OnConflict.Excluded(ident) => stmt"${ident.token}" + } + + implicit val existingTokenizer: Tokenizer[OnConflict.Existing] = Tokenizer[OnConflict.Existing] { + case OnConflict.Existing(ident) => stmt"${ident.token}" + } + implicit def actionTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[Action] = Tokenizer[Action] { case Update(query, assignments) => stmt"${query.token}.update(${assignments.token})" case Insert(query, assignments) => stmt"${query.token}.insert(${assignments.token})" case Delete(query) => stmt"${query.token}.delete" case Returning(query, alias, body) => stmt"${query.token}.returning((${alias.token}) => ${body.token})" case Foreach(query, alias, body) => stmt"${query.token}.foreach((${alias.token}) => ${body.token})" + case c: OnConflict => stmt"${c.token}" + } + + implicit def conflictTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[OnConflict] = { + + def targetProps(l: List[Property]) = l.map(p => Transform(p) { + case Ident(_) => Ident("_") + }) + + implicit val conflictTargetTokenizer = Tokenizer[OnConflict.Target] { + case OnConflict.NoTarget => stmt"" + case OnConflict.Properties(props) => stmt"(${targetProps(props).token})" + } + + val updateAssignsTokenizer = Tokenizer[Assignment] { + case Assignment(i, p, v) => + stmt"(${i.token}, e) => ${p.token} -> ${scopedTokenizer(v)}" + } + + Tokenizer[OnConflict] { + case OnConflict(i, t, OnConflict.Update(assign)) => + stmt"${i.token}.onConflictUpdate${t.token}(${assign.map(updateAssignsTokenizer.token).mkStmt()})" + case OnConflict(i, t, OnConflict.Ignore) => + stmt"${i.token}.onConflictIgnore${t.token}" + } } implicit def assignmentTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[Assignment] = Tokenizer[Assignment] { diff --git a/quill-core/src/main/scala/io/getquill/ast/Ast.scala b/quill-core/src/main/scala/io/getquill/ast/Ast.scala index 78504c67f2..d9ef61d16c 100644 --- a/quill-core/src/main/scala/io/getquill/ast/Ast.scala +++ b/quill-core/src/main/scala/io/getquill/ast/Ast.scala @@ -129,6 +129,20 @@ case class Returning(action: Ast, alias: Ident, property: Ast) extends Action case class Foreach(query: Ast, alias: Ident, body: Ast) extends Action +case class OnConflict(insert: Ast, target: OnConflict.Target, action: OnConflict.Action) extends Action +object OnConflict { + + case class Excluded(alias: Ident) extends Ast + case class Existing(alias: Ident) extends Ast + + sealed trait Target + case object NoTarget extends Target + case class Properties(props: List[Property]) extends Target + + sealed trait Action + case object Ignore extends Action + case class Update(assignments: List[Assignment]) extends Action +} //************************************************************ case class Dynamic(tree: Any) extends Ast diff --git a/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala b/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala index 53a221d0c7..49e2af26e1 100644 --- a/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala +++ b/quill-core/src/main/scala/io/getquill/ast/StatefulTransformer.scala @@ -14,15 +14,14 @@ trait StatefulTransformer[T] { case e: Ident => (e, this) case e: OptionOperation => apply(e) case e: TraversableOperation => apply(e) + case e: Property => apply(e) + case e: OnConflict.Existing => (e, this) + case e: OnConflict.Excluded => (e, this) case Function(a, b) => val (bt, btt) = apply(b) (Function(a, bt), btt) - case Property(a, b) => - val (at, att) = apply(a) - (Property(at, b), att) - case Infix(a, b) => val (bt, btt) = apply(b)(_.apply) (Infix(a, bt), btt) @@ -179,6 +178,13 @@ trait StatefulTransformer[T] { (Assignment(a, bt, ct), ctt) } + def apply(e: Property): (Property, StatefulTransformer[T]) = + e match { + case Property(a, b) => + val (at, att) = apply(a) + (Property(at, b), att) + } + def apply(e: Operation): (Operation, StatefulTransformer[T]) = e match { case UnaryOperation(o, a) => @@ -228,6 +234,27 @@ trait StatefulTransformer[T] { val (at, att) = apply(a) val (ct, ctt) = att.apply(c) (Foreach(at, b, ct), ctt) + case OnConflict(a, b, c) => + val (at, att) = apply(a) + val (bt, btt) = att.apply(b) + val (ct, ctt) = btt.apply(c) + (OnConflict(at, bt, ct), ctt) + } + + def apply(e: OnConflict.Target): (OnConflict.Target, StatefulTransformer[T]) = + e match { + case OnConflict.NoTarget => (e, this) + case OnConflict.Properties(a) => + val (at, att) = apply(a)(_.apply) + (OnConflict.Properties(at), att) + } + + def apply(e: OnConflict.Action): (OnConflict.Action, StatefulTransformer[T]) = + e match { + case OnConflict.Ignore => (e, this) + case OnConflict.Update(a) => + val (at, att) = apply(a)(_.apply) + (OnConflict.Update(at), att) } def apply[U, R](list: List[U])(f: StatefulTransformer[T] => U => (R, StatefulTransformer[T])) = diff --git a/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala b/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala index 5fbacee760..286f3ba46e 100644 --- a/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala +++ b/quill-core/src/main/scala/io/getquill/ast/StatelessTransformer.scala @@ -11,7 +11,7 @@ trait StatelessTransformer { case e: Assignment => apply(e) case Function(params, body) => Function(params, apply(body)) case e: Ident => e - case Property(a, name) => Property(apply(a), name) + case e: Property => apply(e) case Infix(a, b) => Infix(a, b.map(apply)) case e: OptionOperation => apply(e) case e: TraversableOperation => apply(e) @@ -22,6 +22,8 @@ trait StatelessTransformer { case Block(statements) => Block(statements.map(apply)) case Val(name, body) => Val(name, apply(body)) case o: Ordering => o + case e: OnConflict.Excluded => e + case e: OnConflict.Existing => e } def apply(o: OptionOperation): OptionOperation = @@ -72,6 +74,11 @@ trait StatelessTransformer { case Assignment(a, b, c) => Assignment(a, apply(b), apply(c)) } + def apply(e: Property): Property = + e match { + case Property(a, name) => Property(apply(a), name) + } + def apply(e: Operation): Operation = e match { case UnaryOperation(o, a) => UnaryOperation(o, apply(a)) @@ -97,6 +104,19 @@ trait StatelessTransformer { case Delete(query) => Delete(apply(query)) case Returning(query, alias, property) => Returning(apply(query), alias, apply(property)) case Foreach(query, alias, body) => Foreach(apply(query), alias, apply(body)) + case OnConflict(query, target, action) => OnConflict(apply(query), apply(target), apply(action)) + } + + def apply(e: OnConflict.Target): OnConflict.Target = + e match { + case OnConflict.NoTarget => e + case OnConflict.Properties(props) => OnConflict.Properties(props.map(apply)) + } + + def apply(e: OnConflict.Action): OnConflict.Action = + e match { + case OnConflict.Ignore => e + case OnConflict.Update(assigns) => OnConflict.Update(assigns.map(apply)) } } diff --git a/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala b/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala index 14ea7dd5a5..04635b9696 100644 --- a/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala +++ b/quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala @@ -91,6 +91,35 @@ private[dsl] trait QueryDsl { sealed trait Insert[E] extends Action[E] { @compileTimeOnly(NonQuotedException.message) def returning[R](f: E => R): ActionReturning[E, R] = NonQuotedException() + + @compileTimeOnly(NonQuotedException.message) + def onConflictIgnore: Insert[E] = NonQuotedException() + + @compileTimeOnly(NonQuotedException.message) + def onConflictIgnore(target: E => Any, targets: (E => Any)*): Insert[E] = NonQuotedException() + + @compileTimeOnly(NonQuotedException.message) + def onConflictUpdate(assign: ((E, E) => (Any, Any)), assigns: ((E, E) => (Any, Any))*): Insert[E] = NonQuotedException() + + /** + * Generates an atomic INSERT or UPDATE (upsert) action if supported. + * + * @param targets - conflict target + * @param assigns - update statement, declared as function: `(table, excluded) => (assign, result)` + * `table` - is used to extract column for update assignment and reference existing row + * `excluded` - aliases excluded table, e.g. row proposed for insertion. + * `assign` - left hand side of assignment. Should be accessed from `table` argument + * `result` - right hand side of assignment. + * + * Example usage: + * {{{ + * insert.onConflictUpdate(_.id)((t, e) => t.col -> (e.col + t.col)) + * }}} + * If insert statement violates conflict target then the column `col` of row will be updated with sum of + * existing value and and proposed `col` in insert. + */ + @compileTimeOnly(NonQuotedException.message) + def onConflictUpdate(target: E => Any, targets: (E => Any)*)(assign: ((E, E) => (Any, Any)), assigns: ((E, E) => (Any, Any))*): Insert[E] = NonQuotedException() } sealed trait ActionReturning[E, Output] extends Action[E] diff --git a/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala b/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala index 84f72914a6..9ddf1d8ccf 100644 --- a/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala +++ b/quill-core/src/main/scala/io/getquill/norm/NormalizeReturning.scala @@ -6,14 +6,18 @@ object NormalizeReturning { def apply(e: Action): Action = { e match { - case Returning(Insert(query, assignments), alias, body) => - Returning(Insert(query, filterReturnedColumn(assignments, body)), alias, body) - case Returning(Update(query, assignments), alias, body) => - Returning(Update(query, filterReturnedColumn(assignments, body)), alias, body) - case e => e + case Returning(a: Action, alias, body) => Returning(apply(a, body), alias, body) + case _ => e } } + private def apply(e: Action, body: Ast): Action = e match { + case Insert(query, assignments) => Insert(query, filterReturnedColumn(assignments, body)) + case Update(query, assignments) => Update(query, filterReturnedColumn(assignments, body)) + case OnConflict(a: Action, target, act) => OnConflict(apply(a, body), target, act) + case _ => e + } + private def filterReturnedColumn(assignments: List[Assignment], column: Ast): List[Assignment] = assignments.flatMap(filterReturnedColumn(_, column)) diff --git a/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala b/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala index c0f5fb995a..4ba11dfbf0 100644 --- a/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala +++ b/quill-core/src/main/scala/io/getquill/norm/RenameProperties.scala @@ -29,6 +29,10 @@ object RenameProperties extends StatelessTransformer { val bodyr = BetaReduction(body, replace: _*) (Returning(action, alias, bodyr), schema) } + case OnConflict(a: Action, target, act) => + applySchema(a) match { + case (action, schema) => (OnConflict(action, target, act), schema) + } case q => (q, Tuple(List.empty)) } diff --git a/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala b/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala index 3980a8a186..f9255c28c7 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/FreeVariables.scala @@ -51,6 +51,8 @@ case class FreeVariables(state: State) super.apply(other) } + override def apply(e: OnConflict.Target): (OnConflict.Target, StatefulTransformer[State]) = (e, this) + override def apply(query: Query): (Query, StatefulTransformer[State]) = query match { case q @ Filter(a, b, c) => (q, free(a, b, c)) diff --git a/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala b/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala index d95d06c8d4..b71716dcb8 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Liftables.scala @@ -21,9 +21,9 @@ trait Liftables { case ast: Assignment => assignmentLiftable(ast) case ast: OptionOperation => optionOperationLiftable(ast) case ast: TraversableOperation => traversableOperationLiftable(ast) + case ast: Property => propertyLiftable(ast) case Val(name, body) => q"$pack.Val($name, $body)" case Block(statements) => q"$pack.Block($statements)" - case Property(a, b) => q"$pack.Property($a, $b)" case Function(a, b) => q"$pack.Function($a, $b)" case FunctionApply(a, b) => q"$pack.FunctionApply($a, $b)" case BinaryOperation(a, b, c) => q"$pack.BinaryOperation($a, $b, $c)" @@ -33,6 +33,8 @@ trait Liftables { case Dynamic(tree: Tree) if (tree.tpe <:< c.weakTypeOf[CoreDsl#Quoted[Any]]) => q"$tree.ast" case Dynamic(tree: Tree) => q"$pack.Constant($tree)" case QuotedReference(tree: Tree, ast) => q"$ast" + case OnConflict.Excluded(a) => q"$pack.OnConflict.Excluded($a)" + case OnConflict.Existing(a) => q"$pack.OnConflict.Existing($a)" } implicit val optionOperationLiftable: Liftable[OptionOperation] = Liftable[OptionOperation] { @@ -116,6 +118,10 @@ trait Liftables { case PropertyAlias(a, b) => q"$pack.PropertyAlias($a, $b)" } + implicit val propertyLiftable: Liftable[Property] = Liftable[Property] { + case Property(a, b) => q"$pack.Property($a, $b)" + } + implicit val orderingLiftable: Liftable[Ordering] = Liftable[Ordering] { case TupleOrdering(elems) => q"$pack.TupleOrdering($elems)" case Asc => q"$pack.Asc" @@ -134,11 +140,22 @@ trait Liftables { } implicit val actionLiftable: Liftable[Action] = Liftable[Action] { - case Update(a, b) => q"$pack.Update($a, $b)" - case Insert(a, b) => q"$pack.Insert($a, $b)" - case Delete(a) => q"$pack.Delete($a)" - case Returning(a, b, c) => q"$pack.Returning($a, $b, $c)" - case Foreach(a, b, c) => q"$pack.Foreach($a, $b, $c)" + case Update(a, b) => q"$pack.Update($a, $b)" + case Insert(a, b) => q"$pack.Insert($a, $b)" + case Delete(a) => q"$pack.Delete($a)" + case Returning(a, b, c) => q"$pack.Returning($a, $b, $c)" + case Foreach(a, b, c) => q"$pack.Foreach($a, $b, $c)" + case OnConflict(a, b, c) => q"$pack.OnConflict($a, $b, $c)" + } + + implicit val conflictTargetLiftable: Liftable[OnConflict.Target] = Liftable[OnConflict.Target] { + case OnConflict.NoTarget => q"$pack.OnConflict.NoTarget" + case OnConflict.Properties(a) => q"$pack.OnConflict.Properties.apply($a)" + } + + implicit val conflictActionLiftable: Liftable[OnConflict.Action] = Liftable[OnConflict.Action] { + case OnConflict.Ignore => q"$pack.OnConflict.Ignore" + case OnConflict.Update(a) => q"$pack.OnConflict.Update.apply($a)" } implicit val assignmentLiftable: Liftable[Assignment] = Liftable[Assignment] { diff --git a/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala b/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala index 2f7ab46e31..8163050a45 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Parsing.scala @@ -41,6 +41,7 @@ trait Parsing { case `quotedAstParser`(value) => value case `functionParser`(value) => value case `actionParser`(value) => value + case `conflictParser`(value) => value case `infixParser`(value) => value case `orderingParser`(value) => value case `operationParser`(value) => value @@ -323,7 +324,7 @@ trait Parsing { ListContains(astParser(col), astParser(body)) } - val propertyParser: Parser[Ast] = Parser[Ast] { + val propertyParser: Parser[Property] = Parser[Property] { case q"$e.get" if is[Option[Any]](e) => c.fail("Option.get is not supported since it's an unsafe operation. Use `forall` or `exists` instead.") case q"$e.$property" => Property(astParser(e), property.decodedName.toString) @@ -545,10 +546,39 @@ trait Parsing { checkTypes(prop, value) Assignment(i1, astParser(prop), astParser(value)) + case q"((${ identParser(i1) }, ${ identParser(i2) }) => $pack.Predef.ArrowAssoc[$t]($prop).$arrow[$v]($value))" => + checkTypes(prop, value) + val valueAst = Transform(astParser(value)) { + case `i1` => OnConflict.Existing(i1) + case `i2` => OnConflict.Excluded(i2) + } + Assignment(i1, astParser(prop), valueAst) // Unused, it's here only to make eclipse's presentation compiler happy case astParser(ast) => Assignment(Ident("unused"), Ident("unused"), Constant("unused")) } + val conflictParser: Parser[Ast] = Parser[Ast] { + case q"$query.onConflictIgnore" => + OnConflict(astParser(query), OnConflict.NoTarget, OnConflict.Ignore) + case q"$query.onConflictIgnore(..$targets)" => + OnConflict(astParser(query), parseConflictProps(targets), OnConflict.Ignore) + + case q"$query.onConflictUpdate(..$assigns)" => + OnConflict(astParser(query), OnConflict.NoTarget, parseConflictAssigns(assigns)) + case q"$query.onConflictUpdate(..$targets)(..$assigns)" => + OnConflict(astParser(query), parseConflictProps(targets), parseConflictAssigns(assigns)) + } + + private def parseConflictProps(targets: List[Tree]) = OnConflict.Properties { + targets.map { + case q"($e) => $prop" => propertyParser(prop) + case tree => c.fail(s"Tree '$tree' can't be parsed as conflict target") + } + } + + private def parseConflictAssigns(targets: List[Tree]) = + OnConflict.Update(targets.map(assignmentParser(_))) + private def checkTypes(lhs: Tree, rhs: Tree): Unit = { def unquoted(tree: Tree) = is[CoreDsl#Quoted[Any]](tree) match { diff --git a/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala b/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala index f815409c80..c2979bcb62 100644 --- a/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala +++ b/quill-core/src/main/scala/io/getquill/quotation/Unliftables.scala @@ -18,7 +18,7 @@ trait Unliftables { case orderingUnliftable(ast) => ast case optionOperationUnliftable(ast) => ast case traversableOperationUnliftable(ast) => ast - case q"$pack.Property.apply(${ a: Ast }, ${ b: String })" => Property(a, b) + case propertyUnliftable(ast) => ast case q"$pack.Function.apply(${ a: List[Ident] }, ${ b: Ast })" => Function(a, b) case q"$pack.FunctionApply.apply(${ a: Ast }, ${ b: List[Ast] })" => FunctionApply(a, b) case q"$pack.BinaryOperation.apply(${ a: Ast }, ${ b: BinaryOperator }, ${ c: Ast })" => BinaryOperation(a, b, c) @@ -26,6 +26,8 @@ trait Unliftables { case q"$pack.Aggregation.apply(${ a: AggregationOperator }, ${ b: Ast })" => Aggregation(a, b) case q"$pack.Infix.apply(${ a: List[String] }, ${ b: List[Ast] })" => Infix(a, b) case q"$pack.If.apply(${ a: Ast }, ${ b: Ast }, ${ c: Ast })" => If(a, b, c) + case q"$pack.OnConflict.Excluded.apply(${ a: Ident })" => OnConflict.Excluded(a) + case q"$pack.OnConflict.Existing.apply(${ a: Ident })" => OnConflict.Existing(a) case q"$tree.ast" => Dynamic(tree) } @@ -126,6 +128,10 @@ trait Unliftables { case q"$pack.PropertyAlias.apply(${ a: List[String] }, ${ b: String })" => PropertyAlias(a, b) } + implicit val propertyUnliftable: Unliftable[Property] = Unliftable[Property] { + case q"$pack.Property.apply(${ a: Ast }, ${ b: String })" => Property(a, b) + } + implicit def optionUnliftable[T](implicit u: Unliftable[T]): Unliftable[Option[T]] = Unliftable[Option[T]] { case q"scala.None" => None case q"scala.Some.apply[$t]($v)" => Some(u.unapply(v).get) @@ -139,11 +145,22 @@ trait Unliftables { } implicit val actionUnliftable: Unliftable[Action] = Unliftable[Action] { - case q"$pack.Update.apply(${ a: Ast }, ${ b: List[Assignment] })" => Update(a, b) - case q"$pack.Insert.apply(${ a: Ast }, ${ b: List[Assignment] })" => Insert(a, b) - case q"$pack.Delete.apply(${ a: Ast })" => Delete(a) + case q"$pack.Update.apply(${ a: Ast }, ${ b: List[Assignment] })" => Update(a, b) + case q"$pack.Insert.apply(${ a: Ast }, ${ b: List[Assignment] })" => Insert(a, b) + case q"$pack.Delete.apply(${ a: Ast })" => Delete(a) case q"$pack.Returning.apply(${ a: Ast }, ${ b: Ident }, ${ c: Ast })" => Returning(a, b, c) - case q"$pack.Foreach.apply(${ a: Ast }, ${ b: Ident }, ${ c: Ast })" => Foreach(a, b, c) + case q"$pack.Foreach.apply(${ a: Ast }, ${ b: Ident }, ${ c: Ast })" => Foreach(a, b, c) + case q"$pack.OnConflict.apply(${ a: Ast }, ${ b: OnConflict.Target }, ${ c: OnConflict.Action })" => OnConflict(a, b, c) + } + + implicit val conflictTargetUnliftable: Unliftable[OnConflict.Target] = Unliftable[OnConflict.Target] { + case q"$pack.OnConflict.NoTarget" => OnConflict.NoTarget + case q"$pack.OnConflict.Properties.apply(${ a: List[Property] })" => OnConflict.Properties(a) + } + + implicit val conflictActionUnliftable: Unliftable[OnConflict.Action] = Unliftable[OnConflict.Action] { + case q"$pack.OnConflict.Ignore" => OnConflict.Ignore + case q"$pack.OnConflict.Update.apply(${ a: List[Assignment] })" => OnConflict.Update(a) } implicit val assignmentUnliftable: Unliftable[Assignment] = Unliftable[Assignment] { diff --git a/quill-core/src/test/scala/io/getquill/ast/StatefulTransformerSpec.scala b/quill-core/src/test/scala/io/getquill/ast/StatefulTransformerSpec.scala index 688496102c..c2d14319f7 100644 --- a/quill-core/src/test/scala/io/getquill/ast/StatefulTransformerSpec.scala +++ b/quill-core/src/test/scala/io/getquill/ast/StatefulTransformerSpec.scala @@ -216,6 +216,70 @@ class StatefulTransformerSpec extends Spec { att.state mustEqual List(Ident("a")) } } + "onConflict" in { + val ast: Ast = OnConflict(Insert(Ident("a"), Nil), OnConflict.NoTarget, OnConflict.Ignore) + Subject(Nil, Ident("a") -> Ident("a'"))(ast) match { + case (at, att) => + at mustEqual OnConflict(Insert(Ident("a'"), Nil), OnConflict.NoTarget, OnConflict.Ignore) + att.state mustEqual List(Ident("a")) + } + } + } + + "onConflict.target" - { + "no" in { + val target: OnConflict.Target = OnConflict.NoTarget + Subject(Nil)(target) match { + case (at, att) => + at mustEqual target + att.state mustEqual Nil + } + } + "properties" in { + val target: OnConflict.Target = OnConflict.Properties(List(Property(Ident("a"), "b"))) + Subject(Nil, Ident("a") -> Ident("a'"))(target) match { + case (at, att) => + at mustEqual OnConflict.Properties(List(Property(Ident("a'"), "b"))) + att.state mustEqual List(Ident("a")) + } + } + } + + "onConflict.action" - { + "ignore" in { + val action: OnConflict.Action = OnConflict.Ignore + Subject(Nil)(action) match { + case (at, att) => + at mustEqual action + att.state mustEqual Nil + } + } + "update" in { + val action: OnConflict.Action = OnConflict.Update(List(Assignment(Ident("a"), Ident("b"), Ident("c")))) + Subject(Nil, Ident("a") -> Ident("a'"), Ident("b") -> Ident("b'"), Ident("c") -> Ident("c'"))(action) match { + case (at, att) => + at mustEqual OnConflict.Update(List(Assignment(Ident("a"), Ident("b'"), Ident("c'")))) + att.state mustEqual List(Ident("b"), Ident("c")) + } + } + } + + "onConflict.excluded" in { + val ast: Ast = OnConflict.Excluded(Ident("a")) + Subject(Nil)(ast) match { + case (at, att) => + at mustEqual ast + att.state mustEqual Nil + } + } + + "onConflict.existing" in { + val ast: Ast = OnConflict.Existing(Ident("a")) + Subject(Nil)(ast) match { + case (at, att) => + at mustEqual ast + att.state mustEqual Nil + } } "function" in { diff --git a/quill-core/src/test/scala/io/getquill/ast/StatelessTransformerSpec.scala b/quill-core/src/test/scala/io/getquill/ast/StatelessTransformerSpec.scala index 5e31e5eeec..2a341bbcbf 100644 --- a/quill-core/src/test/scala/io/getquill/ast/StatelessTransformerSpec.scala +++ b/quill-core/src/test/scala/io/getquill/ast/StatelessTransformerSpec.scala @@ -148,6 +148,45 @@ class StatelessTransformerSpec extends Spec { Subject(Ident("a") -> Ident("a'"))(ast) mustEqual Delete(Ident("a'")) } + "onConflict" in { + val ast: Ast = OnConflict(Insert(Ident("a"), Nil), OnConflict.NoTarget, OnConflict.Ignore) + Subject(Ident("a") -> Ident("a'"))(ast) mustEqual + OnConflict(Insert(Ident("a'"), Nil), OnConflict.NoTarget, OnConflict.Ignore) + } + } + + "onConflict.target" - { + "no" in { + val target: OnConflict.Target = OnConflict.NoTarget + Subject()(target) mustEqual target + } + "properties" in { + val target: OnConflict.Target = OnConflict.Properties(List(Property(Ident("a"), "b"))) + Subject(Ident("a") -> Ident("a'"))(target) mustEqual + OnConflict.Properties(List(Property(Ident("a'"), "b"))) + } + } + + "onConflict.action" - { + "ignore" in { + val action: OnConflict.Action = OnConflict.Ignore + Subject()(action) mustEqual action + } + "update" in { + val action: OnConflict.Action = OnConflict.Update(List(Assignment(Ident("a"), Ident("b"), Ident("c")))) + Subject(Ident("a") -> Ident("a'"), Ident("b") -> Ident("b'"), Ident("c") -> Ident("c'"))(action) mustEqual + OnConflict.Update(List(Assignment(Ident("a"), Ident("b'"), Ident("c'")))) + } + } + + "onConflict.excluded" in { + val ast: Ast = OnConflict.Excluded(Ident("a")) + Subject()(ast) mustEqual ast + } + + "onConflict.existing" in { + val ast: Ast = OnConflict.Existing(Ident("a")) + Subject()(ast) mustEqual ast } "function" in { diff --git a/quill-core/src/test/scala/io/getquill/context/mirror/MirrorIdiomSpec.scala b/quill-core/src/test/scala/io/getquill/context/mirror/MirrorIdiomSpec.scala index b184f7232c..e7a4be4d2b 100644 --- a/quill-core/src/test/scala/io/getquill/context/mirror/MirrorIdiomSpec.scala +++ b/quill-core/src/test/scala/io/getquill/context/mirror/MirrorIdiomSpec.scala @@ -447,6 +447,29 @@ class MirrorIdiomSpec extends Spec { stmt"${(q.ast: Ast).token}" mustEqual stmt"""querySchema("TestEntity").delete""" } + + "onConflict" - { + val i = quote { + query[TestEntity].insert(t => t.s -> "a") + } + val t = stmt"""querySchema("TestEntity").insert(t => t.s -> "a")""" + "onConflictIgnore" in { + stmt"${(i.onConflictIgnore.ast: Ast).token}" mustEqual + stmt"$t.onConflictIgnore" + } + "onConflictIgnore(targets*)" in { + stmt"${(i.onConflictIgnore(_.i, _.s).ast: Ast).token}" mustEqual + stmt"$t.onConflictIgnore(_.i, _.s)" + } + "onConflictUpdate(assigns*)" in { + stmt"${(i.onConflictUpdate((t, e) => t.s -> e.s, (t, e) => t.i -> (t.i + 1)).ast: Ast).token}" mustEqual + stmt"$t.onConflictUpdate((t, e) => t.s -> e.s, (t, e) => t.i -> (t.i + 1))" + } + "onConflictUpdate(targets*)(assigns*)" in { + stmt"${(i.onConflictUpdate(_.i)((t, e) => t.s -> e.s).ast: Ast).token}" mustEqual + stmt"$t.onConflictUpdate(_.i)((t, e) => t.s -> e.s)" + } + } } "shows infix" - { diff --git a/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/FinagleMysqlContextSpec.scala b/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/FinagleMysqlContextSpec.scala index ebae55fac3..3fbcb2d6ae 100644 --- a/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/FinagleMysqlContextSpec.scala +++ b/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/FinagleMysqlContextSpec.scala @@ -53,4 +53,9 @@ class FinagleMysqlContextSpec extends Spec { intercept[IllegalStateException](ctx.toOk(Error(-1, "no ok", "test"))) ctx.close } + + override protected def beforeAll(): Unit = { + await(testContext.run(qr1.delete)) + () + } } diff --git a/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/OnConflictFinagleSpec.scala b/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/OnConflictFinagleSpec.scala new file mode 100644 index 0000000000..a21d24a05e --- /dev/null +++ b/quill-finagle-mysql/src/test/scala/io/getquill/context/finagle/mysql/OnConflictFinagleSpec.scala @@ -0,0 +1,38 @@ +package io.getquill.context.finagle.mysql + +import com.twitter.util.{ Await, Future } +import io.getquill.context.sql.OnConflictSpec + +class OnConflictFinagleSpec extends OnConflictSpec { + val ctx = testContext + import ctx._ + + def await[T](future: Future[T]) = Await.result(future) + + override protected def beforeAll(): Unit = { + await(ctx.run(qr1.delete)) + () + } + + "INSERT IGNORE" in { + import `onConflictIgnore`._ + await(ctx.run(testQuery1)) mustEqual res1 + await(ctx.run(testQuery2)) mustEqual res2 + await(ctx.run(testQuery3)) mustEqual res3 + } + + "ON DUPLICATE KEY UPDATE i=i " in { + import `onConflictIgnore(_.i)`._ + await(ctx.run(testQuery1)) mustEqual res1 + await(ctx.run(testQuery2)) mustEqual res2 + 1 + await(ctx.run(testQuery3)) mustEqual res3 + } + + "ON DUPLICATE KEY UPDATE ..." in { + import `onConflictUpdate((t, e) => ...)`._ + await(ctx.run(testQuery(e1))) mustEqual res1 + await(ctx.run(testQuery(e2))) mustEqual res2 + 1 + await(ctx.run(testQuery(e3))) mustEqual res3 + 1 + await(ctx.run(testQuery4)) mustEqual res4 + } +} \ No newline at end of file diff --git a/quill-finagle-postgres/src/test/scala/io/getquill/context/finagle/postgres/FinaglePostgresContextSpec.scala b/quill-finagle-postgres/src/test/scala/io/getquill/context/finagle/postgres/FinaglePostgresContextSpec.scala index ea7d55cead..903de82194 100644 --- a/quill-finagle-postgres/src/test/scala/io/getquill/context/finagle/postgres/FinaglePostgresContextSpec.scala +++ b/quill-finagle-postgres/src/test/scala/io/getquill/context/finagle/postgres/FinaglePostgresContextSpec.scala @@ -34,4 +34,9 @@ class FinaglePostgresContextSpec extends Spec { ctx.probe("select 1").toOption mustBe defined ctx.close } + + override protected def beforeAll(): Unit = { + await(testContext.run(qr1.delete)) + () + } } diff --git a/quill-finagle-postgres/src/test/scala/io/getquill/context/finagle/postgres/OnConflictFinagleSpec.scala b/quill-finagle-postgres/src/test/scala/io/getquill/context/finagle/postgres/OnConflictFinagleSpec.scala new file mode 100644 index 0000000000..dd041af835 --- /dev/null +++ b/quill-finagle-postgres/src/test/scala/io/getquill/context/finagle/postgres/OnConflictFinagleSpec.scala @@ -0,0 +1,38 @@ +package io.getquill.context.finagle.postgres + +import com.twitter.util.{ Await, Future } +import io.getquill.context.sql.OnConflictSpec + +class OnConflictFinagleSpec extends OnConflictSpec { + val ctx = testContext + import ctx._ + + def await[T](future: Future[T]) = Await.result(future) + + override protected def beforeAll(): Unit = { + await(ctx.run(qr1.delete)) + () + } + + "ON CONFLICT DO NOTHING" in { + import `onConflictIgnore`._ + await(ctx.run(testQuery1)) mustEqual res1 + await(ctx.run(testQuery2)) mustEqual res2 + await(ctx.run(testQuery3)) mustEqual res3 + } + + "ON CONFLICT (i) DO NOTHING" in { + import `onConflictIgnore(_.i)`._ + await(ctx.run(testQuery1)) mustEqual res1 + await(ctx.run(testQuery2)) mustEqual res2 + await(ctx.run(testQuery3)) mustEqual res3 + } + + "ON CONFLICT (i) DO UPDATE ..." in { + import `onConflictUpdate(_.i)((t, e) => ...)`._ + await(ctx.run(testQuery(e1))) mustEqual res1 + await(ctx.run(testQuery(e2))) mustEqual res2 + await(ctx.run(testQuery(e3))) mustEqual res3 + await(ctx.run(testQuery4)) mustEqual res4 + } +} diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/OnConflictJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/OnConflictJdbcSpec.scala new file mode 100644 index 0000000000..ff80565363 --- /dev/null +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/mysql/OnConflictJdbcSpec.scala @@ -0,0 +1,35 @@ +package io.getquill.context.jdbc.mysql + +import io.getquill.context.sql.OnConflictSpec + +class OnConflictJdbcSpec extends OnConflictSpec { + val ctx = testContext + import ctx._ + + override protected def beforeAll(): Unit = { + ctx.run(qr1.delete) + () + } + + "INSERT IGNORE" in { + import `onConflictIgnore`._ + ctx.run(testQuery1) mustEqual res1 + ctx.run(testQuery2) mustEqual res2 + ctx.run(testQuery3) mustEqual res3 + } + + "ON DUPLICATE KEY UPDATE i=i " in { + import `onConflictIgnore(_.i)`._ + ctx.run(testQuery1) mustEqual res1 + ctx.run(testQuery2) mustEqual res2 + 1 + ctx.run(testQuery3) mustEqual res3 + } + + "ON DUPLICATE KEY UPDATE ..." in { + import `onConflictUpdate((t, e) => ...)`._ + ctx.run(testQuery(e1)) mustEqual res1 + ctx.run(testQuery(e2)) mustEqual res2 + 1 + ctx.run(testQuery(e3)) mustEqual res3 + 1 + ctx.run(testQuery4) mustEqual res4 + } +} diff --git a/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/OnConflictJdbcSpec.scala b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/OnConflictJdbcSpec.scala new file mode 100644 index 0000000000..81ab984d81 --- /dev/null +++ b/quill-jdbc/src/test/scala/io/getquill/context/jdbc/postgres/OnConflictJdbcSpec.scala @@ -0,0 +1,35 @@ +package io.getquill.context.jdbc.postgres + +import io.getquill.context.sql.OnConflictSpec + +class OnConflictJdbcSpec extends OnConflictSpec { + val ctx = testContext + import ctx._ + + override protected def beforeAll(): Unit = { + ctx.run(qr1.delete) + () + } + + "ON CONFLICT DO NOTHING" in { + import `onConflictIgnore`._ + ctx.run(testQuery1) mustEqual res1 + ctx.run(testQuery2) mustEqual res2 + ctx.run(testQuery3) mustEqual res3 + } + + "ON CONFLICT (i) DO NOTHING" in { + import `onConflictIgnore(_.i)`._ + ctx.run(testQuery1) mustEqual res1 + ctx.run(testQuery2) mustEqual res2 + ctx.run(testQuery3) mustEqual res3 + } + + "ON CONFLICT (i) DO UPDATE ..." in { + import `onConflictUpdate(_.i)((t, e) => ...)`._ + ctx.run(testQuery(e1)) mustEqual res1 + ctx.run(testQuery(e2)) mustEqual res2 + ctx.run(testQuery(e3)) mustEqual res3 + ctx.run(testQuery4) mustEqual res4 + } +} diff --git a/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala b/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala index ab7a22d4c5..fdcdd55b25 100644 --- a/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala +++ b/quill-orientdb/src/main/scala/io/getquill/context/orientdb/OrientDBIdiom.scala @@ -60,7 +60,7 @@ trait OrientDBIdiom extends Idiom { a.token case a @ ( _: Function | _: FunctionApply | _: Dynamic | _: OptionOperation | _: Block | - _: Val | _: Ordering | _: QuotedReference | _: TraversableOperation + _: Val | _: Ordering | _: QuotedReference | _: TraversableOperation | _: OnConflict.Excluded | _: OnConflict.Existing ) => fail(s"Malformed or unsupported construct: $a.") } diff --git a/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala b/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala index f1b94f566f..7d9f0cfc2d 100644 --- a/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala +++ b/quill-sql/src/main/scala/io/getquill/MySQLDialect.scala @@ -1,21 +1,11 @@ package io.getquill -import io.getquill.idiom.StatementInterpolator._ -import io.getquill.ast.Asc -import io.getquill.ast.AscNullsFirst -import io.getquill.ast.AscNullsLast -import io.getquill.ast.BinaryOperation -import io.getquill.ast.Desc -import io.getquill.ast.DescNullsFirst -import io.getquill.ast.DescNullsLast -import io.getquill.ast.Operation -import io.getquill.ast.StringOperator -import io.getquill.context.sql.idiom.SqlIdiom +import io.getquill.ast.{ Ast, _ } import io.getquill.context.sql.OrderByCriteria -import io.getquill.context.sql.idiom.QuestionMarkBindVariables +import io.getquill.context.sql.idiom.{ NoConcatSupport, QuestionMarkBindVariables, SqlIdiom } +import io.getquill.idiom.StatementInterpolator._ import io.getquill.idiom.{ Statement, Token } -import io.getquill.ast.Ast -import io.getquill.context.sql.idiom.NoConcatSupport +import io.getquill.util.Messages.fail trait MySQLDialect extends SqlIdiom @@ -29,6 +19,48 @@ trait MySQLDialect override def defaultAutoGeneratedToken(field: Token) = stmt"($field) VALUES (DEFAULT)" + override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Ast] = + Tokenizer[Ast] { + case c: OnConflict => c.token + case ast => super.astTokenizer.token(ast) + } + + implicit def conflictTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[OnConflict] = { + import OnConflict._ + + lazy val insertIgnoreTokenizer = + Tokenizer[Entity] { + case Entity(name, _) => stmt"IGNORE INTO ${strategy.table(name).token}" + } + + def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = + Tokenizer[OnConflict] { + case OnConflict(i, NoTarget, Update(a)) => + stmt"${i.token} ON DUPLICATE KEY UPDATE ${a.token}" + + case OnConflict(i, Properties(p), Ignore) => + val assignments = p + .map(p => astTokenizer.token(p)) + .map(t => stmt"$t=$t") + .mkStmt(",") + stmt"${i.token} ON DUPLICATE KEY UPDATE $assignments" + + case OnConflict(i: io.getquill.ast.Action, NoTarget, Ignore) => + actionTokenizer(insertIgnoreTokenizer)(actionAstTokenizer, strategy).token(i) + + case _ => + fail("This upsert construct is not supported in MySQL. Please refer documentation for details.") + } + + val customAstTokenizer = + Tokenizer.withFallback[Ast](MySQLDialect.this.astTokenizer(_, strategy)) { + case Property(Excluded(_), name) => stmt"VALUES(${strategy.column(name).token})" + case Property(_, name) => strategy.column(name).token + } + + tokenizer(customAstTokenizer) + } + override implicit def operationTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Operation] = Tokenizer[Operation] { case BinaryOperation(a, StringOperator.`+`, b) => stmt"CONCAT(${a.token}, ${b.token})" diff --git a/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala b/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala index 11952a9b87..afa7d378ba 100644 --- a/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala +++ b/quill-sql/src/main/scala/io/getquill/PostgresDialect.scala @@ -3,9 +3,10 @@ package io.getquill import java.util.concurrent.atomic.AtomicInteger import io.getquill.ast._ -import io.getquill.context.sql.idiom.{ QuestionMarkBindVariables, SqlIdiom } +import io.getquill.context.sql.idiom.{ ConcatSupport, QuestionMarkBindVariables, SqlIdiom } import io.getquill.idiom.StatementInterpolator._ -import io.getquill.context.sql.idiom.ConcatSupport +import io.getquill.idiom.Token +import io.getquill.util.Messages.fail trait PostgresDialect extends SqlIdiom @@ -15,9 +16,52 @@ trait PostgresDialect override def astTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Ast] = Tokenizer[Ast] { case ListContains(ast, body) => stmt"${body.token} = ANY(${ast.token})" + case c: OnConflict => conflictTokenizer.token(c) case ast => super.astTokenizer.token(ast) } + implicit def conflictTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[OnConflict] = { + + val customEntityTokenizer = Tokenizer[Entity] { + case Entity(name, _) => stmt"INTO ${strategy.table(name).token} AS t" + } + + val customAstTokenizer = + Tokenizer.withFallback[Ast](PostgresDialect.this.astTokenizer(_, strategy)) { + case _: OnConflict.Excluded => stmt"EXCLUDED" + case OnConflict.Existing(a) => stmt"${a.token}" + case a: Action => super.actionTokenizer(customEntityTokenizer)(actionAstTokenizer, strategy).token(a) + } + + import OnConflict._ + + def doUpdateStmt(i: Token, t: Token, u: Update) = { + val assignments = u.assignments + .map(a => stmt"${actionAstTokenizer.token(a.property)} = ${scopedTokenizer(a.value)(customAstTokenizer)}") + .mkStmt() + + stmt"$i ON CONFLICT $t DO UPDATE SET $assignments" + } + + def doNothingStmt(i: Ast, t: Token) = stmt"${i.token} ON CONFLICT $t DO NOTHING" + + implicit val conflictTargetPropsTokenizer = + Tokenizer[Properties] { + case OnConflict.Properties(props) => stmt"(${props.map(n => strategy.column(n.name)).mkStmt(",")})" + } + + def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = + Tokenizer[OnConflict] { + case OnConflict(_, NoTarget, _: Update) => fail("'DO UPDATE' statement requires explicit conflict target") + case OnConflict(i, p: Properties, u: Update) => doUpdateStmt(i.token, p.token, u) + + case OnConflict(i, NoTarget, Ignore) => stmt"${astTokenizer.token(i)} ON CONFLICT DO NOTHING" + case OnConflict(i, p: Properties, Ignore) => doNothingStmt(i, p.token) + } + + tokenizer(customAstTokenizer) + } + override implicit def operationTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Operation] = Tokenizer[Operation] { case UnaryOperation(StringOperator.`toLong`, ast) => stmt"${scopedTokenizer(ast)}::bigint" diff --git a/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala b/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala index 67c17e9209..5b01e0796a 100644 --- a/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala +++ b/quill-sql/src/main/scala/io/getquill/context/sql/idiom/SqlIdiom.scala @@ -5,9 +5,7 @@ import io.getquill.ast.BooleanOperator._ import io.getquill.ast.Lift import io.getquill.context.sql._ import io.getquill.context.sql.norm._ -import io.getquill.idiom.Idiom -import io.getquill.idiom.SetContainsToken -import io.getquill.idiom.Statement +import io.getquill.idiom.{ Idiom, SetContainsToken, Statement } import io.getquill.idiom.StatementInterpolator._ import io.getquill.NamingStrategy import io.getquill.util.Interleave @@ -60,7 +58,7 @@ trait SqlIdiom extends Idiom { case a: OptionOperation => a.token case a @ ( _: Function | _: FunctionApply | _: Dynamic | _: OptionOperation | _: Block | - _: Val | _: Ordering | _: QuotedReference | _: TraversableOperation + _: Val | _: Ordering | _: QuotedReference | _: TraversableOperation | _: OnConflict.Excluded | _: OnConflict.Existing ) => fail(s"Malformed or unsupported construct: $a.") } @@ -317,49 +315,52 @@ trait SqlIdiom extends Idiom { stmt"${prop.token} = ${scopedTokenizer(value)}" } - implicit def actionTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Action] = { + implicit def defaultAstTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Action] = { + val insertEntityTokenizer = Tokenizer[Entity] { + case Entity(name, _) => stmt"INTO ${strategy.table(name).token}" + } + actionTokenizer(insertEntityTokenizer)(actionAstTokenizer, strategy) + } - def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = - Tokenizer[Action] { + protected def actionAstTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) = + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + case q: Query => astTokenizer.token(q) + case Property(Property(_, name), "isEmpty") => stmt"${strategy.column(name).token} IS NULL" + case Property(Property(_, name), "isDefined") => stmt"${strategy.column(name).token} IS NOT NULL" + case Property(Property(_, name), "nonEmpty") => stmt"${strategy.column(name).token} IS NOT NULL" + case Property(_, name) => strategy.column(name).token + } - case Insert(table: Entity, assignments) => - val columns = assignments.map(_.property.token) - val values = assignments.map(_.value) - stmt"INSERT INTO ${table.token} (${columns.mkStmt(",")}) VALUES (${values.map(scopedTokenizer(_)).mkStmt(", ")})" + protected def actionTokenizer(insertEntityTokenizer: Tokenizer[Entity])(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Action] = + Tokenizer[Action] { - case Update(table: Entity, assignments) => - stmt"UPDATE ${table.token} SET ${assignments.token}" + case Insert(entity: Entity, assignments) => + val table = insertEntityTokenizer.token(entity) + val columns = assignments.map(_.property.token) + val values = assignments.map(_.value) + stmt"INSERT $table (${columns.mkStmt(",")}) VALUES (${values.map(scopedTokenizer(_)).mkStmt(", ")})" - case Update(Filter(table: Entity, x, where), assignments) => - stmt"UPDATE ${table.token} SET ${assignments.token} WHERE ${where.token}" + case Update(table: Entity, assignments) => + stmt"UPDATE ${table.token} SET ${assignments.token}" - case Delete(Filter(table: Entity, x, where)) => - stmt"DELETE FROM ${table.token} WHERE ${where.token}" + case Update(Filter(table: Entity, x, where), assignments) => + stmt"UPDATE ${table.token} SET ${assignments.token} WHERE ${where.token}" - case Delete(table: Entity) => - stmt"DELETE FROM ${table.token}" + case Delete(Filter(table: Entity, x, where)) => + stmt"DELETE FROM ${table.token} WHERE ${where.token}" - case Returning(Insert(table: Entity, Nil), alias, prop) => - stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)}" + case Delete(table: Entity) => + stmt"DELETE FROM ${table.token}" - case Returning(action, alias, prop) => - action.token + case Returning(Insert(table: Entity, Nil), alias, prop) => + stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)}" - case other => - fail(s"Action ast can't be translated to sql: '$other'") - } + case Returning(action, alias, prop) => + action.token - val customAstTokenizer = - Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { - case q: Query => astTokenizer.token(q) - case Property(Property(_, name), "isEmpty") => stmt"${strategy.column(name).token} IS NULL" - case Property(Property(_, name), "isDefined") => stmt"${strategy.column(name).token} IS NOT NULL" - case Property(Property(_, name), "nonEmpty") => stmt"${strategy.column(name).token} IS NOT NULL" - case Property(_, name) => strategy.column(name).token - } - - tokenizer(customAstTokenizer) - } + case other => + fail(s"Action ast can't be translated to sql: '$other'") + } implicit def entityTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Tokenizer[Entity] = Tokenizer[Entity] { case Entity(name, _) => strategy.table(name).token diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/OnConflictSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/OnConflictSpec.scala new file mode 100644 index 0000000000..b30897afa4 --- /dev/null +++ b/quill-sql/src/test/scala/io/getquill/context/sql/OnConflictSpec.scala @@ -0,0 +1,65 @@ +package io.getquill.context.sql + +import io.getquill.{ Spec, TestEntities } + +trait OnConflictSpec extends Spec { + val ctx: SqlContext[_, _] with TestEntities + import ctx._ + + object `onConflictIgnore` { + val testQuery1, testQuery2 = quote { + qr1.insert(lift(TestEntity("", 1, 0, None))).onConflictIgnore + } + val res1 = 1 + val res2 = 0 + + val testQuery3 = quote { + qr1.filter(_.i == 1) + } + val res3 = List(TestEntity("", 1, 0, None)) + } + + object `onConflictIgnore(_.i)` { + val name = "ON CONFLICT (...) DO NOTHING" + val testQuery1, testQuery2 = quote { + qr1.insert(lift(TestEntity("s", 2, 0, None))).onConflictIgnore(_.i) + } + val res1 = 1 + val res2 = 0 + + val testQuery3 = quote { + qr1.filter(_.i == 2) + } + val res3 = List(TestEntity("s", 2, 0, None)) + } + + abstract class onConflictUpdate(id: Int) { + val e1 = TestEntity("r1", id, 0, None) + val e2 = TestEntity("r2", id, 0, None) + val e3 = TestEntity("r3", id, 0, None) + + val res1, res2, res3 = 1 + + val testQuery4 = quote { + qr1.filter(_.i == lift(id)) + } + val res4 = List(TestEntity("r1-r2-r3", id, 2, None)) + } + + object `onConflictUpdate((t, e) => ...)` extends onConflictUpdate(3) { + def testQuery(e: TestEntity) = quote { + qr1 + .insert(lift(e)) + .onConflictUpdate((t, e) => t.s -> (t.s + "-" + e.s), (t, _) => t.l -> (t.l + 1)) + } + } + + object `onConflictUpdate(_.i)((t, e) => ...)` extends onConflictUpdate(4) { + def testQuery(e: TestEntity) = quote { + qr1 + .insert(lift(e)) + .onConflictUpdate(_.i)((t, e) => t.s -> (t.s + "-" + e.s), (t, _) => t.l -> (t.l + 1)) + } + } + +} diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala index 872ca9e62d..ac96d125a3 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/MySQLDialectSpec.scala @@ -1,13 +1,9 @@ package io.getquill.context.sql.idiom -import io.getquill.Spec -import io.getquill.Literal -import io.getquill.MySQLDialect -import io.getquill.SqlMirrorContext -import io.getquill.TestEntities +import io.getquill._ import io.getquill.idiom.StringToken -class MySQLDialectSpec extends Spec { +class MySQLDialectSpec extends OnConflictSpec { val ctx = new SqlMirrorContext(MySQLDialect, Literal) with TestEntities import ctx._ @@ -86,4 +82,25 @@ class MySQLDialectSpec extends Spec { ctx.run(q).string mustEqual "INSERT INTO TestEntity4 (i) VALUES (DEFAULT)" } + + "OnConflict" - { + "no target - ignore" in { + ctx.run(`no target - ignore`.dynamic).string mustEqual + "INSERT IGNORE INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?)" + + } + "cols target - ignore" in { + ctx.run(`cols target - ignore`.dynamic).string mustEqual + "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE i=i" + } + "no target - update" in { + ctx.run(`no target - update`).string mustEqual + "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE l = ((l + VALUES(l)) / 2), s = VALUES(s)" + } + "cols target - update" in { + intercept[IllegalStateException] { + ctx.run(`cols target - update`.dynamic) + } + } + } } diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/OnConflictSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/OnConflictSpec.scala new file mode 100644 index 0000000000..22bdd8f92b --- /dev/null +++ b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/OnConflictSpec.scala @@ -0,0 +1,27 @@ +package io.getquill.context.sql.idiom + +import io.getquill.{ Spec, TestEntities } +import io.getquill.context.sql.SqlContext + +trait OnConflictSpec extends Spec { + val ctx: SqlContext[_, _] with TestEntities + import ctx._ + + lazy val e = TestEntity("s1", 1, 1, None) + + def ins = quote(query[TestEntity].insert(lift(e))) + def del = quote(query[TestEntity].delete) + + def `no target - ignore` = quote { + ins.onConflictIgnore + } + def `cols target - ignore` = quote { + ins.onConflictIgnore(_.i) + } + def `no target - update` = quote { + ins.onConflictUpdate((t, e) => t.l -> (t.l + e.l) / 2, _.s -> _.s) + } + def `cols target - update` = quote { + ins.onConflictUpdate(_.i, _.s)((t, e) => t.l -> (t.l + e.l) / 2, _.s -> _.s) + } +} diff --git a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/PostgresDialectSpec.scala b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/PostgresDialectSpec.scala index 40d56c6c8e..9bbe64a631 100644 --- a/quill-sql/src/test/scala/io/getquill/context/sql/idiom/PostgresDialectSpec.scala +++ b/quill-sql/src/test/scala/io/getquill/context/sql/idiom/PostgresDialectSpec.scala @@ -1,35 +1,34 @@ package io.getquill.context.sql.idiom -import io.getquill.Spec import io.getquill.PostgresDialect import io.getquill.SqlMirrorContext import io.getquill.Literal import io.getquill.TestEntities -class PostgresDialectSpec extends Spec { +class PostgresDialectSpec extends OnConflictSpec { - val context = new SqlMirrorContext(PostgresDialect, Literal) with TestEntities - import context._ + val ctx = new SqlMirrorContext(PostgresDialect, Literal) with TestEntities + import ctx._ "applies explicit casts" - { "toLong" in { val q = quote { qr1.map(t => t.s.toLong) } - context.run(q).string mustEqual "SELECT t.s::bigint FROM TestEntity t" + ctx.run(q).string mustEqual "SELECT t.s::bigint FROM TestEntity t" } "toInt" in { val q = quote { qr1.map(t => t.s.toInt) } - context.run(q).string mustEqual "SELECT t.s::integer FROM TestEntity t" + ctx.run(q).string mustEqual "SELECT t.s::integer FROM TestEntity t" } } "Array Operations" - { case class ArrayOps(id: Int, numbers: Vector[Int]) "contains" in { - context.run(query[ArrayOps].filter(_.numbers.contains(10))).string mustEqual + ctx.run(query[ArrayOps].filter(_.numbers.contains(10))).string mustEqual "SELECT x1.id, x1.numbers FROM ArrayOps x1 WHERE 10 = ANY(x1.numbers)" } } @@ -44,4 +43,24 @@ class PostgresDialectSpec extends Spec { prepareForProbing("INSERT INTO tb (x1,x2,x3) VALUES (?,?,?)") mustEqual s"PREPARE p${id + 2} AS INSERT INTO tb (x1,x2,x3) VALUES ($$1,$$2,$$3)" } + + "OnConflict" - { + "no target - ignore" in { + ctx.run(`no target - ignore`).string mustEqual + "INSERT INTO TestEntity AS t (s,i,l,o) VALUES (?, ?, ?, ?) ON CONFLICT DO NOTHING" + } + "cols target - ignore" in { + ctx.run(`cols target - ignore`).string mustEqual + "INSERT INTO TestEntity (s,i,l,o) VALUES (?, ?, ?, ?) ON CONFLICT (i) DO NOTHING" + } + "no target - update" in { + intercept[IllegalStateException] { + ctx.run(`no target - update`.dynamic) + } + } + "cols target - update" in { + ctx.run(`cols target - update`).string mustEqual + "INSERT INTO TestEntity AS t (s,i,l,o) VALUES (?, ?, ?, ?) ON CONFLICT (i,s) DO UPDATE SET l = ((t.l + EXCLUDED.l) / 2), s = EXCLUDED.s" + } + } } diff --git a/quill-sql/src/test/sql/mysql-schema.sql b/quill-sql/src/test/sql/mysql-schema.sql index eb2cfb618b..d64aa9065d 100644 --- a/quill-sql/src/test/sql/mysql-schema.sql +++ b/quill-sql/src/test/sql/mysql-schema.sql @@ -77,7 +77,7 @@ Create TABLE BooleanEncodingTestEntity( CREATE TABLE TestEntity( s VARCHAR(255), - i INTEGER, + i INTEGER primary key, l BIGINT, o INTEGER ); diff --git a/quill-sql/src/test/sql/postgres-schema.sql b/quill-sql/src/test/sql/postgres-schema.sql index a337b5f98f..31b624b3e2 100644 --- a/quill-sql/src/test/sql/postgres-schema.sql +++ b/quill-sql/src/test/sql/postgres-schema.sql @@ -60,7 +60,7 @@ CREATE TABLE EncodingUUIDTestEntity( CREATE TABLE TestEntity( s VARCHAR(255), - i INTEGER, + i INTEGER primary key, l BIGINT, o INTEGER );