diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 1e82ea06e8e..c7d4c53785c 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -234,8 +234,12 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq }() if session.Options.Workload == querypb.ExecuteOptions_OLAP { - _, err := vh.vtg.StreamExecute(ctx, session, query, make(map[string]*querypb.BindVariable), callback) - return mysql.NewSQLErrorFromError(err) + session, err := vh.vtg.StreamExecute(ctx, session, query, make(map[string]*querypb.BindVariable), callback) + if err != nil { + return mysql.NewSQLErrorFromError(err) + } + fillInTxStatusFlags(c, session) + return nil } session, result, err := vh.vtg.Execute(ctx, session, query, make(map[string]*querypb.BindVariable)) @@ -338,12 +342,15 @@ func (vh *vtgateHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareDat if session.Options.Workload == querypb.ExecuteOptions_OLAP { _, err := vh.vtg.StreamExecute(ctx, session, prepare.PrepareStmt, prepare.BindVars, callback) - return mysql.NewSQLErrorFromError(err) + if err != nil { + return mysql.NewSQLErrorFromError(err) + } + fillInTxStatusFlags(c, session) + return nil } _, qr, err := vh.vtg.Execute(ctx, session, prepare.PrepareStmt, prepare.BindVars) if err != nil { - err = mysql.NewSQLErrorFromError(err) - return err + return mysql.NewSQLErrorFromError(err) } fillInTxStatusFlags(c, session)