Skip to content

Commit

Permalink
Stream rows in the extended mode as well.
Browse files Browse the repository at this point in the history
This also applies the streaming of rows to the extended query mode.
  • Loading branch information
plaflamme committed Aug 6, 2019
1 parent 2e1804f commit 90b1c4f
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ class QuerySpec extends FreeSpec with Matchers with MockFactory {

val client = new PostgresClient {

def prepareAndQuery[T](sql: String, params: Param[_]*)(f: (Row) => T): Future[Seq[T]] =
def prepareAndQueryToStream[T](sql: String, params: Param[_]*)(f: (Row) => T): Future[AsyncStream[T]] =
mockClient.prepareAndQuery(sql, params.toList, f)
.map(AsyncStream.fromSeq)

def prepareAndExecute(sql: String, params: Param[_]*): Future[Int] =
mockClient.prepareAndExecute(sql, params.toList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ trait PostgresClient {
/*
* Issue a single, prepared SELECT query and wrap the response rows with the provided function.
*/
def prepareAndQuery[T](sql: String, params: Param[_]*)
(f: Row => T): Future[Seq[T]]
def prepareAndQuery[T](sql: String, params: Param[_]*)(f: Row => T): Future[Seq[T]] =
prepareAndQueryToStream(sql, params: _*)(f).flatMap(_.toSeq)
def prepareAndQueryToStream[T](sql: String, params: Param[_]*)(f: Row => T): Future[AsyncStream[T]]

/*
* Issue a single, prepared arbitrary query without an expected result set, and provide the affected row count
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ class PostgresClientImpl(
/*
* Issue a single, prepared SELECT query and wrap the response rows with the provided function.
*/
override def prepareAndQuery[T](sql: String, params: Param[_]*)(f: Row => T): Future[Seq[T]] = {
override def prepareAndQueryToStream[T](sql: String, params: Param[_]*)(f: Row => T): Future[AsyncStream[T]] = {
typeMap().flatMap { _ =>
for {
service <- factory()
statement = new PreparedStatementImpl("", sql, service)
result <- statement.select(params: _*)(f)
result <- statement.selectToStream(params: _*)(f)
} yield result
}
}
Expand Down Expand Up @@ -292,9 +292,8 @@ class PostgresClientImpl(
exec <- execute()
} yield exec match {
case CommandCompleteResponse(rows) => OK(rows)
case Rows(rows, true) =>
// TODO: actually make this async
ResultSet(fields, charset, AsyncStream.fromSeq(rows), types, receiveFunctions)
case Rows(rows) =>
ResultSet(fields, charset, rows, types, receiveFunctions)
}
f.transform {
result =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.twitter.finagle.postgres

import com.twitter.concurrent.AsyncStream
import com.twitter.finagle.postgres.codec.Errors
import com.twitter.util.Future

Expand All @@ -14,10 +15,12 @@ trait PreparedStatement {
case ResultSet(_) => Future.exception(Errors.client("Update query expected"))
}

def select[T](params: Param[_]*)(f: Row => T): Future[Seq[T]] = fire(params: _*) flatMap {
case ResultSet(rows) => rows.map(f).toSeq
case OK(_) => Future.Nil
def selectToStream[T](params: Param[_]*)(f: Row => T): Future[AsyncStream[T]] = fire(params: _*) map {
case ResultSet(rows) => rows.map(f)
case OK(_) => AsyncStream.empty
}
def select[T](params: Param[_]*)(f: Row => T): Future[Seq[T]] =
selectToStream(params: _*)(f).flatMap(_.toSeq)

def selectFirst[T](params: Param[_]*)(f: Row => T): Future[Option[T]] =
select[T](params:_*)(f) flatMap { rows => Future.value(rows.headOption) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import com.twitter.finagle.postgres.messages._
import com.twitter.logging.Logger
import com.twitter.util.{Promise, Return, Throw}

import scala.collection.mutable.ListBuffer

/*
* State machine that captures transitions between states.
*
Expand Down Expand Up @@ -105,7 +103,7 @@ class ConnectionStateMachine(state: State = AuthenticationRequired, val id: Int)

case (RowDescription(fields), SimpleQuery) =>
val complete = new Promise[Unit]()
val nextRow = StreamRows(complete)
val nextRow = StreamRows(complete, extended = false)
(Some(SelectResult(fields.map(f => Field(f.name, f.fieldFormat, f.dataType)), nextRow.asyncStream)(complete)), nextRow)
case (ErrorResponse(details), SimpleQuery) =>
(None, EmitOnReadyForQuery(Error(details)))
Expand All @@ -129,36 +127,45 @@ class ConnectionStateMachine(state: State = AuthenticationRequired, val id: Int)
case (CommandComplete(RollBack), ExecutePreparedStatement) => (Some(CommandCompleteResponse(1)), Connected)
case (CommandComplete(Commit), ExecutePreparedStatement) => (Some(CommandCompleteResponse(1)), Connected)
case (CommandComplete(Do), ExecutePreparedStatement) => (Some(CommandCompleteResponse(1)), Connected)
case (row:DataRow, ExecutePreparedStatement) => (None, AggregateRowsWithoutFields(ListBuffer(row)))
case (row:DataRow, state:AggregateRowsWithoutFields) =>
state.buff += row
(None, state)
case (PortalSuspended, AggregateRowsWithoutFields(buff)) => (Some(Rows(buff.toList, completed = false)), Connected)
case (CommandComplete(Select(0)), ExecutePreparedStatement) => (Some(Rows(List.empty, completed = true)), Connected)
case (CommandComplete(Select(_)), AggregateRowsWithoutFields(buff)) =>
(Some(Rows(buff.toList, completed = true)), Connected)
case (CommandComplete(Insert(_)), AggregateRowsWithoutFields(buff)) =>
(Some(Rows(buff.toList, completed = true)), Connected)
case (CommandComplete(Update(_)), AggregateRowsWithoutFields(buff)) =>
(Some(Rows(buff.toList, completed = true)), Connected)
case (CommandComplete(Delete(_)), AggregateRowsWithoutFields(buff)) =>
(Some(Rows(buff.toList, completed = true)), Connected)
case (ErrorResponse(details), ExecutePreparedStatement) => (Some(Error(details)), Connected)
case (ErrorResponse(details), AggregateRowsWithoutFields(_)) => (Some(Error(details)), Connected)
case (row: DataRow, ExecutePreparedStatement) =>
val complete = new Promise[Unit]
val nextRow = StreamRows(complete, extended = true)
val thisRow = AsyncStream.mk(row, nextRow.asyncStream)
val response = Rows(thisRow)(complete)
(Some(response), nextRow)
case (CommandComplete(Select(0)), ExecutePreparedStatement) =>
(Some(Rows.Empty), Connected)
case (ErrorResponse(details), ExecutePreparedStatement) =>
(Some(Error(details)), Connected)
}

transition {
case (row: DataRow, StreamRows(complete, thisRow)) =>
val nextRow = StreamRows(complete)
fullTransition {
case (row: DataRow, StreamRows(complete, extended, thisRow)) =>
val nextRow = StreamRows(complete, extended)
thisRow.setValue(AsyncStream.mk(row, nextRow.asyncStream))
(None, nextRow)
case (CommandComplete(_), StreamRows(complete, thisRow)) =>
case (PortalSuspended, StreamRows(complete, _, thisRow)) =>
thisRow.setValue(AsyncStream.empty)
(Some(StateMachine.Complete(complete, Return.Unit)), Connected)
case (CommandComplete(_), StreamRows(complete, extended, thisRow)) =>
thisRow.setValue(AsyncStream.empty)
(None, EmitOnReadyForQuery(StateMachine.Complete(complete, Return.Unit)))
case (ErrorResponse(details), StreamRows(complete, thisRow)) =>
val response = StateMachine.Complete(complete, Return.Unit)
if (extended) {
// in extended mode, we don't expect a ReadyForQuery, so we respond now
(Some(response), Connected)
} else {
(None, EmitOnReadyForQuery(response))
}
case (ErrorResponse(details), StreamRows(complete, extended, thisRow)) =>
val exn = Errors.server(Error(details), None)
thisRow.setValue(AsyncStream.exception(exn))
(None, EmitOnReadyForQuery(StateMachine.Complete(complete, Throw(exn))))
val response = StateMachine.Complete(complete, Throw(exn))
if (extended) {
// in extended mode, we don't expect a ReadyForQuery, so we respond now
(Some(response), Connected)
} else {
(None, EmitOnReadyForQuery(response))
}
}

fullTransition {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ case object ExecutePreparedStatement extends ExtendedQueryState

case object AwaitParamsDescription extends ExtendedQueryState

case class StreamRows(complete: Promise[Unit], nextRow: Promise[AsyncStream[DataRow]] = new Promise) extends ExtendedQueryState {
case class StreamRows(complete: Promise[Unit], extended: Boolean, nextRow: Promise[AsyncStream[DataRow]] = new Promise) extends ExtendedQueryState {
val asyncStream: AsyncStream[DataRow] = AsyncStream.fromFuture(nextRow).flatten
}

case class AggregateRowsWithoutFields(buff: ListBuffer[DataRow] = ListBuffer()) extends ExtendedQueryState

case class AwaitRowDescription(types: Array[Int]) extends ExtendedQueryState

case class EmitOnReadyForQuery[R <: PgResponse](emit: StateMachine.TransitionResult[R]) extends ExtendedQueryState
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ case class PasswordRequired(encoding: PasswordEncoding) extends PgResponse

case class AuthenticatedResponse(params: Map[String, String], processId: Int, secretKey: Int) extends PgResponse

case class Rows(rows: List[DataRow], completed: Boolean) extends PgResponse
case class Rows(rows: AsyncStream[DataRow])(private[finagle] val complete: Future[Unit]) extends AsyncPgResponse
object Rows {
val Empty = Rows(AsyncStream.empty)(Future.Done)
}

case class Field(name: String, format: Short, dataType: Int)

Expand Down

0 comments on commit 90b1c4f

Please sign in to comment.