Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix decoding of packet numbers in different packet number spaces #2903

Merged
merged 1 commit into from
Dec 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion internal/handshake/aead.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
)

func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD {
Expand Down Expand Up @@ -50,6 +51,7 @@ func (s *longHeaderSealer) Overhead() int {
type longHeaderOpener struct {
aead cipher.AEAD
headerProtector headerProtector
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)

// use a single slice to avoid allocations
nonceBuf []byte
Expand All @@ -65,12 +67,18 @@ func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) Long
}
}

func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN)
}

func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
if err != nil {
if err == nil {
o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn)
} else {
err = ErrDecryptionFailed
}
return dec, err
Expand Down
16 changes: 16 additions & 0 deletions internal/handshake/aead_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ var _ = Describe("Long Header AEAD", func() {
_, err := opener.Open(nil, encrypted, 0x42, ad)
Expect(err).To(MatchError(ErrDecryptionFailed))
})

It("decodes the packet number", func() {
sealer, opener := getSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted, 0x1337, ad)
Expect(err).ToNot(HaveOccurred())
Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338))
})

It("ignores packets it can't decrypt for packet number derivation", func() {
sealer, opener := getSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted[:len(encrypted)-1], 0x1337, ad)
Expect(err).To(HaveOccurred())
Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38))
})
})

Context("header encryption", func() {
Expand Down
2 changes: 2 additions & 0 deletions internal/handshake/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ type headerDecryptor interface {
// LongHeaderOpener opens a long header packet
type LongHeaderOpener interface {
headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
}

// ShortHeaderOpener opens a short header packet
type ShortHeaderOpener interface {
headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
}

Expand Down
8 changes: 8 additions & 0 deletions internal/handshake/updatable_aead.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type updatableAEAD struct {

firstRcvdWithCurrentKey protocol.PacketNumber
firstSentWithCurrentKey protocol.PacketNumber
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
numRcvdWithCurrentKey uint64
numSentWithCurrentKey uint64
rcvAEAD cipher.AEAD
Expand Down Expand Up @@ -153,6 +154,10 @@ func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSu
}
}

func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
}

func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
if err == ErrDecryptionFailed {
Expand All @@ -161,6 +166,9 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac
return nil, qerr.AEADLimitReached
}
}
if err == nil {
a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn)
}
return dec, err
}

Expand Down
14 changes: 14 additions & 0 deletions internal/handshake/updatable_aead_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,20 @@ var _ = Describe("Updatable AEAD", func() {
Expect(err).To(MatchError(ErrDecryptionFailed))
})

It("decodes the packet number", func() {
encrypted := server.Seal(nil, msg, 0x1337, ad)
_, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338))
})

It("ignores packets it can't decrypt for packet number derivation", func() {
encrypted := server.Seal(nil, msg, 0x1337, ad)
_, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).To(HaveOccurred())
Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38))
})

It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() {
client.invalidPacketLimit = 10
for i := 0; i < 9; i++ {
Expand Down
14 changes: 14 additions & 0 deletions internal/mocks/long_header_opener.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions internal/mocks/short_header_opener.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 2 additions & 11 deletions packet_unpacker.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)

Expand Down Expand Up @@ -44,8 +43,6 @@ type unpackedPacket struct {
type packetUnpacker struct {
cs handshake.CryptoSetup

largestRcvdPacketNumber protocol.PacketNumber

version protocol.VersionNumber
}

Expand Down Expand Up @@ -112,9 +109,6 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte
}
}

// Only do this after decrypting, so we are sure the packet is not attacker-controlled
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, extHdr.PacketNumber)

return &unpackedPacket{
hdr: extHdr,
packetNumber: extHdr.PacketNumber,
Expand All @@ -132,6 +126,7 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene
return nil, nil, parseErr
}
extHdrLen := extHdr.ParsedLen()
extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
if err != nil {
return nil, nil, err
Expand All @@ -155,6 +150,7 @@ func (u *packetUnpacker) unpackShortHeaderPacket(
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, nil, parseErr
}
extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
extHdrLen := extHdr.ParsedLen()
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen])
if err != nil {
Expand All @@ -172,11 +168,6 @@ func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data
if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err}
}
extHdr.PacketNumber = protocol.DecodePacketNumber(
extHdr.PacketNumberLen,
u.largestRcvdPacketNumber,
extHdr.PacketNumber,
)
return extHdr, err
}

Expand Down
91 changes: 46 additions & 45 deletions packet_unpacker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ import (
"errors"
"time"

"github.com/lucas-clemente/quic-go/internal/qerr"

"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/wire"

"github.com/golang/mock/gomock"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
Expand Down Expand Up @@ -97,9 +97,12 @@ var _ = Describe("Packet Unpacker", func() {
}
hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetInitialOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil)
gomock.InOrder(
cs.EXPECT().GetInitialOpener().Return(opener, nil),
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
Expand All @@ -115,20 +118,45 @@ var _ = Describe("Packet Unpacker", func() {
DestConnectionID: connID,
Version: version,
},
PacketNumber: 2,
PacketNumberLen: 3,
PacketNumber: 20,
PacketNumberLen: 2,
}
hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().Get0RTTOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil)
gomock.InOrder(
cs.EXPECT().Get0RTTOpener().Return(opener, nil),
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT))
Expect(packet.data).To(Equal([]byte("decrypted")))
})

It("opens short header packets", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
KeyPhase: protocol.KeyPhaseOne,
PacketNumber: 99,
PacketNumberLen: protocol.PacketNumberLen4,
}
hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
now := time.Now()
gomock.InOrder(
cs.EXPECT().Get1RTTOpener().Return(opener, nil),
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.Unpack(hdr, now, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption1RTT))
Expect(packet.data).To(Equal([]byte("decrypted")))
})

It("returns the error when getting the sealer fails", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
Expand Down Expand Up @@ -157,6 +185,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, qerr.CryptoBufferExceeded)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(qerr.CryptoBufferExceeded))
Expand All @@ -178,6 +207,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
Expand All @@ -194,6 +224,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
Expand All @@ -210,6 +241,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
Expand Down Expand Up @@ -247,46 +279,15 @@ var _ = Describe("Packet Unpacker", func() {
pnBytes[i] ^= 0xff // invert the packet number bytes
}
}),
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1337), origHdrRaw).Return([]byte{0}, nil),
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2).Return(protocol.PacketNumber(0x7331)),
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x7331), origHdrRaw).Return([]byte{0}, nil),
)
data := hdrRaw
for i := 1; i <= 100; i++ {
data = append(data, uint8(i))
}
packet, err := unpacker.Unpack(hdr, time.Now(), data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
})

It("decodes the packet number", func() {
rcvTime := time.Now().Add(-time.Hour)
firstHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337,
PacketNumberLen: 2,
KeyPhase: protocol.KeyPhaseOne,
}
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw := getHeader(firstHdr)
packet, err := unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
// the real packet number is 0x1338, but only the last byte is sent
secondHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x38,
PacketNumberLen: 1,
KeyPhase: protocol.KeyPhaseZero,
}
// expect the call with the decoded packet number
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw = getHeader(secondHdr)
packet, err = unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338)))
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x7331)))
})
})