Skip to content
This repository has been archived by the owner on Dec 3, 2019. It is now read-only.

Prepared statement encoder refactoring. It enables to use stages separately #151

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@

package com.github.mauricio.async.db.postgresql.codec

import java.nio.charset.Charset

import com.github.mauricio.async.db.column.ColumnEncoderRegistry
import com.github.mauricio.async.db.exceptions.EncoderNotAvailableException
import com.github.mauricio.async.db.postgresql.encoders._
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
import com.github.mauricio.async.db.postgresql.messages.frontend._
import com.github.mauricio.async.db.util.{BufferDumper, Log}
import java.nio.charset.Charset
import scala.annotation.switch
import io.netty.handler.codec.MessageToMessageEncoder
import io.netty.channel.ChannelHandlerContext
import io.netty.handler.codec.MessageToMessageEncoder

object MessageEncoder {
val log = Log.get[MessageEncoder]
Expand All @@ -44,22 +43,19 @@ class MessageEncoder(charset: Charset, encoderRegistry: ColumnEncoderRegistry) e
override def encode(ctx: ChannelHandlerContext, msg: AnyRef, out: java.util.List[Object]) = {

val buffer = msg match {
case message: ClientMessage => {
val encoder = (message.kind: @switch) match {
case ServerMessage.Close => CloseMessageEncoder
case ServerMessage.Execute => this.executeEncoder
case ServerMessage.Parse => this.openEncoder
case ServerMessage.Startup => this.startupEncoder
case ServerMessage.Query => this.queryEncoder
case ServerMessage.PasswordMessage => this.credentialEncoder
case message: ClientMessage =>
val encoder = message match {
case CloseMessage => CloseMessageEncoder
case _ : PreparedStatementOpeningMessage => this.openEncoder
case _ : StartupMessage => this.startupEncoder
case _ : QueryMessage => this.queryEncoder
case _ : CredentialMessage => this.credentialEncoder
case _ : PreparedStatementExecuteMessage => this.executeEncoder
case _ => throw new EncoderNotAvailableException(message)
}

encoder.encode(message)
}
case _ => {
case _ =>
throw new IllegalArgumentException("Can not encode message %s".format(msg))
}
}

if (log.isTraceEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,17 @@ trait PreparedStatementEncoderHelper {
writeDescribe: Boolean = false
): ByteBuf = {

val bindBuffer: ByteBuf = bind(statementIdBytes, query, values, encoder, charset, writeDescribe)
val executeBuffer: ByteBuf = execute(statementIdBytes, 0)
val closeBuffer: ByteBuf = closePortal(statementIdBytes)
val syncBuffer: ByteBuf = sync

Unpooled.wrappedBuffer(bindBuffer, executeBuffer, syncBuffer, closeBuffer)
}

def bind(statementIdBytes: Array[Byte], query: String, values: Seq[Any], encoder: ColumnEncoderRegistry, charset: Charset, writeDescribe: Boolean): ByteBuf = {
if (log.isDebugEnabled) {
log.debug(s"Preparing execute portal to statement ($query) - values (${values.mkString(", ")}) - ${charset}")
log.debug(s"Preparing execute portal to statement ($query) - values (${values.mkString(", ")}) - $charset")
}

val bindBuffer = Unpooled.buffer(1024)
Expand Down Expand Up @@ -106,31 +115,63 @@ trait PreparedStatementEncoderHelper {
describeBuffer.writeBytes(statementIdBytes)
describeBuffer.writeByte(0)
}
bindBuffer
}

def execute(statementIdBytes: Array[Byte], fetchSize: Int): ByteBuf = {
val executeLength = 1 + 4 + statementIdBytes.length + 1 + 4
val executeBuffer = Unpooled.buffer(executeLength)
executeBuffer.writeByte(ServerMessage.Execute)
executeBuffer.writeInt(executeLength - 1)
executeBuffer.writeBytes(statementIdBytes)
executeBuffer.writeByte(0)
executeBuffer.writeInt(0)
executeBuffer.writeInt(fetchSize)
executeBuffer
}

def sync: ByteBuf = {
val syncBuffer = Unpooled.buffer(5)
syncBuffer.writeByte(ServerMessage.Sync)
syncBuffer.writeInt(4)
syncBuffer
}

def closePortal(statementIdBytes: Array[Byte]): ByteBuf = {
val closeLength = 1 + 4 + 1 + statementIdBytes.length + 1
val closeBuffer = Unpooled.buffer(closeLength)
closeBuffer.writeByte(ServerMessage.CloseStatementOrPortal)
closeBuffer.writeInt(closeLength - 1)
closeBuffer.writeByte('P')
closeBuffer.writeBytes(statementIdBytes)
closeBuffer.writeByte(0)
closeBuffer
}

val syncBuffer = Unpooled.buffer(5)
syncBuffer.writeByte(ServerMessage.Sync)
syncBuffer.writeInt(4)
def isNull(value: Any): Boolean = value == null || value == None

Unpooled.wrappedBuffer(bindBuffer, executeBuffer, syncBuffer, closeBuffer)
def parse(statementIdBytes: Array[Byte], query: String, valueTypes: Seq[Int], charset: Charset): ByteBuf = {
val columnCount = valueTypes.size

}
val parseBuffer = Unpooled.buffer(1024)
parseBuffer.writeByte(ServerMessage.Parse)
parseBuffer.writeInt(0)

def isNull(value: Any): Boolean = value == null || value == None
parseBuffer.writeBytes(statementIdBytes)
parseBuffer.writeByte(0)
parseBuffer.writeBytes(query.getBytes(charset))
parseBuffer.writeByte(0)

parseBuffer.writeShort(columnCount)

if (log.isDebugEnabled) {
log.debug(s"Opening query ($query) - statement id (${statementIdBytes.mkString("-")}) - selected types (${valueTypes.mkString(", ")}))")
}

for (kind <- valueTypes) {
parseBuffer.writeInt(kind)
}

ByteBufferUtils.writeLength(parseBuffer)
parseBuffer
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

package com.github.mauricio.async.db.postgresql.encoders

import java.nio.charset.Charset

import com.github.mauricio.async.db.column.ColumnEncoderRegistry
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
import com.github.mauricio.async.db.postgresql.messages.frontend.{ClientMessage, PreparedStatementOpeningMessage}
import com.github.mauricio.async.db.util.{Log, ByteBufferUtils}
import java.nio.charset.Charset
import io.netty.buffer.{Unpooled, ByteBuf}
import com.github.mauricio.async.db.util.Log
import io.netty.buffer.{ByteBuf, Unpooled}

object PreparedStatementOpeningEncoder {
val log = Log.get[PreparedStatementOpeningEncoder]
Expand All @@ -32,40 +32,14 @@ class PreparedStatementOpeningEncoder(charset: Charset, encoder : ColumnEncoderR
with PreparedStatementEncoderHelper
{

import PreparedStatementOpeningEncoder.log

override def encode(message: ClientMessage): ByteBuf = {

val m = message.asInstanceOf[PreparedStatementOpeningMessage]

val statementIdBytes = m.statementId.toString.getBytes(charset)
val columnCount = m.valueTypes.size

val parseBuffer = Unpooled.buffer(1024)

parseBuffer.writeByte(ServerMessage.Parse)
parseBuffer.writeInt(0)

parseBuffer.writeBytes(statementIdBytes)
parseBuffer.writeByte(0)
parseBuffer.writeBytes(m.query.getBytes(charset))
parseBuffer.writeByte(0)

parseBuffer.writeShort(columnCount)

if ( log.isDebugEnabled ) {
log.debug(s"Opening query (${m.query}) - statement id (${statementIdBytes.mkString("-")}) - selected types (${m.valueTypes.mkString(", ")}) - values (${m.values.mkString(", ")})")
}

for (kind <- m.valueTypes) {
parseBuffer.writeInt(kind)
}

ByteBufferUtils.writeLength(parseBuffer)

val executeBuffer = writeExecutePortal(statementIdBytes, m.query, m.values, encoder, charset, true)
val parseBuffer: ByteBuf = parse(statementIdBytes, m.query, m.valueTypes, charset)
val executeBuffer = writeExecutePortal(statementIdBytes, m.query, m.values, encoder, charset, writeDescribe = true)

Unpooled.wrappedBuffer(parseBuffer, executeBuffer)
}

}