From 0c93e1ff9f91f0c63bf17b123de503d023434fdd Mon Sep 17 00:00:00 2001 From: Adam Langley Date: Sat, 29 Aug 2015 15:53:49 -0700 Subject: [PATCH] golang.org/x/crypto/openssh: don't loop forever after a bad password. SymmetricKeyEncrypted cached the results of decryption so, if a bad password was given, ReadMessage would prompt forever because a later, correct password wouldn't override the cached decryption. The SymmetricKeyEncrypted object can't know whether a given passphrase is correct so it should never have been a mutable object in the first place. This change makes it so that it doesn't cache anything. Fixes #9315 Change-Id: Ic2b75f7f60a575e2182ac7e5c5d4198597c5d0a2 Reviewed-on: https://go-review.googlesource.com/14038 Reviewed-by: Andrew Gerrand Reviewed-by: Adam Langley --- openpgp/packet/symmetric_key_encrypted.go | 65 ++++++++----------- .../packet/symmetric_key_encrypted_test.go | 19 +++--- openpgp/read.go | 6 +- openpgp/read_test.go | 8 ++- 4 files changed, 48 insertions(+), 50 deletions(-) diff --git a/openpgp/packet/symmetric_key_encrypted.go b/openpgp/packet/symmetric_key_encrypted.go index 7dff655221..1deebcdfae 100644 --- a/openpgp/packet/symmetric_key_encrypted.go +++ b/openpgp/packet/symmetric_key_encrypted.go @@ -22,20 +22,17 @@ const maxSessionKeySizeInBytes = 64 // 4880, section 5.3. type SymmetricKeyEncrypted struct { CipherFunc CipherFunction - Encrypted bool - Key []byte // Empty unless Encrypted is false. s2k func(out, in []byte) encryptedKey []byte } const symmetricKeyEncryptedVersion = 4 -func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) { +func (ske *SymmetricKeyEncrypted) parse(r io.Reader) error { // RFC 4880, section 5.3. var buf [2]byte - _, err = readFull(r, buf[:]) - if err != nil { - return + if _, err := readFull(r, buf[:]); err != nil { + return err } if buf[0] != symmetricKeyEncryptedVersion { return errors.UnsupportedError("SymmetricKeyEncrypted version") @@ -46,9 +43,10 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) { return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1]))) } + var err error ske.s2k, err = s2k.Parse(r) if err != nil { - return + return err } encryptedKey := make([]byte, maxSessionKeySizeInBytes) @@ -56,9 +54,9 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) { // out. If it exists then we limit it to maxSessionKeySizeInBytes. n, err := readFull(r, encryptedKey) if err != nil && err != io.ErrUnexpectedEOF { - return + return err } - err = nil + if n != 0 { if n == maxSessionKeySizeInBytes { return errors.UnsupportedError("oversized encrypted session key") @@ -66,42 +64,35 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) { ske.encryptedKey = encryptedKey[:n] } - ske.Encrypted = true - - return + return nil } -// Decrypt attempts to decrypt an encrypted session key. If it returns nil, -// ske.Key will contain the session key. -func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error { - if !ske.Encrypted { - return nil - } - +// Decrypt attempts to decrypt an encrypted session key and returns the key and +// the cipher to use when decrypting a subsequent Symmetrically Encrypted Data +// packet. +func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) ([]byte, CipherFunction, error) { key := make([]byte, ske.CipherFunc.KeySize()) ske.s2k(key, passphrase) if len(ske.encryptedKey) == 0 { - ske.Key = key - } else { - // the IV is all zeros - iv := make([]byte, ske.CipherFunc.blockSize()) - c := cipher.NewCFBDecrypter(ske.CipherFunc.new(key), iv) - c.XORKeyStream(ske.encryptedKey, ske.encryptedKey) - ske.CipherFunc = CipherFunction(ske.encryptedKey[0]) - if ske.CipherFunc.blockSize() == 0 { - return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc))) - } - ske.CipherFunc = CipherFunction(ske.encryptedKey[0]) - ske.Key = ske.encryptedKey[1:] - if len(ske.Key)%ske.CipherFunc.blockSize() != 0 { - ske.Key = nil - return errors.StructuralError("length of decrypted key not a multiple of block size") - } + return key, ske.CipherFunc, nil } - ske.Encrypted = false - return nil + // the IV is all zeros + iv := make([]byte, ske.CipherFunc.blockSize()) + c := cipher.NewCFBDecrypter(ske.CipherFunc.new(key), iv) + plaintextKey := make([]byte, len(ske.encryptedKey)) + c.XORKeyStream(plaintextKey, ske.encryptedKey) + cipherFunc := CipherFunction(plaintextKey[0]) + if cipherFunc.blockSize() == 0 { + return nil, ske.CipherFunc, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc))) + } + plaintextKey = plaintextKey[1:] + if l := len(plaintextKey); l == 0 || l%cipherFunc.blockSize() != 0 { + return nil, cipherFunc, errors.StructuralError("length of decrypted key not a multiple of block size") + } + + return plaintextKey, cipherFunc, nil } // SerializeSymmetricKeyEncrypted serializes a symmetric key packet to w. The diff --git a/openpgp/packet/symmetric_key_encrypted_test.go b/openpgp/packet/symmetric_key_encrypted_test.go index dd983cb387..19538df77c 100644 --- a/openpgp/packet/symmetric_key_encrypted_test.go +++ b/openpgp/packet/symmetric_key_encrypted_test.go @@ -24,7 +24,7 @@ func TestSymmetricKeyEncrypted(t *testing.T) { t.Error("didn't find SymmetricKeyEncrypted packet") return } - err = ske.Decrypt([]byte("password")) + key, cipherFunc, err := ske.Decrypt([]byte("password")) if err != nil { t.Error(err) return @@ -40,7 +40,7 @@ func TestSymmetricKeyEncrypted(t *testing.T) { t.Error("didn't find SymmetricallyEncrypted packet") return } - r, err := se.Decrypt(ske.CipherFunc, ske.Key) + r, err := se.Decrypt(cipherFunc, key) if err != nil { t.Error(err) return @@ -64,8 +64,9 @@ const symmetricallyEncryptedContentsHex = "cb1062004d14c4df636f6e74656e74732e0a" func TestSerializeSymmetricKeyEncrypted(t *testing.T) { buf := bytes.NewBuffer(nil) passphrase := []byte("testing") + const cipherFunc = CipherAES128 config := &Config{ - DefaultCipher: CipherAES128, + DefaultCipher: cipherFunc, } key, err := SerializeSymmetricKeyEncrypted(buf, passphrase, config) @@ -85,18 +86,18 @@ func TestSerializeSymmetricKeyEncrypted(t *testing.T) { return } - if !ske.Encrypted { - t.Errorf("SKE not encrypted but should be") - } if ske.CipherFunc != config.DefaultCipher { t.Errorf("SKE cipher function is %d (expected %d)", ske.CipherFunc, config.DefaultCipher) } - err = ske.Decrypt(passphrase) + parsedKey, parsedCipherFunc, err := ske.Decrypt(passphrase) if err != nil { t.Errorf("failed to decrypt reparsed SKE: %s", err) return } - if !bytes.Equal(key, ske.Key) { - t.Errorf("keys don't match after Decrpyt: %x (original) vs %x (parsed)", key, ske.Key) + if !bytes.Equal(key, parsedKey) { + t.Errorf("keys don't match after Decrypt: %x (original) vs %x (parsed)", key, parsedKey) + } + if parsedCipherFunc != cipherFunc { + t.Errorf("cipher function doesn't match after Decrypt: %d (original) vs %d (parsed)", cipherFunc, parsedCipherFunc) } } diff --git a/openpgp/read.go b/openpgp/read.go index a6cecc529b..dfffc398d5 100644 --- a/openpgp/read.go +++ b/openpgp/read.go @@ -196,9 +196,9 @@ FindKey: // Try the symmetric passphrase first if len(symKeys) != 0 && passphrase != nil { for _, s := range symKeys { - err = s.Decrypt(passphrase) - if err == nil && !s.Encrypted { - decrypted, err = se.Decrypt(s.CipherFunc, s.Key) + key, cipherFunc, err := s.Decrypt(passphrase) + if err == nil { + decrypted, err = se.Decrypt(cipherFunc, key) if err != nil && err != errors.ErrKeyIncorrect { return nil, err } diff --git a/openpgp/read_test.go b/openpgp/read_test.go index 52f942c71c..7524a02e56 100644 --- a/openpgp/read_test.go +++ b/openpgp/read_test.go @@ -243,7 +243,7 @@ func TestUnspecifiedRecipient(t *testing.T) { } func TestSymmetricallyEncrypted(t *testing.T) { - expected := "Symmetrically encrypted.\n" + firstTimeCalled := true prompt := func(keys []Key, symmetric bool) ([]byte, error) { if len(keys) != 0 { @@ -254,6 +254,11 @@ func TestSymmetricallyEncrypted(t *testing.T) { t.Errorf("symmetric is not set") } + if firstTimeCalled { + firstTimeCalled = false + return []byte("wrongpassword"), nil + } + return []byte("password"), nil } @@ -273,6 +278,7 @@ func TestSymmetricallyEncrypted(t *testing.T) { t.Errorf("LiteralData.Time is %d, want %d", md.LiteralData.Time, expectedCreationTime) } + const expected = "Symmetrically encrypted.\n" if string(contents) != expected { t.Errorf("contents got: %s want: %s", string(contents), expected) }