Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added enhanced format support using non-blocking socket. #25

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,27 @@ for (token_hex, fail_time) in apns.feedback_server.items():
# do stuff with token_hex and fail_time
```

## Send a notification in enhanced format
```python
from apns import APNs, Payload, APNResponseError
from datetime import datetime, timedelta

apns = APNs(use_sandbox=True, cert_file='cert.pem', key_file='key.pem', enhanced=True)

token_hex = 'b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b87'
payload = Payload(alert="Hello World!", sound="default", badge=1)
identifier = 1234
expiry = datetime.utcnow() + timedelta(30) # undelivered notification expires after 30 seconds

try:
apns.gateway_server.send_notification(token_hex, payload)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the identifier and the expiry variables be used here?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked through a similar pull request here:
#21

I suppose ageron is okay with the fact that we are using non-blocking socket here..

except APNResponseError, err:
# handle apn's error response
# just tried notification is not sent and this response doesn't belong to that notification.
# formerly sent notifications should to be looked up with err.identifier to find one which caused this error.
# when error response is received, connection to APN server is closed.
```

For more complicated alerts including custom buttons etc, use the PayloadAlert
class. Example:

Expand Down
17 changes: 17 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import apns
import apnserrors

APNs = apns.APNs
Payload = apns.Payload

PayloadTooLargeError = apnserrors.PayloadTooLargeError
APNResponseError = apnserrors.APNResponseError
ProcessingError = apnserrors.ProcessingError
MissingDeviceTokenError = apnserrors.MissingDeviceTokenError
MissingTopicError = apnserrors.MissingTopicError
MissingPayloadError = apnserrors.MissingPayloadError
InvalidTokenSizeError = apnserrors.InvalidTokenSizeError
InvalidTopicSizeError = apnserrors.InvalidTopicSizeError
InvalidPayloadSizeError = apnserrors.InvalidPayloadSizeError
InvalidTokenError = apnserrors.InvalidTokenError
UnknownError = apnserrors.UnknownError
148 changes: 133 additions & 15 deletions apns.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,37 @@

from binascii import a2b_hex, b2a_hex
from datetime import datetime
from socket import socket, AF_INET, SOCK_STREAM
from time import mktime
from socket import socket, AF_INET, SOCK_STREAM, timeout, error as socket_error
from struct import pack, unpack

import select
import errno

support_enhanced = True

try:
from ssl import wrap_socket
from ssl import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE
except ImportError:
from socket import ssl as wrap_socket
support_enhanced = False

try:
import json
except ImportError:
import simplejson as json

from apnserrors import *

MAX_PAYLOAD_LENGTH = 256
TIMEOUT = 60
ERROR_RESPONSE_LENGTH = 6

class APNs(object):
"""A class representing an Apple Push Notification service connection"""

def __init__(self, use_sandbox=False, cert_file=None, key_file=None):
def __init__(self, use_sandbox=False, cert_file=None, key_file=None, enhanced=False):
"""
Set use_sandbox to True to use the sandbox (test) APNs servers.
Default is False.
Expand All @@ -52,9 +64,17 @@ def __init__(self, use_sandbox=False, cert_file=None, key_file=None):
self.use_sandbox = use_sandbox
self.cert_file = cert_file
self.key_file = key_file
self.enhanced = enhanced and support_enhanced
self._feedback_connection = None
self._gateway_connection = None

@staticmethod
def unpacked_uchar_big_endian(byte):
"""
Returns an unsigned char from a packed big-endian (network) byte
"""
return unpack('>B', byte)[0]

@staticmethod
def packed_ushort_big_endian(num):
"""
Expand Down Expand Up @@ -100,7 +120,8 @@ def gateway_server(self):
self._gateway_connection = GatewayConnection(
use_sandbox = self.use_sandbox,
cert_file = self.cert_file,
key_file = self.key_file
key_file = self.key_file,
enhanced = self.enhanced
)
return self._gateway_connection

Expand All @@ -109,10 +130,11 @@ class APNsConnection(object):
"""
A generic connection class for communicating with the APNs
"""
def __init__(self, cert_file=None, key_file=None):
def __init__(self, cert_file=None, key_file=None, enhanced=False):
super(APNsConnection, self).__init__()
self.cert_file = cert_file
self.key_file = key_file
self.enhanced = enhanced
self._socket = None
self._ssl = None

Expand All @@ -123,23 +145,99 @@ def _connect(self):
# Establish an SSL connection
self._socket = socket(AF_INET, SOCK_STREAM)
self._socket.connect((self.server, self.port))
self._ssl = wrap_socket(self._socket, self.key_file, self.cert_file)

if self.enhanced:
self._ssl = wrap_socket(self._socket, self.key_file, self.cert_file,
do_handshake_on_connect=False)
self._ssl.setblocking(0)
while True:
try:
self._ssl.do_handshake()
break
except SSLError, err:
if SSL_ERROR_WANT_READ == err.args[0]:
select.select([self._ssl], [], [])
elif SSL_ERROR_WANT_WRITE == err.args[0]:
select.select([], [self._ssl], [])
else:
raise
else:
self._ssl = wrap_socket(self._socket, self.key_file, self.cert_file)

def _disconnect(self):
if self._socket:
self._socket.close()
self._ssl = None

def _connection(self):
if not self._ssl:
self._connect()
return self._ssl

def read(self, n=None):
return self._connection().read(n)
return self._connection().recv(n)

def recvall(self, n):
data = ""
while True:
more = self._connection().recv(n - len(data))
data += more
if len(data) >= n:
break
rlist, _, _ = select.select([self._connection()], [], [], TIMEOUT)
if not rlist:
raise timeout

return data

def write(self, string):
return self._connection().write(string)

if self.enhanced: # nonblocking socket
rlist, _, _ = select.select([self._connection()], [], [], 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I am trying to get enhanced message with error-response, too.
if select.select here with given timeout = 0s, it often get NO APNS's error-response in time,
when I set to 0.3s it got 90% of error-response, however, it's not guarantee 100% to get error-response for every failed notification, the network connection might also be a factor of delayed time.
Now I am trying to implement using another thread to constantly monitor the read descriptor and response to main thread for further handling. Refering to solution concept: http://redth.info/the-problem-with-apples-push-notification-ser/


if rlist: # there's error response from APNs
buff = self.recvall(ERROR_RESPONSE_LENGTH)
if len(buff) != ERROR_RESPONSE_LENGTH:
return None

command = APNs.unpacked_uchar_big_endian(buff[0])

if 8 != command:
self._disconnect()
raise UnknownError(0)

status = APNs.unpacked_uchar_big_endian(buff[1])
identifier = APNs.unpacked_uint_big_endian(buff[2:6])

self._disconnect()

raise { 1: ProcessingError,
2: MissingDeviceTokenError,
3: MissingTopicError,
4: MissingPayloadError,
5: InvalidTokenSizeError,
6: InvalidTopicSizeError,
7: InvalidPayloadSizeError,
8: InvalidTokenError }.get(status, UnknownError)(identifier)

_, wlist, _ = select.select([], [self._connection()], [], TIMEOUT)
if wlist:
return self._connection().sendall(string)
else:
self._disconnect()
raise timeout

else: # not-enhanced format using blocking socket
try:
return self._connection().write(string)
except socket_error, err:
try:
if errno.EPIPE == err.errno:
self._disconnect()
except AttributeError:
if errno.EPIPE == err.args[0]:
self._disconnect()
finally:
raise err

class PayloadAlert(object):
def __init__(self, body, action_loc_key=None, loc_key=None,
Expand All @@ -163,10 +261,6 @@ def dict(self):
d['launch-image'] = self.launch_image
return d

class PayloadTooLargeError(Exception):
def __init__(self):
super(PayloadTooLargeError, self).__init__()

class Payload(object):
"""A class representing an APNs message payload"""
def __init__(self, alert=None, badge=None, sound=None, custom={}):
Expand Down Expand Up @@ -285,10 +379,34 @@ def _get_notification(self, token_hex, payload):
payload_length_bin = APNs.packed_ushort_big_endian(len(payload_json))

notification = ('\0' + token_length_bin + token_bin
+ payload_length_bin + payload_json)
+ payload_length_bin + payload_json)

return notification

def send_notification(self, token_hex, payload):
self.write(self._get_notification(token_hex, payload))
def _get_enhanced_notification(self, token_hex, payload, identifier, expiry):
"""
Takes a token as a hex string and a payload as a Python dict and sends
the notification in the enhanced format
"""
token_bin = a2b_hex(token_hex)
token_length_bin = APNs.packed_ushort_big_endian(len(token_bin))
payload_json = payload.json()
payload_length_bin = APNs.packed_ushort_big_endian(len(payload_json))
identifier_bin = APNs.packed_uint_big_endian(identifier)

expiry_int = int(mktime(expiry.timetuple())) if isinstance(expiry, datetime) \
else int(expiry)

expiry_bin = APNs.packed_uint_big_endian(expiry_int)

notification = ('\1' + identifier_bin + expiry_bin + token_length_bin + token_bin
+ payload_length_bin + payload_json)

return notification

def send_notification(self, token_hex, payload, identifier=0, expiry=0):
if self.enhanced:
self.write(self._get_enhanced_notification(token_hex, payload, identifier,
expiry))
else:
self.write(self._get_notification(token_hex, payload))
50 changes: 50 additions & 0 deletions apnserrors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
class PayloadTooLargeError(Exception):
def __init__(self):
super(PayloadTooLargeError, self).__init__()

class APNResponseError(Exception):
def __init__(self, status, identifier):
self.status = status
self.identifier = identifier

def __repr__(self):
return "{}<identifier: {}>".format(self.__class__.__name__, self.identifier)

def __str__(self):
return self.__repr__()

class ProcessingError(APNResponseError):
def __init__(self, identifier):
super(ProcessingError, self).__init__(1, identifier)

class MissingDeviceTokenError(APNResponseError):
def __init__(self, identifier):
super(MissingDeviceTokenError, self).__init__(2, identifier)

class MissingTopicError(APNResponseError):
def __init__(self, identifier):
super(MissingTopicError, self).__init__(3, identifier)

class MissingPayloadError(APNResponseError):
def __init__(self, identifier):
super(MissingPayloadError, self).__init__(4, identifier)

class InvalidTokenSizeError(APNResponseError):
def __init__(self, identifier):
super(InvalidTokenSizeError, self).__init__(5, identifier)

class InvalidTopicSizeError(APNResponseError):
def __init__(self, identifier):
super(InvalidTopicSizeError, self).__init__(6, identifier)

class InvalidPayloadSizeError(APNResponseError):
def __init__(self, identifier):
super(InvalidPayloadSizeError, self).__init__(7, identifier)

class InvalidTokenError(APNResponseError):
def __init__(self, identifier):
super(InvalidTokenError, self).__init__(8, identifier)

class UnknownError(APNResponseError):
def __init__(self, identifier):
super(UnknownError, self).__init__(255, identifier)
32 changes: 32 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from apns import *
from binascii import a2b_hex
from random import random
from datetime import datetime, timedelta

import hashlib
import os
Expand Down Expand Up @@ -90,6 +91,37 @@ def testGatewayServer(self):
self.assertEqual(len(notification), expected_length)
self.assertEqual(notification[0], '\0')

def testEnhancedGatewayServer(self):
pem_file = TEST_CERTIFICATE
apns = APNs(use_sandbox=True, cert_file=pem_file, key_file=pem_file, enhanced=True)
gateway_server = apns.gateway_server

self.assertEqual(gateway_server.cert_file, apns.cert_file)
self.assertEqual(gateway_server.key_file, apns.key_file)

token_hex = 'b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c'
payload = Payload(
alert = "Hello World!",
sound = "default",
badge = 4
)
expiry = datetime.utcnow() + timedelta(30)
notification = gateway_server._get_enhanced_notification(token_hex, payload, 0,
expiry)

expected_length = (
1 # leading null byte
+ 4 # identifier as a packed int
+ 4 # expiry as a packed int
+ 2 # length of token as a packed short
+ len(token_hex) / 2 # length of token as binary string
+ 2 # length of payload as a packed short
+ len(payload.json()) # length of JSON-formatted payload
)

self.assertEqual(len(notification), expected_length)
self.assertEqual(notification[0], '\1')

def testFeedbackServer(self):
pem_file = TEST_CERTIFICATE
apns = APNs(use_sandbox=True, cert_file=pem_file, key_file=pem_file)
Expand Down