Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type checking: Union return type in AES.pyi #629

Closed
Akuli opened this issue Jun 8, 2022 · 4 comments
Closed

type checking: Union return type in AES.pyi #629

Akuli opened this issue Jun 8, 2022 · 4 comments

Comments

@Akuli
Copy link

Akuli commented Jun 8, 2022

This works at runtime, but fails to type check:

from Crypto.Cipher import AES
cipher = AES.new(b"k"*32, AES.MODE_GCM, nonce=b"n"*12)
print(cipher.encrypt_and_digest(b"hello world"))

The mypy error is:

asd.py:3: error: Item "EcbMode" of "Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode, OpenPgpMode, CcmMode, EaxMode, GcmMode, SivMode, OcbMode]" has no attribute "encrypt_and_digest"  [union-attr]
asd.py:3: error: Item "CbcMode" of "Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode, OpenPgpMode, CcmMode, EaxMode, GcmMode, SivMode, OcbMode]" has no attribute "encrypt_and_digest"  [union-attr]
asd.py:3: error: Item "CfbMode" of "Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode, OpenPgpMode, CcmMode, EaxMode, GcmMode, SivMode, OcbMode]" has no attribute "encrypt_and_digest"  [union-attr]
asd.py:3: error: Item "OfbMode" of "Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode, OpenPgpMode, CcmMode, EaxMode, GcmMode, SivMode, OcbMode]" has no attribute "encrypt_and_digest"  [union-attr]
asd.py:3: error: Item "CtrMode" of "Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode, OpenPgpMode, CcmMode, EaxMode, GcmMode, SivMode, OcbMode]" has no attribute "encrypt_and_digest"  [union-attr]
asd.py:3: error: Item "OpenPgpMode" of "Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode, OpenPgpMode, CcmMode, EaxMode, GcmMode, SivMode, OcbMode]" has no attribute "encrypt_and_digest"  [union-attr]
Found 6 errors in 1 file (checked 1 source file)

The stub that causes this problem is:

def new(key: Buffer,
mode: AESMode,
iv : Buffer = ...,
IV : Buffer = ...,
nonce : Buffer = ...,
segment_size : int = ...,
mac_len : int = ...,
assoc_len : int = ...,
initial_value : Union[int, Buffer] = ...,
counter : Dict = ...,
use_aesni : bool = ...) -> \
Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode,
OpenPgpMode, CcmMode, EaxMode, GcmMode,
SivMode, OcbMode]: ...

When returning a Union, you say "the code calling this function must be prepared for any of these". That's not how the new() function should work: it returns one specific kind of AES cipher, and I shouldn't need to worry about the other modes when using it.

To fix this, we could use overloads with Literal:

MODE_ECB: Literal[1]
MODE_CBC: Literal[2]
...

@overload
def new(key: Buffer, mode: Literal[1], ...other args...) -> EcbMode: ...
@overload
def new(key: Buffer, mode: Literal[2], ...other args...) -> CbcMode: ...
...

If you want, I can submit a pull request that does this. It will be a lot of copy/paste though. python/typing#566 would also help, but it's just a proposal and far from being implemented in practice.

There are a few possible workarounds. I can make the type checker ignore the other union members with Any:

from Crypto.Cipher import AES
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from typing import Any, Union
    from Crypto.Cipher import _mode_gcm
    Cipher = Union[_mode_gcm.GcmMode, Any]
else:
    Cipher = None

cipher: Cipher = AES.new(b"k"*32, AES.MODE_GCM, nonce=b"n"*12)
print(cipher.encrypt_and_digest(b"hello world"))

Or just use Any for the whole thing (but then I get no type checker errors for wrong code):

from Crypto.Cipher import AES
from typing import Any
cipher: Any = AES.new(b"k"*32, AES.MODE_GCM, nonce=b"n"*12)
print(cipher.encrypt_and_digest(b"hello world"))

Or just ignore the error:

from Crypto.Cipher import AES
cipher = AES.new(b"k"*32, AES.MODE_GCM, nonce=b"n"*12)
print(cipher.encrypt_and_digest(b"hello world"))  # type: ignore
@Legrandin
Copy link
Owner

Thanks for flagging this error. There is an existing PR (#533) to address the problem by means of Literal (as you proposed) but it still not fully convincing yet.

@Legrandin
Copy link
Owner

Fixed with b1794ca

@Legrandin Legrandin reopened this Dec 9, 2022
@Akuli
Copy link
Author

Akuli commented Dec 10, 2022

You shouldn't need to do if sys.version_info >= (3, 8):. In a stub file you can do from typing_extensions import Literal on any Python version.

@Legrandin
Copy link
Owner

You are right. I have just simplified it in that way, and it works also in Python 3.5 (even though it is limited to an older typing-extensions).

@Akuli Akuli closed this as completed Dec 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants