Skip to content

Commit

Permalink
Introduce conflict action (upsert) for Postgres and MySQL dialects
Browse files Browse the repository at this point in the history
  • Loading branch information
mentegy committed Apr 3, 2018
1 parent 3801e60 commit 29f41c2
Show file tree
Hide file tree
Showing 20 changed files with 622 additions and 83 deletions.
36 changes: 36 additions & 0 deletions quill-core/src/main/scala/io/getquill/MirrorIdiom.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class MirrorIdiom extends Idiom {
case ast: QuotedReference => ast.ast.token
case ast: Lift => ast.token
case ast: Assignment => ast.token
case ast: Excluded => ast.token
case ast: Existing => ast.token
}

implicit def ifTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[If] = Tokenizer[If] {
Expand Down Expand Up @@ -178,12 +180,46 @@ class MirrorIdiom extends Idiom {
case e => stmt"${e.name.token}"
}

implicit val excludedTokenizer: Tokenizer[Excluded] = Tokenizer[Excluded] {
case Excluded(ident) => stmt"${ident.token}"
}

implicit val existingTokenizer: Tokenizer[Existing] = Tokenizer[Existing] {
case Existing(ident) => stmt"${ident.token}"
}

implicit def actionTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[Action] = Tokenizer[Action] {
case Update(query, assignments) => stmt"${query.token}.update(${assignments.token})"
case Insert(query, assignments) => stmt"${query.token}.insert(${assignments.token})"
case Delete(query) => stmt"${query.token}.delete"
case Returning(query, alias, body) => stmt"${query.token}.returning((${alias.token}) => ${body.token})"
case Foreach(query, alias, body) => stmt"${query.token}.foreach((${alias.token}) => ${body.token})"
case c: Conflict => stmt"${c.token}"
}

implicit def conflictTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[Conflict] = {

def targetProps(l: List[Property]) = l.map(p => Transform(p) {
case Ident(_) => Ident("_")
})

implicit val conflictTargetTokenizer = Tokenizer[Conflict.Target] {
case Conflict.NoTarget => stmt""
case Conflict.StringValue(s) => stmt"""("${s.token}")"""
case Conflict.Properties(props) => stmt"(${targetProps(props).token})"
}

val updateAssignsTokenizer = Tokenizer[Assignment] {
case Assignment(i, p, v) =>
stmt"(${i.token}, e) => ${p.token} -> ${scopedTokenizer(v)}"
}

Tokenizer[Conflict] {
case Conflict(i, t, Conflict.Update(assign)) =>
stmt"${i.token}.onConflictUpdate${t.token}(${assign.map(updateAssignsTokenizer.token).mkStmt()})"
case Conflict(i, t, Conflict.Ignore) =>
stmt"${i.token}.onConflictIgnore${t.token}"
}
}

implicit def assignmentTokenizer(implicit liftTokenizer: Tokenizer[Lift]): Tokenizer[Assignment] = Tokenizer[Assignment] {
Expand Down
13 changes: 13 additions & 0 deletions quill-core/src/main/scala/io/getquill/ast/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ case class If(condition: Ast, `then`: Ast, `else`: Ast) extends Ast

case class Assignment(alias: Ident, property: Ast, value: Ast) extends Ast

case class Excluded(alias: Ident) extends Ast
case class Existing(alias: Ident) extends Ast
//************************************************************

sealed trait Operation extends Ast
Expand Down Expand Up @@ -126,6 +128,17 @@ case class Returning(action: Ast, alias: Ident, property: Ast) extends Action

case class Foreach(query: Ast, alias: Ident, body: Ast) extends Action

case class Conflict(insert: Ast, target: Conflict.Target, action: Conflict.Action) extends Action
object Conflict {
trait Target
case object NoTarget extends Target
case class StringValue(value: String) extends Target
case class Properties(props: List[Property]) extends Target

trait Action
case object Ignore extends Action
case class Update(assignments: List[Assignment]) extends Action
}
//************************************************************

case class Dynamic(tree: Any) extends Ast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ trait StatefulTransformer[T] {
case e: Ident => (e, this)
case e: OptionOperation => apply(e)
case e: TraversableOperation => apply(e)
case e: Property => apply(e)
case e: Existing => (e, this)
case e: Excluded => (e, this)

case Function(a, b) =>
val (bt, btt) = apply(b)
(Function(a, bt), btt)

case Property(a, b) =>
val (at, att) = apply(a)
(Property(at, b), att)

case Infix(a, b) =>
val (bt, btt) = apply(b)(_.apply)
(Infix(a, bt), btt)
Expand Down Expand Up @@ -168,6 +167,13 @@ trait StatefulTransformer[T] {
(Assignment(a, bt, ct), ctt)
}

def apply(e: Property): (Property, StatefulTransformer[T]) =
e match {
case Property(a, b) =>
val (at, att) = apply(a)
(Property(at, b), att)
}

def apply(e: Operation): (Operation, StatefulTransformer[T]) =
e match {
case UnaryOperation(o, a) =>
Expand Down Expand Up @@ -217,6 +223,27 @@ trait StatefulTransformer[T] {
val (at, att) = apply(a)
val (ct, ctt) = att.apply(c)
(Foreach(at, b, ct), ctt)
case Conflict(a, b, c) =>
val (at, att) = apply(a)
val (bt, btt) = att.apply(b)
val (ct, ctt) = btt.apply(c)
(Conflict(at, bt, ct), ctt)
}

def apply(e: Conflict.Target): (Conflict.Target, StatefulTransformer[T]) =
e match {
case Conflict.NoTarget | Conflict.StringValue(_) => (e, this)
case Conflict.Properties(a) =>
val (at, att) = apply(a)(_.apply)
(Conflict.Properties(at), att)
}

def apply(e: Conflict.Action): (Conflict.Action, StatefulTransformer[T]) =
e match {
case Conflict.Ignore => (e, this)
case Conflict.Update(a) =>
val (at, att) = apply(a)(_.apply)
(Conflict.Update(at), att)
}

def apply[U, R](list: List[U])(f: StatefulTransformer[T] => U => (R, StatefulTransformer[T])) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ trait StatelessTransformer {
case e: Assignment => apply(e)
case Function(params, body) => Function(params, apply(body))
case e: Ident => e
case Property(a, name) => Property(apply(a), name)
case e: Property => apply(e)
case Infix(a, b) => Infix(a, b.map(apply))
case e: OptionOperation => apply(e)
case e: TraversableOperation => apply(e)
Expand All @@ -22,6 +22,8 @@ trait StatelessTransformer {
case Block(statements) => Block(statements.map(apply))
case Val(name, body) => Val(name, apply(body))
case o: Ordering => o
case e: Excluded => e
case e: Existing => e
}

def apply(o: OptionOperation): OptionOperation =
Expand Down Expand Up @@ -69,6 +71,11 @@ trait StatelessTransformer {
case Assignment(a, b, c) => Assignment(a, apply(b), apply(c))
}

def apply(e: Property): Property =
e match {
case Property(a, name) => Property(apply(a), name)
}

def apply(e: Operation): Operation =
e match {
case UnaryOperation(o, a) => UnaryOperation(o, apply(a))
Expand All @@ -94,6 +101,19 @@ trait StatelessTransformer {
case Delete(query) => Delete(apply(query))
case Returning(query, alias, property) => Returning(apply(query), alias, apply(property))
case Foreach(query, alias, body) => Foreach(apply(query), alias, apply(body))
case Conflict(query, target, action) => Conflict(apply(query), apply(target), apply(action))
}

def apply(e: Conflict.Target): Conflict.Target =
e match {
case Conflict.NoTarget | Conflict.StringValue(_) => e
case Conflict.Properties(props) => Conflict.Properties(props.map(apply))
}

def apply(e: Conflict.Action): Conflict.Action =
e match {
case Conflict.Ignore => e
case Conflict.Update(assigns) => Conflict.Update(assigns.map(apply))
}

}
35 changes: 35 additions & 0 deletions quill-core/src/main/scala/io/getquill/dsl/QueryDsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,41 @@ private[dsl] trait QueryDsl {
sealed trait Insert[E] extends Action[E] {
@compileTimeOnly(NonQuotedException.message)
def returning[R](f: E => R): ActionReturning[E, R] = NonQuotedException()

@compileTimeOnly(NonQuotedException.message)
def onConflictIgnore: Insert[E] = NonQuotedException()

@compileTimeOnly(NonQuotedException.message)
def onConflictIgnore(target: String): Insert[E] = NonQuotedException()

@compileTimeOnly(NonQuotedException.message)
def onConflictIgnore(target: E => Any, targets: (E => Any)*): Insert[E] = NonQuotedException()

@compileTimeOnly(NonQuotedException.message)
def onConflictUpdate(assign: ((E, E) => (Any, Any)), assigns: ((E, E) => (Any, Any))*): Insert[E] = NonQuotedException()

@compileTimeOnly(NonQuotedException.message)
def onConflictUpdate(target: String)(assign: ((E, E) => (Any, Any)), assigns: ((E, E) => (Any, Any))*): Insert[E] = NonQuotedException()

/**
* Generates an atomic INSERT or UPDATE (upsert) action if supported.
*
* @param targets - conflict target
* @param assigns - update statement, declared as function: `(table, excluded) => (assign, result)`
* `table` - is used to extract column for update assignment and reference existing row
* `excluded` - aliases excluded table, e.g. row proposed for insertion.
* `assign` - left hand side of assignment. Should be accessed from `table` argument
* `result` - right hand side of assignment.
*
* Example usage:
* {{{
* insert.onConflictUpdate(_.id)((t, e) => t.col -> (e.col + t.col))
* }}}
* If insert statement violates conflict target then the column `col` of row will be updated with sum of
* existing value and and proposed `col` in insert.
*/
@compileTimeOnly(NonQuotedException.message)
def onConflictUpdate(target: E => Any, targets: (E => Any)*)(assign: ((E, E) => (Any, Any)), assigns: ((E, E) => (Any, Any))*): Insert[E] = NonQuotedException()
}

sealed trait ActionReturning[E, Output] extends Action[E]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@ object NormalizeReturning {

def apply(e: Action): Action = {
e match {
case Returning(Insert(query, assignments), alias, body) =>
Returning(Insert(query, filterReturnedColumn(assignments, body)), alias, body)
case Returning(Update(query, assignments), alias, body) =>
Returning(Update(query, filterReturnedColumn(assignments, body)), alias, body)
case e => e
case Returning(a: Action, alias, body) => Returning(apply(a, body), alias, body)
case _ => e
}
}

private def apply(e: Action, body: Ast): Action = e match {
case Insert(query, assignments) => Insert(query, filterReturnedColumn(assignments, body))
case Update(query, assignments) => Update(query, filterReturnedColumn(assignments, body))
case Conflict(a: Action, target, act) => Conflict(apply(a, body), target, act)
case _ => e
}

private def filterReturnedColumn(assignments: List[Assignment], column: Ast): List[Assignment] =
assignments.flatMap(filterReturnedColumn(_, column))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ object RenameProperties extends StatelessTransformer {
val bodyr = BetaReduction(body, replace: _*)
(Returning(action, alias, bodyr), schema)
}
case Conflict(a: Action, target, act) =>
applySchema(a) match {
case (action, schema) => (Conflict(action, target, act), schema)
}
case q => (q, Tuple(List.empty))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ case class FreeVariables(state: State)
super.apply(other)
}

override def apply(e: Conflict.Target): (Conflict.Target, StatefulTransformer[State]) = (e, this)

override def apply(query: Query): (Query, StatefulTransformer[State]) =
query match {
case q @ Filter(a, b, c) => (q, free(a, b, c))
Expand Down
20 changes: 19 additions & 1 deletion quill-core/src/main/scala/io/getquill/quotation/Liftables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ trait Liftables {
case ast: Assignment => assignmentLiftable(ast)
case ast: OptionOperation => optionOperationLiftable(ast)
case ast: TraversableOperation => traversableOperationLiftable(ast)
case ast: Property => propertyLiftable(ast)
case Val(name, body) => q"$pack.Val($name, $body)"
case Block(statements) => q"$pack.Block($statements)"
case Property(a, b) => q"$pack.Property($a, $b)"
case Function(a, b) => q"$pack.Function($a, $b)"
case FunctionApply(a, b) => q"$pack.FunctionApply($a, $b)"
case BinaryOperation(a, b, c) => q"$pack.BinaryOperation($a, $b, $c)"
Expand All @@ -33,6 +33,8 @@ trait Liftables {
case Dynamic(tree: Tree) if (tree.tpe <:< c.weakTypeOf[CoreDsl#Quoted[Any]]) => q"$tree.ast"
case Dynamic(tree: Tree) => q"$pack.Constant($tree)"
case QuotedReference(tree: Tree, ast) => q"$ast"
case Excluded(a) => q"$pack.Excluded($a)"
case Existing(a) => q"$pack.Existing($a)"
}

implicit val optionOperationLiftable: Liftable[OptionOperation] = Liftable[OptionOperation] {
Expand Down Expand Up @@ -113,6 +115,10 @@ trait Liftables {
case PropertyAlias(a, b) => q"$pack.PropertyAlias($a, $b)"
}

implicit val propertyLiftable: Liftable[Property] = Liftable[Property] {
case Property(a, b) => q"$pack.Property($a, $b)"
}

implicit val orderingLiftable: Liftable[Ordering] = Liftable[Ordering] {
case TupleOrdering(elems) => q"$pack.TupleOrdering($elems)"
case Asc => q"$pack.Asc"
Expand All @@ -136,6 +142,18 @@ trait Liftables {
case Delete(a) => q"$pack.Delete($a)"
case Returning(a, b, c) => q"$pack.Returning($a, $b, $c)"
case Foreach(a, b, c) => q"$pack.Foreach($a, $b, $c)"
case Conflict(a, b, c) => q"$pack.Conflict($a, $b, $c)"
}

implicit val conflictTargetLiftable: Liftable[Conflict.Target] = Liftable[Conflict.Target] {
case Conflict.NoTarget => q"$pack.Conflict.NoTarget"
case Conflict.StringValue(a) => q"$pack.Conflict.StringValue.apply($a)"
case Conflict.Properties(a) => q"$pack.Conflict.Properties.apply($a)"
}

implicit val conflictActionLiftable: Liftable[Conflict.Action] = Liftable[Conflict.Action] {
case Conflict.Ignore => q"$pack.Conflict.Ignore"
case Conflict.Update(a) => q"$pack.Conflict.Update.apply($a)"
}

implicit val assignmentLiftable: Liftable[Assignment] = Liftable[Assignment] {
Expand Down
40 changes: 39 additions & 1 deletion quill-core/src/main/scala/io/getquill/quotation/Parsing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ trait Parsing {
case `quotedAstParser`(value) => value
case `functionParser`(value) => value
case `actionParser`(value) => value
case `conflictParser`(value) => value
case `infixParser`(value) => value
case `orderingParser`(value) => value
case `operationParser`(value) => value
Expand Down Expand Up @@ -317,7 +318,7 @@ trait Parsing {
ListContains(astParser(col), astParser(body))
}

val propertyParser: Parser[Ast] = Parser[Ast] {
val propertyParser: Parser[Property] = Parser[Property] {
case q"$e.get" if is[Option[Any]](e) =>
c.fail("Option.get is not supported since it's an unsafe operation. Use `forall` or `exists` instead.")
case q"$e.$property" => Property(astParser(e), property.decodedName.toString)
Expand Down Expand Up @@ -535,10 +536,47 @@ trait Parsing {
checkTypes(prop, value)
Assignment(i1, astParser(prop), astParser(value))

case q"((${ identParser(i1) }, ${ identParser(i2) }) => $pack.Predef.ArrowAssoc[$t]($prop).$arrow[$v]($value))" =>
checkTypes(prop, value)
val valueAst = Transform(astParser(value)) {
case `i1` => Existing(i1)
case `i2` => Excluded(i2)
}
Assignment(i1, astParser(prop), valueAst)
// Unused, it's here only to make eclipse's presentation compiler happy
case astParser(ast) => Assignment(Ident("unused"), Ident("unused"), Constant("unused"))
}

/*private def excludedParser(tpe: Type): Parser[Excluded] = Parser[Excluded] {
case q"$pack.excluded[$t].$prop" => Excluded(Property(Ident("excluded"), prop.decodedName.toString))
}*/

val conflictParser: Parser[Ast] = Parser[Ast] {
case q"$query.onConflictIgnore" =>
Conflict(astParser(query), Conflict.NoTarget, Conflict.Ignore)
case q"$query.onConflictIgnore(${ target: String })" =>
Conflict(astParser(query), Conflict.StringValue(target), Conflict.Ignore)
case q"$query.onConflictIgnore(..$targets)" =>
Conflict(astParser(query), parseConflictProps(targets), Conflict.Ignore)

case q"$query.onConflictUpdate(..$assigns)" =>
Conflict(astParser(query), Conflict.NoTarget, parseConflictAssigns(assigns))
case q"$query.onConflictUpdate(${ target: String })(..$assigns)" =>
Conflict(astParser(query), Conflict.StringValue(target), parseConflictAssigns(assigns))
case q"$query.onConflictUpdate(..$targets)(..$assigns)" =>
Conflict(astParser(query), parseConflictProps(targets), parseConflictAssigns(assigns))
}

private def parseConflictProps(targets: List[Tree]) = Conflict.Properties {
targets.map {
case q"($e) => $prop" => propertyParser(prop)
case tree => c.fail(s"Tree '$tree' can't be parsed as conflict target")
}
}

private def parseConflictAssigns(targets: List[Tree]) =
Conflict.Update(targets.map(assignmentParser(_)))

private def checkTypes(lhs: Tree, rhs: Tree): Unit = {
def unquoted(tree: Tree) =
is[CoreDsl#Quoted[Any]](tree) match {
Expand Down
Loading

0 comments on commit 29f41c2

Please sign in to comment.