Skip to content

Commit

Permalink
quic: provide source conn ID when creating server conns
Browse files Browse the repository at this point in the history
New server-side conns need to know a variety of connection IDs,
such as the Initial DCID used to create Initial encryption keys.
We've been providing these as an ever-growing list of []byte
parameters to newConn. Bundle them all up into a struct.

Add the client's SCID to the set of IDs we pass to newConn.
Up until now, we've been setting this when processing the
first Initial packet from the client. Passing it to newConn
will makes it available when logging the connection_started event.

Update some test infrastructure to deal with the fact that
we need to know the peer's SCID earlier in the test now.

Change-Id: I760ee94af36125acf21c5bf135f1168830ba1ab8
Reviewed-on: https://go-review.googlesource.com/c/net/+/539341
Reviewed-by: Jonathan Amsterdam <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
  • Loading branch information
neild committed Nov 6, 2023
1 parent 5791239 commit ec29a94
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 59 deletions.
22 changes: 15 additions & 7 deletions internal/quic/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,15 @@ type connTestHooks interface {
timeNow() time.Time
}

func newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []byte, peerAddr netip.AddrPort, config *Config, l *Listener) (*Conn, error) {
// newServerConnIDs is connection IDs associated with a new server connection.
type newServerConnIDs struct {
srcConnID []byte // source from client's current Initial
dstConnID []byte // destination from client's current Initial
originalDstConnID []byte // destination from client's first Initial
retrySrcConnID []byte // source from server's Retry
}

func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, l *Listener) (*Conn, error) {
c := &Conn{
side: side,
listener: l,
Expand Down Expand Up @@ -115,11 +123,11 @@ func newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []b
}
initialConnID, _ = c.connIDState.dstConnID()
} else {
initialConnID = originalDstConnID
if retrySrcConnID != nil {
initialConnID = retrySrcConnID
initialConnID = cids.originalDstConnID
if cids.retrySrcConnID != nil {
initialConnID = cids.retrySrcConnID
}
if err := c.connIDState.initServer(c, initialConnID); err != nil {
if err := c.connIDState.initServer(c, cids); err != nil {
return nil, err
}
}
Expand All @@ -134,8 +142,8 @@ func newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []b

if err := c.startTLS(now, initialConnID, transportParameters{
initialSrcConnID: c.connIDState.srcConnID(),
originalDstConnID: originalDstConnID,
retrySrcConnID: retrySrcConnID,
originalDstConnID: cids.originalDstConnID,
retrySrcConnID: cids.retrySrcConnID,
ackDelayExponent: ackDelayExponent,
maxUDPPayloadSize: maxUDPPayloadSize,
maxAckDelay: maxAckDelay,
Expand Down
12 changes: 10 additions & 2 deletions internal/quic/conn_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ func (s *connIDState) initClient(c *Conn) error {
return nil
}

func (s *connIDState) initServer(c *Conn, dstConnID []byte) error {
dstConnID = cloneBytes(dstConnID)
func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error {
dstConnID := cloneBytes(cids.dstConnID)
// Client-chosen, transient connection ID received in the first Initial packet.
// The server will not use this as the Source Connection ID of packets it sends,
// but remembers it because it may receive packets sent to this destination.
Expand All @@ -121,6 +121,14 @@ func (s *connIDState) initServer(c *Conn, dstConnID []byte) error {
conns.addConnID(c, dstConnID)
conns.addConnID(c, locid)
})

// Client chose its own connection ID.
s.remote = append(s.remote, remoteConnID{
connID: connID{
seq: 0,
cid: cloneBytes(cids.srcConnID),
},
})
return nil
}

Expand Down
5 changes: 4 additions & 1 deletion internal/quic/conn_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,11 @@ func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) {
p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0")
p.preferredAddrConnID = testPeerConnID(1)
p.preferredAddrResetToken = make([]byte, 16)
}, func(cids *newServerConnIDs) {
cids.srcConnID = []byte{}
}, func(tc *testConn) {
tc.peerConnID = []byte{}
})
tc.peerConnID = []byte{}

tc.writeFrames(packetTypeInitial,
debugFrameCrypto{
Expand Down
24 changes: 16 additions & 8 deletions internal/quic/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,33 +193,38 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
TLSConfig: newTestTLSConfig(side),
StatelessResetKey: testStatelessResetKey,
}
var cids newServerConnIDs
if side == serverSide {
// The initial connection ID for the server is chosen by the client.
cids.srcConnID = testPeerConnID(0)
cids.dstConnID = testPeerConnID(-1)
}
var configTransportParams []func(*transportParameters)
var configTestConn []func(*testConn)
for _, o := range opts {
switch o := o.(type) {
case func(*Config):
o(config)
case func(*tls.Config):
o(config.TLSConfig)
case func(cids *newServerConnIDs):
o(&cids)
case func(p *transportParameters):
configTransportParams = append(configTransportParams, o)
case func(p *testConn):
configTestConn = append(configTestConn, o)
default:
t.Fatalf("unknown newTestConn option %T", o)
}
}

var initialConnID []byte
if side == serverSide {
// The initial connection ID for the server is chosen by the client.
initialConnID = testPeerConnID(-1)
}

listener := newTestListener(t, config)
listener.configTransportParams = configTransportParams
listener.configTestConn = configTestConn
conn, err := listener.l.newConn(
listener.now,
side,
initialConnID,
nil,
cids,
netip.MustParseAddrPort("127.0.0.1:443"))
if err != nil {
t.Fatal(err)
Expand All @@ -244,6 +249,9 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC
recvDatagram: make(chan *datagram),
}
t.Cleanup(tc.cleanup)
for _, f := range listener.configTestConn {
f(tc)
}
conn.testHooks = (*testConnHooks)(tc)

if listener.peerTLSConn != nil {
Expand Down
19 changes: 11 additions & 8 deletions internal/quic/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, er
}
addr := u.AddrPort()
addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
c, err := l.newConn(time.Now(), clientSide, nil, nil, addr)
c, err := l.newConn(time.Now(), clientSide, newServerConnIDs{}, addr)
if err != nil {
return nil, err
}
Expand All @@ -151,13 +151,13 @@ func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, er
return c, nil
}

func (l *Listener) newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []byte, peerAddr netip.AddrPort) (*Conn, error) {
func (l *Listener) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) {
l.connsMu.Lock()
defer l.connsMu.Unlock()
if l.closing {
return nil, errors.New("listener closed")
}
c, err := newConn(now, side, originalDstConnID, retrySrcConnID, peerAddr, l.config, l)
c, err := newConn(now, side, cids, peerAddr, l.config, l)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -296,19 +296,22 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) {
} else {
now = time.Now()
}
var originalDstConnID, retrySrcConnID []byte
cids := newServerConnIDs{
srcConnID: p.srcConnID,
dstConnID: p.dstConnID,
}
if l.config.RequireAddressValidation {
var ok bool
retrySrcConnID = p.dstConnID
originalDstConnID, ok = l.validateInitialAddress(now, p, m.addr)
cids.retrySrcConnID = p.dstConnID
cids.originalDstConnID, ok = l.validateInitialAddress(now, p, m.addr)
if !ok {
return
}
} else {
originalDstConnID = p.dstConnID
cids.originalDstConnID = p.dstConnID
}
var err error
c, err := l.newConn(now, serverSide, originalDstConnID, retrySrcConnID, m.addr)
c, err := l.newConn(now, serverSide, cids, m.addr)
if err != nil {
// The accept queue is probably full.
// We could send a CONNECTION_CLOSE to the peer to reject the connection.
Expand Down
42 changes: 9 additions & 33 deletions internal/quic/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ import (
)

func TestConnect(t *testing.T) {
newLocalConnPair(t, &Config{}, &Config{})
NewLocalConnPair(t, &Config{}, &Config{})
}

func TestStreamTransfer(t *testing.T) {
ctx := context.Background()
cli, srv := newLocalConnPair(t, &Config{}, &Config{})
cli, srv := NewLocalConnPair(t, &Config{}, &Config{})
data := makeTestData(1 << 20)

srvdone := make(chan struct{})
Expand Down Expand Up @@ -61,11 +61,11 @@ func TestStreamTransfer(t *testing.T) {
}
}

func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
func NewLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
t.Helper()
ctx := context.Background()
l1 := newLocalListener(t, serverSide, conf1)
l2 := newLocalListener(t, clientSide, conf2)
l1 := NewLocalListener(t, serverSide, conf1)
l2 := NewLocalListener(t, clientSide, conf2)
c2, err := l2.Dial(ctx, "udp", l1.LocalAddr().String())
if err != nil {
t.Fatal(err)
Expand All @@ -77,9 +77,11 @@ func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverCon
return c2, c1
}

func newLocalListener(t *testing.T, side connSide, conf *Config) *Listener {
func NewLocalListener(t *testing.T, side connSide, conf *Config) *Listener {
t.Helper()
if conf.TLSConfig == nil {
newConf := *conf
conf = &newConf
conf.TLSConfig = newTestTLSConfig(side)
}
l, err := Listen("udp", "127.0.0.1:0", conf)
Expand All @@ -101,6 +103,7 @@ type testListener struct {
conns map[*Conn]*testConn
acceptQueue []*testConn
configTransportParams []func(*transportParameters)
configTestConn []func(*testConn)
sentDatagrams [][]byte
peerTLSConn *tls.QUICConn
lastInitialDstConnID []byte // for parsing Retry packets
Expand Down Expand Up @@ -251,33 +254,6 @@ func (tl *testListener) wantIdle(expectation string) {
}
}

func (tl *testListener) newClientTLS(srcConnID, dstConnID []byte) []byte {
peerProvidedParams := defaultTransportParameters()
peerProvidedParams.initialSrcConnID = srcConnID
peerProvidedParams.originalDstConnID = dstConnID
for _, f := range tl.configTransportParams {
f(&peerProvidedParams)
}

config := &tls.QUICConfig{TLSConfig: newTestTLSConfig(clientSide)}
tl.peerTLSConn = tls.QUICClient(config)
tl.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
tl.peerTLSConn.Start(context.Background())
var data []byte
for {
e := tl.peerTLSConn.NextEvent()
switch e.Kind {
case tls.QUICNoEvent:
return data
case tls.QUICWriteData:
if e.Level != tls.QUICEncryptionLevelInitial {
tl.t.Fatal("initial data at unexpected level")
}
data = append(data, e.Data...)
}
}
}

// advance causes time to pass.
func (tl *testListener) advance(d time.Duration) {
tl.t.Helper()
Expand Down

0 comments on commit ec29a94

Please sign in to comment.