Skip to content

Commit

Permalink
Merge pull request #206 from underdogio/dev/expand.csrf.support.sqwished
Browse files Browse the repository at this point in the history
Added support for custom token keys and making CSRF values URL safe
  • Loading branch information
lepture committed Dec 3, 2015
2 parents ec26917 + 178c93d commit fc36ddd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
24 changes: 13 additions & 11 deletions flask_wtf/csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
__all__ = ('generate_csrf', 'validate_csrf', 'CsrfProtect')


def generate_csrf(secret_key=None, time_limit=None):
def generate_csrf(secret_key=None, time_limit=None, token_key='csrf_token', url_safe=False):
"""Generate csrf token code.
:param secret_key: A secret key for mixing in the token,
Expand All @@ -45,25 +45,26 @@ def generate_csrf(secret_key=None, time_limit=None):
if time_limit is None:
time_limit = current_app.config.get('WTF_CSRF_TIME_LIMIT', 3600)

if 'csrf_token' not in session:
session['csrf_token'] = hashlib.sha1(os.urandom(64)).hexdigest()
if token_key not in session:
session[token_key] = hashlib.sha1(os.urandom(64)).hexdigest()

if time_limit:
expires = int(time.time() + time_limit)
csrf_build = '%s%s' % (session['csrf_token'], expires)
csrf_build = '%s%s' % (session[token_key], expires)
else:
expires = ''
csrf_build = session['csrf_token']
csrf_build = session[token_key]

hmac_csrf = hmac.new(
to_bytes(secret_key),
to_bytes(csrf_build),
digestmod=hashlib.sha1
).hexdigest()
return '%s##%s' % (expires, hmac_csrf)
delimiter = '--' if url_safe else '##'
return '%s%s%s' % (expires, delimiter, hmac_csrf)


def validate_csrf(data, secret_key=None, time_limit=None):
def validate_csrf(data, secret_key=None, time_limit=None, token_key='csrf_token', url_safe=False):
"""Check if the given data is a valid csrf token.
:param data: The csrf token value to be checked.
Expand All @@ -72,11 +73,12 @@ def validate_csrf(data, secret_key=None, time_limit=None):
:param time_limit: Check if the csrf token is expired.
default is True.
"""
if not data or '##' not in data:
delimiter = '--' if url_safe else '##'
if not data or delimiter not in data:
return False

try:
expires, hmac_csrf = data.split('##', 1)
expires, hmac_csrf = data.split(delimiter, 1)
except ValueError:
return False # unpack error

Expand All @@ -98,10 +100,10 @@ def validate_csrf(data, secret_key=None, time_limit=None):
'WTF_CSRF_SECRET_KEY', current_app.secret_key
)

if 'csrf_token' not in session:
if token_key not in session:
return False

csrf_build = '%s%s' % (session['csrf_token'], expires)
csrf_build = '%s%s' % (session[token_key], expires)
hmac_compare = hmac.new(
to_bytes(secret_key),
to_bytes(csrf_build),
Expand Down
26 changes: 26 additions & 0 deletions tests/test_csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,29 @@ def withtoken():

response = self.client.get('/token')
assert b'#' in response.data

def test_csrf_custom_token_key(self):
with self.app.test_request_context():
# Generate a normal and a custom CSRF token
default_csrf_token = generate_csrf()
custom_csrf_token = generate_csrf(token_key='oauth_state')

# Verify they are different due to using different session keys
assert default_csrf_token != custom_csrf_token

# However, the custom key can validate as well
assert validate_csrf(custom_csrf_token, token_key='oauth_state')

def test_csrf_url_safe(self):
with self.app.test_request_context():
# Generate a normal and URL safe CSRF token
default_csrf_token = generate_csrf()
url_safe_csrf_token = generate_csrf(url_safe=True)

# Verify they are not the same and the URL one is truly URL safe
assert default_csrf_token != url_safe_csrf_token
assert '#' not in url_safe_csrf_token
assert re.match(r'^[a-f0-9]+--[a-f0-9]+$', url_safe_csrf_token)

# Verify we can validate our URL safe key
assert validate_csrf(url_safe_csrf_token, url_safe=True)

0 comments on commit fc36ddd

Please sign in to comment.