diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index e44ac33252..281935d2a0 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -68,12 +68,9 @@ type transport struct { rcmgr network.ResourceManager gater connmgr.ConnectionGater - listenOnce sync.Once - listenOnceErr error - certManager *certManager - certManagerReady chan struct{} // Closed when the certManager has been instantiated. - staticTLSConf *tls.Config - tlsClientConf *tls.Config + certManager *certManager + staticTLSConf *tls.Config + tlsClientConf *tls.Config noise *noise.Transport @@ -97,21 +94,26 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater if err != nil { return nil, err } + t := &transport{ - pid: id, - privKey: key, - rcmgr: rcmgr, - gater: gater, - clock: clock.New(), - connManager: connManager, - conns: map[uint64]*conn{}, - certManagerReady: make(chan struct{}), + pid: id, + privKey: key, + rcmgr: rcmgr, + gater: gater, + clock: clock.New(), + connManager: connManager, + conns: map[uint64]*conn{}, } for _, opt := range opts { if err := opt(t); err != nil { return nil, err } } + cm, err := newCertManager(key, t.clock) + if err != nil { + return nil, err + } + t.certManager = cm n, err := noise.New(noise.ID, key, nil) if err != nil { return nil, err @@ -297,16 +299,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { if !isWebTransport { return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr) } - if t.staticTLSConf == nil { - t.listenOnce.Do(func() { - t.certManager, t.listenOnceErr = newCertManager(t.privKey, t.clock) - close(t.certManagerReady) - }) - if t.listenOnceErr != nil { - return nil, t.listenOnceErr - } - } else { - close(t.certManagerReady) + if t.staticTLSConf != nil { return nil, errors.New("static TLS config not supported on WebTransport") } tlsConf := t.staticTLSConf.Clone() @@ -333,11 +326,7 @@ func (t *transport) Proxy() bool { } func (t *transport) Close() error { - t.listenOnce.Do(func() {}) - if t.certManager != nil { - return t.certManager.Close() - } - return nil + return t.certManager.Close() } func (t *transport) allowWindowIncrease(conn quic.Connection, size uint64) bool { @@ -406,9 +395,5 @@ func (t *transport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]ma.Multiad } func (t *transport) AddCertHashes(m ma.Multiaddr) ma.Multiaddr { - <-t.certManagerReady - if t.certManager == nil { - return m - } return m.Encapsulate(t.certManager.AddrComponent()) }