Skip to content

Commit

Permalink
remove SetTimeout and SetProtocolID from protocol instances
Browse files Browse the repository at this point in the history
Summary:
remove SetTimeout and SetProtocolID from protocol instances

These can now only be set at construction time.

Reviewed By: yarikk

Differential Revision: D60106519

fbshipit-source-id: 60306124608d5ce14a01c07ddb4c77c2c764a30c
  • Loading branch information
awalterschulze authored and facebook-github-bot committed Jul 23, 2024
1 parent 6619d45 commit 8422529
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 187 deletions.
45 changes: 3 additions & 42 deletions thrift/lib/go/thrift/client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,24 +164,6 @@ func newOptions(opts ...ClientOption) (*clientOptions, error) {
return res, nil
}

func setOptions(protocol Protocol, options *clientOptions) error {
proto := protocol.(protocolClient)
for name, value := range options.persistentHeaders {
proto.SetPersistentHeader(name, value)
}
if err := proto.SetProtocolID(options.protocol); err != nil {
return err
}
proto.SetTimeout(options.timeout)
return nil
}

type protocolClient interface {
Protocol
SetProtocolID(protoID ProtocolID) error
SetTimeout(timeout time.Duration)
}

// NewClient will return a connected thrift protocol object.
// Effectively, this is an open thrift connection to a server.
// A thrift client can use this connection to communicate with a server.
Expand All @@ -192,32 +174,11 @@ func NewClient(opts ...ClientOption) (Protocol, error) {
}
switch options.transport {
case TransportIDHeader:
proto, err := newHeaderProtocol(options.conn)
if err != nil {
return nil, err
}
if err := setOptions(proto, options); err != nil {
return nil, err
}
return proto, nil
return newHeaderProtocol(options.conn, options.protocol, options.timeout, options.persistentHeaders)
case TransportIDRocket:
protocol, err := newRocketClient(options.conn)
if err != nil {
return nil, err
}
if err := setOptions(protocol, options); err != nil {
return nil, err
}
return protocol, nil
return newRocketClient(options.conn, options.protocol, options.timeout, options.persistentHeaders)
case TransportIDUpgradeToRocket:
protocol, err := newUpgradeToRocketClient(options.conn)
if err != nil {
return nil, err
}
if err := setOptions(protocol, options); err != nil {
return nil, err
}
return protocol, nil
return newUpgradeToRocketClient(options.conn, options.protocol, options.timeout, options.persistentHeaders)
default:
panic("framed and unframed transport are not supported")
}
Expand Down
12 changes: 6 additions & 6 deletions thrift/lib/go/thrift/context_headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestHeaderProtocolSomeHeaders(t *testing.T) {
t.Fatal(err)
}
}
protocol, err := newHeaderProtocol(newMockSocket())
protocol, err := newHeaderProtocol(newMockSocket(), ProtocolIDCompact, 0, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -46,7 +46,7 @@ func TestHeaderProtocolSomeHeaders(t *testing.T) {

// somewhere we are still passing context as nil, so we need to support this for now
func TestHeaderProtocolSetNilHeaders(t *testing.T) {
protocol, err := newHeaderProtocol(newMockSocket())
protocol, err := newHeaderProtocol(newMockSocket(), ProtocolIDCompact, 0, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -65,7 +65,7 @@ func TestRocketProtocolSomeHeaders(t *testing.T) {
t.Fatal(err)
}
}
protocol, err := newRocketClient(newMockSocket())
protocol, err := newRocketClient(newMockSocket(), ProtocolIDCompact, 0, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -78,7 +78,7 @@ func TestRocketProtocolSomeHeaders(t *testing.T) {

// somewhere we are still passing context as nil, so we need to support this for now
func TestRocketProtocolSetNilHeaders(t *testing.T) {
protocol, err := newRocketClient(newMockSocket())
protocol, err := newRocketClient(newMockSocket(), ProtocolIDCompact, 0, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -97,7 +97,7 @@ func TestUpgradeToRocketProtocolSomeHeaders(t *testing.T) {
t.Fatal(err)
}
}
protocol, err := newUpgradeToRocketClient(newMockSocket())
protocol, err := newUpgradeToRocketClient(newMockSocket(), ProtocolIDCompact, 0, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -110,7 +110,7 @@ func TestUpgradeToRocketProtocolSomeHeaders(t *testing.T) {

// somewhere we are still passing context as nil, so we need to support this for now
func TestUpgradeToRocketProtocolSetNilHeaders(t *testing.T) {
protocol, err := newUpgradeToRocketClient(newMockSocket())
protocol, err := newUpgradeToRocketClient(newMockSocket(), ProtocolIDCompact, 0, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
31 changes: 11 additions & 20 deletions thrift/lib/go/thrift/header_protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,29 @@ type headerProtocol struct {

// NewHeaderProtocol creates a new header protocol.
func NewHeaderProtocol(conn net.Conn) (Protocol, error) {
return newHeaderProtocol(conn)
return newHeaderProtocol(conn, ProtocolIDCompact, 0, nil)
}

func newHeaderProtocol(conn net.Conn) (Protocol, error) {
p := &headerProtocol{
protoID: ProtocolIDCompact,
}
p.trans = newHeaderTransport(conn)
func newHeaderProtocol(conn net.Conn, protoID ProtocolID, timeout time.Duration, persistentHeaders map[string]string) (Protocol, error) {
p := &headerProtocol{protoID: protoID}
p.trans = newHeaderTransport(conn, protoID)
p.trans.conn.readTimeout = timeout
p.trans.conn.writeTimeout = timeout
if err := p.resetProtocol(); err != nil {
return nil, err
}
for name, value := range persistentHeaders {
p.SetPersistentHeader(name, value)
}
return p, nil
}

func (p *headerProtocol) SetTimeout(timeout time.Duration) {
p.trans.conn.readTimeout = timeout
p.trans.conn.writeTimeout = timeout
}

func (p *headerProtocol) resetProtocol() error {
if p.Format != nil && p.protoID == p.trans.ProtocolID() {
if p.Format != nil && p.protoID == p.trans.protoID {
return nil
}

p.protoID = p.trans.ProtocolID()
p.protoID = p.trans.protoID
switch p.protoID {
case ProtocolIDBinary:
// These defaults match cpp implementation
Expand Down Expand Up @@ -154,13 +152,6 @@ func (p *headerProtocol) ProtocolID() ProtocolID {
return p.protoID
}

func (p *headerProtocol) SetProtocolID(protoID ProtocolID) error {
if err := p.trans.SetProtocolID(protoID); err != nil {
return err
}
return p.resetProtocol()
}

// Deprecated: GetFlags() is a deprecated method.
func (t *headerProtocol) GetFlags() HeaderFlags {
return t.trans.GetFlags()
Expand Down
18 changes: 2 additions & 16 deletions thrift/lib/go/thrift/header_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type headerTransport struct {
}

// newHeaderTransport creates a new transport with defaults.
func newHeaderTransport(c net.Conn) *headerTransport {
func newHeaderTransport(c net.Conn, protoID ProtocolID) *headerTransport {
conn := &connTimeout{Conn: c}
return &headerTransport{
conn: conn,
Expand All @@ -70,7 +70,7 @@ func newHeaderTransport(c net.Conn) *headerTransport {
writeInfoHeaders: map[string]string{},
persistentWriteInfoHeaders: map[string]string{},

protoID: DefaulprotoID,
protoID: protoID,
flags: 0,
clientType: DefaultClientType,
writeTransforms: []TransformID{},
Expand Down Expand Up @@ -140,20 +140,6 @@ func (t *headerTransport) GetResponseHeaders() map[string]string {
return res
}

func (t *headerTransport) ProtocolID() ProtocolID {
return t.protoID
}

func (t *headerTransport) SetProtocolID(protoID ProtocolID) error {
if !(protoID == ProtocolIDBinary || protoID == ProtocolIDCompact) {
return NewTransportException(
NOT_IMPLEMENTED, fmt.Sprintf("unimplemented proto ID: %s (%#x)", protoID.String(), int64(protoID)),
)
}
t.protoID = protoID
return nil
}

func (t *headerTransport) AddTransform(trans TransformID) error {
if sup, ok := supportedTransforms[trans]; !ok || !sup {
return NewTransportException(
Expand Down
54 changes: 9 additions & 45 deletions thrift/lib/go/thrift/header_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

func TestHeaderTransport(t *testing.T) {
trans := newHeaderTransport(newMockSocket())
trans := newHeaderTransport(newMockSocket(), ProtocolIDCompact)
TransportTest(t, trans, trans)
}

Expand Down Expand Up @@ -104,7 +104,7 @@ func TestHeaderFramedBinary(t *testing.T) {
testHeaderToProto(
t, FramedDeprecated, tmb,
NewBinaryProtocol(newFramedTransport(tmb), true, true),
newHeaderTransport(tmb),
newHeaderTransport(tmb, ProtocolIDCompact),
)
}

Expand All @@ -113,53 +113,17 @@ func TestHeaderFramedCompact(t *testing.T) {
testHeaderToProto(
t, FramedCompact, tmb,
NewCompactProtocol(newFramedTransport(tmb)),
newHeaderTransport(tmb),
newHeaderTransport(tmb, ProtocolIDCompact),
)
}

func TestHeaderProtoID(t *testing.T) {
n := 1
tmb := newMockSocket()
// write transport
trans1 := newHeaderTransport(tmb)
// read transport
trans2 := newHeaderTransport(tmb)
targetID := ProtocolIDBinary

assertEq(t, DefaulprotoID, trans1.ProtocolID())

err := trans1.SetProtocolID(targetID)
if err != nil {
t.Fatalf("failed to set binary protocol")
}

assertEq(t, targetID, trans1.ProtocolID())

_, err = trans1.Write([]byte("ASDF"))
if err != nil {
t.Fatalf("failed to write frame %d: %s", n, err)
}
err = trans1.Flush()
if err != nil {
t.Fatalf("failed to xmit frame %d: %s", n, err)
}

assertEq(t, DefaulprotoID, trans2.ProtocolID())
err = trans2.ResetProtocol()
if err != nil {
t.Fatalf("failed to reset proto for frame %d: %s", n, err)
}
// Make sure the protocol gets changed after recving the frame
assertEq(t, targetID, trans1.ProtocolID())
}

func TestHeaderHeaders(t *testing.T) {
n := 1
tmb := newMockSocket()
// write transport
trans1 := newHeaderTransport(tmb)
trans1 := newHeaderTransport(tmb, ProtocolIDCompact)
// read transport
trans2 := newHeaderTransport(tmb)
trans2 := newHeaderTransport(tmb, ProtocolIDCompact)

// make sure we don't barf reading header with no frame
_, ok := trans1.GetResponseHeaders()["something"]
Expand Down Expand Up @@ -227,7 +191,7 @@ func peerIdentity(t *headerTransport) string {
func TestHeaderRWSmall(t *testing.T) {
n := 1
tmb := newMockSocket()
trans := newHeaderTransport(tmb)
trans := newHeaderTransport(tmb, ProtocolIDCompact)
data := []byte("ASDFASDFASDF")

_, err := trans.Write(data)
Expand Down Expand Up @@ -290,7 +254,7 @@ func TestHeaderRWSmall(t *testing.T) {
func TestHeaderZlib(t *testing.T) {
n := 1
tmb := newMockSocket()
trans := newHeaderTransport(tmb)
trans := newHeaderTransport(tmb, ProtocolIDCompact)
data := []byte("ASDFASDFASDFASDFASDFASDFASDFASDFASDFASDFASDFASDFASDFASDFASDF")
uncompressedlen := 30

Expand Down Expand Up @@ -361,7 +325,7 @@ func testRWOnce(t *testing.T, n int, data []byte, trans *headerTransport) {

func TestHeaderTransportRWMultiple(t *testing.T) {
tmb := newMockSocket()
trans := newHeaderTransport(tmb)
trans := newHeaderTransport(tmb, ProtocolIDCompact)

// Test Junk Data
testRWOnce(t, 1, []byte("ASDF"), trans)
Expand All @@ -375,7 +339,7 @@ func BenchmarkHeaderFlush(b *testing.B) {

for n := 0; n < b.N; n++ {
tmb := newMockSocket()
trans1 := newHeaderTransport(tmb)
trans1 := newHeaderTransport(tmb, ProtocolIDCompact)

trans1.SetPersistentHeader(IDVersionHeader, IDVersion)
trans1.SetPersistentHeader(IdentityHeader, "localhost")
Expand Down
9 changes: 3 additions & 6 deletions thrift/lib/go/thrift/persistent_headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,24 @@ import (
)

func TestHeaderProtocolSomePersistentHeaders(t *testing.T) {
protocol, err := newHeaderProtocol(newMockSocket())
protocol, err := newHeaderProtocol(newMockSocket(), ProtocolIDCompact, 0, map[string]string{"key": "value"})
assert.NoError(t, err)
protocol.SetPersistentHeader("key", "value")
v, ok := protocol.GetPersistentHeader("key")
assert.True(t, ok)
assert.Equal(t, "value", v)
}

func TestRocketProtocolSomePersistentHeaders(t *testing.T) {
protocol, err := newRocketClient(newMockSocket())
protocol, err := newRocketClient(newMockSocket(), ProtocolIDCompact, 0, map[string]string{"key": "value"})
assert.NoError(t, err)
protocol.SetPersistentHeader("key", "value")
v, ok := protocol.GetPersistentHeader("key")
assert.True(t, ok)
assert.Equal(t, "value", v)
}

func TestUpgradeToRocketProtocolSomePersistentHeaders(t *testing.T) {
protocol, err := newUpgradeToRocketClient(newMockSocket())
protocol, err := newUpgradeToRocketClient(newMockSocket(), ProtocolIDCompact, 0, map[string]string{"key": "value"})
assert.NoError(t, err)
protocol.SetPersistentHeader("key", "value")
v, ok := protocol.GetPersistentHeader("key")
assert.True(t, ok)
assert.Equal(t, "value", v)
Expand Down
Loading

0 comments on commit 8422529

Please sign in to comment.