From 8b19040e456c6d1681f125eabdb1fd660f635904 Mon Sep 17 00:00:00 2001 From: Jay Lee Date: Wed, 30 Sep 2015 09:07:28 -0400 Subject: [PATCH] update googleapiclient, httplib2, oauth2client and passlib to latest versions --- googleapiclient/__init__.py | 2 +- googleapiclient/discovery.py | 72 +- googleapiclient/discovery_cache/__init__.py | 42 + .../discovery_cache/appengine_memcache.py | 52 + googleapiclient/discovery_cache/base.py | 45 + googleapiclient/discovery_cache/file_cache.py | 132 + googleapiclient/http.py | 5 +- httplib2/__init__.py | 11 +- oauth2client/__init__.py | 17 +- oauth2client/_helpers.py | 103 + oauth2client/_openssl_crypt.py | 139 + oauth2client/_pycrypto_crypt.py | 129 + oauth2client/appengine.py | 1512 ++++---- oauth2client/client.py | 3322 +++++++++-------- oauth2client/clientsecrets.py | 162 +- oauth2client/crypt.py | 536 +-- oauth2client/devshell.py | 159 +- oauth2client/django_orm.py | 177 +- oauth2client/file.py | 146 +- oauth2client/flask_util.py | 548 +++ oauth2client/gce.py | 121 +- oauth2client/keyring_storage.py | 130 +- oauth2client/locked_file.py | 565 +-- oauth2client/multistore_file.py | 666 ++-- oauth2client/service_account.py | 170 +- oauth2client/tools.py | 335 +- oauth2client/util.py | 265 +- oauth2client/xsrfutil.py | 144 +- passlib/__init__.py | 4 +- passlib/_setup/docdist.py | 2 +- passlib/_setup/stamp.py | 4 +- passlib/apache.py | 94 +- passlib/apps.py | 12 +- passlib/context.py | 104 +- passlib/exc.py | 29 +- passlib/ext/django/models.py | 33 +- passlib/ext/django/utils.py | 26 +- passlib/handlers/bcrypt.py | 294 +- passlib/handlers/cisco.py | 6 +- passlib/handlers/des_crypt.py | 6 +- passlib/handlers/digests.py | 8 +- passlib/handlers/django.py | 4 +- passlib/handlers/fshp.py | 4 +- passlib/handlers/ldap_digests.py | 6 +- passlib/handlers/md5_crypt.py | 2 +- passlib/handlers/misc.py | 2 +- passlib/handlers/mssql.py | 4 +- passlib/handlers/oracle.py | 2 +- passlib/handlers/pbkdf2.py | 10 +- passlib/handlers/phpass.py | 6 +- passlib/handlers/scram.py | 12 +- passlib/handlers/sha1_crypt.py | 4 +- passlib/handlers/sha2_crypt.py | 10 +- passlib/handlers/sun_md5_crypt.py | 6 +- passlib/handlers/windows.py | 6 +- passlib/hosts.py | 2 +- passlib/ifc.py | 12 +- passlib/registry.py | 10 +- passlib/tests/_test_bad_register.py | 2 +- passlib/tests/test_apache.py | 100 +- passlib/tests/test_apps.py | 2 +- passlib/tests/test_context.py | 81 +- passlib/tests/test_context_deprecated.py | 52 +- passlib/tests/test_ext_django.py | 119 +- passlib/tests/test_handlers.py | 49 +- passlib/tests/test_handlers_bcrypt.py | 66 +- passlib/tests/test_handlers_django.py | 22 +- passlib/tests/test_hosts.py | 2 +- passlib/tests/test_registry.py | 42 +- passlib/tests/test_utils.py | 70 +- passlib/tests/test_utils_crypto.py | 40 +- passlib/tests/test_utils_handlers.py | 57 +- passlib/tests/test_win32.py | 2 +- passlib/tests/tox_support.py | 6 +- passlib/tests/utils.py | 222 +- passlib/utils/__init__.py | 71 +- passlib/utils/_blowfish/__init__.py | 26 +- passlib/utils/_blowfish/_gen_files.py | 2 +- passlib/utils/_blowfish/base.py | 10 +- passlib/utils/compat.py | 10 +- passlib/utils/des.py | 12 +- passlib/utils/handlers.py | 41 +- passlib/utils/md4.py | 6 +- passlib/utils/pbkdf2.py | 6 +- passlib/win32.py | 4 +- 85 files changed, 6642 insertions(+), 4911 deletions(-) create mode 100644 googleapiclient/discovery_cache/__init__.py create mode 100644 googleapiclient/discovery_cache/appengine_memcache.py create mode 100644 googleapiclient/discovery_cache/base.py create mode 100644 googleapiclient/discovery_cache/file_cache.py create mode 100644 oauth2client/_helpers.py create mode 100644 oauth2client/_openssl_crypt.py create mode 100644 oauth2client/_pycrypto_crypt.py create mode 100644 oauth2client/flask_util.py diff --git a/googleapiclient/__init__.py b/googleapiclient/__init__.py index c000e971..ceeae8d1 100644 --- a/googleapiclient/__init__.py +++ b/googleapiclient/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.4.1" +__version__ = "1.4.2" diff --git a/googleapiclient/discovery.py b/googleapiclient/discovery.py index 4109865f..be62cf73 100644 --- a/googleapiclient/discovery.py +++ b/googleapiclient/discovery.py @@ -29,6 +29,7 @@ __all__ = [ ] from six import StringIO +from six.moves import http_client from six.moves.urllib.parse import urlencode, urlparse, urljoin, \ urlunparse, parse_qsl @@ -149,7 +150,9 @@ def build(serviceName, developerKey=None, model=None, requestBuilder=HttpRequest, - credentials=None): + credentials=None, + cache_discovery=True, + cache=None): """Construct a Resource for interacting with an API. Construct a Resource object for interacting with an API. The serviceName and @@ -171,6 +174,9 @@ def build(serviceName, request. credentials: oauth2client.Credentials, credentials to be used for authentication. + cache_discovery: Boolean, whether or not to cache the discovery doc. + cache: googleapiclient.discovery_cache.base.CacheBase, an optional + cache object for the discovery documents. Returns: A Resource object with methods for interacting with the service. @@ -185,22 +191,58 @@ def build(serviceName, requested_url = uritemplate.expand(discoveryServiceUrl, params) + try: + content = _retrieve_discovery_doc(requested_url, http, cache_discovery, + cache) + except HttpError as e: + if e.resp.status == http_client.NOT_FOUND: + raise UnknownApiNameOrVersion("name: %s version: %s" % (serviceName, + version)) + else: + raise e + + return build_from_document(content, base=discoveryServiceUrl, http=http, + developerKey=developerKey, model=model, requestBuilder=requestBuilder, + credentials=credentials) + + +def _retrieve_discovery_doc(url, http, cache_discovery, cache=None): + """Retrieves the discovery_doc from cache or the internet. + + Args: + url: string, the URL of the discovery document. + http: httplib2.Http, An instance of httplib2.Http or something that acts + like it through which HTTP requests will be made. + cache_discovery: Boolean, whether or not to cache the discovery doc. + cache: googleapiclient.discovery_cache.base.Cache, an optional cache + object for the discovery documents. + + Returns: + A unicode string representation of the discovery document. + """ + if cache_discovery: + from . import discovery_cache + from .discovery_cache import base + if cache is None: + cache = discovery_cache.autodetect() + if cache: + content = cache.get(url) + if content: + return content + + actual_url = url # REMOTE_ADDR is defined by the CGI spec [RFC3875] as the environment # variable that contains the network address of the client sending the # request. If it exists then add that to the request for the discovery # document to avoid exceeding the quota on discovery requests. if 'REMOTE_ADDR' in os.environ: - requested_url = _add_query_parameter(requested_url, 'userIp', - os.environ['REMOTE_ADDR']) - logger.info('URL being requested: GET %s' % requested_url) + actual_url = _add_query_parameter(url, 'userIp', os.environ['REMOTE_ADDR']) + logger.info('URL being requested: GET %s', actual_url) - resp, content = http.request(requested_url) + resp, content = http.request(actual_url) - if resp.status == 404: - raise UnknownApiNameOrVersion("name: %s version: %s" % (serviceName, - version)) if resp.status >= 400: - raise HttpError(resp, content, uri=requested_url) + raise HttpError(resp, content, uri=actual_url) try: content = content.decode('utf-8') @@ -212,10 +254,9 @@ def build(serviceName, except ValueError as e: logger.error('Failed to parse as JSON: ' + content) raise InvalidJsonError() - - return build_from_document(content, base=discoveryServiceUrl, http=http, - developerKey=developerKey, model=model, requestBuilder=requestBuilder, - credentials=credentials) + if cache_discovery and cache: + cache.set(url, content) + return content @positional(1) @@ -254,6 +295,9 @@ def build_from_document( A Resource object with methods for interacting with the service. """ + if http is None: + http = httplib2.Http() + # future is no longer used. future = {} @@ -854,7 +898,7 @@ Returns: # Retrieve nextPageToken from previous_response # Use as pageToken in previous_request to create new request. - if 'nextPageToken' not in previous_response: + if 'nextPageToken' not in previous_response or not previous_response['nextPageToken']: return None request = copy.copy(previous_request) diff --git a/googleapiclient/discovery_cache/__init__.py b/googleapiclient/discovery_cache/__init__.py new file mode 100644 index 00000000..c56fd659 --- /dev/null +++ b/googleapiclient/discovery_cache/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2014 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Caching utility for the discovery document.""" + +from __future__ import absolute_import + +import logging +import datetime + +DISCOVERY_DOC_MAX_AGE = 60 * 60 * 24 # 1 day + + +def autodetect(): + """Detects an appropriate cache module and returns it. + + Returns: + googleapiclient.discovery_cache.base.Cache, a cache object which + is auto detected, or None if no cache object is available. + """ + try: + from google.appengine.api import memcache + from . import appengine_memcache + return appengine_memcache.cache + except Exception: + try: + from . import file_cache + return file_cache.cache + except Exception as e: + logging.warning(e, exc_info=True) + return None diff --git a/googleapiclient/discovery_cache/appengine_memcache.py b/googleapiclient/discovery_cache/appengine_memcache.py new file mode 100644 index 00000000..a521fc39 --- /dev/null +++ b/googleapiclient/discovery_cache/appengine_memcache.py @@ -0,0 +1,52 @@ +# Copyright 2014 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""App Engine memcache based cache for the discovery document.""" + +import logging + +# This is only an optional dependency because we only import this +# module when google.appengine.api.memcache is available. +from google.appengine.api import memcache + +from . import base +from ..discovery_cache import DISCOVERY_DOC_MAX_AGE + +NAMESPACE = 'google-api-client' + + +class Cache(base.Cache): + """A cache with app engine memcache API.""" + + def __init__(self, max_age): + """Constructor. + + Args: + max_age: Cache expiration in seconds. + """ + self._max_age = max_age + + def get(self, url): + try: + return memcache.get(url, namespace=NAMESPACE) + except Exception as e: + logging.warning(e, exc_info=True) + + def set(self, url, content): + try: + memcache.set(url, content, time=int(self._max_age), namespace=NAMESPACE) + except Exception as e: + logging.warning(e, exc_info=True) + +cache = Cache(max_age=DISCOVERY_DOC_MAX_AGE) diff --git a/googleapiclient/discovery_cache/base.py b/googleapiclient/discovery_cache/base.py new file mode 100644 index 00000000..00e466d1 --- /dev/null +++ b/googleapiclient/discovery_cache/base.py @@ -0,0 +1,45 @@ +# Copyright 2014 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An abstract class for caching the discovery document.""" + +import abc + + +class Cache(object): + """A base abstract cache class.""" + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def get(self, url): + """Gets the content from the memcache with a given key. + + Args: + url: string, the key for the cache. + + Returns: + object, the value in the cache for the given key, or None if the key is + not in the cache. + """ + raise NotImplementedError() + + @abc.abstractmethod + def set(self, url, content): + """Sets the given key and content in the cache. + + Args: + url: string, the key for the cache. + content: string, the discovery document. + """ + raise NotImplementedError() diff --git a/googleapiclient/discovery_cache/file_cache.py b/googleapiclient/discovery_cache/file_cache.py new file mode 100644 index 00000000..ce540f02 --- /dev/null +++ b/googleapiclient/discovery_cache/file_cache.py @@ -0,0 +1,132 @@ +# Copyright 2014 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""File based cache for the discovery document. + +The cache is stored in a single file so that multiple processes can +share the same cache. It locks the file whenever accesing to the +file. When the cache content is corrupted, it will be initialized with +an empty cache. +""" + +from __future__ import division + +import datetime +import json +import logging +import os +import tempfile +import threading + +from oauth2client.locked_file import LockedFile + +from . import base +from ..discovery_cache import DISCOVERY_DOC_MAX_AGE + +logger = logging.getLogger(__name__) + +FILENAME = 'google-api-python-client-discovery-doc.cache' +EPOCH = datetime.datetime.utcfromtimestamp(0) + + +def _to_timestamp(date): + try: + return (date - EPOCH).total_seconds() + except AttributeError: + # The following is the equivalent of total_seconds() in Python2.6. + # See also: https://docs.python.org/2/library/datetime.html + delta = date - EPOCH + return ((delta.microseconds + (delta.seconds + delta.days * 24 * 3600) + * 10**6) / 10**6) + + +def _read_or_initialize_cache(f): + f.file_handle().seek(0) + try: + cache = json.load(f.file_handle()) + except Exception: + # This means it opens the file for the first time, or the cache is + # corrupted, so initializing the file with an empty dict. + cache = {} + f.file_handle().truncate(0) + f.file_handle().seek(0) + json.dump(cache, f.file_handle()) + return cache + + +class Cache(base.Cache): + """A file based cache for the discovery documents.""" + + def __init__(self, max_age): + """Constructor. + + Args: + max_age: Cache expiration in seconds. + """ + self._max_age = max_age + self._file = os.path.join(tempfile.gettempdir(), FILENAME) + f = LockedFile(self._file, 'a+', 'r') + try: + f.open_and_lock() + if f.is_locked(): + _read_or_initialize_cache(f) + # If we can not obtain the lock, other process or thread must + # have initialized the file. + except Exception as e: + logging.warning(e, exc_info=True) + finally: + f.unlock_and_close() + + def get(self, url): + f = LockedFile(self._file, 'r+', 'r') + try: + f.open_and_lock() + if f.is_locked(): + cache = _read_or_initialize_cache(f) + if url in cache: + content, t = cache.get(url, (None, 0)) + if _to_timestamp(datetime.datetime.now()) < t + self._max_age: + return content + return None + else: + logger.debug('Could not obtain a lock for the cache file.') + return None + except Exception as e: + logger.warning(e, exc_info=True) + finally: + f.unlock_and_close() + + def set(self, url, content): + f = LockedFile(self._file, 'r+', 'r') + try: + f.open_and_lock() + if f.is_locked(): + cache = _read_or_initialize_cache(f) + cache[url] = (content, _to_timestamp(datetime.datetime.now())) + # Remove stale cache. + for k, (_, timestamp) in list(cache.items()): + if _to_timestamp(datetime.datetime.now()) >= timestamp + self._max_age: + del cache[k] + f.file_handle().truncate(0) + f.file_handle().seek(0) + json.dump(cache, f.file_handle()) + else: + logger.debug('Could not obtain a lock for the cache file.') + except Exception as e: + logger.warning(e, exc_info=True) + finally: + f.unlock_and_close() + + +cache = Cache(max_age=DISCOVERY_DOC_MAX_AGE) diff --git a/googleapiclient/http.py b/googleapiclient/http.py index f272ba8c..5fcd7a11 100644 --- a/googleapiclient/http.py +++ b/googleapiclient/http.py @@ -1288,6 +1288,9 @@ class BatchHttpRequest(object): httplib2.HttpLib2Error if a transport error has occured. googleapiclient.errors.BatchError if the response is the wrong format. """ + # If we have no requests return + if len(self._order) == 0: + return None # If http is not supplied use the first valid one given in the requests. if http is None: @@ -1460,7 +1463,7 @@ class HttpMock(object): if headers is None: headers = {'status': '200'} if filename: - f = open(filename, 'r') + f = open(filename, 'rb') self.data = f.read() f.close() else: diff --git a/httplib2/__init__.py b/httplib2/__init__.py index 9d2a2d43..6fa3cc60 100644 --- a/httplib2/__init__.py +++ b/httplib2/__init__.py @@ -22,7 +22,7 @@ __contributors__ = ["Thomas Broyer (t.broyer@ltgt.net)", "Sam Ruby", "Louis Nyffenegger"] __license__ = "MIT" -__version__ = "0.9.1" +__version__ = "0.9.2" import re import sys @@ -255,8 +255,8 @@ def safename(filename): filename = re_slash.sub(",", filename) # limit length of filename - if len(filename)>64: - filename=filename[:64] + if len(filename)>200: + filename=filename[:200] return ",".join((filename, filemd5)) NORMALIZE_SPACE = re.compile(r'(?:\r\n)?[ \t]+') @@ -1285,8 +1285,9 @@ class Http(object): err = getattr(e, 'args')[0] else: err = e.errno - if err == errno.ECONNREFUSED: # Connection refused - raise + if err in (errno.ENETUNREACH, errno.EADDRNOTAVAIL) and i < RETRIES: + continue # retry on potentially transient socket errors + raise except httplib.HTTPException: # Just because the server closed the connection doesn't apparently mean # that the server didn't send a response. diff --git a/oauth2client/__init__.py b/oauth2client/__init__.py index f992cfff..4dda02a4 100644 --- a/oauth2client/__init__.py +++ b/oauth2client/__init__.py @@ -1,8 +1,23 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Client library for using OAuth2, especially with Google APIs.""" -__version__ = '1.4.7' +__version__ = '1.5.1' GOOGLE_AUTH_URI = 'https://accounts.google.com/o/oauth2/auth' GOOGLE_DEVICE_URI = 'https://accounts.google.com/o/oauth2/device/code' GOOGLE_REVOKE_URI = 'https://accounts.google.com/o/oauth2/revoke' GOOGLE_TOKEN_URI = 'https://accounts.google.com/o/oauth2/token' +GOOGLE_TOKEN_INFO_URI = 'https://www.googleapis.com/oauth2/v2/tokeninfo' diff --git a/oauth2client/_helpers.py b/oauth2client/_helpers.py new file mode 100644 index 00000000..39bfeb6a --- /dev/null +++ b/oauth2client/_helpers.py @@ -0,0 +1,103 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions for commonly used utilities.""" + +import base64 +import json +import six + + +def _parse_pem_key(raw_key_input): + """Identify and extract PEM keys. + + Determines whether the given key is in the format of PEM key, and extracts + the relevant part of the key if it is. + + Args: + raw_key_input: The contents of a private key file (either PEM or + PKCS12). + + Returns: + string, The actual key if the contents are from a PEM file, or + else None. + """ + offset = raw_key_input.find(b'-----BEGIN ') + if offset != -1: + return raw_key_input[offset:] + + +def _json_encode(data): + return json.dumps(data, separators=(',', ':')) + + +def _to_bytes(value, encoding='ascii'): + """Converts a string value to bytes, if necessary. + + Unfortunately, ``six.b`` is insufficient for this task since in + Python2 it does not modify ``unicode`` objects. + + Args: + value: The string/bytes value to be converted. + encoding: The encoding to use to convert unicode to bytes. Defaults + to "ascii", which will not allow any characters from ordinals + larger than 127. Other useful values are "latin-1", which + which will only allows byte ordinals (up to 255) and "utf-8", + which will encode any unicode that needs to be. + + Returns: + The original value converted to bytes (if unicode) or as passed in + if it started out as bytes. + + Raises: + ValueError if the value could not be converted to bytes. + """ + result = (value.encode(encoding) + if isinstance(value, six.text_type) else value) + if isinstance(result, six.binary_type): + return result + else: + raise ValueError('%r could not be converted to bytes' % (value,)) + + +def _from_bytes(value): + """Converts bytes to a string value, if necessary. + + Args: + value: The string/bytes value to be converted. + + Returns: + The original value converted to unicode (if bytes) or as passed in + if it started out as unicode. + + Raises: + ValueError if the value could not be converted to unicode. + """ + result = (value.decode('utf-8') + if isinstance(value, six.binary_type) else value) + if isinstance(result, six.text_type): + return result + else: + raise ValueError('%r could not be converted to unicode' % (value,)) + + +def _urlsafe_b64encode(raw_bytes): + raw_bytes = _to_bytes(raw_bytes, encoding='utf-8') + return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=') + + +def _urlsafe_b64decode(b64string): + # Guard against unicode strings, which base64 can't handle. + b64string = _to_bytes(b64string) + padded = b64string + b'=' * (4 - len(b64string) % 4) + return base64.urlsafe_b64decode(padded) diff --git a/oauth2client/_openssl_crypt.py b/oauth2client/_openssl_crypt.py new file mode 100644 index 00000000..aca35054 --- /dev/null +++ b/oauth2client/_openssl_crypt.py @@ -0,0 +1,139 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenSSL Crypto-related routines for oauth2client.""" + +import base64 +import six +from OpenSSL import crypto + +from oauth2client._helpers import _parse_pem_key +from oauth2client._helpers import _to_bytes + + +class OpenSSLVerifier(object): + """Verifies the signature on a message.""" + + def __init__(self, pubkey): + """Constructor. + + Args: + pubkey: OpenSSL.crypto.PKey, The public key to verify with. + """ + self._pubkey = pubkey + + def verify(self, message, signature): + """Verifies a message against a signature. + + Args: + message: string or bytes, The message to verify. If string, will be + encoded to bytes as utf-8. + signature: string or bytes, The signature on the message. If string, + will be encoded to bytes as utf-8. + + Returns: + True if message was signed by the private key associated with the + public key that this object was constructed with. + """ + message = _to_bytes(message, encoding='utf-8') + signature = _to_bytes(signature, encoding='utf-8') + try: + crypto.verify(self._pubkey, signature, message, 'sha256') + return True + except crypto.Error: + return False + + @staticmethod + def from_string(key_pem, is_x509_cert): + """Construct a Verified instance from a string. + + Args: + key_pem: string, public key in PEM format. + is_x509_cert: bool, True if key_pem is an X509 cert, otherwise it + is expected to be an RSA key in PEM format. + + Returns: + Verifier instance. + + Raises: + OpenSSL.crypto.Error: if the key_pem can't be parsed. + """ + if is_x509_cert: + pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, key_pem) + else: + pubkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem) + return OpenSSLVerifier(pubkey) + + +class OpenSSLSigner(object): + """Signs messages with a private key.""" + + def __init__(self, pkey): + """Constructor. + + Args: + pkey: OpenSSL.crypto.PKey (or equiv), The private key to sign with. + """ + self._key = pkey + + def sign(self, message): + """Signs a message. + + Args: + message: bytes, Message to be signed. + + Returns: + string, The signature of the message for the given key. + """ + message = _to_bytes(message, encoding='utf-8') + return crypto.sign(self._key, message, 'sha256') + + @staticmethod + def from_string(key, password=b'notasecret'): + """Construct a Signer instance from a string. + + Args: + key: string, private key in PKCS12 or PEM format. + password: string, password for the private key file. + + Returns: + Signer instance. + + Raises: + OpenSSL.crypto.Error if the key can't be parsed. + """ + parsed_pem_key = _parse_pem_key(key) + if parsed_pem_key: + pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key) + else: + password = _to_bytes(password, encoding='utf-8') + pkey = crypto.load_pkcs12(key, password).get_privatekey() + return OpenSSLSigner(pkey) + + +def pkcs12_key_as_pem(private_key_text, private_key_password): + """Convert the contents of a PKCS12 key to PEM using OpenSSL. + + Args: + private_key_text: String. Private key. + private_key_password: String. Password for PKCS12. + + Returns: + String. PEM contents of ``private_key_text``. + """ + decoded_body = base64.b64decode(private_key_text) + private_key_password = _to_bytes(private_key_password) + + pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password) + return crypto.dump_privatekey(crypto.FILETYPE_PEM, + pkcs12.get_privatekey()) diff --git a/oauth2client/_pycrypto_crypt.py b/oauth2client/_pycrypto_crypt.py new file mode 100644 index 00000000..fa025753 --- /dev/null +++ b/oauth2client/_pycrypto_crypt.py @@ -0,0 +1,129 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""pyCrypto Crypto-related routines for oauth2client.""" + +from Crypto.PublicKey import RSA +from Crypto.Hash import SHA256 +from Crypto.Signature import PKCS1_v1_5 +from Crypto.Util.asn1 import DerSequence +import six + +from oauth2client._helpers import _parse_pem_key +from oauth2client._helpers import _to_bytes +from oauth2client._helpers import _urlsafe_b64decode + + +class PyCryptoVerifier(object): + """Verifies the signature on a message.""" + + def __init__(self, pubkey): + """Constructor. + + Args: + pubkey: OpenSSL.crypto.PKey (or equiv), The public key to verify + with. + """ + self._pubkey = pubkey + + def verify(self, message, signature): + """Verifies a message against a signature. + + Args: + message: string or bytes, The message to verify. If string, will be + encoded to bytes as utf-8. + signature: string or bytes, The signature on the message. + + Returns: + True if message was signed by the private key associated with the + public key that this object was constructed with. + """ + message = _to_bytes(message, encoding='utf-8') + return PKCS1_v1_5.new(self._pubkey).verify( + SHA256.new(message), signature) + + @staticmethod + def from_string(key_pem, is_x509_cert): + """Construct a Verified instance from a string. + + Args: + key_pem: string, public key in PEM format. + is_x509_cert: bool, True if key_pem is an X509 cert, otherwise it + is expected to be an RSA key in PEM format. + + Returns: + Verifier instance. + """ + if is_x509_cert: + key_pem = _to_bytes(key_pem) + pemLines = key_pem.replace(b' ', b'').split() + certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1])) + certSeq = DerSequence() + certSeq.decode(certDer) + tbsSeq = DerSequence() + tbsSeq.decode(certSeq[0]) + pubkey = RSA.importKey(tbsSeq[6]) + else: + pubkey = RSA.importKey(key_pem) + return PyCryptoVerifier(pubkey) + + +class PyCryptoSigner(object): + """Signs messages with a private key.""" + + def __init__(self, pkey): + """Constructor. + + Args: + pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with. + """ + self._key = pkey + + def sign(self, message): + """Signs a message. + + Args: + message: string, Message to be signed. + + Returns: + string, The signature of the message for the given key. + """ + message = _to_bytes(message, encoding='utf-8') + return PKCS1_v1_5.new(self._key).sign(SHA256.new(message)) + + @staticmethod + def from_string(key, password='notasecret'): + """Construct a Signer instance from a string. + + Args: + key: string, private key in PEM format. + password: string, password for private key file. Unused for PEM + files. + + Returns: + Signer instance. + + Raises: + NotImplementedError if the key isn't in PEM format. + """ + parsed_pem_key = _parse_pem_key(key) + if parsed_pem_key: + pkey = RSA.importKey(parsed_pem_key) + else: + raise NotImplementedError( + 'PKCS12 format is not supported by the PyCrypto library. ' + 'Try converting to a "PEM" ' + '(openssl pkcs12 -in xxxxx.p12 -nodes -nocerts > ' + 'privatekey.pem) ' + 'or using PyOpenSSL if native code is an option.') + return PyCryptoSigner(pkey) diff --git a/oauth2client/appengine.py b/oauth2client/appengine.py index 00fe9855..47d390c3 100644 --- a/oauth2client/appengine.py +++ b/oauth2client/appengine.py @@ -17,8 +17,6 @@ Utilities for making it easier to use OAuth 2.0 on Google App Engine. """ -__author__ = 'jcgregorio@google.com (Joe Gregorio)' - import cgi import json import logging @@ -27,12 +25,12 @@ import pickle import threading import httplib2 +import webapp2 as webapp from google.appengine.api import app_identity from google.appengine.api import memcache from google.appengine.api import users from google.appengine.ext import db -from google.appengine.ext import webapp from google.appengine.ext.webapp.util import login_required from google.appengine.ext.webapp.util import run_wsgi_app from oauth2client import GOOGLE_AUTH_URI @@ -51,11 +49,13 @@ from oauth2client.client import Storage # TODO(dhermes): Resolve import issue. # This is a temporary fix for a Google internal issue. try: - from google.appengine.ext import ndb + from google.appengine.ext import ndb except ImportError: - ndb = None + ndb = None +__author__ = 'jcgregorio@google.com (Joe Gregorio)' + logger = logging.getLogger(__name__) OAUTH2CLIENT_NAMESPACE = 'oauth2client#ns' @@ -64,924 +64,954 @@ XSRF_MEMCACHE_ID = 'xsrf_secret_key' def _safe_html(s): - """Escape text to make it safe to display. + """Escape text to make it safe to display. - Args: - s: string, The text to escape. + Args: + s: string, The text to escape. - Returns: - The escaped text as a string. - """ - return cgi.escape(s, quote=1).replace("'", ''') + Returns: + The escaped text as a string. + """ + return cgi.escape(s, quote=1).replace("'", ''') class InvalidClientSecretsError(Exception): - """The client_secrets.json file is malformed or missing required fields.""" + """The client_secrets.json file is malformed or missing required fields.""" class InvalidXsrfTokenError(Exception): - """The XSRF token is invalid or expired.""" + """The XSRF token is invalid or expired.""" class SiteXsrfSecretKey(db.Model): - """Storage for the sites XSRF secret key. + """Storage for the sites XSRF secret key. - There will only be one instance stored of this model, the one used for the - site. - """ - secret = db.StringProperty() - -if ndb is not None: - class SiteXsrfSecretKeyNDB(ndb.Model): - """NDB Model for storage for the sites XSRF secret key. - - Since this model uses the same kind as SiteXsrfSecretKey, it can be used - interchangeably. This simply provides an NDB model for interacting with the - same data the DB model interacts with. - - There should only be one instance stored of this model, the one used for the + There will only be one instance stored of this model, the one used for the site. """ - secret = ndb.StringProperty() + secret = db.StringProperty() - @classmethod - def _get_kind(cls): - """Return the kind name for this class.""" - return 'SiteXsrfSecretKey' +if ndb is not None: + class SiteXsrfSecretKeyNDB(ndb.Model): + """NDB Model for storage for the sites XSRF secret key. + + Since this model uses the same kind as SiteXsrfSecretKey, it can be + used interchangeably. This simply provides an NDB model for interacting + with the same data the DB model interacts with. + + There should only be one instance stored of this model, the one used + for the site. + """ + secret = ndb.StringProperty() + + @classmethod + def _get_kind(cls): + """Return the kind name for this class.""" + return 'SiteXsrfSecretKey' def _generate_new_xsrf_secret_key(): - """Returns a random XSRF secret key. - """ - return os.urandom(16).encode("hex") + """Returns a random XSRF secret key.""" + return os.urandom(16).encode("hex") def xsrf_secret_key(): - """Return the secret key for use for XSRF protection. + """Return the secret key for use for XSRF protection. - If the Site entity does not have a secret key, this method will also create - one and persist it. + If the Site entity does not have a secret key, this method will also create + one and persist it. - Returns: - The secret key. - """ - secret = memcache.get(XSRF_MEMCACHE_ID, namespace=OAUTH2CLIENT_NAMESPACE) - if not secret: - # Load the one and only instance of SiteXsrfSecretKey. - model = SiteXsrfSecretKey.get_or_insert(key_name='site') - if not model.secret: - model.secret = _generate_new_xsrf_secret_key() - model.put() - secret = model.secret - memcache.add(XSRF_MEMCACHE_ID, secret, namespace=OAUTH2CLIENT_NAMESPACE) + Returns: + The secret key. + """ + secret = memcache.get(XSRF_MEMCACHE_ID, namespace=OAUTH2CLIENT_NAMESPACE) + if not secret: + # Load the one and only instance of SiteXsrfSecretKey. + model = SiteXsrfSecretKey.get_or_insert(key_name='site') + if not model.secret: + model.secret = _generate_new_xsrf_secret_key() + model.put() + secret = model.secret + memcache.add(XSRF_MEMCACHE_ID, secret, + namespace=OAUTH2CLIENT_NAMESPACE) - return str(secret) + return str(secret) class AppAssertionCredentials(AssertionCredentials): - """Credentials object for App Engine Assertion Grants + """Credentials object for App Engine Assertion Grants - This object will allow an App Engine application to identify itself to Google - and other OAuth 2.0 servers that can verify assertions. It can be used for the - purpose of accessing data stored under an account assigned to the App Engine - application itself. + This object will allow an App Engine application to identify itself to + Google and other OAuth 2.0 servers that can verify assertions. It can be + used for the purpose of accessing data stored under an account assigned to + the App Engine application itself. - This credential does not require a flow to instantiate because it represents - a two legged flow, and therefore has all of the required information to - generate and refresh its own access tokens. - """ - - @util.positional(2) - def __init__(self, scope, **kwargs): - """Constructor for AppAssertionCredentials - - Args: - scope: string or iterable of strings, scope(s) of the credentials being - requested. - **kwargs: optional keyword args, including: - service_account_id: service account id of the application. If None or - unspecified, the default service account for the app is used. + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. """ - self.scope = util.scopes_to_string(scope) - self._kwargs = kwargs - self.service_account_id = kwargs.get('service_account_id', None) - # Assertion type is no longer used, but still in the parent class signature. - super(AppAssertionCredentials, self).__init__(None) + @util.positional(2) + def __init__(self, scope, **kwargs): + """Constructor for AppAssertionCredentials - @classmethod - def from_json(cls, json_data): - data = json.loads(json_data) - return AppAssertionCredentials(data['scope']) + Args: + scope: string or iterable of strings, scope(s) of the credentials + being requested. + **kwargs: optional keyword args, including: + service_account_id: service account id of the application. If None + or unspecified, the default service account for + the app is used. + """ + self.scope = util.scopes_to_string(scope) + self._kwargs = kwargs + self.service_account_id = kwargs.get('service_account_id', None) - def _refresh(self, http_request): - """Refreshes the access_token. + # Assertion type is no longer used, but still in the + # parent class signature. + super(AppAssertionCredentials, self).__init__(None) - Since the underlying App Engine app_identity implementation does its own - caching we can skip all the storage hoops and just to a refresh using the - API. + @classmethod + def from_json(cls, json_data): + data = json.loads(json_data) + return AppAssertionCredentials(data['scope']) - Args: - http_request: callable, a callable that matches the method signature of - httplib2.Http.request, used to make the refresh request. + def _refresh(self, http_request): + """Refreshes the access_token. - Raises: - AccessTokenRefreshError: When the refresh fails. - """ - try: - scopes = self.scope.split() - (token, _) = app_identity.get_access_token( - scopes, service_account_id=self.service_account_id) - except app_identity.Error as e: - raise AccessTokenRefreshError(str(e)) - self.access_token = token + Since the underlying App Engine app_identity implementation does its + own caching we can skip all the storage hoops and just to a refresh + using the API. - @property - def serialization_data(self): - raise NotImplementedError('Cannot serialize credentials for AppEngine.') + Args: + http_request: callable, a callable that matches the method + signature of httplib2.Http.request, used to make the + refresh request. - def create_scoped_required(self): - return not self.scope + Raises: + AccessTokenRefreshError: When the refresh fails. + """ + try: + scopes = self.scope.split() + (token, _) = app_identity.get_access_token( + scopes, service_account_id=self.service_account_id) + except app_identity.Error as e: + raise AccessTokenRefreshError(str(e)) + self.access_token = token - def create_scoped(self, scopes): - return AppAssertionCredentials(scopes, **self._kwargs) + @property + def serialization_data(self): + raise NotImplementedError('Cannot serialize credentials ' + 'for Google App Engine.') + + def create_scoped_required(self): + return not self.scope + + def create_scoped(self, scopes): + return AppAssertionCredentials(scopes, **self._kwargs) class FlowProperty(db.Property): - """App Engine datastore Property for Flow. - - Utility property that allows easy storage and retrieval of an - oauth2client.Flow""" - - # Tell what the user type is. - data_type = Flow - - # For writing to datastore. - def get_value_for_datastore(self, model_instance): - flow = super(FlowProperty, - self).get_value_for_datastore(model_instance) - return db.Blob(pickle.dumps(flow)) - - # For reading from datastore. - def make_value_from_datastore(self, value): - if value is None: - return None - return pickle.loads(value) - - def validate(self, value): - if value is not None and not isinstance(value, Flow): - raise db.BadValueError('Property %s must be convertible ' - 'to a FlowThreeLegged instance (%s)' % - (self.name, value)) - return super(FlowProperty, self).validate(value) - - def empty(self, value): - return not value - - -if ndb is not None: - class FlowNDBProperty(ndb.PickleProperty): - """App Engine NDB datastore Property for Flow. - - Serves the same purpose as the DB FlowProperty, but for NDB models. Since - PickleProperty inherits from BlobProperty, the underlying representation of - the data in the datastore will be the same as in the DB case. + """App Engine datastore Property for Flow. Utility property that allows easy storage and retrieval of an oauth2client.Flow """ - def _validate(self, value): - """Validates a value as a proper Flow object. + # Tell what the user type is. + data_type = Flow - Args: - value: A value to be set on the property. + # For writing to datastore. + def get_value_for_datastore(self, model_instance): + flow = super(FlowProperty, self).get_value_for_datastore( + model_instance) + return db.Blob(pickle.dumps(flow)) - Raises: - TypeError if the value is not an instance of Flow. - """ - logger.info('validate: Got type %s', type(value)) - if value is not None and not isinstance(value, Flow): - raise TypeError('Property %s must be convertible to a flow ' - 'instance; received: %s.' % (self._name, value)) + # For reading from datastore. + def make_value_from_datastore(self, value): + if value is None: + return None + return pickle.loads(value) + + def validate(self, value): + if value is not None and not isinstance(value, Flow): + raise db.BadValueError('Property %s must be convertible ' + 'to a FlowThreeLegged instance (%s)' % + (self.name, value)) + return super(FlowProperty, self).validate(value) + + def empty(self, value): + return not value + + +if ndb is not None: + class FlowNDBProperty(ndb.PickleProperty): + """App Engine NDB datastore Property for Flow. + + Serves the same purpose as the DB FlowProperty, but for NDB models. + Since PickleProperty inherits from BlobProperty, the underlying + representation of the data in the datastore will be the same as in the + DB case. + + Utility property that allows easy storage and retrieval of an + oauth2client.Flow + """ + + def _validate(self, value): + """Validates a value as a proper Flow object. + + Args: + value: A value to be set on the property. + + Raises: + TypeError if the value is not an instance of Flow. + """ + logger.info('validate: Got type %s', type(value)) + if value is not None and not isinstance(value, Flow): + raise TypeError('Property %s must be convertible to a flow ' + 'instance; received: %s.' % (self._name, + value)) class CredentialsProperty(db.Property): - """App Engine datastore Property for Credentials. + """App Engine datastore Property for Credentials. - Utility property that allows easy storage and retrieval of - oath2client.Credentials - """ + Utility property that allows easy storage and retrieval of + oath2client.Credentials + """ - # Tell what the user type is. - data_type = Credentials + # Tell what the user type is. + data_type = Credentials - # For writing to datastore. - def get_value_for_datastore(self, model_instance): - logger.info("get: Got type " + str(type(model_instance))) - cred = super(CredentialsProperty, - self).get_value_for_datastore(model_instance) - if cred is None: - cred = '' - else: - cred = cred.to_json() - return db.Blob(cred) + # For writing to datastore. + def get_value_for_datastore(self, model_instance): + logger.info("get: Got type " + str(type(model_instance))) + cred = super(CredentialsProperty, self).get_value_for_datastore( + model_instance) + if cred is None: + cred = '' + else: + cred = cred.to_json() + return db.Blob(cred) - # For reading from datastore. - def make_value_from_datastore(self, value): - logger.info("make: Got type " + str(type(value))) - if value is None: - return None - if len(value) == 0: - return None - try: - credentials = Credentials.new_from_json(value) - except ValueError: - credentials = None - return credentials + # For reading from datastore. + def make_value_from_datastore(self, value): + logger.info("make: Got type " + str(type(value))) + if value is None: + return None + if len(value) == 0: + return None + try: + credentials = Credentials.new_from_json(value) + except ValueError: + credentials = None + return credentials - def validate(self, value): - value = super(CredentialsProperty, self).validate(value) - logger.info("validate: Got type " + str(type(value))) - if value is not None and not isinstance(value, Credentials): - raise db.BadValueError('Property %s must be convertible ' - 'to a Credentials instance (%s)' % - (self.name, value)) - #if value is not None and not isinstance(value, Credentials): - # return None - return value + def validate(self, value): + value = super(CredentialsProperty, self).validate(value) + logger.info("validate: Got type " + str(type(value))) + if value is not None and not isinstance(value, Credentials): + raise db.BadValueError('Property %s must be convertible ' + 'to a Credentials instance (%s)' % + (self.name, value)) + return value if ndb is not None: - # TODO(dhermes): Turn this into a JsonProperty and overhaul the Credentials - # and subclass mechanics to use new_from_dict, to_dict, - # from_dict, etc. - class CredentialsNDBProperty(ndb.BlobProperty): - """App Engine NDB datastore Property for Credentials. + # TODO(dhermes): Turn this into a JsonProperty and overhaul the Credentials + # and subclass mechanics to use new_from_dict, to_dict, + # from_dict, etc. + class CredentialsNDBProperty(ndb.BlobProperty): + """App Engine NDB datastore Property for Credentials. - Serves the same purpose as the DB CredentialsProperty, but for NDB models. - Since CredentialsProperty stores data as a blob and this inherits from - BlobProperty, the data in the datastore will be the same as in the DB case. + Serves the same purpose as the DB CredentialsProperty, but for NDB + models. Since CredentialsProperty stores data as a blob and this + inherits from BlobProperty, the data in the datastore will be the same + as in the DB case. - Utility property that allows easy storage and retrieval of Credentials and - subclasses. - """ - def _validate(self, value): - """Validates a value as a proper credentials object. + Utility property that allows easy storage and retrieval of Credentials + and subclasses. + """ - Args: - value: A value to be set on the property. + def _validate(self, value): + """Validates a value as a proper credentials object. - Raises: - TypeError if the value is not an instance of Credentials. - """ - logger.info('validate: Got type %s', type(value)) - if value is not None and not isinstance(value, Credentials): - raise TypeError('Property %s must be convertible to a credentials ' - 'instance; received: %s.' % (self._name, value)) + Args: + value: A value to be set on the property. - def _to_base_type(self, value): - """Converts our validated value to a JSON serialized string. + Raises: + TypeError if the value is not an instance of Credentials. + """ + logger.info('validate: Got type %s', type(value)) + if value is not None and not isinstance(value, Credentials): + raise TypeError('Property %s must be convertible to a ' + 'credentials instance; received: %s.' % + (self._name, value)) - Args: - value: A value to be set in the datastore. + def _to_base_type(self, value): + """Converts our validated value to a JSON serialized string. - Returns: - A JSON serialized version of the credential, else '' if value is None. - """ - if value is None: - return '' - else: - return value.to_json() + Args: + value: A value to be set in the datastore. - def _from_base_type(self, value): - """Converts our stored JSON string back to the desired type. + Returns: + A JSON serialized version of the credential, else '' if value + is None. + """ + if value is None: + return '' + else: + return value.to_json() - Args: - value: A value from the datastore to be converted to the desired type. + def _from_base_type(self, value): + """Converts our stored JSON string back to the desired type. - Returns: - A deserialized Credentials (or subclass) object, else None if the - value can't be parsed. - """ - if not value: - return None - try: - # Uses the from_json method of the implied class of value - credentials = Credentials.new_from_json(value) - except ValueError: - credentials = None - return credentials + Args: + value: A value from the datastore to be converted to the + desired type. + + Returns: + A deserialized Credentials (or subclass) object, else None if + the value can't be parsed. + """ + if not value: + return None + try: + # Uses the from_json method of the implied class of value + credentials = Credentials.new_from_json(value) + except ValueError: + credentials = None + return credentials class StorageByKeyName(Storage): - """Store and retrieve a credential to and from the App Engine datastore. + """Store and retrieve a credential to and from the App Engine datastore. - This Storage helper presumes the Credentials have been stored as a - CredentialsProperty or CredentialsNDBProperty on a datastore model class, and - that entities are stored by key_name. - """ - - @util.positional(4) - def __init__(self, model, key_name, property_name, cache=None, user=None): - """Constructor for Storage. - - Args: - model: db.Model or ndb.Model, model class - key_name: string, key name for the entity that has the credentials - property_name: string, name of the property that is a CredentialsProperty - or CredentialsNDBProperty. - cache: memcache, a write-through cache to put in front of the datastore. - If the model you are using is an NDB model, using a cache will be - redundant since the model uses an instance cache and memcache for you. - user: users.User object, optional. Can be used to grab user ID as a - key_name if no key name is specified. + This Storage helper presumes the Credentials have been stored as a + CredentialsProperty or CredentialsNDBProperty on a datastore model class, + and that entities are stored by key_name. """ - if key_name is None: - if user is None: - raise ValueError('StorageByKeyName called with no key name or user.') - key_name = user.user_id() - self._model = model - self._key_name = key_name - self._property_name = property_name - self._cache = cache + @util.positional(4) + def __init__(self, model, key_name, property_name, cache=None, user=None): + """Constructor for Storage. - def _is_ndb(self): - """Determine whether the model of the instance is an NDB model. + Args: + model: db.Model or ndb.Model, model class + key_name: string, key name for the entity that has the credentials + property_name: string, name of the property that is a + CredentialsProperty or CredentialsNDBProperty. + cache: memcache, a write-through cache to put in front of the + datastore. If the model you are using is an NDB model, using + a cache will be redundant since the model uses an instance + cache and memcache for you. + user: users.User object, optional. Can be used to grab user ID as a + key_name if no key name is specified. + """ + if key_name is None: + if user is None: + raise ValueError('StorageByKeyName called with no ' + 'key name or user.') + key_name = user.user_id() - Returns: - Boolean indicating whether or not the model is an NDB or DB model. - """ - # issubclass will fail if one of the arguments is not a class, only need - # worry about new-style classes since ndb and db models are new-style - if isinstance(self._model, type): - if ndb is not None and issubclass(self._model, ndb.Model): - return True - elif issubclass(self._model, db.Model): - return False + self._model = model + self._key_name = key_name + self._property_name = property_name + self._cache = cache - raise TypeError('Model class not an NDB or DB model: %s.' % (self._model,)) + def _is_ndb(self): + """Determine whether the model of the instance is an NDB model. - def _get_entity(self): - """Retrieve entity from datastore. + Returns: + Boolean indicating whether or not the model is an NDB or DB model. + """ + # issubclass will fail if one of the arguments is not a class, only + # need worry about new-style classes since ndb and db models are + # new-style + if isinstance(self._model, type): + if ndb is not None and issubclass(self._model, ndb.Model): + return True + elif issubclass(self._model, db.Model): + return False - Uses a different model method for db or ndb models. + raise TypeError('Model class not an NDB or DB model: %s.' % + (self._model,)) - Returns: - Instance of the model corresponding to the current storage object - and stored using the key name of the storage object. - """ - if self._is_ndb(): - return self._model.get_by_id(self._key_name) - else: - return self._model.get_by_key_name(self._key_name) + def _get_entity(self): + """Retrieve entity from datastore. - def _delete_entity(self): - """Delete entity from datastore. + Uses a different model method for db or ndb models. - Attempts to delete using the key_name stored on the object, whether or not - the given key is in the datastore. - """ - if self._is_ndb(): - ndb.Key(self._model, self._key_name).delete() - else: - entity_key = db.Key.from_path(self._model.kind(), self._key_name) - db.delete(entity_key) + Returns: + Instance of the model corresponding to the current storage object + and stored using the key name of the storage object. + """ + if self._is_ndb(): + return self._model.get_by_id(self._key_name) + else: + return self._model.get_by_key_name(self._key_name) - @db.non_transactional(allow_existing=True) - def locked_get(self): - """Retrieve Credential from datastore. + def _delete_entity(self): + """Delete entity from datastore. - Returns: - oauth2client.Credentials - """ - credentials = None - if self._cache: - json = self._cache.get(self._key_name) - if json: - credentials = Credentials.new_from_json(json) - if credentials is None: - entity = self._get_entity() - if entity is not None: - credentials = getattr(entity, self._property_name) + Attempts to delete using the key_name stored on the object, whether or + not the given key is in the datastore. + """ + if self._is_ndb(): + ndb.Key(self._model, self._key_name).delete() + else: + entity_key = db.Key.from_path(self._model.kind(), self._key_name) + db.delete(entity_key) + + @db.non_transactional(allow_existing=True) + def locked_get(self): + """Retrieve Credential from datastore. + + Returns: + oauth2client.Credentials + """ + credentials = None if self._cache: - self._cache.set(self._key_name, credentials.to_json()) + json = self._cache.get(self._key_name) + if json: + credentials = Credentials.new_from_json(json) + if credentials is None: + entity = self._get_entity() + if entity is not None: + credentials = getattr(entity, self._property_name) + if self._cache: + self._cache.set(self._key_name, credentials.to_json()) - if credentials and hasattr(credentials, 'set_store'): - credentials.set_store(self) - return credentials + if credentials and hasattr(credentials, 'set_store'): + credentials.set_store(self) + return credentials - @db.non_transactional(allow_existing=True) - def locked_put(self, credentials): - """Write a Credentials to the datastore. + @db.non_transactional(allow_existing=True) + def locked_put(self, credentials): + """Write a Credentials to the datastore. - Args: - credentials: Credentials, the credentials to store. - """ - entity = self._model.get_or_insert(self._key_name) - setattr(entity, self._property_name, credentials) - entity.put() - if self._cache: - self._cache.set(self._key_name, credentials.to_json()) + Args: + credentials: Credentials, the credentials to store. + """ + entity = self._model.get_or_insert(self._key_name) + setattr(entity, self._property_name, credentials) + entity.put() + if self._cache: + self._cache.set(self._key_name, credentials.to_json()) - @db.non_transactional(allow_existing=True) - def locked_delete(self): - """Delete Credential from datastore.""" + @db.non_transactional(allow_existing=True) + def locked_delete(self): + """Delete Credential from datastore.""" - if self._cache: - self._cache.delete(self._key_name) + if self._cache: + self._cache.delete(self._key_name) - self._delete_entity() + self._delete_entity() class CredentialsModel(db.Model): - """Storage for OAuth 2.0 Credentials - - Storage of the model is keyed by the user.user_id(). - """ - credentials = CredentialsProperty() - - -if ndb is not None: - class CredentialsNDBModel(ndb.Model): - """NDB Model for storage of OAuth 2.0 Credentials - - Since this model uses the same kind as CredentialsModel and has a property - which can serialize and deserialize Credentials correctly, it can be used - interchangeably with a CredentialsModel to access, insert and delete the - same entities. This simply provides an NDB model for interacting with the - same data the DB model interacts with. + """Storage for OAuth 2.0 Credentials Storage of the model is keyed by the user.user_id(). """ - credentials = CredentialsNDBProperty() + credentials = CredentialsProperty() - @classmethod - def _get_kind(cls): - """Return the kind name for this class.""" - return 'CredentialsModel' + +if ndb is not None: + class CredentialsNDBModel(ndb.Model): + """NDB Model for storage of OAuth 2.0 Credentials + + Since this model uses the same kind as CredentialsModel and has a + property which can serialize and deserialize Credentials correctly, it + can be used interchangeably with a CredentialsModel to access, insert + and delete the same entities. This simply provides an NDB model for + interacting with the same data the DB model interacts with. + + Storage of the model is keyed by the user.user_id(). + """ + credentials = CredentialsNDBProperty() + + @classmethod + def _get_kind(cls): + """Return the kind name for this class.""" + return 'CredentialsModel' def _build_state_value(request_handler, user): - """Composes the value for the 'state' parameter. + """Composes the value for the 'state' parameter. - Packs the current request URI and an XSRF token into an opaque string that - can be passed to the authentication server via the 'state' parameter. + Packs the current request URI and an XSRF token into an opaque string that + can be passed to the authentication server via the 'state' parameter. - Args: - request_handler: webapp.RequestHandler, The request. - user: google.appengine.api.users.User, The current user. + Args: + request_handler: webapp.RequestHandler, The request. + user: google.appengine.api.users.User, The current user. - Returns: - The state value as a string. - """ - uri = request_handler.request.url - token = xsrfutil.generate_token(xsrf_secret_key(), user.user_id(), - action_id=str(uri)) - return uri + ':' + token + Returns: + The state value as a string. + """ + uri = request_handler.request.url + token = xsrfutil.generate_token(xsrf_secret_key(), user.user_id(), + action_id=str(uri)) + return uri + ':' + token def _parse_state_value(state, user): - """Parse the value of the 'state' parameter. + """Parse the value of the 'state' parameter. - Parses the value and validates the XSRF token in the state parameter. + Parses the value and validates the XSRF token in the state parameter. - Args: - state: string, The value of the state parameter. - user: google.appengine.api.users.User, The current user. + Args: + state: string, The value of the state parameter. + user: google.appengine.api.users.User, The current user. - Raises: - InvalidXsrfTokenError: if the XSRF token is invalid. + Raises: + InvalidXsrfTokenError: if the XSRF token is invalid. - Returns: - The redirect URI. - """ - uri, token = state.rsplit(':', 1) - if not xsrfutil.validate_token(xsrf_secret_key(), token, user.user_id(), - action_id=uri): - raise InvalidXsrfTokenError() + Returns: + The redirect URI. + """ + uri, token = state.rsplit(':', 1) + if not xsrfutil.validate_token(xsrf_secret_key(), token, user.user_id(), + action_id=uri): + raise InvalidXsrfTokenError() - return uri + return uri class OAuth2Decorator(object): - """Utility for making OAuth 2.0 easier. + """Utility for making OAuth 2.0 easier. - Instantiate and then use with oauth_required or oauth_aware - as decorators on webapp.RequestHandler methods. + Instantiate and then use with oauth_required or oauth_aware + as decorators on webapp.RequestHandler methods. - :: + :: - decorator = OAuth2Decorator( - client_id='837...ent.com', - client_secret='Qh...wwI', - scope='https://www.googleapis.com/auth/plus') + decorator = OAuth2Decorator( + client_id='837...ent.com', + client_secret='Qh...wwI', + scope='https://www.googleapis.com/auth/plus') - class MainHandler(webapp.RequestHandler): - @decorator.oauth_required - def get(self): - http = decorator.http() - # http is authorized with the user's Credentials and can be used - # in API calls - - """ - - def set_credentials(self, credentials): - self._tls.credentials = credentials - - def get_credentials(self): - """A thread local Credentials object. - - Returns: - A client.Credentials object, or None if credentials hasn't been set in - this thread yet, which may happen when calling has_credentials inside - oauth_aware. - """ - return getattr(self._tls, 'credentials', None) - - credentials = property(get_credentials, set_credentials) - - def set_flow(self, flow): - self._tls.flow = flow - - def get_flow(self): - """A thread local Flow object. - - Returns: - A credentials.Flow object, or None if the flow hasn't been set in this - thread yet, which happens in _create_flow() since Flows are created - lazily. - """ - return getattr(self._tls, 'flow', None) - - flow = property(get_flow, set_flow) - - - @util.positional(4) - def __init__(self, client_id, client_secret, scope, - auth_uri=GOOGLE_AUTH_URI, - token_uri=GOOGLE_TOKEN_URI, - revoke_uri=GOOGLE_REVOKE_URI, - user_agent=None, - message=None, - callback_path='/oauth2callback', - token_response_param=None, - _storage_class=StorageByKeyName, - _credentials_class=CredentialsModel, - _credentials_property_name='credentials', - **kwargs): - - """Constructor for OAuth2Decorator - - Args: - client_id: string, client identifier. - client_secret: string client secret. - scope: string or iterable of strings, scope(s) of the credentials being - requested. - auth_uri: string, URI for authorization endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - token_uri: string, URI for token endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - revoke_uri: string, URI for revoke endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - user_agent: string, User agent of your application, default to None. - message: Message to display if there are problems with the OAuth 2.0 - configuration. The message may contain HTML and will be presented on the - web interface for any method that uses the decorator. - callback_path: string, The absolute path to use as the callback URI. Note - that this must match up with the URI given when registering the - application in the APIs Console. - token_response_param: string. If provided, the full JSON response - to the access token request will be encoded and included in this query - parameter in the callback URI. This is useful with providers (e.g. - wordpress.com) that include extra fields that the client may want. - _storage_class: "Protected" keyword argument not typically provided to - this constructor. A storage class to aid in storing a Credentials object - for a user in the datastore. Defaults to StorageByKeyName. - _credentials_class: "Protected" keyword argument not typically provided to - this constructor. A db or ndb Model class to hold credentials. Defaults - to CredentialsModel. - _credentials_property_name: "Protected" keyword argument not typically - provided to this constructor. A string indicating the name of the field - on the _credentials_class where a Credentials object will be stored. - Defaults to 'credentials'. - **kwargs: dict, Keyword arguments are passed along as kwargs to - the OAuth2WebServerFlow constructor. + class MainHandler(webapp.RequestHandler): + @decorator.oauth_required + def get(self): + http = decorator.http() + # http is authorized with the user's Credentials and can be + # used in API calls """ - self._tls = threading.local() - self.flow = None - self.credentials = None - self._client_id = client_id - self._client_secret = client_secret - self._scope = util.scopes_to_string(scope) - self._auth_uri = auth_uri - self._token_uri = token_uri - self._revoke_uri = revoke_uri - self._user_agent = user_agent - self._kwargs = kwargs - self._message = message - self._in_error = False - self._callback_path = callback_path - self._token_response_param = token_response_param - self._storage_class = _storage_class - self._credentials_class = _credentials_class - self._credentials_property_name = _credentials_property_name - def _display_error_message(self, request_handler): - request_handler.response.out.write('') - request_handler.response.out.write(_safe_html(self._message)) - request_handler.response.out.write('') + def set_credentials(self, credentials): + self._tls.credentials = credentials - def oauth_required(self, method): - """Decorator that starts the OAuth 2.0 dance. + def get_credentials(self): + """A thread local Credentials object. - Starts the OAuth dance for the logged in user if they haven't already - granted access for this application. + Returns: + A client.Credentials object, or None if credentials hasn't been set + in this thread yet, which may happen when calling has_credentials + inside oauth_aware. + """ + return getattr(self._tls, 'credentials', None) - Args: - method: callable, to be decorated method of a webapp.RequestHandler - instance. - """ + credentials = property(get_credentials, set_credentials) - def check_oauth(request_handler, *args, **kwargs): - if self._in_error: - self._display_error_message(request_handler) - return + def set_flow(self, flow): + self._tls.flow = flow - user = users.get_current_user() - # Don't use @login_decorator as this could be used in a POST request. - if not user: - request_handler.redirect(users.create_login_url( - request_handler.request.uri)) - return + def get_flow(self): + """A thread local Flow object. - self._create_flow(request_handler) + Returns: + A credentials.Flow object, or None if the flow hasn't been set in + this thread yet, which happens in _create_flow() since Flows are + created lazily. + """ + return getattr(self._tls, 'flow', None) - # Store the request URI in 'state' so we can use it later - self.flow.params['state'] = _build_state_value(request_handler, user) - self.credentials = self._storage_class( - self._credentials_class, None, - self._credentials_property_name, user=user).get() + flow = property(get_flow, set_flow) - if not self.has_credentials(): - return request_handler.redirect(self.authorize_url()) - try: - resp = method(request_handler, *args, **kwargs) - except AccessTokenRefreshError: - return request_handler.redirect(self.authorize_url()) - finally: + @util.positional(4) + def __init__(self, client_id, client_secret, scope, + auth_uri=GOOGLE_AUTH_URI, + token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI, + user_agent=None, + message=None, + callback_path='/oauth2callback', + token_response_param=None, + _storage_class=StorageByKeyName, + _credentials_class=CredentialsModel, + _credentials_property_name='credentials', + **kwargs): + """Constructor for OAuth2Decorator + + Args: + client_id: string, client identifier. + client_secret: string client secret. + scope: string or iterable of strings, scope(s) of the credentials + being requested. + auth_uri: string, URI for authorization endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider + can be used. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. + user_agent: string, User agent of your application, default to + None. + message: Message to display if there are problems with the + OAuth 2.0 configuration. The message may contain HTML and + will be presented on the web interface for any method that + uses the decorator. + callback_path: string, The absolute path to use as the callback + URI. Note that this must match up with the URI given + when registering the application in the APIs + Console. + token_response_param: string. If provided, the full JSON response + to the access token request will be encoded + and included in this query parameter in the + callback URI. This is useful with providers + (e.g. wordpress.com) that include extra + fields that the client may want. + _storage_class: "Protected" keyword argument not typically provided + to this constructor. A storage class to aid in + storing a Credentials object for a user in the + datastore. Defaults to StorageByKeyName. + _credentials_class: "Protected" keyword argument not typically + provided to this constructor. A db or ndb Model + class to hold credentials. Defaults to + CredentialsModel. + _credentials_property_name: "Protected" keyword argument not + typically provided to this constructor. + A string indicating the name of the + field on the _credentials_class where a + Credentials object will be stored. + Defaults to 'credentials'. + **kwargs: dict, Keyword arguments are passed along as kwargs to + the OAuth2WebServerFlow constructor. + """ + self._tls = threading.local() + self.flow = None self.credentials = None - return resp + self._client_id = client_id + self._client_secret = client_secret + self._scope = util.scopes_to_string(scope) + self._auth_uri = auth_uri + self._token_uri = token_uri + self._revoke_uri = revoke_uri + self._user_agent = user_agent + self._kwargs = kwargs + self._message = message + self._in_error = False + self._callback_path = callback_path + self._token_response_param = token_response_param + self._storage_class = _storage_class + self._credentials_class = _credentials_class + self._credentials_property_name = _credentials_property_name - return check_oauth + def _display_error_message(self, request_handler): + request_handler.response.out.write('') + request_handler.response.out.write(_safe_html(self._message)) + request_handler.response.out.write('') - def _create_flow(self, request_handler): - """Create the Flow object. + def oauth_required(self, method): + """Decorator that starts the OAuth 2.0 dance. - The Flow is calculated lazily since we don't know where this app is - running until it receives a request, at which point redirect_uri can be - calculated and then the Flow object can be constructed. + Starts the OAuth dance for the logged in user if they haven't already + granted access for this application. - Args: - request_handler: webapp.RequestHandler, the request handler. - """ - if self.flow is None: - redirect_uri = request_handler.request.relative_url( - self._callback_path) # Usually /oauth2callback - self.flow = OAuth2WebServerFlow(self._client_id, self._client_secret, - self._scope, redirect_uri=redirect_uri, - user_agent=self._user_agent, - auth_uri=self._auth_uri, - token_uri=self._token_uri, - revoke_uri=self._revoke_uri, - **self._kwargs) + Args: + method: callable, to be decorated method of a webapp.RequestHandler + instance. + """ - def oauth_aware(self, method): - """Decorator that sets up for OAuth 2.0 dance, but doesn't do it. + def check_oauth(request_handler, *args, **kwargs): + if self._in_error: + self._display_error_message(request_handler) + return - Does all the setup for the OAuth dance, but doesn't initiate it. - This decorator is useful if you want to create a page that knows - whether or not the user has granted access to this application. - From within a method decorated with @oauth_aware the has_credentials() - and authorize_url() methods can be called. + user = users.get_current_user() + # Don't use @login_decorator as this could be used in a + # POST request. + if not user: + request_handler.redirect(users.create_login_url( + request_handler.request.uri)) + return - Args: - method: callable, to be decorated method of a webapp.RequestHandler - instance. - """ + self._create_flow(request_handler) - def setup_oauth(request_handler, *args, **kwargs): - if self._in_error: - self._display_error_message(request_handler) - return + # Store the request URI in 'state' so we can use it later + self.flow.params['state'] = _build_state_value( + request_handler, user) + self.credentials = self._storage_class( + self._credentials_class, None, + self._credentials_property_name, user=user).get() - user = users.get_current_user() - # Don't use @login_decorator as this could be used in a POST request. - if not user: - request_handler.redirect(users.create_login_url( - request_handler.request.uri)) - return + if not self.has_credentials(): + return request_handler.redirect(self.authorize_url()) + try: + resp = method(request_handler, *args, **kwargs) + except AccessTokenRefreshError: + return request_handler.redirect(self.authorize_url()) + finally: + self.credentials = None + return resp - self._create_flow(request_handler) + return check_oauth - self.flow.params['state'] = _build_state_value(request_handler, user) - self.credentials = self._storage_class( - self._credentials_class, None, - self._credentials_property_name, user=user).get() - try: - resp = method(request_handler, *args, **kwargs) - finally: - self.credentials = None - return resp - return setup_oauth + def _create_flow(self, request_handler): + """Create the Flow object. + The Flow is calculated lazily since we don't know where this app is + running until it receives a request, at which point redirect_uri can be + calculated and then the Flow object can be constructed. - def has_credentials(self): - """True if for the logged in user there are valid access Credentials. + Args: + request_handler: webapp.RequestHandler, the request handler. + """ + if self.flow is None: + redirect_uri = request_handler.request.relative_url( + self._callback_path) # Usually /oauth2callback + self.flow = OAuth2WebServerFlow( + self._client_id, self._client_secret, self._scope, + redirect_uri=redirect_uri, user_agent=self._user_agent, + auth_uri=self._auth_uri, token_uri=self._token_uri, + revoke_uri=self._revoke_uri, **self._kwargs) - Must only be called from with a webapp.RequestHandler subclassed method - that had been decorated with either @oauth_required or @oauth_aware. - """ - return self.credentials is not None and not self.credentials.invalid + def oauth_aware(self, method): + """Decorator that sets up for OAuth 2.0 dance, but doesn't do it. - def authorize_url(self): - """Returns the URL to start the OAuth dance. + Does all the setup for the OAuth dance, but doesn't initiate it. + This decorator is useful if you want to create a page that knows + whether or not the user has granted access to this application. + From within a method decorated with @oauth_aware the has_credentials() + and authorize_url() methods can be called. - Must only be called from with a webapp.RequestHandler subclassed method - that had been decorated with either @oauth_required or @oauth_aware. - """ - url = self.flow.step1_get_authorize_url() - return str(url) + Args: + method: callable, to be decorated method of a webapp.RequestHandler + instance. + """ - def http(self, *args, **kwargs): - """Returns an authorized http instance. + def setup_oauth(request_handler, *args, **kwargs): + if self._in_error: + self._display_error_message(request_handler) + return - Must only be called from within an @oauth_required decorated method, or - from within an @oauth_aware decorated method where has_credentials() - returns True. + user = users.get_current_user() + # Don't use @login_decorator as this could be used in a + # POST request. + if not user: + request_handler.redirect(users.create_login_url( + request_handler.request.uri)) + return - Args: - *args: Positional arguments passed to httplib2.Http constructor. - **kwargs: Positional arguments passed to httplib2.Http constructor. - """ - return self.credentials.authorize(httplib2.Http(*args, **kwargs)) + self._create_flow(request_handler) - @property - def callback_path(self): - """The absolute path where the callback will occur. + self.flow.params['state'] = _build_state_value(request_handler, + user) + self.credentials = self._storage_class( + self._credentials_class, None, + self._credentials_property_name, user=user).get() + try: + resp = method(request_handler, *args, **kwargs) + finally: + self.credentials = None + return resp + return setup_oauth - Note this is the absolute path, not the absolute URI, that will be - calculated by the decorator at runtime. See callback_handler() for how this - should be used. + def has_credentials(self): + """True if for the logged in user there are valid access Credentials. - Returns: - The callback path as a string. - """ - return self._callback_path + Must only be called from with a webapp.RequestHandler subclassed method + that had been decorated with either @oauth_required or @oauth_aware. + """ + return self.credentials is not None and not self.credentials.invalid + def authorize_url(self): + """Returns the URL to start the OAuth dance. - def callback_handler(self): - """RequestHandler for the OAuth 2.0 redirect callback. + Must only be called from with a webapp.RequestHandler subclassed method + that had been decorated with either @oauth_required or @oauth_aware. + """ + url = self.flow.step1_get_authorize_url() + return str(url) - Usage:: + def http(self, *args, **kwargs): + """Returns an authorized http instance. - app = webapp.WSGIApplication([ - ('/index', MyIndexHandler), - ..., - (decorator.callback_path, decorator.callback_handler()) - ]) + Must only be called from within an @oauth_required decorated method, or + from within an @oauth_aware decorated method where has_credentials() + returns True. - Returns: - A webapp.RequestHandler that handles the redirect back from the - server during the OAuth 2.0 dance. - """ - decorator = self + Args: + *args: Positional arguments passed to httplib2.Http constructor. + **kwargs: Positional arguments passed to httplib2.Http constructor. + """ + return self.credentials.authorize(httplib2.Http(*args, **kwargs)) - class OAuth2Handler(webapp.RequestHandler): - """Handler for the redirect_uri of the OAuth 2.0 dance.""" + @property + def callback_path(self): + """The absolute path where the callback will occur. - @login_required - def get(self): - error = self.request.get('error') - if error: - errormsg = self.request.get('error_description', error) - self.response.out.write( - 'The authorization request failed: %s' % _safe_html(errormsg)) - else: - user = users.get_current_user() - decorator._create_flow(self) - credentials = decorator.flow.step2_exchange(self.request.params) - decorator._storage_class( - decorator._credentials_class, None, - decorator._credentials_property_name, user=user).put(credentials) - redirect_uri = _parse_state_value(str(self.request.get('state')), - user) + Note this is the absolute path, not the absolute URI, that will be + calculated by the decorator at runtime. See callback_handler() for how + this should be used. - if decorator._token_response_param and credentials.token_response: - resp_json = json.dumps(credentials.token_response) - redirect_uri = util._add_query_parameter( - redirect_uri, decorator._token_response_param, resp_json) + Returns: + The callback path as a string. + """ + return self._callback_path - self.redirect(redirect_uri) + def callback_handler(self): + """RequestHandler for the OAuth 2.0 redirect callback. - return OAuth2Handler + Usage:: - def callback_application(self): - """WSGI application for handling the OAuth 2.0 redirect callback. + app = webapp.WSGIApplication([ + ('/index', MyIndexHandler), + ..., + (decorator.callback_path, decorator.callback_handler()) + ]) - If you need finer grained control use `callback_handler` which returns just - the webapp.RequestHandler. + Returns: + A webapp.RequestHandler that handles the redirect back from the + server during the OAuth 2.0 dance. + """ + decorator = self - Returns: - A webapp.WSGIApplication that handles the redirect back from the - server during the OAuth 2.0 dance. - """ - return webapp.WSGIApplication([ - (self.callback_path, self.callback_handler()) + class OAuth2Handler(webapp.RequestHandler): + """Handler for the redirect_uri of the OAuth 2.0 dance.""" + + @login_required + def get(self): + error = self.request.get('error') + if error: + errormsg = self.request.get('error_description', error) + self.response.out.write( + 'The authorization request failed: %s' % + _safe_html(errormsg)) + else: + user = users.get_current_user() + decorator._create_flow(self) + credentials = decorator.flow.step2_exchange( + self.request.params) + decorator._storage_class( + decorator._credentials_class, None, + decorator._credentials_property_name, + user=user).put(credentials) + redirect_uri = _parse_state_value( + str(self.request.get('state')), user) + + if (decorator._token_response_param and + credentials.token_response): + resp_json = json.dumps(credentials.token_response) + redirect_uri = util._add_query_parameter( + redirect_uri, decorator._token_response_param, + resp_json) + + self.redirect(redirect_uri) + + return OAuth2Handler + + def callback_application(self): + """WSGI application for handling the OAuth 2.0 redirect callback. + + If you need finer grained control use `callback_handler` which returns + just the webapp.RequestHandler. + + Returns: + A webapp.WSGIApplication that handles the redirect back from the + server during the OAuth 2.0 dance. + """ + return webapp.WSGIApplication([ + (self.callback_path, self.callback_handler()) ]) class OAuth2DecoratorFromClientSecrets(OAuth2Decorator): - """An OAuth2Decorator that builds from a clientsecrets file. + """An OAuth2Decorator that builds from a clientsecrets file. - Uses a clientsecrets file as the source for all the information when - constructing an OAuth2Decorator. + Uses a clientsecrets file as the source for all the information when + constructing an OAuth2Decorator. - :: + :: - decorator = OAuth2DecoratorFromClientSecrets( - os.path.join(os.path.dirname(__file__), 'client_secrets.json') - scope='https://www.googleapis.com/auth/plus') + decorator = OAuth2DecoratorFromClientSecrets( + os.path.join(os.path.dirname(__file__), 'client_secrets.json') + scope='https://www.googleapis.com/auth/plus') - class MainHandler(webapp.RequestHandler): - @decorator.oauth_required - def get(self): - http = decorator.http() - # http is authorized with the user's Credentials and can be used - # in API calls + class MainHandler(webapp.RequestHandler): + @decorator.oauth_required + def get(self): + http = decorator.http() + # http is authorized with the user's Credentials and can be + # used in API calls - """ - - @util.positional(3) - def __init__(self, filename, scope, message=None, cache=None, **kwargs): - """Constructor - - Args: - filename: string, File name of client secrets. - scope: string or iterable of strings, scope(s) of the credentials being - requested. - message: string, A friendly string to display to the user if the - clientsecrets file is missing or invalid. The message may contain HTML - and will be presented on the web interface for any method that uses the - decorator. - cache: An optional cache service client that implements get() and set() - methods. See clientsecrets.loadfile() for details. - **kwargs: dict, Keyword arguments are passed along as kwargs to - the OAuth2WebServerFlow constructor. """ - client_type, client_info = clientsecrets.loadfile(filename, cache=cache) - if client_type not in [ - clientsecrets.TYPE_WEB, clientsecrets.TYPE_INSTALLED]: - raise InvalidClientSecretsError( - "OAuth2Decorator doesn't support this OAuth 2.0 flow.") - constructor_kwargs = dict(kwargs) - constructor_kwargs.update({ - 'auth_uri': client_info['auth_uri'], - 'token_uri': client_info['token_uri'], - 'message': message, - }) - revoke_uri = client_info.get('revoke_uri') - if revoke_uri is not None: - constructor_kwargs['revoke_uri'] = revoke_uri - super(OAuth2DecoratorFromClientSecrets, self).__init__( - client_info['client_id'], client_info['client_secret'], - scope, **constructor_kwargs) - if message is not None: - self._message = message - else: - self._message = 'Please configure your application for OAuth 2.0.' + + @util.positional(3) + def __init__(self, filename, scope, message=None, cache=None, **kwargs): + """Constructor + + Args: + filename: string, File name of client secrets. + scope: string or iterable of strings, scope(s) of the credentials + being requested. + message: string, A friendly string to display to the user if the + clientsecrets file is missing or invalid. The message may + contain HTML and will be presented on the web interface + for any method that uses the decorator. + cache: An optional cache service client that implements get() and + set() + methods. See clientsecrets.loadfile() for details. + **kwargs: dict, Keyword arguments are passed along as kwargs to + the OAuth2WebServerFlow constructor. + """ + client_type, client_info = clientsecrets.loadfile(filename, + cache=cache) + if client_type not in (clientsecrets.TYPE_WEB, + clientsecrets.TYPE_INSTALLED): + raise InvalidClientSecretsError( + "OAuth2Decorator doesn't support this OAuth 2.0 flow.") + + constructor_kwargs = dict(kwargs) + constructor_kwargs.update({ + 'auth_uri': client_info['auth_uri'], + 'token_uri': client_info['token_uri'], + 'message': message, + }) + revoke_uri = client_info.get('revoke_uri') + if revoke_uri is not None: + constructor_kwargs['revoke_uri'] = revoke_uri + super(OAuth2DecoratorFromClientSecrets, self).__init__( + client_info['client_id'], client_info['client_secret'], + scope, **constructor_kwargs) + if message is not None: + self._message = message + else: + self._message = 'Please configure your application for OAuth 2.0.' @util.positional(2) def oauth2decorator_from_clientsecrets(filename, scope, message=None, cache=None): - """Creates an OAuth2Decorator populated from a clientsecrets file. + """Creates an OAuth2Decorator populated from a clientsecrets file. - Args: - filename: string, File name of client secrets. - scope: string or list of strings, scope(s) of the credentials being - requested. - message: string, A friendly string to display to the user if the - clientsecrets file is missing or invalid. The message may contain HTML and - will be presented on the web interface for any method that uses the - decorator. - cache: An optional cache service client that implements get() and set() - methods. See clientsecrets.loadfile() for details. + Args: + filename: string, File name of client secrets. + scope: string or list of strings, scope(s) of the credentials being + requested. + message: string, A friendly string to display to the user if the + clientsecrets file is missing or invalid. The message may + contain HTML and will be presented on the web interface for + any method that uses the decorator. + cache: An optional cache service client that implements get() and set() + methods. See clientsecrets.loadfile() for details. - Returns: An OAuth2Decorator - - """ - return OAuth2DecoratorFromClientSecrets(filename, scope, - message=message, cache=cache) + Returns: An OAuth2Decorator + """ + return OAuth2DecoratorFromClientSecrets(filename, scope, + message=message, cache=cache) diff --git a/oauth2client/client.py b/oauth2client/client.py index f3ec3a34..8485c57e 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -17,8 +17,6 @@ Tools for interacting with OAuth 2.0 protected resources. """ -__author__ = 'jcgregorio@google.com (Joe Gregorio)' - import base64 import collections import copy @@ -28,27 +26,37 @@ import logging import os import socket import sys +import tempfile import time +import shutil import six from six.moves import urllib import httplib2 -from oauth2client import clientsecrets from oauth2client import GOOGLE_AUTH_URI from oauth2client import GOOGLE_DEVICE_URI from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_TOKEN_URI +from oauth2client import GOOGLE_TOKEN_INFO_URI +from oauth2client._helpers import _from_bytes +from oauth2client._helpers import _to_bytes +from oauth2client._helpers import _urlsafe_b64decode +from oauth2client import clientsecrets from oauth2client import util + +__author__ = 'jcgregorio@google.com (Joe Gregorio)' + HAS_OPENSSL = False HAS_CRYPTO = False try: - from oauth2client import crypt - HAS_CRYPTO = True - if crypt.OpenSSLVerifier is not None: - HAS_OPENSSL = True + from oauth2client import crypt + HAS_CRYPTO = True + if crypt.OpenSSLVerifier is not None: + HAS_OPENSSL = True except ImportError: - pass + pass + logger = logging.getLogger(__name__) @@ -76,16 +84,22 @@ SERVICE_ACCOUNT = 'service_account' # The environment variable pointing the file with local # Application Default Credentials. GOOGLE_APPLICATION_CREDENTIALS = 'GOOGLE_APPLICATION_CREDENTIALS' +# The ~/.config subdirectory containing gcloud credentials. Intended +# to be swapped out in tests. +_CLOUDSDK_CONFIG_DIRECTORY = 'gcloud' +# The environment variable name which can replace ~/.config if set. +_CLOUDSDK_CONFIG_ENV_VAR = 'CLOUDSDK_CONFIG' # The error message we show users when we can't find the Application # Default Credentials. ADC_HELP_MSG = ( - 'The Application Default Credentials are not available. They are available ' - 'if running in Google Compute Engine. Otherwise, the environment variable ' - + GOOGLE_APPLICATION_CREDENTIALS + + 'The Application Default Credentials are not available. They are ' + 'available if running in Google Compute Engine. Otherwise, the ' + 'environment variable ' + + GOOGLE_APPLICATION_CREDENTIALS + ' must be defined pointing to a file defining the credentials. See ' - 'https://developers.google.com/accounts/docs/application-default-credentials' # pylint:disable=line-too-long - ' for more information.') + 'https://developers.google.com/accounts/docs/' + 'application-default-credentials for more information.') # The access token along with the seconds in which it expires. AccessTokenInfo = collections.namedtuple( @@ -96,1515 +110,1655 @@ DEFAULT_ENV_NAME = 'UNKNOWN' # If set to True _get_environment avoid GCE check (_detect_gce_environment) NO_GCE_CHECK = os.environ.setdefault('NO_GCE_CHECK', 'False') +_SERVER_SOFTWARE = 'SERVER_SOFTWARE' +_GCE_METADATA_HOST = '169.254.169.254' +_METADATA_FLAVOR_HEADER = 'Metadata-Flavor' +_DESIRED_METADATA_FLAVOR = 'Google' + + class SETTINGS(object): - """Settings namespace for globally defined values.""" - env_name = None + """Settings namespace for globally defined values.""" + env_name = None class Error(Exception): - """Base error for this module.""" + """Base error for this module.""" class FlowExchangeError(Error): - """Error trying to exchange an authorization grant for an access token.""" + """Error trying to exchange an authorization grant for an access token.""" class AccessTokenRefreshError(Error): - """Error trying to refresh an expired access token.""" + """Error trying to refresh an expired access token.""" class TokenRevokeError(Error): - """Error trying to revoke a token.""" + """Error trying to revoke a token.""" class UnknownClientSecretsFlowError(Error): - """The client secrets file called for an unknown type of OAuth 2.0 flow. """ + """The client secrets file called for an unknown type of OAuth 2.0 flow.""" class AccessTokenCredentialsError(Error): - """Having only the access_token means no refresh is possible.""" + """Having only the access_token means no refresh is possible.""" class VerifyJwtTokenError(Error): - """Could not retrieve certificates for validation.""" + """Could not retrieve certificates for validation.""" class NonAsciiHeaderError(Error): - """Header names and values must be ASCII strings.""" + """Header names and values must be ASCII strings.""" class ApplicationDefaultCredentialsError(Error): - """Error retrieving the Application Default Credentials.""" + """Error retrieving the Application Default Credentials.""" class OAuth2DeviceCodeError(Error): - """Error trying to retrieve a device code.""" + """Error trying to retrieve a device code.""" class CryptoUnavailableError(Error, NotImplementedError): - """Raised when a crypto library is required, but none is available.""" + """Raised when a crypto library is required, but none is available.""" def _abstract(): - raise NotImplementedError('You need to override this function') + raise NotImplementedError('You need to override this function') class MemoryCache(object): - """httplib2 Cache implementation which only caches locally.""" + """httplib2 Cache implementation which only caches locally.""" - def __init__(self): - self.cache = {} + def __init__(self): + self.cache = {} - def get(self, key): - return self.cache.get(key) + def get(self, key): + return self.cache.get(key) - def set(self, key, value): - self.cache[key] = value + def set(self, key, value): + self.cache[key] = value - def delete(self, key): - self.cache.pop(key, None) + def delete(self, key): + self.cache.pop(key, None) class Credentials(object): - """Base class for all Credentials objects. + """Base class for all Credentials objects. - Subclasses must define an authorize() method that applies the credentials to - an HTTP transport. + Subclasses must define an authorize() method that applies the credentials + to an HTTP transport. - Subclasses must also specify a classmethod named 'from_json' that takes a JSON - string as input and returns an instantiated Credentials object. - """ - - NON_SERIALIZED_MEMBERS = ['store'] - - - def authorize(self, http): - """Take an httplib2.Http instance (or equivalent) and authorizes it. - - Authorizes it for the set of credentials, usually by replacing - http.request() with a method that adds in the appropriate headers and then - delegates to the original Http.request() method. - - Args: - http: httplib2.Http, an http object to be used to make the refresh - request. + Subclasses must also specify a classmethod named 'from_json' that takes a + JSON string as input and returns an instantiated Credentials object. """ - _abstract() + NON_SERIALIZED_MEMBERS = ['store'] - def refresh(self, http): - """Forces a refresh of the access_token. + def authorize(self, http): + """Take an httplib2.Http instance (or equivalent) and authorizes it. - Args: - http: httplib2.Http, an http object to be used to make the refresh - request. - """ - _abstract() + Authorizes it for the set of credentials, usually by replacing + http.request() with a method that adds in the appropriate headers and + then delegates to the original Http.request() method. + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + """ + _abstract() - def revoke(self, http): - """Revokes a refresh_token and makes the credentials void. + def refresh(self, http): + """Forces a refresh of the access_token. - Args: - http: httplib2.Http, an http object to be used to make the revoke - request. - """ - _abstract() + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + """ + _abstract() + def revoke(self, http): + """Revokes a refresh_token and makes the credentials void. - def apply(self, headers): - """Add the authorization to the headers. + Args: + http: httplib2.Http, an http object to be used to make the revoke + request. + """ + _abstract() - Args: - headers: dict, the headers to add the Authorization header to. - """ - _abstract() + def apply(self, headers): + """Add the authorization to the headers. - def _to_json(self, strip): - """Utility function that creates JSON repr. of a Credentials object. + Args: + headers: dict, the headers to add the Authorization header to. + """ + _abstract() - Args: - strip: array, An array of names of members to not include in the JSON. + def _to_json(self, strip): + """Utility function that creates JSON repr. of a Credentials object. - Returns: - string, a JSON representation of this instance, suitable to pass to - from_json(). - """ - t = type(self) - d = copy.copy(self.__dict__) - for member in strip: - if member in d: - del d[member] - if (d.get('token_expiry') and - isinstance(d['token_expiry'], datetime.datetime)): - d['token_expiry'] = d['token_expiry'].strftime(EXPIRY_FORMAT) - # Add in information we will need later to reconsistitue this instance. - d['_class'] = t.__name__ - d['_module'] = t.__module__ - for key, val in d.items(): - if isinstance(val, bytes): - d[key] = val.decode('utf-8') - return json.dumps(d) + Args: + strip: array, An array of names of members to not include in the + JSON. - def to_json(self): - """Creating a JSON representation of an instance of Credentials. + Returns: + string, a JSON representation of this instance, suitable to pass to + from_json(). + """ + t = type(self) + d = copy.copy(self.__dict__) + for member in strip: + if member in d: + del d[member] + if (d.get('token_expiry') and + isinstance(d['token_expiry'], datetime.datetime)): + d['token_expiry'] = d['token_expiry'].strftime(EXPIRY_FORMAT) + # Add in information we will need later to reconsistitue this instance. + d['_class'] = t.__name__ + d['_module'] = t.__module__ + for key, val in d.items(): + if isinstance(val, bytes): + d[key] = val.decode('utf-8') + if isinstance(val, set): + d[key] = list(val) + return json.dumps(d) - Returns: - string, a JSON representation of this instance, suitable to pass to - from_json(). - """ - return self._to_json(Credentials.NON_SERIALIZED_MEMBERS) + def to_json(self): + """Creating a JSON representation of an instance of Credentials. - @classmethod - def new_from_json(cls, s): - """Utility class method to instantiate a Credentials subclass from a JSON - representation produced by to_json(). + Returns: + string, a JSON representation of this instance, suitable to pass to + from_json(). + """ + return self._to_json(Credentials.NON_SERIALIZED_MEMBERS) - Args: - s: string, JSON from to_json(). + @classmethod + def new_from_json(cls, s): + """Utility class method to instantiate a Credentials subclass from JSON. - Returns: - An instance of the subclass of Credentials that was serialized with - to_json(). - """ - if six.PY3 and isinstance(s, bytes): - s = s.decode('utf-8') - data = json.loads(s) - # Find and call the right classmethod from_json() to restore the object. - module = data['_module'] - try: - m = __import__(module) - except ImportError: - # In case there's an object from the old package structure, update it - module = module.replace('.googleapiclient', '') - m = __import__(module) + Expects the JSON string to have been produced by to_json(). - m = __import__(module, fromlist=module.split('.')[:-1]) - kls = getattr(m, data['_class']) - from_json = getattr(kls, 'from_json') - return from_json(s) + Args: + s: string or bytes, JSON from to_json(). - @classmethod - def from_json(cls, unused_data): - """Instantiate a Credentials object from a JSON description of it. + Returns: + An instance of the subclass of Credentials that was serialized with + to_json(). + """ + json_string_as_unicode = _from_bytes(s) + data = json.loads(json_string_as_unicode) + # Find and call the right classmethod from_json() to restore + # the object. + module_name = data['_module'] + try: + module_obj = __import__(module_name) + except ImportError: + # In case there's an object from the old package structure, + # update it + module_name = module_name.replace('.googleapiclient', '') + module_obj = __import__(module_name) - The JSON should have been produced by calling .to_json() on the object. + module_obj = __import__(module_name, + fromlist=module_name.split('.')[:-1]) + kls = getattr(module_obj, data['_class']) + from_json = getattr(kls, 'from_json') + return from_json(json_string_as_unicode) - Args: - unused_data: dict, A deserialized JSON object. + @classmethod + def from_json(cls, unused_data): + """Instantiate a Credentials object from a JSON description of it. - Returns: - An instance of a Credentials subclass. - """ - return Credentials() + The JSON should have been produced by calling .to_json() on the object. + + Args: + unused_data: dict, A deserialized JSON object. + + Returns: + An instance of a Credentials subclass. + """ + return Credentials() class Flow(object): - """Base class for all Flow objects.""" - pass + """Base class for all Flow objects.""" + pass class Storage(object): - """Base class for all Storage objects. + """Base class for all Storage objects. - Store and retrieve a single credential. This class supports locking - such that multiple processes and threads can operate on a single - store. - """ - - def acquire_lock(self): - """Acquires any lock necessary to access this Storage. - - This lock is not reentrant. + Store and retrieve a single credential. This class supports locking + such that multiple processes and threads can operate on a single + store. """ - pass - def release_lock(self): - """Release the Storage lock. + def acquire_lock(self): + """Acquires any lock necessary to access this Storage. - Trying to release a lock that isn't held will result in a - RuntimeError. - """ - pass + This lock is not reentrant. + """ + pass - def locked_get(self): - """Retrieve credential. + def release_lock(self): + """Release the Storage lock. - The Storage lock must be held when this is called. + Trying to release a lock that isn't held will result in a + RuntimeError. + """ + pass - Returns: - oauth2client.client.Credentials - """ - _abstract() + def locked_get(self): + """Retrieve credential. - def locked_put(self, credentials): - """Write a credential. + The Storage lock must be held when this is called. - The Storage lock must be held when this is called. + Returns: + oauth2client.client.Credentials + """ + _abstract() - Args: - credentials: Credentials, the credentials to store. - """ - _abstract() + def locked_put(self, credentials): + """Write a credential. - def locked_delete(self): - """Delete a credential. + The Storage lock must be held when this is called. - The Storage lock must be held when this is called. - """ - _abstract() + Args: + credentials: Credentials, the credentials to store. + """ + _abstract() - def get(self): - """Retrieve credential. + def locked_delete(self): + """Delete a credential. - The Storage lock must *not* be held when this is called. + The Storage lock must be held when this is called. + """ + _abstract() - Returns: - oauth2client.client.Credentials - """ - self.acquire_lock() - try: - return self.locked_get() - finally: - self.release_lock() + def get(self): + """Retrieve credential. - def put(self, credentials): - """Write a credential. + The Storage lock must *not* be held when this is called. - The Storage lock must be held when this is called. + Returns: + oauth2client.client.Credentials + """ + self.acquire_lock() + try: + return self.locked_get() + finally: + self.release_lock() - Args: - credentials: Credentials, the credentials to store. - """ - self.acquire_lock() - try: - self.locked_put(credentials) - finally: - self.release_lock() + def put(self, credentials): + """Write a credential. - def delete(self): - """Delete credential. + The Storage lock must be held when this is called. - Frees any resources associated with storing the credential. - The Storage lock must *not* be held when this is called. + Args: + credentials: Credentials, the credentials to store. + """ + self.acquire_lock() + try: + self.locked_put(credentials) + finally: + self.release_lock() - Returns: - None - """ - self.acquire_lock() - try: - return self.locked_delete() - finally: - self.release_lock() + def delete(self): + """Delete credential. + + Frees any resources associated with storing the credential. + The Storage lock must *not* be held when this is called. + + Returns: + None + """ + self.acquire_lock() + try: + return self.locked_delete() + finally: + self.release_lock() def clean_headers(headers): - """Forces header keys and values to be strings, i.e not unicode. + """Forces header keys and values to be strings, i.e not unicode. - The httplib module just concats the header keys and values in a way that may - make the message header a unicode string, which, if it then tries to - contatenate to a binary request body may result in a unicode decode error. + The httplib module just concats the header keys and values in a way that + may make the message header a unicode string, which, if it then tries to + contatenate to a binary request body may result in a unicode decode error. - Args: - headers: dict, A dictionary of headers. + Args: + headers: dict, A dictionary of headers. - Returns: - The same dictionary but with all the keys converted to strings. - """ - clean = {} - try: - for k, v in six.iteritems(headers): - clean_k = k if isinstance(k, bytes) else str(k).encode('ascii') - clean_v = v if isinstance(v, bytes) else str(v).encode('ascii') - clean[clean_k] = clean_v - except UnicodeEncodeError: - raise NonAsciiHeaderError(k + ': ' + v) - return clean + Returns: + The same dictionary but with all the keys converted to strings. + """ + clean = {} + try: + for k, v in six.iteritems(headers): + if not isinstance(k, six.binary_type): + k = str(k) + if not isinstance(v, six.binary_type): + v = str(v) + clean[_to_bytes(k)] = _to_bytes(v) + except UnicodeEncodeError: + raise NonAsciiHeaderError(k, ': ', v) + return clean def _update_query_params(uri, params): - """Updates a URI with new query parameters. + """Updates a URI with new query parameters. - Args: - uri: string, A valid URI, with potential existing query parameters. - params: dict, A dictionary of query parameters. + Args: + uri: string, A valid URI, with potential existing query parameters. + params: dict, A dictionary of query parameters. - Returns: - The same URI but with the new query parameters added. - """ - parts = urllib.parse.urlparse(uri) - query_params = dict(urllib.parse.parse_qsl(parts.query)) - query_params.update(params) - new_parts = parts._replace(query=urllib.parse.urlencode(query_params)) - return urllib.parse.urlunparse(new_parts) + Returns: + The same URI but with the new query parameters added. + """ + parts = urllib.parse.urlparse(uri) + query_params = dict(urllib.parse.parse_qsl(parts.query)) + query_params.update(params) + new_parts = parts._replace(query=urllib.parse.urlencode(query_params)) + return urllib.parse.urlunparse(new_parts) class OAuth2Credentials(Credentials): - """Credentials object for OAuth 2.0. + """Credentials object for OAuth 2.0. - Credentials can be applied to an httplib2.Http object using the authorize() - method, which then adds the OAuth 2.0 access token to each request. + Credentials can be applied to an httplib2.Http object using the authorize() + method, which then adds the OAuth 2.0 access token to each request. - OAuth2Credentials objects may be safely pickled and unpickled. - """ - - @util.positional(8) - def __init__(self, access_token, client_id, client_secret, refresh_token, - token_expiry, token_uri, user_agent, revoke_uri=None, - id_token=None, token_response=None): - """Create an instance of OAuth2Credentials. - - This constructor is not usually called by the user, instead - OAuth2Credentials objects are instantiated by the OAuth2WebServerFlow. - - Args: - access_token: string, access token. - client_id: string, client identifier. - client_secret: string, client secret. - refresh_token: string, refresh token. - token_expiry: datetime, when the access_token expires. - token_uri: string, URI of token endpoint. - user_agent: string, The HTTP User-Agent to provide for this application. - revoke_uri: string, URI for revoke endpoint. Defaults to None; a token - can't be revoked if this is None. - id_token: object, The identity of the resource owner. - token_response: dict, the decoded response to the token request. None - if a token hasn't been requested yet. Stored because some providers - (e.g. wordpress.com) include extra fields that clients may want. - - Notes: - store: callable, A callable that when passed a Credential - will store the credential back to where it came from. - This is needed to store the latest access_token if it - has expired and been refreshed. + OAuth2Credentials objects may be safely pickled and unpickled. """ - self.access_token = access_token - self.client_id = client_id - self.client_secret = client_secret - self.refresh_token = refresh_token - self.store = None - self.token_expiry = token_expiry - self.token_uri = token_uri - self.user_agent = user_agent - self.revoke_uri = revoke_uri - self.id_token = id_token - self.token_response = token_response - # True if the credentials have been revoked or expired and can't be - # refreshed. - self.invalid = False + @util.positional(8) + def __init__(self, access_token, client_id, client_secret, refresh_token, + token_expiry, token_uri, user_agent, revoke_uri=None, + id_token=None, token_response=None, scopes=None, + token_info_uri=None): + """Create an instance of OAuth2Credentials. - def authorize(self, http): - """Authorize an httplib2.Http instance with these credentials. + This constructor is not usually called by the user, instead + OAuth2Credentials objects are instantiated by the OAuth2WebServerFlow. - The modified http.request method will add authentication headers to each - request and will refresh access_tokens when a 401 is received on a - request. In addition the http.request method has a credentials property, - http.request.credentials, which is the Credentials object that authorized - it. + Args: + access_token: string, access token. + client_id: string, client identifier. + client_secret: string, client secret. + refresh_token: string, refresh token. + token_expiry: datetime, when the access_token expires. + token_uri: string, URI of token endpoint. + user_agent: string, The HTTP User-Agent to provide for this + application. + revoke_uri: string, URI for revoke endpoint. Defaults to None; a + token can't be revoked if this is None. + id_token: object, The identity of the resource owner. + token_response: dict, the decoded response to the token request. + None if a token hasn't been requested yet. Stored + because some providers (e.g. wordpress.com) include + extra fields that clients may want. + scopes: list, authorized scopes for these credentials. + token_info_uri: string, the URI for the token info endpoint. Defaults + to None; scopes can not be refreshed if this is None. - Args: - http: An instance of ``httplib2.Http`` or something that acts - like it. + Notes: + store: callable, A callable that when passed a Credential + will store the credential back to where it came from. + This is needed to store the latest access_token if it + has expired and been refreshed. + """ + self.access_token = access_token + self.client_id = client_id + self.client_secret = client_secret + self.refresh_token = refresh_token + self.store = None + self.token_expiry = token_expiry + self.token_uri = token_uri + self.user_agent = user_agent + self.revoke_uri = revoke_uri + self.id_token = id_token + self.token_response = token_response + self.scopes = set(util.string_to_scopes(scopes or [])) + self.token_info_uri = token_info_uri - Returns: - A modified instance of http that was passed in. + # True if the credentials have been revoked or expired and can't be + # refreshed. + self.invalid = False - Example:: + def authorize(self, http): + """Authorize an httplib2.Http instance with these credentials. - h = httplib2.Http() - h = credentials.authorize(h) + The modified http.request method will add authentication headers to + each request and will refresh access_tokens when a 401 is received on a + request. In addition the http.request method has a credentials + property, http.request.credentials, which is the Credentials object + that authorized it. - You can't create a new OAuth subclass of httplib2.Authentication - because it never gets passed the absolute URI, which is needed for - signing. So instead we have to overload 'request' with a closure - that adds in the Authorization header and then calls the original - version of 'request()'. + Args: + http: An instance of ``httplib2.Http`` or something that acts + like it. - """ - request_orig = http.request + Returns: + A modified instance of http that was passed in. - # The closure that will replace 'httplib2.Http.request'. - @util.positional(1) - def new_request(uri, method='GET', body=None, headers=None, - redirections=httplib2.DEFAULT_MAX_REDIRECTS, - connection_type=None): - if not self.access_token: - logger.info('Attempting refresh to obtain initial access_token') - self._refresh(request_orig) + Example:: - # Clone and modify the request headers to add the appropriate - # Authorization header. - if headers is None: - headers = {} - else: - headers = dict(headers) - self.apply(headers) + h = httplib2.Http() + h = credentials.authorize(h) - if self.user_agent is not None: - if 'user-agent' in headers: - headers['user-agent'] = self.user_agent + ' ' + headers['user-agent'] - else: - headers['user-agent'] = self.user_agent + You can't create a new OAuth subclass of httplib2.Authentication + because it never gets passed the absolute URI, which is needed for + signing. So instead we have to overload 'request' with a closure + that adds in the Authorization header and then calls the original + version of 'request()'. + """ + request_orig = http.request - resp, content = request_orig(uri, method, body, clean_headers(headers), - redirections, connection_type) + # The closure that will replace 'httplib2.Http.request'. + def new_request(uri, method='GET', body=None, headers=None, + redirections=httplib2.DEFAULT_MAX_REDIRECTS, + connection_type=None): + if not self.access_token: + logger.info('Attempting refresh to obtain ' + 'initial access_token') + self._refresh(request_orig) - if resp.status in REFRESH_STATUS_CODES: - logger.info('Refreshing due to a %s', resp.status) - self._refresh(request_orig) - self.apply(headers) - return request_orig(uri, method, body, clean_headers(headers), - redirections, connection_type) - else: - return (resp, content) + # Clone and modify the request headers to add the appropriate + # Authorization header. + if headers is None: + headers = {} + else: + headers = dict(headers) + self.apply(headers) - # Replace the request method with our own closure. - http.request = new_request + if self.user_agent is not None: + if 'user-agent' in headers: + headers['user-agent'] = (self.user_agent + ' ' + + headers['user-agent']) + else: + headers['user-agent'] = self.user_agent - # Set credentials as a property of the request method. - setattr(http.request, 'credentials', self) + body_stream_position = None + if all(getattr(body, stream_prop, None) for stream_prop in + ('read', 'seek', 'tell')): + body_stream_position = body.tell() - return http + resp, content = request_orig(uri, method, body, + clean_headers(headers), + redirections, connection_type) - def refresh(self, http): - """Forces a refresh of the access_token. + # A stored token may expire between the time it is retrieved and + # the time the request is made, so we may need to try twice. + max_refresh_attempts = 2 + for refresh_attempt in range(max_refresh_attempts): + if resp.status not in REFRESH_STATUS_CODES: + break + logger.info('Refreshing due to a %s (attempt %s/%s)', + resp.status, refresh_attempt + 1, + max_refresh_attempts) + self._refresh(request_orig) + self.apply(headers) + if body_stream_position is not None: + body.seek(body_stream_position) - Args: - http: httplib2.Http, an http object to be used to make the refresh - request. - """ - self._refresh(http.request) + resp, content = request_orig(uri, method, body, + clean_headers(headers), + redirections, connection_type) - def revoke(self, http): - """Revokes a refresh_token and makes the credentials void. + return (resp, content) - Args: - http: httplib2.Http, an http object to be used to make the revoke - request. - """ - self._revoke(http.request) + # Replace the request method with our own closure. + http.request = new_request - def apply(self, headers): - """Add the authorization to the headers. + # Set credentials as a property of the request method. + setattr(http.request, 'credentials', self) - Args: - headers: dict, the headers to add the Authorization header to. - """ - headers['Authorization'] = 'Bearer ' + self.access_token + return http - def to_json(self): - return self._to_json(Credentials.NON_SERIALIZED_MEMBERS) + def refresh(self, http): + """Forces a refresh of the access_token. - @classmethod - def from_json(cls, s): - """Instantiate a Credentials object from a JSON description of it. The JSON - should have been produced by calling .to_json() on the object. + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + """ + self._refresh(http.request) - Args: - data: dict, A deserialized JSON object. + def revoke(self, http): + """Revokes a refresh_token and makes the credentials void. - Returns: - An instance of a Credentials subclass. - """ - if six.PY3 and isinstance(s, bytes): - s = s.decode('utf-8') - data = json.loads(s) - if (data.get('token_expiry') and - not isinstance(data['token_expiry'], datetime.datetime)): - try: - data['token_expiry'] = datetime.datetime.strptime( - data['token_expiry'], EXPIRY_FORMAT) - except ValueError: - data['token_expiry'] = None - retval = cls( - data['access_token'], - data['client_id'], - data['client_secret'], - data['refresh_token'], - data['token_expiry'], - data['token_uri'], - data['user_agent'], - revoke_uri=data.get('revoke_uri', None), - id_token=data.get('id_token', None), - token_response=data.get('token_response', None)) - retval.invalid = data['invalid'] - return retval + Args: + http: httplib2.Http, an http object to be used to make the revoke + request. + """ + self._revoke(http.request) - @property - def access_token_expired(self): - """True if the credential is expired or invalid. + def apply(self, headers): + """Add the authorization to the headers. - If the token_expiry isn't set, we assume the token doesn't expire. - """ - if self.invalid: - return True + Args: + headers: dict, the headers to add the Authorization header to. + """ + headers['Authorization'] = 'Bearer ' + self.access_token - if not self.token_expiry: - return False + def has_scopes(self, scopes): + """Verify that the credentials are authorized for the given scopes. - now = datetime.datetime.utcnow() - if now >= self.token_expiry: - logger.info('access_token is expired. Now: %s, token_expiry: %s', - now, self.token_expiry) - return True - return False + Returns True if the credentials authorized scopes contain all of the + scopes given. - def get_access_token(self, http=None): - """Return the access token and its expiration information. + Args: + scopes: list or string, the scopes to check. - If the token does not exist, get one. - If the token expired, refresh it. - """ - if not self.access_token or self.access_token_expired: - if not http: - http = httplib2.Http() - self.refresh(http) - return AccessTokenInfo(access_token=self.access_token, - expires_in=self._expires_in()) + Notes: + There are cases where the credentials are unaware of which scopes + are authorized. Notably, credentials obtained and stored before + this code was added will not have scopes, AccessTokenCredentials do + not have scopes. In both cases, you can use refresh_scopes() to + obtain the canonical set of scopes. + """ + scopes = util.string_to_scopes(scopes) + return set(scopes).issubset(self.scopes) - def set_store(self, store): - """Set the Storage for the credential. + def retrieve_scopes(self, http): + """Retrieves the canonical list of scopes for this access token. - Args: - store: Storage, an implementation of Storage object. - This is needed to store the latest access_token if it - has expired and been refreshed. This implementation uses - locking to check for updates before updating the - access_token. - """ - self.store = store + Gets the scopes from the OAuth2 provider. - def _expires_in(self): - """Return the number of seconds until this token expires. + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. - If token_expiry is in the past, this method will return 0, meaning the - token has already expired. - If token_expiry is None, this method will return None. Note that returning - 0 in such a case would not be fair: the token may still be valid; - we just don't know anything about it. - """ - if self.token_expiry: - now = datetime.datetime.utcnow() - if self.token_expiry > now: - time_delta = self.token_expiry - now - # TODO(orestica): return time_delta.total_seconds() - # once dropping support for Python 2.6 - return time_delta.days * 86400 + time_delta.seconds - else: - return 0 + Returns: + A set of strings containing the canonical list of scopes. + """ + self._retrieve_scopes(http.request) + return self.scopes - def _updateFromCredential(self, other): - """Update this Credential from another instance.""" - self.__dict__.update(other.__getstate__()) + def to_json(self): + return self._to_json(Credentials.NON_SERIALIZED_MEMBERS) - def __getstate__(self): - """Trim the state down to something that can be pickled.""" - d = copy.copy(self.__dict__) - del d['store'] - return d + @classmethod + def from_json(cls, s): + """Instantiate a Credentials object from a JSON description of it. - def __setstate__(self, state): - """Reconstitute the state of the object from being pickled.""" - self.__dict__.update(state) - self.store = None + The JSON should have been produced by calling .to_json() on the object. - def _generate_refresh_request_body(self): - """Generate the body that will be used in the refresh request.""" - body = urllib.parse.urlencode({ - 'grant_type': 'refresh_token', - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'refresh_token': self.refresh_token, + Args: + data: dict, A deserialized JSON object. + + Returns: + An instance of a Credentials subclass. + """ + s = _from_bytes(s) + data = json.loads(s) + if (data.get('token_expiry') and + not isinstance(data['token_expiry'], datetime.datetime)): + try: + data['token_expiry'] = datetime.datetime.strptime( + data['token_expiry'], EXPIRY_FORMAT) + except ValueError: + data['token_expiry'] = None + retval = cls( + data['access_token'], + data['client_id'], + data['client_secret'], + data['refresh_token'], + data['token_expiry'], + data['token_uri'], + data['user_agent'], + revoke_uri=data.get('revoke_uri', None), + id_token=data.get('id_token', None), + token_response=data.get('token_response', None), + scopes=data.get('scopes', None), + token_info_uri=data.get('token_info_uri', None)) + retval.invalid = data['invalid'] + return retval + + @property + def access_token_expired(self): + """True if the credential is expired or invalid. + + If the token_expiry isn't set, we assume the token doesn't expire. + """ + if self.invalid: + return True + + if not self.token_expiry: + return False + + now = datetime.datetime.utcnow() + if now >= self.token_expiry: + logger.info('access_token is expired. Now: %s, token_expiry: %s', + now, self.token_expiry) + return True + return False + + def get_access_token(self, http=None): + """Return the access token and its expiration information. + + If the token does not exist, get one. + If the token expired, refresh it. + """ + if not self.access_token or self.access_token_expired: + if not http: + http = httplib2.Http() + self.refresh(http) + return AccessTokenInfo(access_token=self.access_token, + expires_in=self._expires_in()) + + def set_store(self, store): + """Set the Storage for the credential. + + Args: + store: Storage, an implementation of Storage object. + This is needed to store the latest access_token if it + has expired and been refreshed. This implementation uses + locking to check for updates before updating the + access_token. + """ + self.store = store + + def _expires_in(self): + """Return the number of seconds until this token expires. + + If token_expiry is in the past, this method will return 0, meaning the + token has already expired. + + If token_expiry is None, this method will return None. Note that + returning 0 in such a case would not be fair: the token may still be + valid; we just don't know anything about it. + """ + if self.token_expiry: + now = datetime.datetime.utcnow() + if self.token_expiry > now: + time_delta = self.token_expiry - now + # TODO(orestica): return time_delta.total_seconds() + # once dropping support for Python 2.6 + return time_delta.days * 86400 + time_delta.seconds + else: + return 0 + + def _updateFromCredential(self, other): + """Update this Credential from another instance.""" + self.__dict__.update(other.__getstate__()) + + def __getstate__(self): + """Trim the state down to something that can be pickled.""" + d = copy.copy(self.__dict__) + del d['store'] + return d + + def __setstate__(self, state): + """Reconstitute the state of the object from being pickled.""" + self.__dict__.update(state) + self.store = None + + def _generate_refresh_request_body(self): + """Generate the body that will be used in the refresh request.""" + body = urllib.parse.urlencode({ + 'grant_type': 'refresh_token', + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'refresh_token': self.refresh_token, }) - return body + return body - def _generate_refresh_request_headers(self): - """Generate the headers that will be used in the refresh request.""" - headers = { - 'content-type': 'application/x-www-form-urlencoded', - } + def _generate_refresh_request_headers(self): + """Generate the headers that will be used in the refresh request.""" + headers = { + 'content-type': 'application/x-www-form-urlencoded', + } - if self.user_agent is not None: - headers['user-agent'] = self.user_agent + if self.user_agent is not None: + headers['user-agent'] = self.user_agent - return headers + return headers - def _refresh(self, http_request): - """Refreshes the access_token. + def _refresh(self, http_request): + """Refreshes the access_token. - This method first checks by reading the Storage object if available. - If a refresh is still needed, it holds the Storage lock until the - refresh is completed. + This method first checks by reading the Storage object if available. + If a refresh is still needed, it holds the Storage lock until the + refresh is completed. - Args: - http_request: callable, a callable that matches the method signature of - httplib2.Http.request, used to make the refresh request. + Args: + http_request: callable, a callable that matches the method + signature of httplib2.Http.request, used to make the + refresh request. - Raises: - AccessTokenRefreshError: When the refresh fails. - """ - if not self.store: - self._do_refresh_request(http_request) - else: - self.store.acquire_lock() - try: - new_cred = self.store.locked_get() - if (new_cred and not new_cred.invalid and - new_cred.access_token != self.access_token): - logger.info('Updated access_token read from Storage') - self._updateFromCredential(new_cred) + Raises: + AccessTokenRefreshError: When the refresh fails. + """ + if not self.store: + self._do_refresh_request(http_request) else: - self._do_refresh_request(http_request) - finally: - self.store.release_lock() + self.store.acquire_lock() + try: + new_cred = self.store.locked_get() - def _do_refresh_request(self, http_request): - """Refresh the access_token using the refresh_token. + if (new_cred and not new_cred.invalid and + new_cred.access_token != self.access_token and + not new_cred.access_token_expired): + logger.info('Updated access_token read from Storage') + self._updateFromCredential(new_cred) + else: + self._do_refresh_request(http_request) + finally: + self.store.release_lock() - Args: - http_request: callable, a callable that matches the method signature of - httplib2.Http.request, used to make the refresh request. + def _do_refresh_request(self, http_request): + """Refresh the access_token using the refresh_token. - Raises: - AccessTokenRefreshError: When the refresh fails. - """ - body = self._generate_refresh_request_body() - headers = self._generate_refresh_request_headers() + Args: + http_request: callable, a callable that matches the method + signature of httplib2.Http.request, used to make the + refresh request. - logger.info('Refreshing access_token') - resp, content = http_request( - self.token_uri, method='POST', body=body, headers=headers) - if six.PY3 and isinstance(content, bytes): - content = content.decode('utf-8') - if resp.status == 200: - d = json.loads(content) - self.token_response = d - self.access_token = d['access_token'] - self.refresh_token = d.get('refresh_token', self.refresh_token) - if 'expires_in' in d: - self.token_expiry = datetime.timedelta( - seconds=int(d['expires_in'])) + datetime.datetime.utcnow() - else: - self.token_expiry = None - # On temporary refresh errors, the user does not actually have to - # re-authorize, so we unflag here. - self.invalid = False - if self.store: - self.store.locked_put(self) - else: - # An {'error':...} response body means the token is expired or revoked, - # so we flag the credentials as such. - logger.info('Failed to retrieve access token: %s', content) - error_msg = 'Invalid response %s.' % resp['status'] - try: - d = json.loads(content) - if 'error' in d: - error_msg = d['error'] - if 'error_description' in d: - error_msg += ': ' + d['error_description'] - self.invalid = True - if self.store: - self.store.locked_put(self) - except (TypeError, ValueError): - pass - raise AccessTokenRefreshError(error_msg) + Raises: + AccessTokenRefreshError: When the refresh fails. + """ + body = self._generate_refresh_request_body() + headers = self._generate_refresh_request_headers() - def _revoke(self, http_request): - """Revokes this credential and deletes the stored copy (if it exists). + logger.info('Refreshing access_token') + resp, content = http_request( + self.token_uri, method='POST', body=body, headers=headers) + content = _from_bytes(content) + if resp.status == 200: + d = json.loads(content) + self.token_response = d + self.access_token = d['access_token'] + self.refresh_token = d.get('refresh_token', self.refresh_token) + if 'expires_in' in d: + self.token_expiry = datetime.timedelta( + seconds=int(d['expires_in'])) + datetime.datetime.utcnow() + else: + self.token_expiry = None + # On temporary refresh errors, the user does not actually have to + # re-authorize, so we unflag here. + self.invalid = False + if self.store: + self.store.locked_put(self) + else: + # An {'error':...} response body means the token is expired or + # revoked, so we flag the credentials as such. + logger.info('Failed to retrieve access token: %s', content) + error_msg = 'Invalid response %s.' % resp['status'] + try: + d = json.loads(content) + if 'error' in d: + error_msg = d['error'] + if 'error_description' in d: + error_msg += ': ' + d['error_description'] + self.invalid = True + if self.store: + self.store.locked_put(self) + except (TypeError, ValueError): + pass + raise AccessTokenRefreshError(error_msg) - Args: - http_request: callable, a callable that matches the method signature of - httplib2.Http.request, used to make the revoke request. - """ - self._do_revoke(http_request, self.refresh_token or self.access_token) + def _revoke(self, http_request): + """Revokes this credential and deletes the stored copy (if it exists). - def _do_revoke(self, http_request, token): - """Revokes this credential and deletes the stored copy (if it exists). + Args: + http_request: callable, a callable that matches the method + signature of httplib2.Http.request, used to make the + revoke request. + """ + self._do_revoke(http_request, self.refresh_token or self.access_token) - Args: - http_request: callable, a callable that matches the method signature of - httplib2.Http.request, used to make the refresh request. - token: A string used as the token to be revoked. Can be either an - access_token or refresh_token. + def _do_revoke(self, http_request, token): + """Revokes this credential and deletes the stored copy (if it exists). - Raises: - TokenRevokeError: If the revoke request does not return with a 200 OK. - """ - logger.info('Revoking token') - query_params = {'token': token} - token_revoke_uri = _update_query_params(self.revoke_uri, query_params) - resp, content = http_request(token_revoke_uri) - if resp.status == 200: - self.invalid = True - else: - error_msg = 'Invalid response %s.' % resp.status - try: - d = json.loads(content) - if 'error' in d: - error_msg = d['error'] - except (TypeError, ValueError): - pass - raise TokenRevokeError(error_msg) + Args: + http_request: callable, a callable that matches the method + signature of httplib2.Http.request, used to make the + refresh request. + token: A string used as the token to be revoked. Can be either an + access_token or refresh_token. - if self.store: - self.store.delete() + Raises: + TokenRevokeError: If the revoke request does not return with a + 200 OK. + """ + logger.info('Revoking token') + query_params = {'token': token} + token_revoke_uri = _update_query_params(self.revoke_uri, query_params) + resp, content = http_request(token_revoke_uri) + if resp.status == 200: + self.invalid = True + else: + error_msg = 'Invalid response %s.' % resp.status + try: + d = json.loads(_from_bytes(content)) + if 'error' in d: + error_msg = d['error'] + except (TypeError, ValueError): + pass + raise TokenRevokeError(error_msg) + + if self.store: + self.store.delete() + + def _retrieve_scopes(self, http_request): + """Retrieves the list of authorized scopes from the OAuth2 provider. + + Args: + http_request: callable, a callable that matches the method + signature of httplib2.Http.request, used to make the + revoke request. + """ + self._do_retrieve_scopes(http_request, self.access_token) + + def _do_retrieve_scopes(self, http_request, token): + """Retrieves the list of authorized scopes from the OAuth2 provider. + + Args: + http_request: callable, a callable that matches the method + signature of httplib2.Http.request, used to make the + refresh request. + token: A string used as the token to identify the credentials to + the provider. + + Raises: + Error: When refresh fails, indicating the the access token is + invalid. + """ + logger.info('Refreshing scopes') + query_params = {'access_token': token, 'fields': 'scope'} + token_info_uri = _update_query_params(self.token_info_uri, + query_params) + resp, content = http_request(token_info_uri) + content = _from_bytes(content) + if resp.status == 200: + d = json.loads(content) + self.scopes = set(util.string_to_scopes(d.get('scope', ''))) + else: + error_msg = 'Invalid response %s.' % (resp.status,) + try: + d = json.loads(content) + if 'error_description' in d: + error_msg = d['error_description'] + except (TypeError, ValueError): + pass + raise Error(error_msg) class AccessTokenCredentials(OAuth2Credentials): - """Credentials object for OAuth 2.0. + """Credentials object for OAuth 2.0. - Credentials can be applied to an httplib2.Http object using the - authorize() method, which then signs each request from that object - with the OAuth 2.0 access token. This set of credentials is for the - use case where you have acquired an OAuth 2.0 access_token from - another place such as a JavaScript client or another web - application, and wish to use it from Python. Because only the - access_token is present it can not be refreshed and will in time - expire. + Credentials can be applied to an httplib2.Http object using the + authorize() method, which then signs each request from that object + with the OAuth 2.0 access token. This set of credentials is for the + use case where you have acquired an OAuth 2.0 access_token from + another place such as a JavaScript client or another web + application, and wish to use it from Python. Because only the + access_token is present it can not be refreshed and will in time + expire. - AccessTokenCredentials objects may be safely pickled and unpickled. + AccessTokenCredentials objects may be safely pickled and unpickled. - Usage:: + Usage:: - credentials = AccessTokenCredentials('', - 'my-user-agent/1.0') - http = httplib2.Http() - http = credentials.authorize(http) + credentials = AccessTokenCredentials('', + 'my-user-agent/1.0') + http = httplib2.Http() + http = credentials.authorize(http) - Exceptions: - AccessTokenCredentialsExpired: raised when the access_token expires or is - revoked. - """ - - def __init__(self, access_token, user_agent, revoke_uri=None): - """Create an instance of OAuth2Credentials - - This is one of the few types if Credentials that you should contrust, - Credentials objects are usually instantiated by a Flow. - - Args: - access_token: string, access token. - user_agent: string, The HTTP User-Agent to provide for this application. - revoke_uri: string, URI for revoke endpoint. Defaults to None; a token - can't be revoked if this is None. + Raises: + AccessTokenCredentialsExpired: raised when the access_token expires or + is revoked. """ - super(AccessTokenCredentials, self).__init__( - access_token, - None, - None, - None, - None, - None, - user_agent, - revoke_uri=revoke_uri) + + def __init__(self, access_token, user_agent, revoke_uri=None): + """Create an instance of OAuth2Credentials + + This is one of the few types if Credentials that you should contrust, + Credentials objects are usually instantiated by a Flow. + + Args: + access_token: string, access token. + user_agent: string, The HTTP User-Agent to provide for this + application. + revoke_uri: string, URI for revoke endpoint. Defaults to None; a + token can't be revoked if this is None. + """ + super(AccessTokenCredentials, self).__init__( + access_token, + None, + None, + None, + None, + None, + user_agent, + revoke_uri=revoke_uri) + + @classmethod + def from_json(cls, s): + data = json.loads(_from_bytes(s)) + retval = AccessTokenCredentials( + data['access_token'], + data['user_agent']) + return retval + + def _refresh(self, http_request): + raise AccessTokenCredentialsError( + 'The access_token is expired or invalid and can\'t be refreshed.') + + def _revoke(self, http_request): + """Revokes the access_token and deletes the store if available. + + Args: + http_request: callable, a callable that matches the method + signature of httplib2.Http.request, used to make the + revoke request. + """ + self._do_revoke(http_request, self.access_token) - @classmethod - def from_json(cls, s): - if six.PY3 and isinstance(s, bytes): - s = s.decode('utf-8') - data = json.loads(s) - retval = AccessTokenCredentials( - data['access_token'], - data['user_agent']) - return retval +def _detect_gce_environment(): + """Determine if the current environment is Compute Engine. - def _refresh(self, http_request): - raise AccessTokenCredentialsError( - 'The access_token is expired or invalid and can\'t be refreshed.') - - def _revoke(self, http_request): - """Revokes the access_token and deletes the store if available. - - Args: - http_request: callable, a callable that matches the method signature of - httplib2.Http.request, used to make the revoke request. + Returns: + Boolean indicating whether or not the current environment is Google + Compute Engine. """ - self._do_revoke(http_request, self.access_token) + # NOTE: The explicit ``timeout`` is a workaround. The underlying + # issue is that resolving an unknown host on some networks will take + # 20-30 seconds; making this timeout short fixes the issue, but + # could lead to false negatives in the event that we are on GCE, but + # the metadata resolution was particularly slow. The latter case is + # "unlikely". + connection = six.moves.http_client.HTTPConnection( + _GCE_METADATA_HOST, timeout=1) + + try: + headers = {_METADATA_FLAVOR_HEADER: _DESIRED_METADATA_FLAVOR} + connection.request('GET', '/', headers=headers) + response = connection.getresponse() + if response.status == 200: + return (response.getheader(_METADATA_FLAVOR_HEADER) == + _DESIRED_METADATA_FLAVOR) + except socket.error: # socket.timeout or socket.error(64, 'Host is down') + logger.info('Timeout attempting to reach GCE metadata service.') + return False + finally: + connection.close() -def _detect_gce_environment(urlopen=None): - """Determine if the current environment is Compute Engine. +def _in_gae_environment(): + """Detects if the code is running in the App Engine environment. - Args: - urlopen: Optional argument. Function used to open a connection to a URL. + Returns: + True if running in the GAE environment, False otherwise. + """ + if SETTINGS.env_name is not None: + return SETTINGS.env_name in ('GAE_PRODUCTION', 'GAE_LOCAL') + + try: + import google.appengine + server_software = os.environ.get(_SERVER_SOFTWARE, '') + if server_software.startswith('Google App Engine/'): + SETTINGS.env_name = 'GAE_PRODUCTION' + return True + elif server_software.startswith('Development/'): + SETTINGS.env_name = 'GAE_LOCAL' + return True + except ImportError: + pass - Returns: - Boolean indicating whether or not the current environment is Google - Compute Engine. - """ - urlopen = urlopen or urllib.request.urlopen - # Note: the explicit `timeout` below is a workaround. The underlying - # issue is that resolving an unknown host on some networks will take - # 20-30 seconds; making this timeout short fixes the issue, but - # could lead to false negatives in the event that we are on GCE, but - # the metadata resolution was particularly slow. The latter case is - # "unlikely". - try: - response = urlopen('http://169.254.169.254/', timeout=1) - return response.info().get('Metadata-Flavor', '') == 'Google' - except socket.timeout: - logger.info('Timeout attempting to reach GCE metadata service.') - return False - except urllib.error.URLError as e: - if isinstance(getattr(e, 'reason', None), socket.timeout): - logger.info('Timeout attempting to reach GCE metadata service.') return False -def _get_environment(urlopen=None): - """Detect the environment the code is being run on. +def _in_gce_environment(): + """Detect if the code is running in the Compute Engine environment. - Args: - urlopen: Optional argument. Function used to open a connection to a URL. + Returns: + True if running in the GCE environment, False otherwise. + """ + if SETTINGS.env_name is not None: + return SETTINGS.env_name == 'GCE_PRODUCTION' - Returns: - The value of SETTINGS.env_name after being set. If already - set, simply returns the value. - """ - if SETTINGS.env_name is not None: - return SETTINGS.env_name - - # None is an unset value, not the default. - SETTINGS.env_name = DEFAULT_ENV_NAME - - server_software = os.environ.get('SERVER_SOFTWARE', '') - if server_software.startswith('Google App Engine/'): - SETTINGS.env_name = 'GAE_PRODUCTION' - elif server_software.startswith('Development/'): - SETTINGS.env_name = 'GAE_LOCAL' - elif NO_GCE_CHECK != 'True' and _detect_gce_environment(urlopen=urlopen): - SETTINGS.env_name = 'GCE_PRODUCTION' - - return SETTINGS.env_name + if NO_GCE_CHECK != 'True' and _detect_gce_environment(): + SETTINGS.env_name = 'GCE_PRODUCTION' + return True + return False class GoogleCredentials(OAuth2Credentials): - """Application Default Credentials for use in calling Google APIs. + """Application Default Credentials for use in calling Google APIs. - The Application Default Credentials are being constructed as a function of - the environment where the code is being run. - More details can be found on this page: - https://developers.google.com/accounts/docs/application-default-credentials + The Application Default Credentials are being constructed as a function of + the environment where the code is being run. + More details can be found on this page: + https://developers.google.com/accounts/docs/application-default-credentials - Here is an example of how to use the Application Default Credentials for a - service that requires authentication: + Here is an example of how to use the Application Default Credentials for a + service that requires authentication:: - from googleapiclient.discovery import build - from oauth2client.client import GoogleCredentials + from googleapiclient.discovery import build + from oauth2client.client import GoogleCredentials - credentials = GoogleCredentials.get_application_default() - service = build('compute', 'v1', credentials=credentials) + credentials = GoogleCredentials.get_application_default() + service = build('compute', 'v1', credentials=credentials) - PROJECT = 'bamboo-machine-422' - ZONE = 'us-central1-a' - request = service.instances().list(project=PROJECT, zone=ZONE) - response = request.execute() + PROJECT = 'bamboo-machine-422' + ZONE = 'us-central1-a' + request = service.instances().list(project=PROJECT, zone=ZONE) + response = request.execute() - print(response) - """ + print(response) + """ - def __init__(self, access_token, client_id, client_secret, refresh_token, - token_expiry, token_uri, user_agent, - revoke_uri=GOOGLE_REVOKE_URI): - """Create an instance of GoogleCredentials. + def __init__(self, access_token, client_id, client_secret, refresh_token, + token_expiry, token_uri, user_agent, + revoke_uri=GOOGLE_REVOKE_URI): + """Create an instance of GoogleCredentials. - This constructor is not usually called by the user, instead - GoogleCredentials objects are instantiated by - GoogleCredentials.from_stream() or - GoogleCredentials.get_application_default(). + This constructor is not usually called by the user, instead + GoogleCredentials objects are instantiated by + GoogleCredentials.from_stream() or + GoogleCredentials.get_application_default(). + + Args: + access_token: string, access token. + client_id: string, client identifier. + client_secret: string, client secret. + refresh_token: string, refresh token. + token_expiry: datetime, when the access_token expires. + token_uri: string, URI of token endpoint. + user_agent: string, The HTTP User-Agent to provide for this + application. + revoke_uri: string, URI for revoke endpoint. Defaults to + GOOGLE_REVOKE_URI; a token can't be revoked if this + is None. + """ + super(GoogleCredentials, self).__init__( + access_token, client_id, client_secret, refresh_token, + token_expiry, token_uri, user_agent, revoke_uri=revoke_uri) + + def create_scoped_required(self): + """Whether this Credentials object is scopeless. + + create_scoped(scopes) method needs to be called in order to create + a Credentials object for API calls. + """ + return False + + def create_scoped(self, scopes): + """Create a Credentials object for the given scopes. + + The Credentials type is preserved. + """ + return self + + @property + def serialization_data(self): + """Get the fields and values identifying the current credentials.""" + return { + 'type': 'authorized_user', + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'refresh_token': self.refresh_token + } + + @staticmethod + def _implicit_credentials_from_gae(): + """Attempts to get implicit credentials in Google App Engine env. + + If the current environment is not detected as App Engine, returns None, + indicating no Google App Engine credentials can be detected from the + current environment. + + Returns: + None, if not in GAE, else an appengine.AppAssertionCredentials + object. + """ + if not _in_gae_environment(): + return None + + return _get_application_default_credential_GAE() + + @staticmethod + def _implicit_credentials_from_gce(): + """Attempts to get implicit credentials in Google Compute Engine env. + + If the current environment is not detected as Compute Engine, returns + None, indicating no Google Compute Engine credentials can be detected + from the current environment. + + Returns: + None, if not in GCE, else a gce.AppAssertionCredentials object. + """ + if not _in_gce_environment(): + return None + + return _get_application_default_credential_GCE() + + @staticmethod + def _implicit_credentials_from_files(): + """Attempts to get implicit credentials from local credential files. + + First checks if the environment variable GOOGLE_APPLICATION_CREDENTIALS + is set with a filename and then falls back to a configuration file (the + "well known" file) associated with the 'gcloud' command line tool. + + Returns: + Credentials object associated with the + GOOGLE_APPLICATION_CREDENTIALS file or the "well known" file if + either exist. If neither file is define, returns None, indicating + no credentials from a file can detected from the current + environment. + """ + credentials_filename = _get_environment_variable_file() + if not credentials_filename: + credentials_filename = _get_well_known_file() + if os.path.isfile(credentials_filename): + extra_help = (' (produced automatically when running' + ' "gcloud auth login" command)') + else: + credentials_filename = None + else: + extra_help = (' (pointed to by ' + GOOGLE_APPLICATION_CREDENTIALS + + ' environment variable)') + + if not credentials_filename: + return + + # If we can read the credentials from a file, we don't need to know + # what environment we are in. + SETTINGS.env_name = DEFAULT_ENV_NAME + + try: + return _get_application_default_credential_from_file( + credentials_filename) + except (ApplicationDefaultCredentialsError, ValueError) as error: + _raise_exception_for_reading_json(credentials_filename, + extra_help, error) + + @classmethod + def _get_implicit_credentials(cls): + """Gets credentials implicitly from the environment. + + Checks environment in order of precedence: + - Google App Engine (production and testing) + - Environment variable GOOGLE_APPLICATION_CREDENTIALS pointing to + a file with stored credentials information. + - Stored "well known" file associated with `gcloud` command line tool. + - Google Compute Engine production environment. + + Raises: + ApplicationDefaultCredentialsError: raised when the credentials + fail to be retrieved. + """ + # Environ checks (in order). + environ_checkers = [ + cls._implicit_credentials_from_gae, + cls._implicit_credentials_from_files, + cls._implicit_credentials_from_gce, + ] + + for checker in environ_checkers: + credentials = checker() + if credentials is not None: + return credentials + + # If no credentials, fail. + raise ApplicationDefaultCredentialsError(ADC_HELP_MSG) + + @staticmethod + def get_application_default(): + """Get the Application Default Credentials for the current environment. + + Raises: + ApplicationDefaultCredentialsError: raised when the credentials + fail to be retrieved. + """ + return GoogleCredentials._get_implicit_credentials() + + @staticmethod + def from_stream(credential_filename): + """Create a Credentials object by reading information from a file. + + It returns an object of type GoogleCredentials. + + Args: + credential_filename: the path to the file from where the + credentials are to be read + + Raises: + ApplicationDefaultCredentialsError: raised when the credentials + fail to be retrieved. + """ + if credential_filename and os.path.isfile(credential_filename): + try: + return _get_application_default_credential_from_file( + credential_filename) + except (ApplicationDefaultCredentialsError, ValueError) as error: + extra_help = (' (provided as parameter to the ' + 'from_stream() method)') + _raise_exception_for_reading_json(credential_filename, + extra_help, + error) + else: + raise ApplicationDefaultCredentialsError( + 'The parameter passed to the from_stream() ' + 'method should point to a file.') + + +def _save_private_file(filename, json_contents): + """Saves a file with read-write permissions on for the owner. Args: - access_token: string, access token. - client_id: string, client identifier. - client_secret: string, client secret. - refresh_token: string, refresh token. - token_expiry: datetime, when the access_token expires. - token_uri: string, URI of token endpoint. - user_agent: string, The HTTP User-Agent to provide for this application. - revoke_uri: string, URI for revoke endpoint. - Defaults to GOOGLE_REVOKE_URI; a token can't be revoked if this is None. + filename: String. Absolute path to file. + json_contents: JSON serializable object to be saved. """ - super(GoogleCredentials, self).__init__( - access_token, client_id, client_secret, refresh_token, token_expiry, - token_uri, user_agent, revoke_uri=revoke_uri) - - def create_scoped_required(self): - """Whether this Credentials object is scopeless. - - create_scoped(scopes) method needs to be called in order to create - a Credentials object for API calls. - """ - return False - - def create_scoped(self, scopes): - """Create a Credentials object for the given scopes. - - The Credentials type is preserved. - """ - return self - - @property - def serialization_data(self): - """Get the fields and their values identifying the current credentials.""" - return { - 'type': 'authorized_user', - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'refresh_token': self.refresh_token - } - - @staticmethod - def _implicit_credentials_from_gae(env_name=None): - """Attempts to get implicit credentials in Google App Engine env. - - If the current environment is not detected as App Engine, returns None, - indicating no Google App Engine credentials can be detected from the - current environment. - - Args: - env_name: String, indicating current environment. - - Returns: - None, if not in GAE, else an appengine.AppAssertionCredentials object. - """ - env_name = env_name or _get_environment() - if env_name not in ('GAE_PRODUCTION', 'GAE_LOCAL'): - return None - - return _get_application_default_credential_GAE() - - @staticmethod - def _implicit_credentials_from_gce(env_name=None): - """Attempts to get implicit credentials in Google Compute Engine env. - - If the current environment is not detected as Compute Engine, returns None, - indicating no Google Compute Engine credentials can be detected from the - current environment. - - Args: - env_name: String, indicating current environment. - - Returns: - None, if not in GCE, else a gce.AppAssertionCredentials object. - """ - env_name = env_name or _get_environment() - if env_name != 'GCE_PRODUCTION': - return None - - return _get_application_default_credential_GCE() - - @staticmethod - def _implicit_credentials_from_files(env_name=None): - """Attempts to get implicit credentials from local credential files. - - First checks if the environment variable GOOGLE_APPLICATION_CREDENTIALS - is set with a filename and then falls back to a configuration file (the - "well known" file) associated with the 'gcloud' command line tool. - - Args: - env_name: Unused argument. - - Returns: - Credentials object associated with the GOOGLE_APPLICATION_CREDENTIALS - file or the "well known" file if either exist. If neither file is - define, returns None, indicating no credentials from a file can - detected from the current environment. - """ - credentials_filename = _get_environment_variable_file() - if not credentials_filename: - credentials_filename = _get_well_known_file() - if os.path.isfile(credentials_filename): - extra_help = (' (produced automatically when running' - ' "gcloud auth login" command)') - else: - credentials_filename = None - else: - extra_help = (' (pointed to by ' + GOOGLE_APPLICATION_CREDENTIALS + - ' environment variable)') - - if not credentials_filename: - return - - try: - return _get_application_default_credential_from_file(credentials_filename) - except (ApplicationDefaultCredentialsError, ValueError) as error: - _raise_exception_for_reading_json(credentials_filename, extra_help, error) - - @classmethod - def _get_implicit_credentials(cls): - """Gets credentials implicitly from the environment. - - Checks environment in order of precedence: - - Google App Engine (production and testing) - - Environment variable GOOGLE_APPLICATION_CREDENTIALS pointing to - a file with stored credentials information. - - Stored "well known" file associated with `gcloud` command line tool. - - Google Compute Engine production environment. - - Exceptions: - ApplicationDefaultCredentialsError: raised when the credentials fail - to be retrieved. - """ - env_name = _get_environment() - - # Environ checks (in order). Assumes each checker takes `env_name` - # as a kwarg. - environ_checkers = [ - cls._implicit_credentials_from_gae, - cls._implicit_credentials_from_files, - cls._implicit_credentials_from_gce, - ] - - for checker in environ_checkers: - credentials = checker(env_name=env_name) - if credentials is not None: - return credentials - - # If no credentials, fail. - raise ApplicationDefaultCredentialsError(ADC_HELP_MSG) - - @staticmethod - def get_application_default(): - """Get the Application Default Credentials for the current environment. - - Exceptions: - ApplicationDefaultCredentialsError: raised when the credentials fail - to be retrieved. - """ - return GoogleCredentials._get_implicit_credentials() - - @staticmethod - def from_stream(credential_filename): - """Create a Credentials object by reading the information from a given file. - - It returns an object of type GoogleCredentials. - - Args: - credential_filename: the path to the file from where the credentials - are to be read - - Exceptions: - ApplicationDefaultCredentialsError: raised when the credentials fail - to be retrieved. - """ - - if credential_filename and os.path.isfile(credential_filename): - try: - return _get_application_default_credential_from_file( - credential_filename) - except (ApplicationDefaultCredentialsError, ValueError) as error: - extra_help = ' (provided as parameter to the from_stream() method)' - _raise_exception_for_reading_json(credential_filename, - extra_help, - error) - else: - raise ApplicationDefaultCredentialsError( - 'The parameter passed to the from_stream() ' - 'method should point to a file.') + temp_filename = tempfile.mktemp() + file_desc = os.open(temp_filename, os.O_WRONLY | os.O_CREAT, 0o600) + with os.fdopen(file_desc, 'w') as file_handle: + json.dump(json_contents, file_handle, sort_keys=True, + indent=2, separators=(',', ': ')) + shutil.move(temp_filename, filename) def save_to_well_known_file(credentials, well_known_file=None): - """Save the provided GoogleCredentials to the well known file. + """Save the provided GoogleCredentials to the well known file. - Args: - credentials: - the credentials to be saved to the well known file; - it should be an instance of GoogleCredentials - well_known_file: - the name of the file where the credentials are to be saved; - this parameter is supposed to be used for testing only - """ - # TODO(orestica): move this method to tools.py - # once the argparse import gets fixed (it is not present in Python 2.6) + Args: + credentials: the credentials to be saved to the well known file; + it should be an instance of GoogleCredentials + well_known_file: the name of the file where the credentials are to be + saved; this parameter is supposed to be used for + testing only + """ + # TODO(orestica): move this method to tools.py + # once the argparse import gets fixed (it is not present in Python 2.6) - if well_known_file is None: - well_known_file = _get_well_known_file() + if well_known_file is None: + well_known_file = _get_well_known_file() - credentials_data = credentials.serialization_data + config_dir = os.path.dirname(well_known_file) + if not os.path.isdir(config_dir): + raise OSError('Config directory does not exist: %s' % config_dir) - with open(well_known_file, 'w') as f: - json.dump(credentials_data, f, sort_keys=True, indent=2, separators=(',', ': ')) + credentials_data = credentials.serialization_data + _save_private_file(well_known_file, credentials_data) def _get_environment_variable_file(): - application_default_credential_filename = ( + application_default_credential_filename = ( os.environ.get(GOOGLE_APPLICATION_CREDENTIALS, None)) - if application_default_credential_filename: - if os.path.isfile(application_default_credential_filename): - return application_default_credential_filename - else: - raise ApplicationDefaultCredentialsError( - 'File ' + application_default_credential_filename + ' (pointed by ' + - GOOGLE_APPLICATION_CREDENTIALS + - ' environment variable) does not exist!') + if application_default_credential_filename: + if os.path.isfile(application_default_credential_filename): + return application_default_credential_filename + else: + raise ApplicationDefaultCredentialsError( + 'File ' + application_default_credential_filename + + ' (pointed by ' + + GOOGLE_APPLICATION_CREDENTIALS + + ' environment variable) does not exist!') def _get_well_known_file(): - """Get the well known file produced by command 'gcloud auth login'.""" - # TODO(orestica): Revisit this method once gcloud provides a better way - # of pinpointing the exact location of the file. + """Get the well known file produced by command 'gcloud auth login'.""" + # TODO(orestica): Revisit this method once gcloud provides a better way + # of pinpointing the exact location of the file. - WELL_KNOWN_CREDENTIALS_FILE = 'application_default_credentials.json' - CLOUDSDK_CONFIG_DIRECTORY = 'gcloud' + WELL_KNOWN_CREDENTIALS_FILE = 'application_default_credentials.json' - if os.name == 'nt': - try: - default_config_path = os.path.join(os.environ['APPDATA'], - CLOUDSDK_CONFIG_DIRECTORY) - except KeyError: - # This should never happen unless someone is really messing with things. - drive = os.environ.get('SystemDrive', 'C:') - default_config_path = os.path.join(drive, '\\', CLOUDSDK_CONFIG_DIRECTORY) - else: - default_config_path = os.path.join(os.path.expanduser('~'), - '.config', - CLOUDSDK_CONFIG_DIRECTORY) + default_config_dir = os.getenv(_CLOUDSDK_CONFIG_ENV_VAR) + if default_config_dir is None: + if os.name == 'nt': + try: + default_config_dir = os.path.join(os.environ['APPDATA'], + _CLOUDSDK_CONFIG_DIRECTORY) + except KeyError: + # This should never happen unless someone is really + # messing with things. + drive = os.environ.get('SystemDrive', 'C:') + default_config_dir = os.path.join(drive, '\\', + _CLOUDSDK_CONFIG_DIRECTORY) + else: + default_config_dir = os.path.join(os.path.expanduser('~'), + '.config', + _CLOUDSDK_CONFIG_DIRECTORY) - default_config_path = os.path.join(default_config_path, - WELL_KNOWN_CREDENTIALS_FILE) - - return default_config_path + return os.path.join(default_config_dir, WELL_KNOWN_CREDENTIALS_FILE) def _get_application_default_credential_from_file(filename): - """Build the Application Default Credentials from file.""" + """Build the Application Default Credentials from file.""" - from oauth2client import service_account + from oauth2client import service_account - # read the credentials from the file - with open(filename) as file_obj: - client_credentials = json.load(file_obj) + # read the credentials from the file + with open(filename) as file_obj: + client_credentials = json.load(file_obj) - credentials_type = client_credentials.get('type') - if credentials_type == AUTHORIZED_USER: - required_fields = set(['client_id', 'client_secret', 'refresh_token']) - elif credentials_type == SERVICE_ACCOUNT: - required_fields = set(['client_id', 'client_email', 'private_key_id', - 'private_key']) - else: - raise ApplicationDefaultCredentialsError( - "'type' field should be defined (and have one of the '" + - AUTHORIZED_USER + "' or '" + SERVICE_ACCOUNT + "' values)") + credentials_type = client_credentials.get('type') + if credentials_type == AUTHORIZED_USER: + required_fields = set(['client_id', 'client_secret', 'refresh_token']) + elif credentials_type == SERVICE_ACCOUNT: + required_fields = set(['client_id', 'client_email', 'private_key_id', + 'private_key']) + else: + raise ApplicationDefaultCredentialsError( + "'type' field should be defined (and have one of the '" + + AUTHORIZED_USER + "' or '" + SERVICE_ACCOUNT + "' values)") - missing_fields = required_fields.difference(client_credentials.keys()) + missing_fields = required_fields.difference(client_credentials.keys()) - if missing_fields: - _raise_exception_for_missing_fields(missing_fields) + if missing_fields: + _raise_exception_for_missing_fields(missing_fields) - if client_credentials['type'] == AUTHORIZED_USER: - return GoogleCredentials( - access_token=None, - client_id=client_credentials['client_id'], - client_secret=client_credentials['client_secret'], - refresh_token=client_credentials['refresh_token'], - token_expiry=None, - token_uri=GOOGLE_TOKEN_URI, - user_agent='Python client library') - else: # client_credentials['type'] == SERVICE_ACCOUNT - return service_account._ServiceAccountCredentials( - service_account_id=client_credentials['client_id'], - service_account_email=client_credentials['client_email'], - private_key_id=client_credentials['private_key_id'], - private_key_pkcs8_text=client_credentials['private_key'], - scopes=[]) + if client_credentials['type'] == AUTHORIZED_USER: + return GoogleCredentials( + access_token=None, + client_id=client_credentials['client_id'], + client_secret=client_credentials['client_secret'], + refresh_token=client_credentials['refresh_token'], + token_expiry=None, + token_uri=GOOGLE_TOKEN_URI, + user_agent='Python client library') + else: # client_credentials['type'] == SERVICE_ACCOUNT + return service_account._ServiceAccountCredentials( + service_account_id=client_credentials['client_id'], + service_account_email=client_credentials['client_email'], + private_key_id=client_credentials['private_key_id'], + private_key_pkcs8_text=client_credentials['private_key'], + scopes=[]) def _raise_exception_for_missing_fields(missing_fields): - raise ApplicationDefaultCredentialsError( - 'The following field(s) must be defined: ' + ', '.join(missing_fields)) + raise ApplicationDefaultCredentialsError( + 'The following field(s) must be defined: ' + ', '.join(missing_fields)) def _raise_exception_for_reading_json(credential_file, extra_help, error): - raise ApplicationDefaultCredentialsError( - 'An error was encountered while reading json file: '+ + raise ApplicationDefaultCredentialsError( + 'An error was encountered while reading json file: ' + credential_file + extra_help + ': ' + str(error)) def _get_application_default_credential_GAE(): - from oauth2client.appengine import AppAssertionCredentials + from oauth2client.appengine import AppAssertionCredentials - return AppAssertionCredentials([]) + return AppAssertionCredentials([]) def _get_application_default_credential_GCE(): - from oauth2client.gce import AppAssertionCredentials + from oauth2client.gce import AppAssertionCredentials - return AppAssertionCredentials([]) + return AppAssertionCredentials([]) class AssertionCredentials(GoogleCredentials): - """Abstract Credentials object used for OAuth 2.0 assertion grants. + """Abstract Credentials object used for OAuth 2.0 assertion grants. - This credential does not require a flow to instantiate because it - represents a two legged flow, and therefore has all of the required - information to generate and refresh its own access tokens. It must - be subclassed to generate the appropriate assertion string. + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. It must + be subclassed to generate the appropriate assertion string. - AssertionCredentials objects may be safely pickled and unpickled. - """ - - @util.positional(2) - def __init__(self, assertion_type, user_agent=None, - token_uri=GOOGLE_TOKEN_URI, - revoke_uri=GOOGLE_REVOKE_URI, - **unused_kwargs): - """Constructor for AssertionFlowCredentials. - - Args: - assertion_type: string, assertion type that will be declared to the auth - server - user_agent: string, The HTTP User-Agent to provide for this application. - token_uri: string, URI for token endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - revoke_uri: string, URI for revoke endpoint. + AssertionCredentials objects may be safely pickled and unpickled. """ - super(AssertionCredentials, self).__init__( - None, - None, - None, - None, - None, - token_uri, - user_agent, - revoke_uri=revoke_uri) - self.assertion_type = assertion_type - def _generate_refresh_request_body(self): - assertion = self._generate_assertion() + @util.positional(2) + def __init__(self, assertion_type, user_agent=None, + token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI, + **unused_kwargs): + """Constructor for AssertionFlowCredentials. - body = urllib.parse.urlencode({ - 'assertion': assertion, - 'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer', + Args: + assertion_type: string, assertion type that will be declared to the + auth server + user_agent: string, The HTTP User-Agent to provide for this + application. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. + """ + super(AssertionCredentials, self).__init__( + None, + None, + None, + None, + None, + token_uri, + user_agent, + revoke_uri=revoke_uri) + self.assertion_type = assertion_type + + def _generate_refresh_request_body(self): + assertion = self._generate_assertion() + + body = urllib.parse.urlencode({ + 'assertion': assertion, + 'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer', }) - return body + return body - def _generate_assertion(self): - """Generate the assertion string that will be used in the access token - request. - """ - _abstract() + def _generate_assertion(self): + """Generate assertion string to be used in the access token request.""" + _abstract() - def _revoke(self, http_request): - """Revokes the access_token and deletes the store if available. + def _revoke(self, http_request): + """Revokes the access_token and deletes the store if available. - Args: - http_request: callable, a callable that matches the method signature of - httplib2.Http.request, used to make the revoke request. - """ - self._do_revoke(http_request, self.access_token) + Args: + http_request: callable, a callable that matches the method + signature of httplib2.Http.request, used to make the + revoke request. + """ + self._do_revoke(http_request, self.access_token) def _RequireCryptoOrDie(): - """Ensure we have a crypto library, or throw CryptoUnavailableError. + """Ensure we have a crypto library, or throw CryptoUnavailableError. - The oauth2client.crypt module requires either PyCrypto or PyOpenSSL - to be available in order to function, but these are optional - dependencies. - """ - if not HAS_CRYPTO: - raise CryptoUnavailableError('No crypto library available') + The oauth2client.crypt module requires either PyCrypto or PyOpenSSL + to be available in order to function, but these are optional + dependencies. + """ + if not HAS_CRYPTO: + raise CryptoUnavailableError('No crypto library available') class SignedJwtAssertionCredentials(AssertionCredentials): - """Credentials object used for OAuth 2.0 Signed JWT assertion grants. + """Credentials object used for OAuth 2.0 Signed JWT assertion grants. - This credential does not require a flow to instantiate because it - represents a two legged flow, and therefore has all of the required - information to generate and refresh its own access tokens. + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. - SignedJwtAssertionCredentials requires either PyOpenSSL, or PyCrypto - 2.6 or later. For App Engine you may also consider using - AppAssertionCredentials. - """ - - MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds - - @util.positional(4) - def __init__(self, - service_account_name, - private_key, - scope, - private_key_password='notasecret', - user_agent=None, - token_uri=GOOGLE_TOKEN_URI, - revoke_uri=GOOGLE_REVOKE_URI, - **kwargs): - """Constructor for SignedJwtAssertionCredentials. - - Args: - service_account_name: string, id for account, usually an email address. - private_key: string, private key in PKCS12 or PEM format. - scope: string or iterable of strings, scope(s) of the credentials being - requested. - private_key_password: string, password for private_key, unused if - private_key is in PEM format. - user_agent: string, HTTP User-Agent to provide for this application. - token_uri: string, URI for token endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - revoke_uri: string, URI for revoke endpoint. - kwargs: kwargs, Additional parameters to add to the JWT token, for - example sub=joe@xample.org. - - Raises: - CryptoUnavailableError if no crypto library is available. + SignedJwtAssertionCredentials requires either PyOpenSSL, or PyCrypto + 2.6 or later. For App Engine you may also consider using + AppAssertionCredentials. """ - _RequireCryptoOrDie() - super(SignedJwtAssertionCredentials, self).__init__( - None, - user_agent=user_agent, - token_uri=token_uri, - revoke_uri=revoke_uri, + + MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds + + @util.positional(4) + def __init__(self, + service_account_name, + private_key, + scope, + private_key_password='notasecret', + user_agent=None, + token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI, + **kwargs): + """Constructor for SignedJwtAssertionCredentials. + + Args: + service_account_name: string, id for account, usually an email + address. + private_key: string, private key in PKCS12 or PEM format. + scope: string or iterable of strings, scope(s) of the credentials + being requested. + private_key_password: string, password for private_key, unused if + private_key is in PEM format. + user_agent: string, HTTP User-Agent to provide for this + application. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. + kwargs: kwargs, Additional parameters to add to the JWT token, for + example sub=joe@xample.org. + + Raises: + CryptoUnavailableError if no crypto library is available. + """ + _RequireCryptoOrDie() + super(SignedJwtAssertionCredentials, self).__init__( + None, + user_agent=user_agent, + token_uri=token_uri, + revoke_uri=revoke_uri, ) - self.scope = util.scopes_to_string(scope) + self.scope = util.scopes_to_string(scope) - # Keep base64 encoded so it can be stored in JSON. - self.private_key = base64.b64encode(private_key) - if isinstance(self.private_key, six.text_type): - self.private_key = self.private_key.encode('utf-8') + # Keep base64 encoded so it can be stored in JSON. + self.private_key = base64.b64encode(private_key) + self.private_key = _to_bytes(self.private_key, encoding='utf-8') + self.private_key_password = private_key_password + self.service_account_name = service_account_name + self.kwargs = kwargs - self.private_key_password = private_key_password - self.service_account_name = service_account_name - self.kwargs = kwargs - - @classmethod - def from_json(cls, s): - data = json.loads(s) - retval = SignedJwtAssertionCredentials( - data['service_account_name'], - base64.b64decode(data['private_key']), - data['scope'], - private_key_password=data['private_key_password'], - user_agent=data['user_agent'], - token_uri=data['token_uri'], - **data['kwargs'] + @classmethod + def from_json(cls, s): + data = json.loads(_from_bytes(s)) + retval = SignedJwtAssertionCredentials( + data['service_account_name'], + base64.b64decode(data['private_key']), + data['scope'], + private_key_password=data['private_key_password'], + user_agent=data['user_agent'], + token_uri=data['token_uri'], + **data['kwargs'] ) - retval.invalid = data['invalid'] - retval.access_token = data['access_token'] - return retval + retval.invalid = data['invalid'] + retval.access_token = data['access_token'] + return retval - def _generate_assertion(self): - """Generate the assertion that will be used in the request.""" - now = int(time.time()) - payload = { - 'aud': self.token_uri, - 'scope': self.scope, - 'iat': now, - 'exp': now + SignedJwtAssertionCredentials.MAX_TOKEN_LIFETIME_SECS, - 'iss': self.service_account_name - } - payload.update(self.kwargs) - logger.debug(str(payload)) + def _generate_assertion(self): + """Generate the assertion that will be used in the request.""" + now = int(time.time()) + payload = { + 'aud': self.token_uri, + 'scope': self.scope, + 'iat': now, + 'exp': now + SignedJwtAssertionCredentials.MAX_TOKEN_LIFETIME_SECS, + 'iss': self.service_account_name + } + payload.update(self.kwargs) + logger.debug(str(payload)) - private_key = base64.b64decode(self.private_key) - return crypt.make_signed_jwt(crypt.Signer.from_string( - private_key, self.private_key_password), payload) + private_key = base64.b64decode(self.private_key) + return crypt.make_signed_jwt(crypt.Signer.from_string( + private_key, self.private_key_password), payload) # Only used in verify_id_token(), which is always calling to the same URI # for the certs. _cached_http = httplib2.Http(MemoryCache()) + @util.positional(2) def verify_id_token(id_token, audience, http=None, cert_uri=ID_TOKEN_VERIFICATION_CERTS): - """Verifies a signed JWT id_token. + """Verifies a signed JWT id_token. - This function requires PyOpenSSL and because of that it does not work on - App Engine. + This function requires PyOpenSSL and because of that it does not work on + App Engine. - Args: - id_token: string, A Signed JWT. - audience: string, The audience 'aud' that the token should be for. - http: httplib2.Http, instance to use to make the HTTP request. Callers - should supply an instance that has caching enabled. - cert_uri: string, URI of the certificates in JSON format to - verify the JWT against. + Args: + id_token: string, A Signed JWT. + audience: string, The audience 'aud' that the token should be for. + http: httplib2.Http, instance to use to make the HTTP request. Callers + should supply an instance that has caching enabled. + cert_uri: string, URI of the certificates in JSON format to + verify the JWT against. - Returns: - The deserialized JSON in the JWT. + Returns: + The deserialized JSON in the JWT. - Raises: - oauth2client.crypt.AppIdentityError: if the JWT fails to verify. - CryptoUnavailableError: if no crypto library is available. - """ - _RequireCryptoOrDie() - if http is None: - http = _cached_http + Raises: + oauth2client.crypt.AppIdentityError: if the JWT fails to verify. + CryptoUnavailableError: if no crypto library is available. + """ + _RequireCryptoOrDie() + if http is None: + http = _cached_http - resp, content = http.request(cert_uri) - - if resp.status == 200: - certs = json.loads(content.decode('utf-8')) - return crypt.verify_signed_jwt_with_certs(id_token, certs, audience) - else: - raise VerifyJwtTokenError('Status code: %d' % resp.status) - - -def _urlsafe_b64decode(b64string): - # Guard against unicode strings, which base64 can't handle. - if isinstance(b64string, six.text_type): - b64string = b64string.encode('ascii') - padded = b64string + b'=' * (4 - len(b64string) % 4) - return base64.urlsafe_b64decode(padded) + resp, content = http.request(cert_uri) + if resp.status == 200: + certs = json.loads(_from_bytes(content)) + return crypt.verify_signed_jwt_with_certs(id_token, certs, audience) + else: + raise VerifyJwtTokenError('Status code: %d' % resp.status) def _extract_id_token(id_token): - """Extract the JSON payload from a JWT. + """Extract the JSON payload from a JWT. - Does the extraction w/o checking the signature. + Does the extraction w/o checking the signature. - Args: - id_token: string or bytestring, OAuth 2.0 id_token. + Args: + id_token: string or bytestring, OAuth 2.0 id_token. - Returns: - object, The deserialized JSON payload. - """ - if type(id_token) == bytes: - segments = id_token.split(b'.') - else: - segments = id_token.split(u'.') + Returns: + object, The deserialized JSON payload. + """ + if type(id_token) == bytes: + segments = id_token.split(b'.') + else: + segments = id_token.split(u'.') - if len(segments) != 3: - raise VerifyJwtTokenError( - 'Wrong number of segments in token: %s' % id_token) + if len(segments) != 3: + raise VerifyJwtTokenError( + 'Wrong number of segments in token: %s' % id_token) - return json.loads(_urlsafe_b64decode(segments[1]).decode('utf-8')) + return json.loads(_from_bytes(_urlsafe_b64decode(segments[1]))) def _parse_exchange_token_response(content): - """Parses response of an exchange token request. + """Parses response of an exchange token request. - Most providers return JSON but some (e.g. Facebook) return a - url-encoded string. + Most providers return JSON but some (e.g. Facebook) return a + url-encoded string. - Args: - content: The body of a response + Args: + content: The body of a response - Returns: - Content as a dictionary object. Note that the dict could be empty, - i.e. {}. That basically indicates a failure. - """ - resp = {} - try: - resp = json.loads(content.decode('utf-8')) - except Exception: - # different JSON libs raise different exceptions, - # so we just do a catch-all here - content = content.decode('utf-8') - resp = dict(urllib.parse.parse_qsl(content)) + Returns: + Content as a dictionary object. Note that the dict could be empty, + i.e. {}. That basically indicates a failure. + """ + resp = {} + content = _from_bytes(content) + try: + resp = json.loads(content) + except Exception: + # different JSON libs raise different exceptions, + # so we just do a catch-all here + resp = dict(urllib.parse.parse_qsl(content)) - # some providers respond with 'expires', others with 'expires_in' - if resp and 'expires' in resp: - resp['expires_in'] = resp.pop('expires') + # some providers respond with 'expires', others with 'expires_in' + if resp and 'expires' in resp: + resp['expires_in'] = resp.pop('expires') - return resp + return resp @util.positional(4) @@ -1613,419 +1767,469 @@ def credentials_from_code(client_id, client_secret, scope, code, user_agent=None, token_uri=GOOGLE_TOKEN_URI, auth_uri=GOOGLE_AUTH_URI, revoke_uri=GOOGLE_REVOKE_URI, - device_uri=GOOGLE_DEVICE_URI): - """Exchanges an authorization code for an OAuth2Credentials object. + device_uri=GOOGLE_DEVICE_URI, + token_info_uri=GOOGLE_TOKEN_INFO_URI): + """Exchanges an authorization code for an OAuth2Credentials object. - Args: - client_id: string, client identifier. - client_secret: string, client secret. - scope: string or iterable of strings, scope(s) to request. - code: string, An authorization code, most likely passed down from - the client - redirect_uri: string, this is generally set to 'postmessage' to match the - redirect_uri that the client specified - http: httplib2.Http, optional http instance to use to do the fetch - token_uri: string, URI for token endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - auth_uri: string, URI for authorization endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - revoke_uri: string, URI for revoke endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - device_uri: string, URI for device authorization endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. + Args: + client_id: string, client identifier. + client_secret: string, client secret. + scope: string or iterable of strings, scope(s) to request. + code: string, An authorization code, most likely passed down from + the client + redirect_uri: string, this is generally set to 'postmessage' to match + the redirect_uri that the client specified + http: httplib2.Http, optional http instance to use to do the fetch + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + auth_uri: string, URI for authorization endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider + can be used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider + can be used. + device_uri: string, URI for device authorization endpoint. For + convenience defaults to Google's endpoints but any OAuth + 2.0 provider can be used. - Returns: - An OAuth2Credentials object. + Returns: + An OAuth2Credentials object. - Raises: - FlowExchangeError if the authorization code cannot be exchanged for an - access token - """ - flow = OAuth2WebServerFlow(client_id, client_secret, scope, - redirect_uri=redirect_uri, user_agent=user_agent, - auth_uri=auth_uri, token_uri=token_uri, - revoke_uri=revoke_uri, device_uri=device_uri) + Raises: + FlowExchangeError if the authorization code cannot be exchanged for an + access token + """ + flow = OAuth2WebServerFlow(client_id, client_secret, scope, + redirect_uri=redirect_uri, + user_agent=user_agent, auth_uri=auth_uri, + token_uri=token_uri, revoke_uri=revoke_uri, + device_uri=device_uri, + token_info_uri=token_info_uri) - credentials = flow.step2_exchange(code, http=http) - return credentials + credentials = flow.step2_exchange(code, http=http) + return credentials @util.positional(3) def credentials_from_clientsecrets_and_code(filename, scope, code, - message = None, + message=None, redirect_uri='postmessage', http=None, cache=None, device_uri=None): - """Returns OAuth2Credentials from a clientsecrets file and an auth code. + """Returns OAuth2Credentials from a clientsecrets file and an auth code. - Will create the right kind of Flow based on the contents of the clientsecrets - file or will raise InvalidClientSecretsError for unknown types of Flows. + Will create the right kind of Flow based on the contents of the + clientsecrets file or will raise InvalidClientSecretsError for unknown + types of Flows. - Args: - filename: string, File name of clientsecrets. - scope: string or iterable of strings, scope(s) to request. - code: string, An authorization code, most likely passed down from - the client - message: string, A friendly string to display to the user if the - clientsecrets file is missing or invalid. If message is provided then - sys.exit will be called in the case of an error. If message in not - provided then clientsecrets.InvalidClientSecretsError will be raised. - redirect_uri: string, this is generally set to 'postmessage' to match the - redirect_uri that the client specified - http: httplib2.Http, optional http instance to use to do the fetch - cache: An optional cache service client that implements get() and set() - methods. See clientsecrets.loadfile() for details. - device_uri: string, OAuth 2.0 device authorization endpoint + Args: + filename: string, File name of clientsecrets. + scope: string or iterable of strings, scope(s) to request. + code: string, An authorization code, most likely passed down from + the client + message: string, A friendly string to display to the user if the + clientsecrets file is missing or invalid. If message is + provided then sys.exit will be called in the case of an error. + If message in not provided then + clientsecrets.InvalidClientSecretsError will be raised. + redirect_uri: string, this is generally set to 'postmessage' to match + the redirect_uri that the client specified + http: httplib2.Http, optional http instance to use to do the fetch + cache: An optional cache service client that implements get() and set() + methods. See clientsecrets.loadfile() for details. + device_uri: string, OAuth 2.0 device authorization endpoint - Returns: - An OAuth2Credentials object. + Returns: + An OAuth2Credentials object. - Raises: - FlowExchangeError if the authorization code cannot be exchanged for an - access token - UnknownClientSecretsFlowError if the file describes an unknown kind of Flow. - clientsecrets.InvalidClientSecretsError if the clientsecrets file is - invalid. - """ - flow = flow_from_clientsecrets(filename, scope, message=message, cache=cache, - redirect_uri=redirect_uri, - device_uri=device_uri) - credentials = flow.step2_exchange(code, http=http) - return credentials + Raises: + FlowExchangeError: if the authorization code cannot be exchanged for an + access token + UnknownClientSecretsFlowError: if the file describes an unknown kind + of Flow. + clientsecrets.InvalidClientSecretsError: if the clientsecrets file is + invalid. + """ + flow = flow_from_clientsecrets(filename, scope, message=message, + cache=cache, redirect_uri=redirect_uri, + device_uri=device_uri) + credentials = flow.step2_exchange(code, http=http) + return credentials class DeviceFlowInfo(collections.namedtuple('DeviceFlowInfo', ( - 'device_code', 'user_code', 'interval', 'verification_url', - 'user_code_expiry'))): - """Intermediate information the OAuth2 for devices flow.""" + 'device_code', 'user_code', 'interval', 'verification_url', + 'user_code_expiry'))): + """Intermediate information the OAuth2 for devices flow.""" - @classmethod - def FromResponse(cls, response): - """Create a DeviceFlowInfo from a server response. + @classmethod + def FromResponse(cls, response): + """Create a DeviceFlowInfo from a server response. - The response should be a dict containing entries as described here: + The response should be a dict containing entries as described here: - http://tools.ietf.org/html/draft-ietf-oauth-v2-05#section-3.7.1 - """ - # device_code, user_code, and verification_url are required. - kwargs = { - 'device_code': response['device_code'], - 'user_code': response['user_code'], - } - # The response may list the verification address as either - # verification_url or verification_uri, so we check for both. - verification_url = response.get( - 'verification_url', response.get('verification_uri')) - if verification_url is None: - raise OAuth2DeviceCodeError( - 'No verification_url provided in server response') - kwargs['verification_url'] = verification_url - # expires_in and interval are optional. - kwargs.update({ - 'interval': response.get('interval'), - 'user_code_expiry': None, - }) - if 'expires_in' in response: - kwargs['user_code_expiry'] = datetime.datetime.now() + datetime.timedelta( - seconds=int(response['expires_in'])) + http://tools.ietf.org/html/draft-ietf-oauth-v2-05#section-3.7.1 + """ + # device_code, user_code, and verification_url are required. + kwargs = { + 'device_code': response['device_code'], + 'user_code': response['user_code'], + } + # The response may list the verification address as either + # verification_url or verification_uri, so we check for both. + verification_url = response.get( + 'verification_url', response.get('verification_uri')) + if verification_url is None: + raise OAuth2DeviceCodeError( + 'No verification_url provided in server response') + kwargs['verification_url'] = verification_url + # expires_in and interval are optional. + kwargs.update({ + 'interval': response.get('interval'), + 'user_code_expiry': None, + }) + if 'expires_in' in response: + kwargs['user_code_expiry'] = ( + datetime.datetime.now() + + datetime.timedelta(seconds=int(response['expires_in']))) + return cls(**kwargs) - return cls(**kwargs) class OAuth2WebServerFlow(Flow): - """Does the Web Server Flow for OAuth 2.0. + """Does the Web Server Flow for OAuth 2.0. - OAuth2WebServerFlow objects may be safely pickled and unpickled. - """ - - @util.positional(4) - def __init__(self, client_id, client_secret, scope, - redirect_uri=None, - user_agent=None, - auth_uri=GOOGLE_AUTH_URI, - token_uri=GOOGLE_TOKEN_URI, - revoke_uri=GOOGLE_REVOKE_URI, - login_hint=None, - device_uri=GOOGLE_DEVICE_URI, - **kwargs): - """Constructor for OAuth2WebServerFlow. - - The kwargs argument is used to set extra query parameters on the - auth_uri. For example, the access_type and approval_prompt - query parameters can be set via kwargs. - - Args: - client_id: string, client identifier. - client_secret: string client secret. - scope: string or iterable of strings, scope(s) of the credentials being - requested. - redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' for - a non-web-based application, or a URI that handles the callback from - the authorization server. - user_agent: string, HTTP User-Agent to provide for this application. - auth_uri: string, URI for authorization endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - token_uri: string, URI for token endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - revoke_uri: string, URI for revoke endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - login_hint: string, Either an email address or domain. Passing this hint - will either pre-fill the email box on the sign-in form or select the - proper multi-login session, thereby simplifying the login flow. - device_uri: string, URI for device authorization endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. - **kwargs: dict, The keyword arguments are all optional and required - parameters for the OAuth calls. + OAuth2WebServerFlow objects may be safely pickled and unpickled. """ - self.client_id = client_id - self.client_secret = client_secret - self.scope = util.scopes_to_string(scope) - self.redirect_uri = redirect_uri - self.login_hint = login_hint - self.user_agent = user_agent - self.auth_uri = auth_uri - self.token_uri = token_uri - self.revoke_uri = revoke_uri - self.device_uri = device_uri - self.params = { - 'access_type': 'offline', - 'response_type': 'code', - } - self.params.update(kwargs) - @util.positional(1) - def step1_get_authorize_url(self, redirect_uri=None): - """Returns a URI to redirect to the provider. + @util.positional(4) + def __init__(self, client_id, + client_secret=None, + scope=None, + redirect_uri=None, + user_agent=None, + auth_uri=GOOGLE_AUTH_URI, + token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI, + login_hint=None, + device_uri=GOOGLE_DEVICE_URI, + token_info_uri=GOOGLE_TOKEN_INFO_URI, + authorization_header=None, + **kwargs): + """Constructor for OAuth2WebServerFlow. - Args: - redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' for - a non-web-based application, or a URI that handles the callback from - the authorization server. This parameter is deprecated, please move to - passing the redirect_uri in via the constructor. + The kwargs argument is used to set extra query parameters on the + auth_uri. For example, the access_type and approval_prompt + query parameters can be set via kwargs. - Returns: - A URI as a string to redirect the user to begin the authorization flow. - """ - if redirect_uri is not None: - logger.warning(( - 'The redirect_uri parameter for ' - 'OAuth2WebServerFlow.step1_get_authorize_url is deprecated. Please ' - 'move to passing the redirect_uri in via the constructor.')) - self.redirect_uri = redirect_uri + Args: + client_id: string, client identifier. + client_secret: string client secret. + scope: string or iterable of strings, scope(s) of the credentials + being requested. + redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' + for a non-web-based application, or a URI that + handles the callback from the authorization server. + user_agent: string, HTTP User-Agent to provide for this + application. + auth_uri: string, URI for authorization endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider + can be used. + token_uri: string, URI for token endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. + login_hint: string, Either an email address or domain. Passing this + hint will either pre-fill the email box on the sign-in + form or select the proper multi-login session, thereby + simplifying the login flow. + device_uri: string, URI for device authorization endpoint. For + convenience defaults to Google's endpoints but any + OAuth 2.0 provider can be used. + authorization_header: string, For use with OAuth 2.0 providers that + require a client to authenticate using a + header value instead of passing client_secret + in the POST body. + **kwargs: dict, The keyword arguments are all optional and required + parameters for the OAuth calls. + """ + # scope is a required argument, but to preserve backwards-compatibility + # we don't want to rearrange the positional arguments + if scope is None: + raise TypeError("The value of scope must not be None") + self.client_id = client_id + self.client_secret = client_secret + self.scope = util.scopes_to_string(scope) + self.redirect_uri = redirect_uri + self.login_hint = login_hint + self.user_agent = user_agent + self.auth_uri = auth_uri + self.token_uri = token_uri + self.revoke_uri = revoke_uri + self.device_uri = device_uri + self.token_info_uri = token_info_uri + self.authorization_header = authorization_header + self.params = { + 'access_type': 'offline', + 'response_type': 'code', + } + self.params.update(kwargs) - if self.redirect_uri is None: - raise ValueError('The value of redirect_uri must not be None.') + @util.positional(1) + def step1_get_authorize_url(self, redirect_uri=None, state=None): + """Returns a URI to redirect to the provider. - query_params = { - 'client_id': self.client_id, - 'redirect_uri': self.redirect_uri, - 'scope': self.scope, - } - if self.login_hint is not None: - query_params['login_hint'] = self.login_hint - query_params.update(self.params) - return _update_query_params(self.auth_uri, query_params) + Args: + redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' + for a non-web-based application, or a URI that + handles the callback from the authorization server. + This parameter is deprecated, please move to passing + the redirect_uri in via the constructor. + state: string, Opaque state string which is passed through the + OAuth2 flow and returned to the client as a query parameter + in the callback. - @util.positional(1) - def step1_get_device_and_user_codes(self, http=None): - """Returns a user code and the verification URL where to enter it + Returns: + A URI as a string to redirect the user to begin the authorization + flow. + """ + if redirect_uri is not None: + logger.warning(( + 'The redirect_uri parameter for ' + 'OAuth2WebServerFlow.step1_get_authorize_url is deprecated. ' + 'Please move to passing the redirect_uri in via the ' + 'constructor.')) + self.redirect_uri = redirect_uri - Returns: - A user code as a string for the user to authorize the application - An URL as a string where the user has to enter the code - """ - if self.device_uri is None: - raise ValueError('The value of device_uri must not be None.') + if self.redirect_uri is None: + raise ValueError('The value of redirect_uri must not be None.') - body = urllib.parse.urlencode({ - 'client_id': self.client_id, - 'scope': self.scope, - }) - headers = { - 'content-type': 'application/x-www-form-urlencoded', - } + query_params = { + 'client_id': self.client_id, + 'redirect_uri': self.redirect_uri, + 'scope': self.scope, + } + if state is not None: + query_params['state'] = state + if self.login_hint is not None: + query_params['login_hint'] = self.login_hint + query_params.update(self.params) + return _update_query_params(self.auth_uri, query_params) - if self.user_agent is not None: - headers['user-agent'] = self.user_agent + @util.positional(1) + def step1_get_device_and_user_codes(self, http=None): + """Returns a user code and the verification URL where to enter it - if http is None: - http = httplib2.Http() + Returns: + A user code as a string for the user to authorize the application + An URL as a string where the user has to enter the code + """ + if self.device_uri is None: + raise ValueError('The value of device_uri must not be None.') - resp, content = http.request(self.device_uri, method='POST', body=body, - headers=headers) - if resp.status == 200: - try: - flow_info = json.loads(content) - except ValueError as e: - raise OAuth2DeviceCodeError( - 'Could not parse server response as JSON: "%s", error: "%s"' % ( - content, e)) - return DeviceFlowInfo.FromResponse(flow_info) - else: - error_msg = 'Invalid response %s.' % resp.status - try: - d = json.loads(content) - if 'error' in d: - error_msg += ' Error: %s' % d['error'] - except ValueError: - # Couldn't decode a JSON response, stick with the default message. - pass - raise OAuth2DeviceCodeError(error_msg) + body = urllib.parse.urlencode({ + 'client_id': self.client_id, + 'scope': self.scope, + }) + headers = { + 'content-type': 'application/x-www-form-urlencoded', + } - @util.positional(2) - def step2_exchange(self, code=None, http=None, device_flow_info=None): - """Exchanges a code for OAuth2Credentials. + if self.user_agent is not None: + headers['user-agent'] = self.user_agent - Args: + if http is None: + http = httplib2.Http() - code: string, a dict-like object, or None. For a non-device - flow, this is either the response code as a string, or a - dictionary of query parameters to the redirect_uri. For a - device flow, this should be None. - http: httplib2.Http, optional http instance to use when fetching - credentials. - device_flow_info: DeviceFlowInfo, return value from step1 in the - case of a device flow. + resp, content = http.request(self.device_uri, method='POST', body=body, + headers=headers) + content = _from_bytes(content) + if resp.status == 200: + try: + flow_info = json.loads(content) + except ValueError as e: + raise OAuth2DeviceCodeError( + 'Could not parse server response as JSON: "%s", ' + 'error: "%s"' % (content, e)) + return DeviceFlowInfo.FromResponse(flow_info) + else: + error_msg = 'Invalid response %s.' % resp.status + try: + d = json.loads(content) + if 'error' in d: + error_msg += ' Error: %s' % d['error'] + except ValueError: + # Couldn't decode a JSON response, stick with the + # default message. + pass + raise OAuth2DeviceCodeError(error_msg) - Returns: - An OAuth2Credentials object that can be used to authorize requests. + @util.positional(2) + def step2_exchange(self, code=None, http=None, device_flow_info=None): + """Exchanges a code for OAuth2Credentials. - Raises: - FlowExchangeError: if a problem occurred exchanging the code for a - refresh_token. - ValueError: if code and device_flow_info are both provided or both - missing. + Args: + code: string, a dict-like object, or None. For a non-device + flow, this is either the response code as a string, or a + dictionary of query parameters to the redirect_uri. For a + device flow, this should be None. + http: httplib2.Http, optional http instance to use when fetching + credentials. + device_flow_info: DeviceFlowInfo, return value from step1 in the + case of a device flow. - """ - if code is None and device_flow_info is None: - raise ValueError('No code or device_flow_info provided.') - if code is not None and device_flow_info is not None: - raise ValueError('Cannot provide both code and device_flow_info.') + Returns: + An OAuth2Credentials object that can be used to authorize requests. - if code is None: - code = device_flow_info.device_code - elif not isinstance(code, six.string_types): - if 'code' not in code: - raise FlowExchangeError(code.get( - 'error', 'No code was supplied in the query parameters.')) - code = code['code'] + Raises: + FlowExchangeError: if a problem occurred exchanging the code for a + refresh_token. + ValueError: if code and device_flow_info are both provided or both + missing. + """ + if code is None and device_flow_info is None: + raise ValueError('No code or device_flow_info provided.') + if code is not None and device_flow_info is not None: + raise ValueError('Cannot provide both code and device_flow_info.') - post_data = { - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'code': code, - 'scope': self.scope, - } - if device_flow_info is not None: - post_data['grant_type'] = 'http://oauth.net/grant_type/device/1.0' - else: - post_data['grant_type'] = 'authorization_code' - post_data['redirect_uri'] = self.redirect_uri - body = urllib.parse.urlencode(post_data) - headers = { - 'content-type': 'application/x-www-form-urlencoded', - } + if code is None: + code = device_flow_info.device_code + elif not isinstance(code, six.string_types): + if 'code' not in code: + raise FlowExchangeError(code.get( + 'error', 'No code was supplied in the query parameters.')) + code = code['code'] - if self.user_agent is not None: - headers['user-agent'] = self.user_agent + post_data = { + 'client_id': self.client_id, + 'code': code, + 'scope': self.scope, + } + if self.client_secret is not None: + post_data['client_secret'] = self.client_secret + if device_flow_info is not None: + post_data['grant_type'] = 'http://oauth.net/grant_type/device/1.0' + else: + post_data['grant_type'] = 'authorization_code' + post_data['redirect_uri'] = self.redirect_uri + body = urllib.parse.urlencode(post_data) + headers = { + 'content-type': 'application/x-www-form-urlencoded', + } + if self.authorization_header is not None: + headers['Authorization'] = self.authorization_header + if self.user_agent is not None: + headers['user-agent'] = self.user_agent - if http is None: - http = httplib2.Http() + if http is None: + http = httplib2.Http() - resp, content = http.request(self.token_uri, method='POST', body=body, - headers=headers) - d = _parse_exchange_token_response(content) - if resp.status == 200 and 'access_token' in d: - access_token = d['access_token'] - refresh_token = d.get('refresh_token', None) - if not refresh_token: - logger.info( - 'Received token response with no refresh_token. Consider ' - "reauthenticating with approval_prompt='force'.") - token_expiry = None - if 'expires_in' in d: - token_expiry = datetime.datetime.utcnow() + datetime.timedelta( - seconds=int(d['expires_in'])) + resp, content = http.request(self.token_uri, method='POST', body=body, + headers=headers) + d = _parse_exchange_token_response(content) + if resp.status == 200 and 'access_token' in d: + access_token = d['access_token'] + refresh_token = d.get('refresh_token', None) + if not refresh_token: + logger.info( + 'Received token response with no refresh_token. Consider ' + "reauthenticating with approval_prompt='force'.") + token_expiry = None + if 'expires_in' in d: + token_expiry = ( + datetime.datetime.utcnow() + + datetime.timedelta(seconds=int(d['expires_in']))) - extracted_id_token = None - if 'id_token' in d: - extracted_id_token = _extract_id_token(d['id_token']) + extracted_id_token = None + if 'id_token' in d: + extracted_id_token = _extract_id_token(d['id_token']) - logger.info('Successfully retrieved access token') - return OAuth2Credentials(access_token, self.client_id, - self.client_secret, refresh_token, token_expiry, - self.token_uri, self.user_agent, - revoke_uri=self.revoke_uri, - id_token=extracted_id_token, - token_response=d) - else: - logger.info('Failed to retrieve access token: %s', content) - if 'error' in d: - # you never know what those providers got to say - error_msg = str(d['error']) + str(d.get('error_description', '')) - else: - error_msg = 'Invalid response: %s.' % str(resp.status) - raise FlowExchangeError(error_msg) + logger.info('Successfully retrieved access token') + return OAuth2Credentials( + access_token, self.client_id, self.client_secret, + refresh_token, token_expiry, self.token_uri, self.user_agent, + revoke_uri=self.revoke_uri, id_token=extracted_id_token, + token_response=d, scopes=self.scope, + token_info_uri=self.token_info_uri) + else: + logger.info('Failed to retrieve access token: %s', content) + if 'error' in d: + # you never know what those providers got to say + error_msg = (str(d['error']) + + str(d.get('error_description', ''))) + else: + error_msg = 'Invalid response: %s.' % str(resp.status) + raise FlowExchangeError(error_msg) @util.positional(2) def flow_from_clientsecrets(filename, scope, redirect_uri=None, message=None, cache=None, login_hint=None, device_uri=None): - """Create a Flow from a clientsecrets file. + """Create a Flow from a clientsecrets file. - Will create the right kind of Flow based on the contents of the clientsecrets - file or will raise InvalidClientSecretsError for unknown types of Flows. + Will create the right kind of Flow based on the contents of the + clientsecrets file or will raise InvalidClientSecretsError for unknown + types of Flows. - Args: - filename: string, File name of client secrets. - scope: string or iterable of strings, scope(s) to request. - redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' for - a non-web-based application, or a URI that handles the callback from - the authorization server. - message: string, A friendly string to display to the user if the - clientsecrets file is missing or invalid. If message is provided then - sys.exit will be called in the case of an error. If message in not - provided then clientsecrets.InvalidClientSecretsError will be raised. - cache: An optional cache service client that implements get() and set() - methods. See clientsecrets.loadfile() for details. - login_hint: string, Either an email address or domain. Passing this hint - will either pre-fill the email box on the sign-in form or select the - proper multi-login session, thereby simplifying the login flow. - device_uri: string, URI for device authorization endpoint. For convenience - defaults to Google's endpoints but any OAuth 2.0 provider can be used. + Args: + filename: string, File name of client secrets. + scope: string or iterable of strings, scope(s) to request. + redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' for + a non-web-based application, or a URI that handles the + callback from the authorization server. + message: string, A friendly string to display to the user if the + clientsecrets file is missing or invalid. If message is + provided then sys.exit will be called in the case of an error. + If message in not provided then + clientsecrets.InvalidClientSecretsError will be raised. + cache: An optional cache service client that implements get() and set() + methods. See clientsecrets.loadfile() for details. + login_hint: string, Either an email address or domain. Passing this + hint will either pre-fill the email box on the sign-in form + or select the proper multi-login session, thereby + simplifying the login flow. + device_uri: string, URI for device authorization endpoint. For + convenience defaults to Google's endpoints but any + OAuth 2.0 provider can be used. - Returns: - A Flow object. + Returns: + A Flow object. - Raises: - UnknownClientSecretsFlowError if the file describes an unknown kind of Flow. - clientsecrets.InvalidClientSecretsError if the clientsecrets file is - invalid. - """ - try: - client_type, client_info = clientsecrets.loadfile(filename, cache=cache) - if client_type in (clientsecrets.TYPE_WEB, clientsecrets.TYPE_INSTALLED): - constructor_kwargs = { - 'redirect_uri': redirect_uri, - 'auth_uri': client_info['auth_uri'], - 'token_uri': client_info['token_uri'], - 'login_hint': login_hint, - } - revoke_uri = client_info.get('revoke_uri') - if revoke_uri is not None: - constructor_kwargs['revoke_uri'] = revoke_uri - if device_uri is not None: - constructor_kwargs['device_uri'] = device_uri - return OAuth2WebServerFlow( - client_info['client_id'], client_info['client_secret'], - scope, **constructor_kwargs) + Raises: + UnknownClientSecretsFlowError: if the file describes an unknown kind of + Flow. + clientsecrets.InvalidClientSecretsError: if the clientsecrets file is + invalid. + """ + try: + client_type, client_info = clientsecrets.loadfile(filename, + cache=cache) + if client_type in (clientsecrets.TYPE_WEB, + clientsecrets.TYPE_INSTALLED): + constructor_kwargs = { + 'redirect_uri': redirect_uri, + 'auth_uri': client_info['auth_uri'], + 'token_uri': client_info['token_uri'], + 'login_hint': login_hint, + } + revoke_uri = client_info.get('revoke_uri') + if revoke_uri is not None: + constructor_kwargs['revoke_uri'] = revoke_uri + if device_uri is not None: + constructor_kwargs['device_uri'] = device_uri + return OAuth2WebServerFlow( + client_info['client_id'], client_info['client_secret'], + scope, **constructor_kwargs) - except clientsecrets.InvalidClientSecretsError: - if message: - sys.exit(message) + except clientsecrets.InvalidClientSecretsError: + if message: + sys.exit(message) + else: + raise else: - raise - else: - raise UnknownClientSecretsFlowError( - 'This OAuth 2.0 flow is unsupported: %r' % client_type) + raise UnknownClientSecretsFlowError( + 'This OAuth 2.0 flow is unsupported: %r' % client_type) diff --git a/oauth2client/clientsecrets.py b/oauth2client/clientsecrets.py index 08a17020..eba1fd9d 100644 --- a/oauth2client/clientsecrets.py +++ b/oauth2client/clientsecrets.py @@ -18,12 +18,12 @@ A client_secrets.json file contains all the information needed to interact with an OAuth 2.0 protected service. """ -__author__ = 'jcgregorio@google.com (Joe Gregorio)' - import json import six +__author__ = 'jcgregorio@google.com (Joe Gregorio)' + # Properties that make a client_secrets.json file valid. TYPE_WEB = 'web' TYPE_INSTALLED = 'installed' @@ -59,105 +59,115 @@ VALID_CLIENT = { class Error(Exception): - """Base error for this module.""" - pass + """Base error for this module.""" class InvalidClientSecretsError(Error): - """Format of ClientSecrets file is invalid.""" - pass + """Format of ClientSecrets file is invalid.""" -def _validate_clientsecrets(obj): - _INVALID_FILE_FORMAT_MSG = ( - 'Invalid file format. See ' - 'https://developers.google.com/api-client-library/' - 'python/guide/aaa_client_secrets') +def _validate_clientsecrets(clientsecrets_dict): + """Validate parsed client secrets from a file. - if obj is None: - raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG) - if len(obj) != 1: - raise InvalidClientSecretsError( - _INVALID_FILE_FORMAT_MSG + ' ' - 'Expected a JSON object with a single property for a "web" or ' - '"installed" application') - client_type = tuple(obj)[0] - if client_type not in VALID_CLIENT: - raise InvalidClientSecretsError('Unknown client type: %s.' % (client_type,)) - client_info = obj[client_type] - for prop_name in VALID_CLIENT[client_type]['required']: - if prop_name not in client_info: - raise InvalidClientSecretsError( - 'Missing property "%s" in a client type of "%s".' % (prop_name, - client_type)) - for prop_name in VALID_CLIENT[client_type]['string']: - if client_info[prop_name].startswith('[['): - raise InvalidClientSecretsError( - 'Property "%s" is not configured.' % prop_name) - return client_type, client_info + Args: + clientsecrets_dict: dict, a dictionary holding the client secrets. + + Returns: + tuple, a string of the client type and the information parsed + from the file. + """ + _INVALID_FILE_FORMAT_MSG = ( + 'Invalid file format. See ' + 'https://developers.google.com/api-client-library/' + 'python/guide/aaa_client_secrets') + + if clientsecrets_dict is None: + raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG) + try: + (client_type, client_info), = clientsecrets_dict.items() + except (ValueError, AttributeError): + raise InvalidClientSecretsError( + _INVALID_FILE_FORMAT_MSG + ' ' + 'Expected a JSON object with a single property for a "web" or ' + '"installed" application') + + if client_type not in VALID_CLIENT: + raise InvalidClientSecretsError( + 'Unknown client type: %s.' % (client_type,)) + + for prop_name in VALID_CLIENT[client_type]['required']: + if prop_name not in client_info: + raise InvalidClientSecretsError( + 'Missing property "%s" in a client type of "%s".' % + (prop_name, client_type)) + for prop_name in VALID_CLIENT[client_type]['string']: + if client_info[prop_name].startswith('[['): + raise InvalidClientSecretsError( + 'Property "%s" is not configured.' % prop_name) + return client_type, client_info def load(fp): - obj = json.load(fp) - return _validate_clientsecrets(obj) + obj = json.load(fp) + return _validate_clientsecrets(obj) def loads(s): - obj = json.loads(s) - return _validate_clientsecrets(obj) + obj = json.loads(s) + return _validate_clientsecrets(obj) def _loadfile(filename): - try: - with open(filename, 'r') as fp: - obj = json.load(fp) - except IOError: - raise InvalidClientSecretsError('File not found: "%s"' % filename) - return _validate_clientsecrets(obj) + try: + with open(filename, 'r') as fp: + obj = json.load(fp) + except IOError: + raise InvalidClientSecretsError('File not found: "%s"' % filename) + return _validate_clientsecrets(obj) def loadfile(filename, cache=None): - """Loading of client_secrets JSON file, optionally backed by a cache. + """Loading of client_secrets JSON file, optionally backed by a cache. - Typical cache storage would be App Engine memcache service, - but you can pass in any other cache client that implements - these methods: + Typical cache storage would be App Engine memcache service, + but you can pass in any other cache client that implements + these methods: - * ``get(key, namespace=ns)`` - * ``set(key, value, namespace=ns)`` + * ``get(key, namespace=ns)`` + * ``set(key, value, namespace=ns)`` - Usage:: + Usage:: - # without caching - client_type, client_info = loadfile('secrets.json') - # using App Engine memcache service - from google.appengine.api import memcache - client_type, client_info = loadfile('secrets.json', cache=memcache) + # without caching + client_type, client_info = loadfile('secrets.json') + # using App Engine memcache service + from google.appengine.api import memcache + client_type, client_info = loadfile('secrets.json', cache=memcache) - Args: - filename: string, Path to a client_secrets.json file on a filesystem. - cache: An optional cache service client that implements get() and set() - methods. If not specified, the file is always being loaded from - a filesystem. + Args: + filename: string, Path to a client_secrets.json file on a filesystem. + cache: An optional cache service client that implements get() and set() + methods. If not specified, the file is always being loaded from + a filesystem. - Raises: - InvalidClientSecretsError: In case of a validation error or some - I/O failure. Can happen only on cache miss. + Raises: + InvalidClientSecretsError: In case of a validation error or some + I/O failure. Can happen only on cache miss. - Returns: - (client_type, client_info) tuple, as _loadfile() normally would. - JSON contents is validated only during first load. Cache hits are not - validated. - """ - _SECRET_NAMESPACE = 'oauth2client:secrets#ns' + Returns: + (client_type, client_info) tuple, as _loadfile() normally would. + JSON contents is validated only during first load. Cache hits are not + validated. + """ + _SECRET_NAMESPACE = 'oauth2client:secrets#ns' - if not cache: - return _loadfile(filename) + if not cache: + return _loadfile(filename) - obj = cache.get(filename, namespace=_SECRET_NAMESPACE) - if obj is None: - client_type, client_info = _loadfile(filename) - obj = {client_type: client_info} - cache.set(filename, obj, namespace=_SECRET_NAMESPACE) + obj = cache.get(filename, namespace=_SECRET_NAMESPACE) + if obj is None: + client_type, client_info = _loadfile(filename) + obj = {client_type: client_info} + cache.set(filename, obj, namespace=_SECRET_NAMESPACE) - return next(six.iteritems(obj)) + return next(six.iteritems(obj)) diff --git a/oauth2client/crypt.py b/oauth2client/crypt.py index 381f389e..c450c5c6 100644 --- a/oauth2client/crypt.py +++ b/oauth2client/crypt.py @@ -15,415 +15,229 @@ # limitations under the License. """Crypto-related routines for oauth2client.""" -import base64 import json import logging -import sys import time -import six +from oauth2client._helpers import _from_bytes +from oauth2client._helpers import _json_encode +from oauth2client._helpers import _to_bytes +from oauth2client._helpers import _urlsafe_b64decode +from oauth2client._helpers import _urlsafe_b64encode CLOCK_SKEW_SECS = 300 # 5 minutes in seconds AUTH_TOKEN_LIFETIME_SECS = 300 # 5 minutes in seconds MAX_TOKEN_LIFETIME_SECS = 86400 # 1 day in seconds - logger = logging.getLogger(__name__) class AppIdentityError(Exception): - pass + """Error to indicate crypto failure.""" -try: - from OpenSSL import crypto - - class OpenSSLVerifier(object): - """Verifies the signature on a message.""" - - def __init__(self, pubkey): - """Constructor. - - Args: - pubkey, OpenSSL.crypto.PKey, The public key to verify with. - """ - self._pubkey = pubkey - - def verify(self, message, signature): - """Verifies a message against a signature. - - Args: - message: string, The message to verify. - signature: string, The signature on the message. - - Returns: - True if message was signed by the private key associated with the public - key that this object was constructed with. - """ - try: - if isinstance(message, six.text_type): - message = message.encode('utf-8') - crypto.verify(self._pubkey, signature, message, 'sha256') - return True - except: - return False - - @staticmethod - def from_string(key_pem, is_x509_cert): - """Construct a Verified instance from a string. - - Args: - key_pem: string, public key in PEM format. - is_x509_cert: bool, True if key_pem is an X509 cert, otherwise it is - expected to be an RSA key in PEM format. - - Returns: - Verifier instance. - - Raises: - OpenSSL.crypto.Error if the key_pem can't be parsed. - """ - if is_x509_cert: - pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, key_pem) - else: - pubkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem) - return OpenSSLVerifier(pubkey) - - - class OpenSSLSigner(object): - """Signs messages with a private key.""" - - def __init__(self, pkey): - """Constructor. - - Args: - pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with. - """ - self._key = pkey - - def sign(self, message): - """Signs a message. - - Args: - message: bytes, Message to be signed. - - Returns: - string, The signature of the message for the given key. - """ - if isinstance(message, six.text_type): - message = message.encode('utf-8') - return crypto.sign(self._key, message, 'sha256') - - @staticmethod - def from_string(key, password=b'notasecret'): - """Construct a Signer instance from a string. - - Args: - key: string, private key in PKCS12 or PEM format. - password: string, password for the private key file. - - Returns: - Signer instance. - - Raises: - OpenSSL.crypto.Error if the key can't be parsed. - """ - parsed_pem_key = _parse_pem_key(key) - if parsed_pem_key: - pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key) - else: - if isinstance(password, six.text_type): - password = password.encode('utf-8') - pkey = crypto.load_pkcs12(key, password).get_privatekey() - return OpenSSLSigner(pkey) - - - def pkcs12_key_as_pem(private_key_text, private_key_password): - """Convert the contents of a PKCS12 key to PEM using OpenSSL. - - Args: - private_key_text: String. Private key. - private_key_password: String. Password for PKCS12. - - Returns: - String. PEM contents of ``private_key_text``. - """ - decoded_body = base64.b64decode(private_key_text) - if isinstance(private_key_password, six.string_types): - private_key_password = private_key_password.encode('ascii') - - pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password) - return crypto.dump_privatekey(crypto.FILETYPE_PEM, - pkcs12.get_privatekey()) -except ImportError: - OpenSSLVerifier = None - OpenSSLSigner = None - def pkcs12_key_as_pem(*args, **kwargs): +def _bad_pkcs12_key_as_pem(*args, **kwargs): raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.') try: - from Crypto.PublicKey import RSA - from Crypto.Hash import SHA256 - from Crypto.Signature import PKCS1_v1_5 - from Crypto.Util.asn1 import DerSequence + from oauth2client._openssl_crypt import OpenSSLVerifier + from oauth2client._openssl_crypt import OpenSSLSigner + from oauth2client._openssl_crypt import pkcs12_key_as_pem +except ImportError: # pragma: NO COVER + OpenSSLVerifier = None + OpenSSLSigner = None + pkcs12_key_as_pem = _bad_pkcs12_key_as_pem - - class PyCryptoVerifier(object): - """Verifies the signature on a message.""" - - def __init__(self, pubkey): - """Constructor. - - Args: - pubkey, OpenSSL.crypto.PKey (or equiv), The public key to verify with. - """ - self._pubkey = pubkey - - def verify(self, message, signature): - """Verifies a message against a signature. - - Args: - message: string, The message to verify. - signature: string, The signature on the message. - - Returns: - True if message was signed by the private key associated with the public - key that this object was constructed with. - """ - try: - return PKCS1_v1_5.new(self._pubkey).verify( - SHA256.new(message), signature) - except: - return False - - @staticmethod - def from_string(key_pem, is_x509_cert): - """Construct a Verified instance from a string. - - Args: - key_pem: string, public key in PEM format. - is_x509_cert: bool, True if key_pem is an X509 cert, otherwise it is - expected to be an RSA key in PEM format. - - Returns: - Verifier instance. - """ - if is_x509_cert: - if isinstance(key_pem, six.text_type): - key_pem = key_pem.encode('ascii') - pemLines = key_pem.replace(b' ', b'').split() - certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1])) - certSeq = DerSequence() - certSeq.decode(certDer) - tbsSeq = DerSequence() - tbsSeq.decode(certSeq[0]) - pubkey = RSA.importKey(tbsSeq[6]) - else: - pubkey = RSA.importKey(key_pem) - return PyCryptoVerifier(pubkey) - - - class PyCryptoSigner(object): - """Signs messages with a private key.""" - - def __init__(self, pkey): - """Constructor. - - Args: - pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with. - """ - self._key = pkey - - def sign(self, message): - """Signs a message. - - Args: - message: string, Message to be signed. - - Returns: - string, The signature of the message for the given key. - """ - if isinstance(message, six.text_type): - message = message.encode('utf-8') - return PKCS1_v1_5.new(self._key).sign(SHA256.new(message)) - - @staticmethod - def from_string(key, password='notasecret'): - """Construct a Signer instance from a string. - - Args: - key: string, private key in PEM format. - password: string, password for private key file. Unused for PEM files. - - Returns: - Signer instance. - - Raises: - NotImplementedError if they key isn't in PEM format. - """ - parsed_pem_key = _parse_pem_key(key) - if parsed_pem_key: - pkey = RSA.importKey(parsed_pem_key) - else: - raise NotImplementedError( - 'PKCS12 format is not supported by the PyCrypto library. ' - 'Try converting to a "PEM" ' - '(openssl pkcs12 -in xxxxx.p12 -nodes -nocerts > privatekey.pem) ' - 'or using PyOpenSSL if native code is an option.') - return PyCryptoSigner(pkey) - -except ImportError: - PyCryptoVerifier = None - PyCryptoSigner = None +try: + from oauth2client._pycrypto_crypt import PyCryptoVerifier + from oauth2client._pycrypto_crypt import PyCryptoSigner +except ImportError: # pragma: NO COVER + PyCryptoVerifier = None + PyCryptoSigner = None if OpenSSLSigner: - Signer = OpenSSLSigner - Verifier = OpenSSLVerifier -elif PyCryptoSigner: - Signer = PyCryptoSigner - Verifier = PyCryptoVerifier -else: - raise ImportError('No encryption library found. Please install either ' - 'PyOpenSSL, or PyCrypto 2.6 or later') - - -def _parse_pem_key(raw_key_input): - """Identify and extract PEM keys. - - Determines whether the given key is in the format of PEM key, and extracts - the relevant part of the key if it is. - - Args: - raw_key_input: The contents of a private key file (either PEM or PKCS12). - - Returns: - string, The actual key if the contents are from a PEM file, or else None. - """ - offset = raw_key_input.find(b'-----BEGIN ') - if offset != -1: - return raw_key_input[offset:] - - -def _urlsafe_b64encode(raw_bytes): - if isinstance(raw_bytes, six.text_type): - raw_bytes = raw_bytes.encode('utf-8') - return base64.urlsafe_b64encode(raw_bytes).decode('ascii').rstrip('=') - - -def _urlsafe_b64decode(b64string): - # Guard against unicode strings, which base64 can't handle. - if isinstance(b64string, six.text_type): - b64string = b64string.encode('ascii') - padded = b64string + b'=' * (4 - len(b64string) % 4) - return base64.urlsafe_b64decode(padded) - - -def _json_encode(data): - return json.dumps(data, separators=(',', ':')) + Signer = OpenSSLSigner + Verifier = OpenSSLVerifier +elif PyCryptoSigner: # pragma: NO COVER + Signer = PyCryptoSigner + Verifier = PyCryptoVerifier +else: # pragma: NO COVER + raise ImportError('No encryption library found. Please install either ' + 'PyOpenSSL, or PyCrypto 2.6 or later') def make_signed_jwt(signer, payload): - """Make a signed JWT. + """Make a signed JWT. - See http://self-issued.info/docs/draft-jones-json-web-token.html. + See http://self-issued.info/docs/draft-jones-json-web-token.html. - Args: - signer: crypt.Signer, Cryptographic signer. - payload: dict, Dictionary of data to convert to JSON and then sign. + Args: + signer: crypt.Signer, Cryptographic signer. + payload: dict, Dictionary of data to convert to JSON and then sign. - Returns: - string, The JWT for the payload. - """ - header = {'typ': 'JWT', 'alg': 'RS256'} + Returns: + string, The JWT for the payload. + """ + header = {'typ': 'JWT', 'alg': 'RS256'} - segments = [ + segments = [ _urlsafe_b64encode(_json_encode(header)), _urlsafe_b64encode(_json_encode(payload)), - ] - signing_input = '.'.join(segments) + ] + signing_input = b'.'.join(segments) - signature = signer.sign(signing_input) - segments.append(_urlsafe_b64encode(signature)) + signature = signer.sign(signing_input) + segments.append(_urlsafe_b64encode(signature)) - logger.debug(str(segments)) + logger.debug(str(segments)) - return '.'.join(segments) + return b'.'.join(segments) -def verify_signed_jwt_with_certs(jwt, certs, audience): - """Verify a JWT against public certs. +def _verify_signature(message, signature, certs): + """Verifies signed content using a list of certificates. - See http://self-issued.info/docs/draft-jones-json-web-token.html. + Args: + message: string or bytes, The message to verify. + signature: string or bytes, The signature on the message. + certs: iterable, certificates in PEM format. - Args: - jwt: string, A JWT. - certs: dict, Dictionary where values of public keys in PEM format. - audience: string, The audience, 'aud', that this JWT should contain. If - None then the JWT's 'aud' parameter is not verified. + Raises: + AppIdentityError: If none of the certificates can verify the message + against the signature. + """ + for pem in certs: + verifier = Verifier.from_string(pem, is_x509_cert=True) + if verifier.verify(message, signature): + return - Returns: - dict, The deserialized JSON payload in the JWT. + # If we have not returned, no certificate confirms the signature. + raise AppIdentityError('Invalid token signature') - Raises: - AppIdentityError if any checks are failed. - """ - segments = jwt.split('.') - if len(segments) != 3: - raise AppIdentityError('Wrong number of segments in token: %s' % jwt) - signed = '%s.%s' % (segments[0], segments[1]) +def _check_audience(payload_dict, audience): + """Checks audience field from a JWT payload. - signature = _urlsafe_b64decode(segments[2]) + Does nothing if the passed in ``audience`` is null. - # Parse token. - json_body = _urlsafe_b64decode(segments[1]) - try: - parsed = json.loads(json_body.decode('utf-8')) - except: - raise AppIdentityError('Can\'t parse token: %s' % json_body) + Args: + payload_dict: dict, A dictionary containing a JWT payload. + audience: string or NoneType, an audience to check for in + the JWT payload. - # Check signature. - verified = False - for pem in certs.values(): - verifier = Verifier.from_string(pem, True) - if verifier.verify(signed, signature): - verified = True - break - if not verified: - raise AppIdentityError('Invalid token signature: %s' % jwt) + Raises: + AppIdentityError: If there is no ``'aud'`` field in the payload + dictionary but there is an ``audience`` to check. + AppIdentityError: If the ``'aud'`` field in the payload dictionary + does not match the ``audience``. + """ + if audience is None: + return - # Check creation timestamp. - iat = parsed.get('iat') - if iat is None: - raise AppIdentityError('No iat field in token: %s' % json_body) - earliest = iat - CLOCK_SKEW_SECS + audience_in_payload = payload_dict.get('aud') + if audience_in_payload is None: + raise AppIdentityError('No aud field in token: %s' % + (payload_dict,)) + if audience_in_payload != audience: + raise AppIdentityError('Wrong recipient, %s != %s: %s' % + (audience_in_payload, audience, payload_dict)) - # Check expiration timestamp. - now = int(time.time()) - exp = parsed.get('exp') - if exp is None: - raise AppIdentityError('No exp field in token: %s' % json_body) - if exp >= now + MAX_TOKEN_LIFETIME_SECS: - raise AppIdentityError('exp field too far in future: %s' % json_body) - latest = exp + CLOCK_SKEW_SECS - if now < earliest: - raise AppIdentityError('Token used too early, %d < %d: %s' % - (now, earliest, json_body)) - if now > latest: - raise AppIdentityError('Token used too late, %d > %d: %s' % - (now, latest, json_body)) +def _verify_time_range(payload_dict): + """Verifies the issued at and expiration from a JWT payload. - # Check audience. - if audience is not None: - aud = parsed.get('aud') - if aud is None: - raise AppIdentityError('No aud field in token: %s' % json_body) - if aud != audience: - raise AppIdentityError('Wrong recipient, %s != %s: %s' % - (aud, audience, json_body)) + Makes sure the current time (in UTC) falls between the issued at and + expiration for the JWT (with some skew allowed for via + ``CLOCK_SKEW_SECS``). - return parsed + Args: + payload_dict: dict, A dictionary containing a JWT payload. + + Raises: + AppIdentityError: If there is no ``'iat'`` field in the payload + dictionary. + AppIdentityError: If there is no ``'exp'`` field in the payload + dictionary. + AppIdentityError: If the JWT expiration is too far in the future (i.e. + if the expiration would imply a token lifetime + longer than what is allowed.) + AppIdentityError: If the token appears to have been issued in the + future (up to clock skew). + AppIdentityError: If the token appears to have expired in the past + (up to clock skew). + """ + # Get the current time to use throughout. + now = int(time.time()) + + # Make sure issued at and expiration are in the payload. + issued_at = payload_dict.get('iat') + if issued_at is None: + raise AppIdentityError('No iat field in token: %s' % (payload_dict,)) + expiration = payload_dict.get('exp') + if expiration is None: + raise AppIdentityError('No exp field in token: %s' % (payload_dict,)) + + # Make sure the expiration gives an acceptable token lifetime. + if expiration >= now + MAX_TOKEN_LIFETIME_SECS: + raise AppIdentityError('exp field too far in future: %s' % + (payload_dict,)) + + # Make sure (up to clock skew) that the token wasn't issued in the future. + earliest = issued_at - CLOCK_SKEW_SECS + if now < earliest: + raise AppIdentityError('Token used too early, %d < %d: %s' % + (now, earliest, payload_dict)) + # Make sure (up to clock skew) that the token isn't already expired. + latest = expiration + CLOCK_SKEW_SECS + if now > latest: + raise AppIdentityError('Token used too late, %d > %d: %s' % + (now, latest, payload_dict)) + + +def verify_signed_jwt_with_certs(jwt, certs, audience=None): + """Verify a JWT against public certs. + + See http://self-issued.info/docs/draft-jones-json-web-token.html. + + Args: + jwt: string, A JWT. + certs: dict, Dictionary where values of public keys in PEM format. + audience: string, The audience, 'aud', that this JWT should contain. If + None then the JWT's 'aud' parameter is not verified. + + Returns: + dict, The deserialized JSON payload in the JWT. + + Raises: + AppIdentityError: if any checks are failed. + """ + jwt = _to_bytes(jwt) + + if jwt.count(b'.') != 2: + raise AppIdentityError( + 'Wrong number of segments in token: %s' % (jwt,)) + + header, payload, signature = jwt.split(b'.') + message_to_sign = header + b'.' + payload + signature = _urlsafe_b64decode(signature) + + # Parse token. + payload_bytes = _urlsafe_b64decode(payload) + try: + payload_dict = json.loads(_from_bytes(payload_bytes)) + except: + raise AppIdentityError('Can\'t parse token: %s' % (payload_bytes,)) + + # Verify that the signature matches the message. + _verify_signature(message_to_sign, signature, certs.values()) + + # Verify the issued at and created times in the payload. + _verify_time_range(payload_dict) + + # Check audience. + _check_audience(payload_dict, audience) + + return payload_dict diff --git a/oauth2client/devshell.py b/oauth2client/devshell.py index a33de871..8131affe 100644 --- a/oauth2client/devshell.py +++ b/oauth2client/devshell.py @@ -16,121 +16,122 @@ import json import os +import socket +from oauth2client._helpers import _to_bytes from oauth2client import client - DEVSHELL_ENV = 'DEVSHELL_CLIENT_PORT' class Error(Exception): - """Errors for this module.""" - pass + """Errors for this module.""" + pass class CommunicationError(Error): - """Errors for communication with the Developer Shell server.""" + """Errors for communication with the Developer Shell server.""" class NoDevshellServer(Error): - """Error when no Developer Shell server can be contacted.""" + """Error when no Developer Shell server can be contacted.""" - -# The request for credential information to the Developer Shell client socket is -# always an empty PBLite-formatted JSON object, so just define it as a constant. +# The request for credential information to the Developer Shell client socket +# is always an empty PBLite-formatted JSON object, so just define it as a +# constant. CREDENTIAL_INFO_REQUEST_JSON = '[]' class CredentialInfoResponse(object): - """Credential information response from Developer Shell server. + """Credential information response from Developer Shell server. - The credential information response from Developer Shell socket is a - PBLite-formatted JSON array with fields encoded by their index in the array: - * Index 0 - user email - * Index 1 - default project ID. None if the project context is not known. - * Index 2 - OAuth2 access token. None if there is no valid auth context. - """ + The credential information response from Developer Shell socket is a + PBLite-formatted JSON array with fields encoded by their index in the + array: - def __init__(self, json_string): - """Initialize the response data from JSON PBLite array.""" - pbl = json.loads(json_string) - if not isinstance(pbl, list): - raise ValueError('Not a list: ' + str(pbl)) - pbl_len = len(pbl) - self.user_email = pbl[0] if pbl_len > 0 else None - self.project_id = pbl[1] if pbl_len > 1 else None - self.access_token = pbl[2] if pbl_len > 2 else None + * Index 0 - user email + * Index 1 - default project ID. None if the project context is not known. + * Index 2 - OAuth2 access token. None if there is no valid auth context. + """ + + def __init__(self, json_string): + """Initialize the response data from JSON PBLite array.""" + pbl = json.loads(json_string) + if not isinstance(pbl, list): + raise ValueError('Not a list: ' + str(pbl)) + pbl_len = len(pbl) + self.user_email = pbl[0] if pbl_len > 0 else None + self.project_id = pbl[1] if pbl_len > 1 else None + self.access_token = pbl[2] if pbl_len > 2 else None def _SendRecv(): - """Communicate with the Developer Shell server socket.""" + """Communicate with the Developer Shell server socket.""" - port = int(os.getenv(DEVSHELL_ENV, 0)) - if port == 0: - raise NoDevshellServer() + port = int(os.getenv(DEVSHELL_ENV, 0)) + if port == 0: + raise NoDevshellServer() - import socket + sock = socket.socket() + sock.connect(('localhost', port)) - sock = socket.socket() - sock.connect(('localhost', port)) + data = CREDENTIAL_INFO_REQUEST_JSON + msg = '%s\n%s' % (len(data), data) + sock.sendall(_to_bytes(msg, encoding='utf-8')) - data = CREDENTIAL_INFO_REQUEST_JSON - msg = '%s\n%s' % (len(data), data) - sock.sendall(msg.encode()) + header = sock.recv(6).decode() + if '\n' not in header: + raise CommunicationError('saw no newline in the first 6 bytes') + len_str, json_str = header.split('\n', 1) + to_read = int(len_str) - len(json_str) + if to_read > 0: + json_str += sock.recv(to_read, socket.MSG_WAITALL).decode() - header = sock.recv(6).decode() - if '\n' not in header: - raise CommunicationError('saw no newline in the first 6 bytes') - len_str, json_str = header.split('\n', 1) - to_read = int(len_str) - len(json_str) - if to_read > 0: - json_str += sock.recv(to_read, socket.MSG_WAITALL).decode() - - return CredentialInfoResponse(json_str) + return CredentialInfoResponse(json_str) class DevshellCredentials(client.GoogleCredentials): - """Credentials object for Google Developer Shell environment. + """Credentials object for Google Developer Shell environment. - This object will allow a Google Developer Shell session to identify its user - to Google and other OAuth 2.0 servers that can verify assertions. It can be - used for the purpose of accessing data stored under the user account. + This object will allow a Google Developer Shell session to identify its + user to Google and other OAuth 2.0 servers that can verify assertions. It + can be used for the purpose of accessing data stored under the user + account. - This credential does not require a flow to instantiate because it represents - a two legged flow, and therefore has all of the required information to - generate and refresh its own access tokens. - """ + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. + """ - def __init__(self, user_agent=None): - super(DevshellCredentials, self).__init__( - None, # access_token, initialized below - None, # client_id - None, # client_secret - None, # refresh_token - None, # token_expiry - None, # token_uri - user_agent) - self._refresh(None) + def __init__(self, user_agent=None): + super(DevshellCredentials, self).__init__( + None, # access_token, initialized below + None, # client_id + None, # client_secret + None, # refresh_token + None, # token_expiry + None, # token_uri + user_agent) + self._refresh(None) - def _refresh(self, http_request): - self.devshell_response = _SendRecv() - self.access_token = self.devshell_response.access_token + def _refresh(self, http_request): + self.devshell_response = _SendRecv() + self.access_token = self.devshell_response.access_token - @property - def user_email(self): - return self.devshell_response.user_email + @property + def user_email(self): + return self.devshell_response.user_email - @property - def project_id(self): - return self.devshell_response.project_id + @property + def project_id(self): + return self.devshell_response.project_id - @classmethod - def from_json(cls, json_data): - raise NotImplementedError( - 'Cannot load Developer Shell credentials from JSON.') - - @property - def serialization_data(self): - raise NotImplementedError( - 'Cannot serialize Developer Shell credentials.') + @classmethod + def from_json(cls, json_data): + raise NotImplementedError( + 'Cannot load Developer Shell credentials from JSON.') + @property + def serialization_data(self): + raise NotImplementedError( + 'Cannot serialize Developer Shell credentials.') diff --git a/oauth2client/django_orm.py b/oauth2client/django_orm.py index 65c5d201..d119f042 100644 --- a/oauth2client/django_orm.py +++ b/oauth2client/django_orm.py @@ -18,8 +18,6 @@ Utilities for using OAuth 2.0 in conjunction with the Django datastore. """ -__author__ = 'jcgregorio@google.com (Joe Gregorio)' - import oauth2client import base64 import pickle @@ -27,115 +25,122 @@ import pickle from django.db import models from oauth2client.client import Storage as BaseStorage + +__author__ = 'jcgregorio@google.com (Joe Gregorio)' + + class CredentialsField(models.Field): - __metaclass__ = models.SubfieldBase + __metaclass__ = models.SubfieldBase - def __init__(self, *args, **kwargs): - if 'null' not in kwargs: - kwargs['null'] = True - super(CredentialsField, self).__init__(*args, **kwargs) + def __init__(self, *args, **kwargs): + if 'null' not in kwargs: + kwargs['null'] = True + super(CredentialsField, self).__init__(*args, **kwargs) - def get_internal_type(self): - return "TextField" + def get_internal_type(self): + return "TextField" - def to_python(self, value): - if value is None: - return None - if isinstance(value, oauth2client.client.Credentials): - return value - return pickle.loads(base64.b64decode(value)) + def to_python(self, value): + if value is None: + return None + if isinstance(value, oauth2client.client.Credentials): + return value + return pickle.loads(base64.b64decode(value)) - def get_db_prep_value(self, value, connection, prepared=False): - if value is None: - return None - return base64.b64encode(pickle.dumps(value)) + def get_db_prep_value(self, value, connection, prepared=False): + if value is None: + return None + return base64.b64encode(pickle.dumps(value)) class FlowField(models.Field): - __metaclass__ = models.SubfieldBase + __metaclass__ = models.SubfieldBase - def __init__(self, *args, **kwargs): - if 'null' not in kwargs: - kwargs['null'] = True - super(FlowField, self).__init__(*args, **kwargs) + def __init__(self, *args, **kwargs): + if 'null' not in kwargs: + kwargs['null'] = True + super(FlowField, self).__init__(*args, **kwargs) - def get_internal_type(self): - return "TextField" + def get_internal_type(self): + return "TextField" - def to_python(self, value): - if value is None: - return None - if isinstance(value, oauth2client.client.Flow): - return value - return pickle.loads(base64.b64decode(value)) + def to_python(self, value): + if value is None: + return None + if isinstance(value, oauth2client.client.Flow): + return value + return pickle.loads(base64.b64decode(value)) - def get_db_prep_value(self, value, connection, prepared=False): - if value is None: - return None - return base64.b64encode(pickle.dumps(value)) + def get_db_prep_value(self, value, connection, prepared=False): + if value is None: + return None + return base64.b64encode(pickle.dumps(value)) class Storage(BaseStorage): - """Store and retrieve a single credential to and from - the datastore. + """Store and retrieve a single credential to and from the datastore. - This Storage helper presumes the Credentials - have been stored as a CredenialsField - on a db model class. - """ - - def __init__(self, model_class, key_name, key_value, property_name): - """Constructor for Storage. - - Args: - model: db.Model, model class - key_name: string, key name for the entity that has the credentials - key_value: string, key value for the entity that has the credentials - property_name: string, name of the property that is an CredentialsProperty + This Storage helper presumes the Credentials + have been stored as a CredenialsField + on a db model class. """ - self.model_class = model_class - self.key_name = key_name - self.key_value = key_value - self.property_name = property_name - def locked_get(self): - """Retrieve Credential from datastore. + def __init__(self, model_class, key_name, key_value, property_name): + """Constructor for Storage. - Returns: - oauth2client.Credentials - """ - credential = None + Args: + model: db.Model, model class + key_name: string, key name for the entity that has the credentials + key_value: string, key value for the entity that has the + credentials + property_name: string, name of the property that is an + CredentialsProperty + """ + self.model_class = model_class + self.key_name = key_name + self.key_value = key_value + self.property_name = property_name - query = {self.key_name: self.key_value} - entities = self.model_class.objects.filter(**query) - if len(entities) > 0: - credential = getattr(entities[0], self.property_name) - if credential and hasattr(credential, 'set_store'): - credential.set_store(self) - return credential + def locked_get(self): + """Retrieve Credential from datastore. - def locked_put(self, credentials, overwrite=False): - """Write a Credentials to the datastore. + Returns: + oauth2client.Credentials + """ + credential = None - Args: - credentials: Credentials, the credentials to store. - overwrite: Boolean, indicates whether you would like these credentials to - overwrite any existing stored credentials. - """ - args = {self.key_name: self.key_value} + query = {self.key_name: self.key_value} + entities = self.model_class.objects.filter(**query) + if len(entities) > 0: + credential = getattr(entities[0], self.property_name) + if credential and hasattr(credential, 'set_store'): + credential.set_store(self) + return credential - if overwrite: - entity, unused_is_new = self.model_class.objects.get_or_create(**args) - else: - entity = self.model_class(**args) + def locked_put(self, credentials, overwrite=False): + """Write a Credentials to the datastore. - setattr(entity, self.property_name, credentials) - entity.save() + Args: + credentials: Credentials, the credentials to store. + overwrite: Boolean, indicates whether you would like these + credentials to overwrite any existing stored + credentials. + """ + args = {self.key_name: self.key_value} - def locked_delete(self): - """Delete Credentials from the datastore.""" + if overwrite: + (entity, + unused_is_new) = self.model_class.objects.get_or_create(**args) + else: + entity = self.model_class(**args) - query = {self.key_name: self.key_value} - entities = self.model_class.objects.filter(**query).delete() + setattr(entity, self.property_name, credentials) + entity.save() + + def locked_delete(self): + """Delete Credentials from the datastore.""" + + query = {self.key_name: self.key_value} + entities = self.model_class.objects.filter(**query).delete() diff --git a/oauth2client/file.py b/oauth2client/file.py index 9d0ae7fa..d0dd174f 100644 --- a/oauth2client/file.py +++ b/oauth2client/file.py @@ -18,8 +18,6 @@ Utilities for making it easier to work with OAuth 2.0 credentials. """ -__author__ = 'jcgregorio@google.com (Joe Gregorio)' - import os import threading @@ -27,96 +25,98 @@ from oauth2client.client import Credentials from oauth2client.client import Storage as BaseStorage +__author__ = 'jcgregorio@google.com (Joe Gregorio)' + + class CredentialsFileSymbolicLinkError(Exception): - """Credentials files must not be symbolic links.""" + """Credentials files must not be symbolic links.""" class Storage(BaseStorage): - """Store and retrieve a single credential to and from a file.""" + """Store and retrieve a single credential to and from a file.""" - def __init__(self, filename): - self._filename = filename - self._lock = threading.Lock() + def __init__(self, filename): + self._filename = filename + self._lock = threading.Lock() - def _validate_file(self): - if os.path.islink(self._filename): - raise CredentialsFileSymbolicLinkError( - 'File: %s is a symbolic link.' % self._filename) + def _validate_file(self): + if os.path.islink(self._filename): + raise CredentialsFileSymbolicLinkError( + 'File: %s is a symbolic link.' % self._filename) - def acquire_lock(self): - """Acquires any lock necessary to access this Storage. + def acquire_lock(self): + """Acquires any lock necessary to access this Storage. - This lock is not reentrant.""" - self._lock.acquire() + This lock is not reentrant. + """ + self._lock.acquire() - def release_lock(self): - """Release the Storage lock. + def release_lock(self): + """Release the Storage lock. - Trying to release a lock that isn't held will result in a - RuntimeError. - """ - self._lock.release() + Trying to release a lock that isn't held will result in a + RuntimeError. + """ + self._lock.release() - def locked_get(self): - """Retrieve Credential from file. + def locked_get(self): + """Retrieve Credential from file. - Returns: - oauth2client.client.Credentials + Returns: + oauth2client.client.Credentials - Raises: - CredentialsFileSymbolicLinkError if the file is a symbolic link. - """ - credentials = None - self._validate_file() - try: - f = open(self._filename, 'rb') - content = f.read() - f.close() - except IOError: - return credentials + Raises: + CredentialsFileSymbolicLinkError if the file is a symbolic link. + """ + credentials = None + self._validate_file() + try: + f = open(self._filename, 'rb') + content = f.read() + f.close() + except IOError: + return credentials - try: - credentials = Credentials.new_from_json(content) - credentials.set_store(self) - except ValueError: - pass + try: + credentials = Credentials.new_from_json(content) + credentials.set_store(self) + except ValueError: + pass - return credentials + return credentials - def _create_file_if_needed(self): - """Create an empty file if necessary. + def _create_file_if_needed(self): + """Create an empty file if necessary. - This method will not initialize the file. Instead it implements a - simple version of "touch" to ensure the file has been created. - """ - if not os.path.exists(self._filename): - old_umask = os.umask(0o177) - try: - open(self._filename, 'a+b').close() - finally: - os.umask(old_umask) + This method will not initialize the file. Instead it implements a + simple version of "touch" to ensure the file has been created. + """ + if not os.path.exists(self._filename): + old_umask = os.umask(0o177) + try: + open(self._filename, 'a+b').close() + finally: + os.umask(old_umask) - def locked_put(self, credentials): - """Write Credentials to file. + def locked_put(self, credentials): + """Write Credentials to file. - Args: - credentials: Credentials, the credentials to store. + Args: + credentials: Credentials, the credentials to store. - Raises: - CredentialsFileSymbolicLinkError if the file is a symbolic link. - """ + Raises: + CredentialsFileSymbolicLinkError if the file is a symbolic link. + """ + self._create_file_if_needed() + self._validate_file() + f = open(self._filename, 'w') + f.write(credentials.to_json()) + f.close() - self._create_file_if_needed() - self._validate_file() - f = open(self._filename, 'w') - f.write(credentials.to_json()) - f.close() + def locked_delete(self): + """Delete Credentials file. - def locked_delete(self): - """Delete Credentials file. - - Args: - credentials: Credentials, the credentials to store. - """ - - os.unlink(self._filename) + Args: + credentials: Credentials, the credentials to store. + """ + os.unlink(self._filename) diff --git a/oauth2client/flask_util.py b/oauth2client/flask_util.py new file mode 100644 index 00000000..2c455acf --- /dev/null +++ b/oauth2client/flask_util.py @@ -0,0 +1,548 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for the Flask web framework + +Provides a Flask extension that makes using OAuth2 web server flow easier. +The extension includes views that handle the entire auth flow and a +``@required`` decorator to automatically ensure that user credentials are +available. + + +Configuration +============= + +To configure, you'll need a set of OAuth2 client ID from the +`Google Developer's Console `__. + +.. code-block:: python + + from oauth2client.flask_util import UserOAuth2 + + app = Flask(__name__) + + app.config['SECRET_KEY'] = 'your-secret-key' + + app.config['OAUTH2_CLIENT_SECRETS_JSON'] = 'client_secrets.json' + # or, specify the client id and secret separately + app.config['OAUTH2_CLIENT_ID'] = 'your-client-id' + app.config['OAUTH2_CLIENT_SECRET'] = 'your-client-secret' + + oauth2 = UserOAuth2(app) + + +Usage +===== + +Once configured, you can use the :meth:`UserOAuth2.required` decorator to +ensure that credentials are available within a view. + +.. code-block:: python + :emphasize-lines: 3,7,10 + + # Note that app.route should be the outermost decorator. + @app.route('/needs_credentials') + @oauth2.required + def example(): + # http is authorized with the user's credentials and can be used + # to make http calls. + http = oauth2.http() + + # Or, you can access the credentials directly + credentials = oauth2.credentials + +If you want credentials to be optional for a view, you can leave the decorator +off and use :meth:`UserOAuth2.has_credentials` to check. + +.. code-block:: python + :emphasize-lines: 3 + + @app.route('/optional') + def optional(): + if oauth2.has_credentials(): + return 'Credentials found!' + else: + return 'No credentials!' + + +When credentials are available, you can use :attr:`UserOAuth2.email` and +:attr:`UserOAuth2.user_id` to access information from the `ID Token +`__, if +available. + +.. code-block:: python + :emphasize-lines: 4 + + @app.route('/info') + @oauth2.required + def info(): + return "Hello, {} ({})".format(oauth2.email, oauth2.user_id) + + +URLs & Trigging Authorization +============================= + +The extension will add two new routes to your application: + + * ``"oauth2.authorize"`` -> ``/oauth2authorize`` + * ``"oauth2.callback"`` -> ``/oauth2callback`` + +When configuring your OAuth2 credentials on the Google Developer's Console, be +sure to add ``http[s]://[your-app-url]/oauth2callback`` as an authorized +callback url. + +Typically you don't not need to use these routes directly, just be sure to +decorate any views that require credentials with ``@oauth2.required``. If +needed, you can trigger authorization at any time by redirecting the user +to the URL returned by :meth:`UserOAuth2.authorize_url`. + +.. code-block:: python + :emphasize-lines: 3 + + @app.route('/login') + def login(): + return oauth2.authorize_url("/") + + +Incremental Auth +================ + +This extension also supports `Incremental Auth `__. To enable it, +configure the extension with ``include_granted_scopes``. + +.. code-block:: python + + oauth2 = UserOAuth2(app, include_granted_scopes=True) + +Then specify any additional scopes needed on the decorator, for example: + +.. code-block:: python + :emphasize-lines: 2,7 + + @app.route('/drive') + @oauth2.required(scopes=["https://www.googleapis.com/auth/drive"]) + def requires_drive(): + ... + + @app.route('/calendar') + @oauth2.required(scopes=["https://www.googleapis.com/auth/calendar"]) + def requires_calendar(): + ... + +The decorator will ensure that the the user has authorized all specified scopes +before allowing them to access the view, and will also ensure that credentials +do not lose any previously authorized scopes. + + +Storage +======= + +By default, the extension uses a Flask session-based storage solution. This +means that credentials are only available for the duration of a session. It +also means that with Flask's default configuration, the credentials will be +visible in the session cookie. It's highly recommended to use database-backed +session and to use https whenever handling user credentials. + +If you need the credentials to be available longer than a user session or +available outside of a request context, you will need to implement your own +:class:`oauth2client.Storage`. +""" + +import hashlib +import json +import os +from functools import wraps + +import six.moves.http_client as httplib +import httplib2 + +try: + from flask import Blueprint + from flask import _app_ctx_stack + from flask import current_app + from flask import redirect + from flask import request + from flask import session + from flask import url_for +except ImportError: # pragma: NO COVER + raise ImportError('The flask utilities require flask 0.9 or newer.') + +from oauth2client.client import FlowExchangeError +from oauth2client.client import OAuth2Credentials +from oauth2client.client import OAuth2WebServerFlow +from oauth2client.client import Storage +from oauth2client import clientsecrets +from oauth2client import util + + +__author__ = 'jonwayne@google.com (Jon Wayne Parrott)' + +DEFAULT_SCOPES = ('email',) + + +class UserOAuth2(object): + """Flask extension for making OAuth 2.0 easier. + + Configuration values: + + * ``GOOGLE_OAUTH2_CLIENT_SECRETS_JSON`` path to a client secrets json + file, obtained from the credentials screen in the Google Developers + console. + * ``GOOGLE_OAUTH2_CLIENT_ID`` the oauth2 credentials' client ID. This + is only needed if ``OAUTH2_CLIENT_SECRETS_JSON`` is not specified. + * ``GOOGLE_OAUTH2_CLIENT_SECRET`` the oauth2 credentials' client + secret. This is only needed if ``OAUTH2_CLIENT_SECRETS_JSON`` is not + specified. + + If app is specified, all arguments will be passed along to init_app. + + If no app is specified, then you should call init_app in your application + factory to finish initialization. + """ + + def __init__(self, app=None, *args, **kwargs): + self.app = app + if app is not None: + self.init_app(app, *args, **kwargs) + + def init_app(self, app, scopes=None, client_secrets_file=None, + client_id=None, client_secret=None, authorize_callback=None, + storage=None, **kwargs): + """Initialize this extension for the given app. + + Arguments: + app: A Flask application. + scopes: Optional list of scopes to authorize. + client_secrets_file: Path to a file containing client secrets. You + can also specify the OAUTH2_CLIENT_SECRETS_JSON config value. + client_id: If not specifying a client secrets file, specify the + OAuth2 client id. You can also specify the + GOOGLE_OAUTH2_CLIENT_ID config value. You must also provide a + client secret. + client_secret: The OAuth2 client secret. You can also specify the + GOOGLE_OAUTH2_CLIENT_SECRET config value. + authorize_callback: A function that is executed after successful + user authorization. + storage: A oauth2client.client.Storage subclass for storing the + credentials. By default, this is a Flask session based storage. + kwargs: Any additional args are passed along to the Flow + constructor. + """ + self.app = app + self.authorize_callback = authorize_callback + self.flow_kwargs = kwargs + + if storage is None: + storage = FlaskSessionStorage() + self.storage = storage + + if scopes is None: + scopes = app.config.get('GOOGLE_OAUTH2_SCOPES', DEFAULT_SCOPES) + self.scopes = scopes + + self._load_config(client_secrets_file, client_id, client_secret) + + app.register_blueprint(self._create_blueprint()) + + def _load_config(self, client_secrets_file, client_id, client_secret): + """Loads oauth2 configuration in order of priority. + + Priority: + 1. Config passed to the constructor or init_app. + 2. Config passed via the GOOGLE_OAUTH2_CLIENT_SECRETS_FILE app + config. + 3. Config passed via the GOOGLE_OAUTH2_CLIENT_ID and + GOOGLE_OAUTH2_CLIENT_SECRET app config. + + Raises: + ValueError if no config could be found. + """ + if client_id and client_secret: + self.client_id, self.client_secret = client_id, client_secret + return + + if client_secrets_file: + self._load_client_secrets(client_secrets_file) + return + + if 'GOOGLE_OAUTH2_CLIENT_SECRETS_FILE' in self.app.config: + self._load_client_secrets( + self.app.config['GOOGLE_OAUTH2_CLIENT_SECRETS_FILE']) + return + + try: + self.client_id, self.client_secret = ( + self.app.config['GOOGLE_OAUTH2_CLIENT_ID'], + self.app.config['GOOGLE_OAUTH2_CLIENT_SECRET']) + except KeyError: + raise ValueError( + 'OAuth2 configuration could not be found. Either specify the ' + 'client_secrets_file or client_id and client_secret or set the' + 'app configuration variables ' + 'GOOGLE_OAUTH2_CLIENT_SECRETS_FILE or ' + 'GOOGLE_OAUTH2_CLIENT_ID and GOOGLE_OAUTH2_CLIENT_SECRET.') + + def _load_client_secrets(self, filename): + """Loads client secrets from the given filename.""" + client_type, client_info = clientsecrets.loadfile(filename) + if client_type != clientsecrets.TYPE_WEB: + raise ValueError( + 'The flow specified in %s is not supported.' % client_type) + + self.client_id = client_info['client_id'] + self.client_secret = client_info['client_secret'] + + def _make_flow(self, return_url=None, **kwargs): + """Creates a Web Server Flow""" + # Generate a CSRF token to prevent malicious requests. + csrf_token = hashlib.sha256(os.urandom(1024)).hexdigest() + + session['google_oauth2_csrf_token'] = csrf_token + + state = json.dumps({ + 'csrf_token': csrf_token, + 'return_url': return_url + }) + + kw = self.flow_kwargs.copy() + kw.update(kwargs) + + extra_scopes = util.scopes_to_string(kw.pop('scopes', '')) + scopes = ' '.join([util.scopes_to_string(self.scopes), extra_scopes]) + + return OAuth2WebServerFlow( + client_id=self.client_id, + client_secret=self.client_secret, + scope=scopes, + state=state, + redirect_uri=url_for('oauth2.callback', _external=True), + **kw) + + def _create_blueprint(self): + bp = Blueprint('oauth2', __name__) + bp.add_url_rule('/oauth2authorize', 'authorize', self.authorize_view) + bp.add_url_rule('/oauth2callback', 'callback', self.callback_view) + + return bp + + def authorize_view(self): + """Flask view that starts the authorization flow. + + Starts flow by redirecting the user to the OAuth2 provider. + """ + args = request.args.to_dict() + + # Scopes will be passed as mutliple args, and to_dict() will only + # return one. So, we use getlist() to get all of the scopes. + args['scopes'] = request.args.getlist('scopes') + + return_url = args.pop('return_url', None) + if return_url is None: + return_url = request.referrer or '/' + + flow = self._make_flow(return_url=return_url, **args) + auth_url = flow.step1_get_authorize_url() + + return redirect(auth_url) + + def callback_view(self): + """Flask view that handles the user's return from OAuth2 provider. + + On return, exchanges the authorization code for credentials and stores + the credentials. + """ + if 'error' in request.args: + reason = request.args.get( + 'error_description', request.args.get('error', '')) + return 'Authorization failed: %s' % reason, httplib.BAD_REQUEST + + try: + encoded_state = request.args['state'] + server_csrf = session['google_oauth2_csrf_token'] + code = request.args['code'] + except KeyError: + return 'Invalid request', httplib.BAD_REQUEST + + try: + state = json.loads(encoded_state) + client_csrf = state['csrf_token'] + return_url = state['return_url'] + except (ValueError, KeyError): + return 'Invalid request state', httplib.BAD_REQUEST + + if client_csrf != server_csrf: + return 'Invalid request state', httplib.BAD_REQUEST + + flow = self._make_flow() + + # Exchange the auth code for credentials. + try: + credentials = flow.step2_exchange(code) + except FlowExchangeError as exchange_error: + current_app.logger.exception(exchange_error) + content = 'An error occurred: %s' % (exchange_error,) + return content, httplib.BAD_REQUEST + + # Save the credentials to the storage. + self.storage.put(credentials) + + if self.authorize_callback: + self.authorize_callback(credentials) + + return redirect(return_url) + + @property + def credentials(self): + """The credentials for the current user or None if unavailable.""" + ctx = _app_ctx_stack.top + + if not hasattr(ctx, 'google_oauth2_credentials'): + ctx.google_oauth2_credentials = self.storage.get() + + return ctx.google_oauth2_credentials + + def has_credentials(self): + """Returns True if there are valid credentials for the current user.""" + return self.credentials and not self.credentials.invalid + + @property + def email(self): + """Returns the user's email address or None if there are no credentials. + + The email address is provided by the current credentials' id_token. + This should not be used as unique identifier as the user can change + their email. If you need a unique identifier, use user_id. + """ + if not self.credentials: + return None + try: + return self.credentials.id_token['email'] + except KeyError: + current_app.logger.error( + 'Invalid id_token %s', self.credentials.id_token) + + @property + def user_id(self): + """Returns the a unique identifier for the user + + Returns None if there are no credentials. + + The id is provided by the current credentials' id_token. + """ + if not self.credentials: + return None + try: + return self.credentials.id_token['sub'] + except KeyError: + current_app.logger.error( + 'Invalid id_token %s', self.credentials.id_token) + + def authorize_url(self, return_url, **kwargs): + """Creates a URL that can be used to start the authorization flow. + + When the user is directed to the URL, the authorization flow will + begin. Once complete, the user will be redirected to the specified + return URL. + + Any kwargs are passed into the flow constructor. + """ + return url_for('oauth2.authorize', return_url=return_url, **kwargs) + + def required(self, decorated_function=None, scopes=None, + **decorator_kwargs): + """Decorator to require OAuth2 credentials for a view. + + If credentials are not available for the current user, then they will + be redirected to the authorization flow. Once complete, the user will + be redirected back to the original page. + """ + + def curry_wrapper(wrapped_function): + @wraps(wrapped_function) + def required_wrapper(*args, **kwargs): + + return_url = decorator_kwargs.pop('return_url', request.url) + + # No credentials, redirect for new authorization. + if not self.has_credentials(): + auth_url = self.authorize_url( + return_url, + scopes=scopes, + **decorator_kwargs) + return redirect(auth_url) + + # Existing credentials but mismatching scopes, redirect for + # incremental authorization. + if scopes and not self.credentials.has_scopes(scopes): + auth_url = self.authorize_url( + return_url, + scopes=list(self.credentials.scopes) + scopes, + **decorator_kwargs) + return redirect(auth_url) + + return wrapped_function(*args, **kwargs) + + return required_wrapper + + if decorated_function: + return curry_wrapper(decorated_function) + else: + return curry_wrapper + + def http(self, *args, **kwargs): + """Returns an authorized http instance. + + Can only be called if there are valid credentials for the user, such + as inside of a view that is decorated with @required. + + Args: + *args: Positional arguments passed to httplib2.Http constructor. + **kwargs: Positional arguments passed to httplib2.Http constructor. + + Raises: + ValueError if no credentials are available. + """ + if not self.credentials: + raise ValueError('No credentials available.') + return self.credentials.authorize(httplib2.Http(*args, **kwargs)) + + +class FlaskSessionStorage(Storage): + """Storage implementation that uses Flask sessions. + + Note that flask's default sessions are signed but not encrypted. Users + can see their own credentials and non-https connections can intercept user + credentials. We strongly recommend using a server-side session + implementation. + """ + + def locked_get(self): + serialized = session.get('google_oauth2_credentials') + + if serialized is None: + return None + + credentials = OAuth2Credentials.from_json(serialized) + credentials.set_store(self) + + return credentials + + def locked_put(self, credentials): + session['google_oauth2_credentials'] = credentials.to_json() + + def locked_delete(self): + if 'google_oauth2_credentials' in session: + del session['google_oauth2_credentials'] diff --git a/oauth2client/gce.py b/oauth2client/gce.py index fc3bd77b..4b0b7efa 100644 --- a/oauth2client/gce.py +++ b/oauth2client/gce.py @@ -17,16 +17,18 @@ Utilities for making it easier to use OAuth 2.0 on Google Compute Engine. """ -__author__ = 'jcgregorio@google.com (Joe Gregorio)' - import json import logging from six.moves import urllib +from oauth2client._helpers import _from_bytes from oauth2client import util from oauth2client.client import AccessTokenRefreshError from oauth2client.client import AssertionCredentials + +__author__ = 'jcgregorio@google.com (Joe Gregorio)' + logger = logging.getLogger(__name__) # URI Template for the endpoint that returns access_tokens. @@ -35,71 +37,74 @@ META = ('http://metadata.google.internal/0.1/meta-data/service-accounts/' class AppAssertionCredentials(AssertionCredentials): - """Credentials object for Compute Engine Assertion Grants + """Credentials object for Compute Engine Assertion Grants - This object will allow a Compute Engine instance to identify itself to - Google and other OAuth 2.0 servers that can verify assertions. It can be used - for the purpose of accessing data stored under an account assigned to the - Compute Engine instance itself. + This object will allow a Compute Engine instance to identify itself to + Google and other OAuth 2.0 servers that can verify assertions. It can be + used for the purpose of accessing data stored under an account assigned to + the Compute Engine instance itself. - This credential does not require a flow to instantiate because it represents - a two legged flow, and therefore has all of the required information to - generate and refresh its own access tokens. - """ - - @util.positional(2) - def __init__(self, scope, **kwargs): - """Constructor for AppAssertionCredentials - - Args: - scope: string or iterable of strings, scope(s) of the credentials being - requested. + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. """ - self.scope = util.scopes_to_string(scope) - self.kwargs = kwargs - # Assertion type is no longer used, but still in the parent class signature. - super(AppAssertionCredentials, self).__init__(None) + @util.positional(2) + def __init__(self, scope, **kwargs): + """Constructor for AppAssertionCredentials - @classmethod - def from_json(cls, json_data): - data = json.loads(json_data) - return AppAssertionCredentials(data['scope']) + Args: + scope: string or iterable of strings, scope(s) of the credentials + being requested. + """ + self.scope = util.scopes_to_string(scope) + self.kwargs = kwargs - def _refresh(self, http_request): - """Refreshes the access_token. + # Assertion type is no longer used, but still in the + # parent class signature. + super(AppAssertionCredentials, self).__init__(None) - Skip all the storage hoops and just refresh using the API. + @classmethod + def from_json(cls, json_data): + data = json.loads(_from_bytes(json_data)) + return AppAssertionCredentials(data['scope']) - Args: - http_request: callable, a callable that matches the method signature of - httplib2.Http.request, used to make the refresh request. + def _refresh(self, http_request): + """Refreshes the access_token. - Raises: - AccessTokenRefreshError: When the refresh fails. - """ - query = '?scope=%s' % urllib.parse.quote(self.scope, '') - uri = META.replace('{?scope}', query) - response, content = http_request(uri) - if response.status == 200: - try: - d = json.loads(content) - except Exception as e: - raise AccessTokenRefreshError(str(e)) - self.access_token = d['accessToken'] - else: - if response.status == 404: - content += (' This can occur if a VM was created' - ' with no service account or scopes.') - raise AccessTokenRefreshError(content) + Skip all the storage hoops and just refresh using the API. - @property - def serialization_data(self): - raise NotImplementedError( - 'Cannot serialize credentials for GCE service accounts.') + Args: + http_request: callable, a callable that matches the method + signature of httplib2.Http.request, used to make + the refresh request. - def create_scoped_required(self): - return not self.scope + Raises: + AccessTokenRefreshError: When the refresh fails. + """ + query = '?scope=%s' % urllib.parse.quote(self.scope, '') + uri = META.replace('{?scope}', query) + response, content = http_request(uri) + content = _from_bytes(content) + if response.status == 200: + try: + d = json.loads(content) + except Exception as e: + raise AccessTokenRefreshError(str(e)) + self.access_token = d['accessToken'] + else: + if response.status == 404: + content += (' This can occur if a VM was created' + ' with no service account or scopes.') + raise AccessTokenRefreshError(content) - def create_scoped(self, scopes): - return AppAssertionCredentials(scopes, **self.kwargs) + @property + def serialization_data(self): + raise NotImplementedError( + 'Cannot serialize credentials for GCE service accounts.') + + def create_scoped_required(self): + return not self.scope + + def create_scoped(self, scopes): + return AppAssertionCredentials(scopes, **self.kwargs) diff --git a/oauth2client/keyring_storage.py b/oauth2client/keyring_storage.py index cda1d9a3..0a4c2857 100644 --- a/oauth2client/keyring_storage.py +++ b/oauth2client/keyring_storage.py @@ -17,8 +17,6 @@ A Storage for Credentials that uses the keyring module. """ -__author__ = 'jcgregorio@google.com (Joe Gregorio)' - import threading import keyring @@ -27,84 +25,90 @@ from oauth2client.client import Credentials from oauth2client.client import Storage as BaseStorage +__author__ = 'jcgregorio@google.com (Joe Gregorio)' + + class Storage(BaseStorage): - """Store and retrieve a single credential to and from the keyring. + """Store and retrieve a single credential to and from the keyring. - To use this module you must have the keyring module installed. See - . This is an optional module and is not - installed with oauth2client by default because it does not work on all the - platforms that oauth2client supports, such as Google App Engine. + To use this module you must have the keyring module installed. See + . This is an optional module and is + not installed with oauth2client by default because it does not work on all + the platforms that oauth2client supports, such as Google App Engine. - The keyring module is a cross-platform - library for access the keyring capabilities of the local system. The user will - be prompted for their keyring password when this module is used, and the - manner in which the user is prompted will vary per platform. + The keyring module is a + cross-platform library for access the keyring capabilities of the local + system. The user will be prompted for their keyring password when this + module is used, and the manner in which the user is prompted will vary per + platform. - Usage: - from oauth2client.keyring_storage import Storage + Usage:: - s = Storage('name_of_application', 'user1') - credentials = s.get() + from oauth2client.keyring_storage import Storage - """ + s = Storage('name_of_application', 'user1') + credentials = s.get() - def __init__(self, service_name, user_name): - """Constructor. - - Args: - service_name: string, The name of the service under which the credentials - are stored. - user_name: string, The name of the user to store credentials for. """ - self._service_name = service_name - self._user_name = user_name - self._lock = threading.Lock() - def acquire_lock(self): - """Acquires any lock necessary to access this Storage. + def __init__(self, service_name, user_name): + """Constructor. - This lock is not reentrant.""" - self._lock.acquire() + Args: + service_name: string, The name of the service under which the + credentials are stored. + user_name: string, The name of the user to store credentials for. + """ + self._service_name = service_name + self._user_name = user_name + self._lock = threading.Lock() - def release_lock(self): - """Release the Storage lock. + def acquire_lock(self): + """Acquires any lock necessary to access this Storage. - Trying to release a lock that isn't held will result in a - RuntimeError. - """ - self._lock.release() + This lock is not reentrant. + """ + self._lock.acquire() - def locked_get(self): - """Retrieve Credential from file. + def release_lock(self): + """Release the Storage lock. - Returns: - oauth2client.client.Credentials - """ - credentials = None - content = keyring.get_password(self._service_name, self._user_name) + Trying to release a lock that isn't held will result in a + RuntimeError. + """ + self._lock.release() - if content is not None: - try: - credentials = Credentials.new_from_json(content) - credentials.set_store(self) - except ValueError: - pass + def locked_get(self): + """Retrieve Credential from file. - return credentials + Returns: + oauth2client.client.Credentials + """ + credentials = None + content = keyring.get_password(self._service_name, self._user_name) - def locked_put(self, credentials): - """Write Credentials to file. + if content is not None: + try: + credentials = Credentials.new_from_json(content) + credentials.set_store(self) + except ValueError: + pass - Args: - credentials: Credentials, the credentials to store. - """ - keyring.set_password(self._service_name, self._user_name, - credentials.to_json()) + return credentials - def locked_delete(self): - """Delete Credentials file. + def locked_put(self, credentials): + """Write Credentials to file. - Args: - credentials: Credentials, the credentials to store. - """ - keyring.set_password(self._service_name, self._user_name, '') + Args: + credentials: Credentials, the credentials to store. + """ + keyring.set_password(self._service_name, self._user_name, + credentials.to_json()) + + def locked_delete(self): + """Delete Credentials file. + + Args: + credentials: Credentials, the credentials to store. + """ + keyring.set_password(self._service_name, self._user_name, '') diff --git a/oauth2client/locked_file.py b/oauth2client/locked_file.py index af92398e..1028a7e0 100644 --- a/oauth2client/locked_file.py +++ b/oauth2client/locked_file.py @@ -32,8 +32,6 @@ Usage:: from __future__ import print_function -__author__ = 'cache@google.com (David T McWherter)' - import errno import logging import os @@ -41,338 +39,349 @@ import time from oauth2client import util + +__author__ = 'cache@google.com (David T McWherter)' + logger = logging.getLogger(__name__) class CredentialsFileSymbolicLinkError(Exception): - """Credentials files must not be symbolic links.""" + """Credentials files must not be symbolic links.""" class AlreadyLockedException(Exception): - """Trying to lock a file that has already been locked by the LockedFile.""" - pass + """Trying to lock a file that has already been locked by the LockedFile.""" + pass def validate_file(filename): - if os.path.islink(filename): - raise CredentialsFileSymbolicLinkError( - 'File: %s is a symbolic link.' % filename) + if os.path.islink(filename): + raise CredentialsFileSymbolicLinkError( + 'File: %s is a symbolic link.' % filename) + class _Opener(object): - """Base class for different locking primitives.""" + """Base class for different locking primitives.""" - def __init__(self, filename, mode, fallback_mode): - """Create an Opener. + def __init__(self, filename, mode, fallback_mode): + """Create an Opener. - Args: - filename: string, The pathname of the file. - mode: string, The preferred mode to access the file with. - fallback_mode: string, The mode to use if locking fails. - """ - self._locked = False - self._filename = filename - self._mode = mode - self._fallback_mode = fallback_mode - self._fh = None - self._lock_fd = None + Args: + filename: string, The pathname of the file. + mode: string, The preferred mode to access the file with. + fallback_mode: string, The mode to use if locking fails. + """ + self._locked = False + self._filename = filename + self._mode = mode + self._fallback_mode = fallback_mode + self._fh = None + self._lock_fd = None - def is_locked(self): - """Was the file locked.""" - return self._locked + def is_locked(self): + """Was the file locked.""" + return self._locked - def file_handle(self): - """The file handle to the file. Valid only after opened.""" - return self._fh + def file_handle(self): + """The file handle to the file. Valid only after opened.""" + return self._fh - def filename(self): - """The filename that is being locked.""" - return self._filename + def filename(self): + """The filename that is being locked.""" + return self._filename - def open_and_lock(self, timeout, delay): - """Open the file and lock it. + def open_and_lock(self, timeout, delay): + """Open the file and lock it. - Args: - timeout: float, How long to try to lock for. - delay: float, How long to wait between retries. - """ - pass + Args: + timeout: float, How long to try to lock for. + delay: float, How long to wait between retries. + """ + pass - def unlock_and_close(self): - """Unlock and close the file.""" - pass + def unlock_and_close(self): + """Unlock and close the file.""" + pass class _PosixOpener(_Opener): - """Lock files using Posix advisory lock files.""" - - def open_and_lock(self, timeout, delay): - """Open the file and lock it. - - Tries to create a .lock file next to the file we're trying to open. - - Args: - timeout: float, How long to try to lock for. - delay: float, How long to wait between retries. - - Raises: - AlreadyLockedException: if the lock is already acquired. - IOError: if the open fails. - CredentialsFileSymbolicLinkError if the file is a symbolic link. - """ - if self._locked: - raise AlreadyLockedException('File %s is already locked' % - self._filename) - self._locked = False - - validate_file(self._filename) - try: - self._fh = open(self._filename, self._mode) - except IOError as e: - # If we can't access with _mode, try _fallback_mode and don't lock. - if e.errno == errno.EACCES: - self._fh = open(self._filename, self._fallback_mode) - return - - lock_filename = self._posix_lockfile(self._filename) - start_time = time.time() - while True: - try: - self._lock_fd = os.open(lock_filename, - os.O_CREAT|os.O_EXCL|os.O_RDWR) - self._locked = True - break - - except OSError as e: - if e.errno != errno.EEXIST: - raise - if (time.time() - start_time) >= timeout: - logger.warn('Could not acquire lock %s in %s seconds', - lock_filename, timeout) - # Close the file and open in fallback_mode. - if self._fh: - self._fh.close() - self._fh = open(self._filename, self._fallback_mode) - return - time.sleep(delay) - - def unlock_and_close(self): - """Unlock a file by removing the .lock file, and close the handle.""" - if self._locked: - lock_filename = self._posix_lockfile(self._filename) - os.close(self._lock_fd) - os.unlink(lock_filename) - self._locked = False - self._lock_fd = None - if self._fh: - self._fh.close() - - def _posix_lockfile(self, filename): - """The name of the lock file to use for posix locking.""" - return '%s.lock' % filename - - -try: - import fcntl - - class _FcntlOpener(_Opener): - """Open, lock, and unlock a file using fcntl.lockf.""" + """Lock files using Posix advisory lock files.""" def open_and_lock(self, timeout, delay): - """Open the file and lock it. + """Open the file and lock it. - Args: - timeout: float, How long to try to lock for. - delay: float, How long to wait between retries + Tries to create a .lock file next to the file we're trying to open. - Raises: - AlreadyLockedException: if the lock is already acquired. - IOError: if the open fails. - CredentialsFileSymbolicLinkError if the file is a symbolic link. - """ - if self._locked: - raise AlreadyLockedException('File %s is already locked' % - self._filename) - start_time = time.time() + Args: + timeout: float, How long to try to lock for. + delay: float, How long to wait between retries. - validate_file(self._filename) - try: - self._fh = open(self._filename, self._mode) - except IOError as e: - # If we can't access with _mode, try _fallback_mode and don't lock. - if e.errno in (errno.EPERM, errno.EACCES): - self._fh = open(self._filename, self._fallback_mode) - return + Raises: + AlreadyLockedException: if the lock is already acquired. + IOError: if the open fails. + CredentialsFileSymbolicLinkError if the file is a symbolic link. + """ + if self._locked: + raise AlreadyLockedException('File %s is already locked' % + self._filename) + self._locked = False - # We opened in _mode, try to lock the file. - while True: + validate_file(self._filename) try: - fcntl.lockf(self._fh.fileno(), fcntl.LOCK_EX) - self._locked = True - return + self._fh = open(self._filename, self._mode) except IOError as e: - # If not retrying, then just pass on the error. - if timeout == 0: - raise - if e.errno != errno.EACCES: - raise - # We could not acquire the lock. Try again. - if (time.time() - start_time) >= timeout: - logger.warn('Could not lock %s in %s seconds', - self._filename, timeout) - if self._fh: - self._fh.close() - self._fh = open(self._filename, self._fallback_mode) - return - time.sleep(delay) + # If we can't access with _mode, try _fallback_mode and don't lock. + if e.errno == errno.EACCES: + self._fh = open(self._filename, self._fallback_mode) + return + + lock_filename = self._posix_lockfile(self._filename) + start_time = time.time() + while True: + try: + self._lock_fd = os.open(lock_filename, + os.O_CREAT | os.O_EXCL | os.O_RDWR) + self._locked = True + break + + except OSError as e: + if e.errno != errno.EEXIST: + raise + if (time.time() - start_time) >= timeout: + logger.warn('Could not acquire lock %s in %s seconds', + lock_filename, timeout) + # Close the file and open in fallback_mode. + if self._fh: + self._fh.close() + self._fh = open(self._filename, self._fallback_mode) + return + time.sleep(delay) def unlock_and_close(self): - """Close and unlock the file using the fcntl.lockf primitive.""" - if self._locked: - fcntl.lockf(self._fh.fileno(), fcntl.LOCK_UN) - self._locked = False - if self._fh: - self._fh.close() -except ImportError: - _FcntlOpener = None + """Unlock a file by removing the .lock file, and close the handle.""" + if self._locked: + lock_filename = self._posix_lockfile(self._filename) + os.close(self._lock_fd) + os.unlink(lock_filename) + self._locked = False + self._lock_fd = None + if self._fh: + self._fh.close() + + def _posix_lockfile(self, filename): + """The name of the lock file to use for posix locking.""" + return '%s.lock' % filename try: - import pywintypes - import win32con - import win32file + import fcntl - class _Win32Opener(_Opener): - """Open, lock, and unlock a file using windows primitives.""" + class _FcntlOpener(_Opener): + """Open, lock, and unlock a file using fcntl.lockf.""" - # Error #33: - # 'The process cannot access the file because another process' - FILE_IN_USE_ERROR = 33 + def open_and_lock(self, timeout, delay): + """Open the file and lock it. - # Error #158: - # 'The segment is already unlocked.' - FILE_ALREADY_UNLOCKED_ERROR = 158 + Args: + timeout: float, How long to try to lock for. + delay: float, How long to wait between retries - def open_and_lock(self, timeout, delay): - """Open the file and lock it. + Raises: + AlreadyLockedException: if the lock is already acquired. + IOError: if the open fails. + CredentialsFileSymbolicLinkError: if the file is a symbolic + link. + """ + if self._locked: + raise AlreadyLockedException('File %s is already locked' % + self._filename) + start_time = time.time() - Args: - timeout: float, How long to try to lock for. - delay: float, How long to wait between retries + validate_file(self._filename) + try: + self._fh = open(self._filename, self._mode) + except IOError as e: + # If we can't access with _mode, try _fallback_mode and + # don't lock. + if e.errno in (errno.EPERM, errno.EACCES): + self._fh = open(self._filename, self._fallback_mode) + return - Raises: - AlreadyLockedException: if the lock is already acquired. - IOError: if the open fails. - CredentialsFileSymbolicLinkError if the file is a symbolic link. - """ - if self._locked: - raise AlreadyLockedException('File %s is already locked' % - self._filename) - start_time = time.time() + # We opened in _mode, try to lock the file. + while True: + try: + fcntl.lockf(self._fh.fileno(), fcntl.LOCK_EX) + self._locked = True + return + except IOError as e: + # If not retrying, then just pass on the error. + if timeout == 0: + raise + if e.errno != errno.EACCES: + raise + # We could not acquire the lock. Try again. + if (time.time() - start_time) >= timeout: + logger.warn('Could not lock %s in %s seconds', + self._filename, timeout) + if self._fh: + self._fh.close() + self._fh = open(self._filename, self._fallback_mode) + return + time.sleep(delay) - validate_file(self._filename) - try: - self._fh = open(self._filename, self._mode) - except IOError as e: - # If we can't access with _mode, try _fallback_mode and don't lock. - if e.errno == errno.EACCES: - self._fh = open(self._filename, self._fallback_mode) - return - - # We opened in _mode, try to lock the file. - while True: - try: - hfile = win32file._get_osfhandle(self._fh.fileno()) - win32file.LockFileEx( - hfile, - (win32con.LOCKFILE_FAIL_IMMEDIATELY| - win32con.LOCKFILE_EXCLUSIVE_LOCK), 0, -0x10000, - pywintypes.OVERLAPPED()) - self._locked = True - return - except pywintypes.error as e: - if timeout == 0: - raise - - # If the error is not that the file is already in use, raise. - if e[0] != _Win32Opener.FILE_IN_USE_ERROR: - raise - - # We could not acquire the lock. Try again. - if (time.time() - start_time) >= timeout: - logger.warn('Could not lock %s in %s seconds' % ( - self._filename, timeout)) + def unlock_and_close(self): + """Close and unlock the file using the fcntl.lockf primitive.""" + if self._locked: + fcntl.lockf(self._fh.fileno(), fcntl.LOCK_UN) + self._locked = False if self._fh: - self._fh.close() - self._fh = open(self._filename, self._fallback_mode) - return - time.sleep(delay) - - def unlock_and_close(self): - """Close and unlock the file using the win32 primitive.""" - if self._locked: - try: - hfile = win32file._get_osfhandle(self._fh.fileno()) - win32file.UnlockFileEx(hfile, 0, -0x10000, pywintypes.OVERLAPPED()) - except pywintypes.error as e: - if e[0] != _Win32Opener.FILE_ALREADY_UNLOCKED_ERROR: - raise - self._locked = False - if self._fh: - self._fh.close() + self._fh.close() except ImportError: - _Win32Opener = None + _FcntlOpener = None + + +try: + import pywintypes + import win32con + import win32file + + class _Win32Opener(_Opener): + """Open, lock, and unlock a file using windows primitives.""" + + # Error #33: + # 'The process cannot access the file because another process' + FILE_IN_USE_ERROR = 33 + + # Error #158: + # 'The segment is already unlocked.' + FILE_ALREADY_UNLOCKED_ERROR = 158 + + def open_and_lock(self, timeout, delay): + """Open the file and lock it. + + Args: + timeout: float, How long to try to lock for. + delay: float, How long to wait between retries + + Raises: + AlreadyLockedException: if the lock is already acquired. + IOError: if the open fails. + CredentialsFileSymbolicLinkError: if the file is a symbolic + link. + """ + if self._locked: + raise AlreadyLockedException('File %s is already locked' % + self._filename) + start_time = time.time() + + validate_file(self._filename) + try: + self._fh = open(self._filename, self._mode) + except IOError as e: + # If we can't access with _mode, try _fallback_mode + # and don't lock. + if e.errno == errno.EACCES: + self._fh = open(self._filename, self._fallback_mode) + return + + # We opened in _mode, try to lock the file. + while True: + try: + hfile = win32file._get_osfhandle(self._fh.fileno()) + win32file.LockFileEx( + hfile, + (win32con.LOCKFILE_FAIL_IMMEDIATELY | + win32con.LOCKFILE_EXCLUSIVE_LOCK), 0, -0x10000, + pywintypes.OVERLAPPED()) + self._locked = True + return + except pywintypes.error as e: + if timeout == 0: + raise + + # If the error is not that the file is already + # in use, raise. + if e[0] != _Win32Opener.FILE_IN_USE_ERROR: + raise + + # We could not acquire the lock. Try again. + if (time.time() - start_time) >= timeout: + logger.warn('Could not lock %s in %s seconds' % ( + self._filename, timeout)) + if self._fh: + self._fh.close() + self._fh = open(self._filename, self._fallback_mode) + return + time.sleep(delay) + + def unlock_and_close(self): + """Close and unlock the file using the win32 primitive.""" + if self._locked: + try: + hfile = win32file._get_osfhandle(self._fh.fileno()) + win32file.UnlockFileEx(hfile, 0, -0x10000, + pywintypes.OVERLAPPED()) + except pywintypes.error as e: + if e[0] != _Win32Opener.FILE_ALREADY_UNLOCKED_ERROR: + raise + self._locked = False + if self._fh: + self._fh.close() +except ImportError: + _Win32Opener = None class LockedFile(object): - """Represent a file that has exclusive access.""" + """Represent a file that has exclusive access.""" - @util.positional(4) - def __init__(self, filename, mode, fallback_mode, use_native_locking=True): - """Construct a LockedFile. + @util.positional(4) + def __init__(self, filename, mode, fallback_mode, use_native_locking=True): + """Construct a LockedFile. - Args: - filename: string, The path of the file to open. - mode: string, The mode to try to open the file with. - fallback_mode: string, The mode to use if locking fails. - use_native_locking: bool, Whether or not fcntl/win32 locking is used. - """ - opener = None - if not opener and use_native_locking: - if _Win32Opener: - opener = _Win32Opener(filename, mode, fallback_mode) - if _FcntlOpener: - opener = _FcntlOpener(filename, mode, fallback_mode) + Args: + filename: string, The path of the file to open. + mode: string, The mode to try to open the file with. + fallback_mode: string, The mode to use if locking fails. + use_native_locking: bool, Whether or not fcntl/win32 locking is + used. + """ + opener = None + if not opener and use_native_locking: + if _Win32Opener: + opener = _Win32Opener(filename, mode, fallback_mode) + if _FcntlOpener: + opener = _FcntlOpener(filename, mode, fallback_mode) - if not opener: - opener = _PosixOpener(filename, mode, fallback_mode) + if not opener: + opener = _PosixOpener(filename, mode, fallback_mode) - self._opener = opener + self._opener = opener - def filename(self): - """Return the filename we were constructed with.""" - return self._opener._filename + def filename(self): + """Return the filename we were constructed with.""" + return self._opener._filename - def file_handle(self): - """Return the file_handle to the opened file.""" - return self._opener.file_handle() + def file_handle(self): + """Return the file_handle to the opened file.""" + return self._opener.file_handle() - def is_locked(self): - """Return whether we successfully locked the file.""" - return self._opener.is_locked() + def is_locked(self): + """Return whether we successfully locked the file.""" + return self._opener.is_locked() - def open_and_lock(self, timeout=0, delay=0.05): - """Open the file, trying to lock it. + def open_and_lock(self, timeout=0, delay=0.05): + """Open the file, trying to lock it. - Args: - timeout: float, The number of seconds to try to acquire the lock. - delay: float, The number of seconds to wait between retry attempts. + Args: + timeout: float, The number of seconds to try to acquire the lock. + delay: float, The number of seconds to wait between retry attempts. - Raises: - AlreadyLockedException: if the lock is already acquired. - IOError: if the open fails. - """ - self._opener.open_and_lock(timeout, delay) + Raises: + AlreadyLockedException: if the lock is already acquired. + IOError: if the open fails. + """ + self._opener.open_and_lock(timeout, delay) - def unlock_and_close(self): - """Unlock and close a file.""" - self._opener.unlock_and_close() + def unlock_and_close(self): + """Unlock and close a file.""" + self._opener.unlock_and_close() diff --git a/oauth2client/multistore_file.py b/oauth2client/multistore_file.py index f4ba4a70..f5a85930 100644 --- a/oauth2client/multistore_file.py +++ b/oauth2client/multistore_file.py @@ -26,26 +26,24 @@ The credential themselves are keyed off of: The format of the stored data is like so:: - { - 'file_version': 1, - 'data': [ - { - 'key': { - 'clientId': '', - 'userAgent': '', - 'scope': '' - }, - 'credential': { - # JSON serialized Credentials. - } - } - ] - } + { + 'file_version': 1, + 'data': [ + { + 'key': { + 'clientId': '', + 'userAgent': '', + 'scope': '' + }, + 'credential': { + # JSON serialized Credentials. + } + } + ] + } """ -__author__ = 'jbeda@google.com (Joe Beda)' - import errno import json import logging @@ -57,6 +55,9 @@ from oauth2client.client import Storage as BaseStorage from oauth2client import util from oauth2client.locked_file import LockedFile + +__author__ = 'jbeda@google.com (Joe Beda)' + logger = logging.getLogger(__name__) # A dict from 'filename'->_MultiStore instances @@ -65,411 +66,416 @@ _multistores_lock = threading.Lock() class Error(Exception): - """Base error for this module.""" + """Base error for this module.""" class NewerCredentialStoreError(Error): - """The credential store is a newer version than supported.""" + """The credential store is a newer version than supported.""" @util.positional(4) def get_credential_storage(filename, client_id, user_agent, scope, warn_on_readonly=True): - """Get a Storage instance for a credential. + """Get a Storage instance for a credential. - Args: - filename: The JSON file storing a set of credentials - client_id: The client_id for the credential - user_agent: The user agent for the credential - scope: string or iterable of strings, Scope(s) being requested - warn_on_readonly: if True, log a warning if the store is readonly + Args: + filename: The JSON file storing a set of credentials + client_id: The client_id for the credential + user_agent: The user agent for the credential + scope: string or iterable of strings, Scope(s) being requested + warn_on_readonly: if True, log a warning if the store is readonly - Returns: - An object derived from client.Storage for getting/setting the - credential. - """ - # Recreate the legacy key with these specific parameters - key = {'clientId': client_id, 'userAgent': user_agent, - 'scope': util.scopes_to_string(scope)} - return get_credential_storage_custom_key( + Returns: + An object derived from client.Storage for getting/setting the + credential. + """ + # Recreate the legacy key with these specific parameters + key = {'clientId': client_id, 'userAgent': user_agent, + 'scope': util.scopes_to_string(scope)} + return get_credential_storage_custom_key( filename, key, warn_on_readonly=warn_on_readonly) @util.positional(2) -def get_credential_storage_custom_string_key( - filename, key_string, warn_on_readonly=True): - """Get a Storage instance for a credential using a single string as a key. +def get_credential_storage_custom_string_key(filename, key_string, + warn_on_readonly=True): + """Get a Storage instance for a credential using a single string as a key. - Allows you to provide a string as a custom key that will be used for - credential storage and retrieval. + Allows you to provide a string as a custom key that will be used for + credential storage and retrieval. - Args: - filename: The JSON file storing a set of credentials - key_string: A string to use as the key for storing this credential. - warn_on_readonly: if True, log a warning if the store is readonly + Args: + filename: The JSON file storing a set of credentials + key_string: A string to use as the key for storing this credential. + warn_on_readonly: if True, log a warning if the store is readonly - Returns: - An object derived from client.Storage for getting/setting the - credential. - """ - # Create a key dictionary that can be used - key_dict = {'key': key_string} - return get_credential_storage_custom_key( + Returns: + An object derived from client.Storage for getting/setting the + credential. + """ + # Create a key dictionary that can be used + key_dict = {'key': key_string} + return get_credential_storage_custom_key( filename, key_dict, warn_on_readonly=warn_on_readonly) @util.positional(2) -def get_credential_storage_custom_key( - filename, key_dict, warn_on_readonly=True): - """Get a Storage instance for a credential using a dictionary as a key. +def get_credential_storage_custom_key(filename, key_dict, + warn_on_readonly=True): + """Get a Storage instance for a credential using a dictionary as a key. - Allows you to provide a dictionary as a custom key that will be used for - credential storage and retrieval. + Allows you to provide a dictionary as a custom key that will be used for + credential storage and retrieval. - Args: - filename: The JSON file storing a set of credentials - key_dict: A dictionary to use as the key for storing this credential. There - is no ordering of the keys in the dictionary. Logically equivalent - dictionaries will produce equivalent storage keys. - warn_on_readonly: if True, log a warning if the store is readonly + Args: + filename: The JSON file storing a set of credentials + key_dict: A dictionary to use as the key for storing this credential. + There is no ordering of the keys in the dictionary. Logically + equivalent dictionaries will produce equivalent storage keys. + warn_on_readonly: if True, log a warning if the store is readonly - Returns: - An object derived from client.Storage for getting/setting the - credential. - """ - multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly) - key = util.dict_to_tuple_key(key_dict) - return multistore._get_storage(key) + Returns: + An object derived from client.Storage for getting/setting the + credential. + """ + multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly) + key = util.dict_to_tuple_key(key_dict) + return multistore._get_storage(key) @util.positional(1) def get_all_credential_keys(filename, warn_on_readonly=True): - """Gets all the registered credential keys in the given Multistore. + """Gets all the registered credential keys in the given Multistore. - Args: - filename: The JSON file storing a set of credentials - warn_on_readonly: if True, log a warning if the store is readonly + Args: + filename: The JSON file storing a set of credentials + warn_on_readonly: if True, log a warning if the store is readonly - Returns: - A list of the credential keys present in the file. They are returned as - dictionaries that can be passed into get_credential_storage_custom_key to - get the actual credentials. - """ - multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly) - multistore._lock() - try: - return multistore._get_all_credential_keys() - finally: - multistore._unlock() + Returns: + A list of the credential keys present in the file. They are returned + as dictionaries that can be passed into + get_credential_storage_custom_key to get the actual credentials. + """ + multistore = _get_multistore(filename, warn_on_readonly=warn_on_readonly) + multistore._lock() + try: + return multistore._get_all_credential_keys() + finally: + multistore._unlock() @util.positional(1) def _get_multistore(filename, warn_on_readonly=True): - """A helper method to initialize the multistore with proper locking. + """A helper method to initialize the multistore with proper locking. - Args: - filename: The JSON file storing a set of credentials - warn_on_readonly: if True, log a warning if the store is readonly + Args: + filename: The JSON file storing a set of credentials + warn_on_readonly: if True, log a warning if the store is readonly - Returns: - A multistore object - """ - filename = os.path.expanduser(filename) - _multistores_lock.acquire() - try: - multistore = _multistores.setdefault( - filename, _MultiStore(filename, warn_on_readonly=warn_on_readonly)) - finally: - _multistores_lock.release() - return multistore + Returns: + A multistore object + """ + filename = os.path.expanduser(filename) + _multistores_lock.acquire() + try: + multistore = _multistores.setdefault( + filename, _MultiStore(filename, warn_on_readonly=warn_on_readonly)) + finally: + _multistores_lock.release() + return multistore class _MultiStore(object): - """A file backed store for multiple credentials.""" + """A file backed store for multiple credentials.""" - @util.positional(2) - def __init__(self, filename, warn_on_readonly=True): - """Initialize the class. + @util.positional(2) + def __init__(self, filename, warn_on_readonly=True): + """Initialize the class. - This will create the file if necessary. - """ - self._file = LockedFile(filename, 'r+', 'r') - self._thread_lock = threading.Lock() - self._read_only = False - self._warn_on_readonly = warn_on_readonly + This will create the file if necessary. + """ + self._file = LockedFile(filename, 'r+', 'r') + self._thread_lock = threading.Lock() + self._read_only = False + self._warn_on_readonly = warn_on_readonly - self._create_file_if_needed() + self._create_file_if_needed() - # Cache of deserialized store. This is only valid after the - # _MultiStore is locked or _refresh_data_cache is called. This is - # of the form of: - # - # ((key, value), (key, value)...) -> OAuth2Credential - # - # If this is None, then the store hasn't been read yet. - self._data = None + # Cache of deserialized store. This is only valid after the + # _MultiStore is locked or _refresh_data_cache is called. This is + # of the form of: + # + # ((key, value), (key, value)...) -> OAuth2Credential + # + # If this is None, then the store hasn't been read yet. + self._data = None - class _Storage(BaseStorage): - """A Storage object that knows how to read/write a single credential.""" + class _Storage(BaseStorage): + """A Storage object that can read/write a single credential.""" - def __init__(self, multistore, key): - self._multistore = multistore - self._key = key + def __init__(self, multistore, key): + self._multistore = multistore + self._key = key - def acquire_lock(self): - """Acquires any lock necessary to access this Storage. + def acquire_lock(self): + """Acquires any lock necessary to access this Storage. - This lock is not reentrant. - """ - self._multistore._lock() + This lock is not reentrant. + """ + self._multistore._lock() - def release_lock(self): - """Release the Storage lock. + def release_lock(self): + """Release the Storage lock. - Trying to release a lock that isn't held will result in a - RuntimeError. - """ - self._multistore._unlock() + Trying to release a lock that isn't held will result in a + RuntimeError. + """ + self._multistore._unlock() - def locked_get(self): - """Retrieve credential. + def locked_get(self): + """Retrieve credential. - The Storage lock must be held when this is called. + The Storage lock must be held when this is called. - Returns: - oauth2client.client.Credentials - """ - credential = self._multistore._get_credential(self._key) - if credential: - credential.set_store(self) - return credential + Returns: + oauth2client.client.Credentials + """ + credential = self._multistore._get_credential(self._key) + if credential: + credential.set_store(self) + return credential - def locked_put(self, credentials): - """Write a credential. + def locked_put(self, credentials): + """Write a credential. - The Storage lock must be held when this is called. + The Storage lock must be held when this is called. - Args: - credentials: Credentials, the credentials to store. - """ - self._multistore._update_credential(self._key, credentials) + Args: + credentials: Credentials, the credentials to store. + """ + self._multistore._update_credential(self._key, credentials) - def locked_delete(self): - """Delete a credential. + def locked_delete(self): + """Delete a credential. - The Storage lock must be held when this is called. + The Storage lock must be held when this is called. - Args: - credentials: Credentials, the credentials to store. - """ - self._multistore._delete_credential(self._key) + Args: + credentials: Credentials, the credentials to store. + """ + self._multistore._delete_credential(self._key) - def _create_file_if_needed(self): - """Create an empty file if necessary. + def _create_file_if_needed(self): + """Create an empty file if necessary. - This method will not initialize the file. Instead it implements a - simple version of "touch" to ensure the file has been created. - """ - if not os.path.exists(self._file.filename()): - old_umask = os.umask(0o177) - try: - open(self._file.filename(), 'a+b').close() - finally: - os.umask(old_umask) + This method will not initialize the file. Instead it implements a + simple version of "touch" to ensure the file has been created. + """ + if not os.path.exists(self._file.filename()): + old_umask = os.umask(0o177) + try: + open(self._file.filename(), 'a+b').close() + finally: + os.umask(old_umask) - def _lock(self): - """Lock the entire multistore.""" - self._thread_lock.acquire() - try: - self._file.open_and_lock() - except IOError as e: - if e.errno == errno.ENOSYS: - logger.warn('File system does not support locking the credentials ' - 'file.') - elif e.errno == errno.ENOLCK: - logger.warn('File system is out of resources for writing the ' - 'credentials file (is your disk full?).') - else: - raise - if not self._file.is_locked(): - self._read_only = True - if self._warn_on_readonly: - logger.warn('The credentials file (%s) is not writable. Opening in ' - 'read-only mode. Any refreshed credentials will only be ' - 'valid for this run.', self._file.filename()) - if os.path.getsize(self._file.filename()) == 0: - logger.debug('Initializing empty multistore file') - # The multistore is empty so write out an empty file. - self._data = {} - self._write() - elif not self._read_only or self._data is None: - # Only refresh the data if we are read/write or we haven't - # cached the data yet. If we are readonly, we assume is isn't - # changing out from under us and that we only have to read it - # once. This prevents us from whacking any new access keys that - # we have cached in memory but were unable to write out. - self._refresh_data_cache() + def _lock(self): + """Lock the entire multistore.""" + self._thread_lock.acquire() + try: + self._file.open_and_lock() + except IOError as e: + if e.errno == errno.ENOSYS: + logger.warn('File system does not support locking the ' + 'credentials file.') + elif e.errno == errno.ENOLCK: + logger.warn('File system is out of resources for writing the ' + 'credentials file (is your disk full?).') + else: + raise + if not self._file.is_locked(): + self._read_only = True + if self._warn_on_readonly: + logger.warn('The credentials file (%s) is not writable. ' + 'Opening in read-only mode. Any refreshed ' + 'credentials will only be ' + 'valid for this run.', self._file.filename()) + if os.path.getsize(self._file.filename()) == 0: + logger.debug('Initializing empty multistore file') + # The multistore is empty so write out an empty file. + self._data = {} + self._write() + elif not self._read_only or self._data is None: + # Only refresh the data if we are read/write or we haven't + # cached the data yet. If we are readonly, we assume is isn't + # changing out from under us and that we only have to read it + # once. This prevents us from whacking any new access keys that + # we have cached in memory but were unable to write out. + self._refresh_data_cache() - def _unlock(self): - """Release the lock on the multistore.""" - self._file.unlock_and_close() - self._thread_lock.release() + def _unlock(self): + """Release the lock on the multistore.""" + self._file.unlock_and_close() + self._thread_lock.release() - def _locked_json_read(self): - """Get the raw content of the multistore file. + def _locked_json_read(self): + """Get the raw content of the multistore file. - The multistore must be locked when this is called. + The multistore must be locked when this is called. - Returns: - The contents of the multistore decoded as JSON. - """ - assert self._thread_lock.locked() - self._file.file_handle().seek(0) - return json.load(self._file.file_handle()) + Returns: + The contents of the multistore decoded as JSON. + """ + assert self._thread_lock.locked() + self._file.file_handle().seek(0) + return json.load(self._file.file_handle()) - def _locked_json_write(self, data): - """Write a JSON serializable data structure to the multistore. + def _locked_json_write(self, data): + """Write a JSON serializable data structure to the multistore. - The multistore must be locked when this is called. + The multistore must be locked when this is called. - Args: - data: The data to be serialized and written. - """ - assert self._thread_lock.locked() - if self._read_only: - return - self._file.file_handle().seek(0) - json.dump(data, self._file.file_handle(), sort_keys=True, indent=2, separators=(',', ': ')) - self._file.file_handle().truncate() + Args: + data: The data to be serialized and written. + """ + assert self._thread_lock.locked() + if self._read_only: + return + self._file.file_handle().seek(0) + json.dump(data, self._file.file_handle(), + sort_keys=True, indent=2, separators=(',', ': ')) + self._file.file_handle().truncate() - def _refresh_data_cache(self): - """Refresh the contents of the multistore. + def _refresh_data_cache(self): + """Refresh the contents of the multistore. - The multistore must be locked when this is called. + The multistore must be locked when this is called. - Raises: - NewerCredentialStoreError: Raised when a newer client has written the - store. - """ - self._data = {} - try: - raw_data = self._locked_json_read() - except Exception: - logger.warn('Credential data store could not be loaded. ' - 'Will ignore and overwrite.') - return + Raises: + NewerCredentialStoreError: Raised when a newer client has written + the store. + """ + self._data = {} + try: + raw_data = self._locked_json_read() + except Exception: + logger.warn('Credential data store could not be loaded. ' + 'Will ignore and overwrite.') + return - version = 0 - try: - version = raw_data['file_version'] - except Exception: - logger.warn('Missing version for credential data store. It may be ' - 'corrupt or an old version. Overwriting.') - if version > 1: - raise NewerCredentialStoreError( - 'Credential file has file_version of %d. ' - 'Only file_version of 1 is supported.' % version) + version = 0 + try: + version = raw_data['file_version'] + except Exception: + logger.warn('Missing version for credential data store. It may be ' + 'corrupt or an old version. Overwriting.') + if version > 1: + raise NewerCredentialStoreError( + 'Credential file has file_version of %d. ' + 'Only file_version of 1 is supported.' % version) - credentials = [] - try: - credentials = raw_data['data'] - except (TypeError, KeyError): - pass + credentials = [] + try: + credentials = raw_data['data'] + except (TypeError, KeyError): + pass - for cred_entry in credentials: - try: - (key, credential) = self._decode_credential_from_json(cred_entry) - self._data[key] = credential - except: - # If something goes wrong loading a credential, just ignore it - logger.info('Error decoding credential, skipping', exc_info=True) + for cred_entry in credentials: + try: + key, credential = self._decode_credential_from_json(cred_entry) + self._data[key] = credential + except: + # If something goes wrong loading a credential, just ignore it + logger.info('Error decoding credential, skipping', + exc_info=True) - def _decode_credential_from_json(self, cred_entry): - """Load a credential from our JSON serialization. + def _decode_credential_from_json(self, cred_entry): + """Load a credential from our JSON serialization. - Args: - cred_entry: A dict entry from the data member of our format + Args: + cred_entry: A dict entry from the data member of our format - Returns: - (key, cred) where the key is the key tuple and the cred is the - OAuth2Credential object. - """ - raw_key = cred_entry['key'] - key = util.dict_to_tuple_key(raw_key) - credential = None - credential = Credentials.new_from_json(json.dumps(cred_entry['credential'])) - return (key, credential) + Returns: + (key, cred) where the key is the key tuple and the cred is the + OAuth2Credential object. + """ + raw_key = cred_entry['key'] + key = util.dict_to_tuple_key(raw_key) + credential = None + credential = Credentials.new_from_json( + json.dumps(cred_entry['credential'])) + return (key, credential) - def _write(self): - """Write the cached data back out. + def _write(self): + """Write the cached data back out. - The multistore must be locked. - """ - raw_data = {'file_version': 1} - raw_creds = [] - raw_data['data'] = raw_creds - for (cred_key, cred) in self._data.items(): - raw_key = dict(cred_key) - raw_cred = json.loads(cred.to_json()) - raw_creds.append({'key': raw_key, 'credential': raw_cred}) - self._locked_json_write(raw_data) + The multistore must be locked. + """ + raw_data = {'file_version': 1} + raw_creds = [] + raw_data['data'] = raw_creds + for (cred_key, cred) in self._data.items(): + raw_key = dict(cred_key) + raw_cred = json.loads(cred.to_json()) + raw_creds.append({'key': raw_key, 'credential': raw_cred}) + self._locked_json_write(raw_data) - def _get_all_credential_keys(self): - """Gets all the registered credential keys in the multistore. + def _get_all_credential_keys(self): + """Gets all the registered credential keys in the multistore. - Returns: - A list of dictionaries corresponding to all the keys currently registered - """ - return [dict(key) for key in self._data.keys()] + Returns: + A list of dictionaries corresponding to all the keys currently + registered + """ + return [dict(key) for key in self._data.keys()] - def _get_credential(self, key): - """Get a credential from the multistore. + def _get_credential(self, key): + """Get a credential from the multistore. - The multistore must be locked. + The multistore must be locked. - Args: - key: The key used to retrieve the credential + Args: + key: The key used to retrieve the credential - Returns: - The credential specified or None if not present - """ - return self._data.get(key, None) + Returns: + The credential specified or None if not present + """ + return self._data.get(key, None) - def _update_credential(self, key, cred): - """Update a credential and write the multistore. + def _update_credential(self, key, cred): + """Update a credential and write the multistore. - This must be called when the multistore is locked. + This must be called when the multistore is locked. - Args: - key: The key used to retrieve the credential - cred: The OAuth2Credential to update/set - """ - self._data[key] = cred - self._write() + Args: + key: The key used to retrieve the credential + cred: The OAuth2Credential to update/set + """ + self._data[key] = cred + self._write() - def _delete_credential(self, key): - """Delete a credential and write the multistore. + def _delete_credential(self, key): + """Delete a credential and write the multistore. - This must be called when the multistore is locked. + This must be called when the multistore is locked. - Args: - key: The key used to retrieve the credential - """ - try: - del self._data[key] - except KeyError: - pass - self._write() + Args: + key: The key used to retrieve the credential + """ + try: + del self._data[key] + except KeyError: + pass + self._write() - def _get_storage(self, key): - """Get a Storage object to get/set a credential. + def _get_storage(self, key): + """Get a Storage object to get/set a credential. - This Storage is a 'view' into the multistore. + This Storage is a 'view' into the multistore. - Args: - key: The key used to retrieve the credential + Args: + key: The key used to retrieve the credential - Returns: - A Storage object that can be used to get/set this cred - """ - return self._Storage(self, key) + Returns: + A Storage object that can be used to get/set this cred + """ + return self._Storage(self, key) diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py index d1d1d895..8d3dc652 100644 --- a/oauth2client/service_account.py +++ b/oauth2client/service_account.py @@ -18,8 +18,6 @@ This credentials class is implemented on top of rsa library. """ import base64 -import json -import six import time from pyasn1.codec.ber import decoder @@ -28,112 +26,108 @@ import rsa from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_TOKEN_URI +from oauth2client._helpers import _json_encode +from oauth2client._helpers import _to_bytes +from oauth2client._helpers import _urlsafe_b64encode from oauth2client import util from oauth2client.client import AssertionCredentials class _ServiceAccountCredentials(AssertionCredentials): - """Class representing a service account (signed JWT) credential.""" + """Class representing a service account (signed JWT) credential.""" - MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds + MAX_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds - def __init__(self, service_account_id, service_account_email, private_key_id, - private_key_pkcs8_text, scopes, user_agent=None, - token_uri=GOOGLE_TOKEN_URI, revoke_uri=GOOGLE_REVOKE_URI, - **kwargs): + def __init__(self, service_account_id, service_account_email, + private_key_id, private_key_pkcs8_text, scopes, + user_agent=None, token_uri=GOOGLE_TOKEN_URI, + revoke_uri=GOOGLE_REVOKE_URI, **kwargs): - super(_ServiceAccountCredentials, self).__init__( - None, user_agent=user_agent, token_uri=token_uri, revoke_uri=revoke_uri) + super(_ServiceAccountCredentials, self).__init__( + None, user_agent=user_agent, token_uri=token_uri, + revoke_uri=revoke_uri) - self._service_account_id = service_account_id - self._service_account_email = service_account_email - self._private_key_id = private_key_id - self._private_key = _get_private_key(private_key_pkcs8_text) - self._private_key_pkcs8_text = private_key_pkcs8_text - self._scopes = util.scopes_to_string(scopes) - self._user_agent = user_agent - self._token_uri = token_uri - self._revoke_uri = revoke_uri - self._kwargs = kwargs + self._service_account_id = service_account_id + self._service_account_email = service_account_email + self._private_key_id = private_key_id + self._private_key = _get_private_key(private_key_pkcs8_text) + self._private_key_pkcs8_text = private_key_pkcs8_text + self._scopes = util.scopes_to_string(scopes) + self._user_agent = user_agent + self._token_uri = token_uri + self._revoke_uri = revoke_uri + self._kwargs = kwargs - def _generate_assertion(self): - """Generate the assertion that will be used in the request.""" + def _generate_assertion(self): + """Generate the assertion that will be used in the request.""" - header = { - 'alg': 'RS256', - 'typ': 'JWT', - 'kid': self._private_key_id - } + header = { + 'alg': 'RS256', + 'typ': 'JWT', + 'kid': self._private_key_id + } - now = int(time.time()) - payload = { - 'aud': self._token_uri, - 'scope': self._scopes, - 'iat': now, - 'exp': now + _ServiceAccountCredentials.MAX_TOKEN_LIFETIME_SECS, - 'iss': self._service_account_email - } - payload.update(self._kwargs) + now = int(time.time()) + payload = { + 'aud': self._token_uri, + 'scope': self._scopes, + 'iat': now, + 'exp': now + _ServiceAccountCredentials.MAX_TOKEN_LIFETIME_SECS, + 'iss': self._service_account_email + } + payload.update(self._kwargs) - assertion_input = (_urlsafe_b64encode(header) + b'.' + - _urlsafe_b64encode(payload)) + first_segment = _urlsafe_b64encode(_json_encode(header)) + second_segment = _urlsafe_b64encode(_json_encode(payload)) + assertion_input = first_segment + b'.' + second_segment - # Sign the assertion. - rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, 'SHA-256') - signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=') + # Sign the assertion. + rsa_bytes = rsa.pkcs1.sign(assertion_input, self._private_key, + 'SHA-256') + signature = base64.urlsafe_b64encode(rsa_bytes).rstrip(b'=') - return assertion_input + b'.' + signature + return assertion_input + b'.' + signature - def sign_blob(self, blob): - # Ensure that it is bytes - try: - blob = blob.encode('utf-8') - except AttributeError: - pass - return (self._private_key_id, - rsa.pkcs1.sign(blob, self._private_key, 'SHA-256')) + def sign_blob(self, blob): + # Ensure that it is bytes + blob = _to_bytes(blob, encoding='utf-8') + return (self._private_key_id, + rsa.pkcs1.sign(blob, self._private_key, 'SHA-256')) - @property - def service_account_email(self): - return self._service_account_email + @property + def service_account_email(self): + return self._service_account_email - @property - def serialization_data(self): - return { - 'type': 'service_account', - 'client_id': self._service_account_id, - 'client_email': self._service_account_email, - 'private_key_id': self._private_key_id, - 'private_key': self._private_key_pkcs8_text - } + @property + def serialization_data(self): + return { + 'type': 'service_account', + 'client_id': self._service_account_id, + 'client_email': self._service_account_email, + 'private_key_id': self._private_key_id, + 'private_key': self._private_key_pkcs8_text + } - def create_scoped_required(self): - return not self._scopes + def create_scoped_required(self): + return not self._scopes - def create_scoped(self, scopes): - return _ServiceAccountCredentials(self._service_account_id, - self._service_account_email, - self._private_key_id, - self._private_key_pkcs8_text, - scopes, - user_agent=self._user_agent, - token_uri=self._token_uri, - revoke_uri=self._revoke_uri, - **self._kwargs) - - -def _urlsafe_b64encode(data): - return base64.urlsafe_b64encode( - json.dumps(data, separators=(',', ':')).encode('UTF-8')).rstrip(b'=') + def create_scoped(self, scopes): + return _ServiceAccountCredentials(self._service_account_id, + self._service_account_email, + self._private_key_id, + self._private_key_pkcs8_text, + scopes, + user_agent=self._user_agent, + token_uri=self._token_uri, + revoke_uri=self._revoke_uri, + **self._kwargs) def _get_private_key(private_key_pkcs8_text): - """Get an RSA private key object from a pkcs8 representation.""" - - if not isinstance(private_key_pkcs8_text, six.binary_type): - private_key_pkcs8_text = private_key_pkcs8_text.encode('ascii') - der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY') - asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo()) - return rsa.PrivateKey.load_pkcs1( - asn1_private_key.getComponentByName('privateKey').asOctets(), - format='DER') + """Get an RSA private key object from a pkcs8 representation.""" + private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text) + der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY') + asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo()) + return rsa.PrivateKey.load_pkcs1( + asn1_private_key.getComponentByName('privateKey').asOctets(), + format='DER') diff --git a/oauth2client/tools.py b/oauth2client/tools.py index 3c729031..629866b1 100644 --- a/oauth2client/tools.py +++ b/oauth2client/tools.py @@ -21,9 +21,6 @@ the same directory. from __future__ import print_function -__author__ = 'jcgregorio@google.com (Joe Gregorio)' -__all__ = ['argparser', 'run_flow', 'run', 'message_if_missing'] - import logging import socket import sys @@ -36,6 +33,9 @@ from oauth2client import client from oauth2client import util +__author__ = 'jcgregorio@google.com (Joe Gregorio)' +__all__ = ['argparser', 'run_flow', 'message_if_missing'] + _CLIENT_SECRETS_MESSAGE = """WARNING: Please configure OAuth 2.0 To make this sample run you will need to populate the client_secrets.json file @@ -47,22 +47,24 @@ with information from the APIs Console . """ + def _CreateArgumentParser(): - try: - import argparse - except ImportError: - return None - parser = argparse.ArgumentParser(add_help=False) - parser.add_argument('--auth_host_name', default='localhost', - help='Hostname when running a local web server.') - parser.add_argument('--noauth_local_webserver', action='store_true', - default=False, help='Do not run a local web server.') - parser.add_argument('--auth_host_port', default=[8080, 8090], type=int, - nargs='*', help='Port web server should listen on.') - parser.add_argument('--logging_level', default='ERROR', - choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], - help='Set the logging level of detail.') - return parser + try: + import argparse + except ImportError: + return None + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument('--auth_host_name', default='localhost', + help='Hostname when running a local web server.') + parser.add_argument('--noauth_local_webserver', action='store_true', + default=False, help='Do not run a local web server.') + parser.add_argument('--auth_host_port', default=[8080, 8090], type=int, + nargs='*', help='Port web server should listen on.') + parser.add_argument( + '--logging_level', default='ERROR', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + help='Set the logging level of detail.') + return parser # argparser is an ArgumentParser that contains command-line options expected # by tools.run(). Pass it in as part of the 'parents' argument to your own @@ -71,187 +73,172 @@ argparser = _CreateArgumentParser() class ClientRedirectServer(BaseHTTPServer.HTTPServer): - """A server to handle OAuth 2.0 redirects back to localhost. + """A server to handle OAuth 2.0 redirects back to localhost. - Waits for a single request and parses the query parameters - into query_params and then stops serving. - """ - query_params = {} + Waits for a single request and parses the query parameters + into query_params and then stops serving. + """ + query_params = {} class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler): - """A handler for OAuth 2.0 redirects back to localhost. + """A handler for OAuth 2.0 redirects back to localhost. - Waits for a single request and parses the query parameters - into the servers query_params and then stops serving. - """ - - def do_GET(self): - """Handle a GET request. - - Parses the query parameters and prints a message - if the flow has completed. Note that we can't detect - if an error occurred. + Waits for a single request and parses the query parameters + into the servers query_params and then stops serving. """ - self.send_response(200) - self.send_header("Content-type", "text/html") - self.end_headers() - query = self.path.split('?', 1)[-1] - query = dict(urllib.parse.parse_qsl(query)) - self.server.query_params = query - self.wfile.write(b"Authentication Status") - self.wfile.write(b"

The authentication flow has completed.

") - self.wfile.write(b"") - def log_message(self, format, *args): - """Do not log messages to stdout while running as command line program.""" + def do_GET(self): + """Handle a GET request. + + Parses the query parameters and prints a message + if the flow has completed. Note that we can't detect + if an error occurred. + """ + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + query = self.path.split('?', 1)[-1] + query = dict(urllib.parse.parse_qsl(query)) + self.server.query_params = query + self.wfile.write( + b"Authentication Status") + self.wfile.write( + b"

The authentication flow has completed.

") + self.wfile.write(b"") + + def log_message(self, format, *args): + """Do not log messages to stdout while running as cmd. line program.""" @util.positional(3) def run_flow(flow, storage, flags, http=None): - """Core code for a command-line application. + """Core code for a command-line application. - The ``run()`` function is called from your application and runs - through all the steps to obtain credentials. It takes a ``Flow`` - argument and attempts to open an authorization server page in the - user's default web browser. The server asks the user to grant your - application access to the user's data. If the user grants access, - the ``run()`` function returns new credentials. The new credentials - are also stored in the ``storage`` argument, which updates the file - associated with the ``Storage`` object. + The ``run()`` function is called from your application and runs + through all the steps to obtain credentials. It takes a ``Flow`` + argument and attempts to open an authorization server page in the + user's default web browser. The server asks the user to grant your + application access to the user's data. If the user grants access, + the ``run()`` function returns new credentials. The new credentials + are also stored in the ``storage`` argument, which updates the file + associated with the ``Storage`` object. - It presumes it is run from a command-line application and supports the - following flags: + It presumes it is run from a command-line application and supports the + following flags: - ``--auth_host_name`` (string, default: ``localhost``) - Host name to use when running a local web server to handle - redirects during OAuth authorization. + ``--auth_host_name`` (string, default: ``localhost``) + Host name to use when running a local web server to handle + redirects during OAuth authorization. - ``--auth_host_port`` (integer, default: ``[8080, 8090]``) - Port to use when running a local web server to handle redirects - during OAuth authorization. Repeat this option to specify a list - of values. + ``--auth_host_port`` (integer, default: ``[8080, 8090]``) + Port to use when running a local web server to handle redirects + during OAuth authorization. Repeat this option to specify a list + of values. - ``--[no]auth_local_webserver`` (boolean, default: ``True``) - Run a local web server to handle redirects during OAuth authorization. + ``--[no]auth_local_webserver`` (boolean, default: ``True``) + Run a local web server to handle redirects during OAuth + authorization. + The tools module defines an ``ArgumentParser`` the already contains the + flag definitions that ``run()`` requires. You can pass that + ``ArgumentParser`` to your ``ArgumentParser`` constructor:: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + parents=[tools.argparser]) + flags = parser.parse_args(argv) + Args: + flow: Flow, an OAuth 2.0 Flow to step through. + storage: Storage, a ``Storage`` to store the credential in. + flags: ``argparse.Namespace``, The command-line flags. This is the + object returned from calling ``parse_args()`` on + ``argparse.ArgumentParser`` as described above. + http: An instance of ``httplib2.Http.request`` or something that + acts like it. - The tools module defines an ``ArgumentParser`` the already contains the flag - definitions that ``run()`` requires. You can pass that ``ArgumentParser`` to your - ``ArgumentParser`` constructor:: + Returns: + Credentials, the obtained credential. + """ + logging.getLogger().setLevel(getattr(logging, flags.logging_level)) + if not flags.noauth_local_webserver: + success = False + port_number = 0 + for port in flags.auth_host_port: + port_number = port + try: + httpd = ClientRedirectServer((flags.auth_host_name, port), + ClientRedirectHandler) + except socket.error: + pass + else: + success = True + break + flags.noauth_local_webserver = not success + if not success: + print('Failed to start a local webserver listening ' + 'on either port 8080') + print('or port 8090. Please check your firewall settings and locally') + print('running programs that may be blocking or using those ports.') + print() + print('Falling back to --noauth_local_webserver and continuing with') + print('authorization.') + print() - parser = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - parents=[tools.argparser]) - flags = parser.parse_args(argv) - - Args: - flow: Flow, an OAuth 2.0 Flow to step through. - storage: Storage, a ``Storage`` to store the credential in. - flags: ``argparse.Namespace``, The command-line flags. This is the - object returned from calling ``parse_args()`` on - ``argparse.ArgumentParser`` as described above. - http: An instance of ``httplib2.Http.request`` or something that - acts like it. - - Returns: - Credentials, the obtained credential. - """ - logging.getLogger().setLevel(getattr(logging, flags.logging_level)) - if not flags.noauth_local_webserver: - success = False - port_number = 0 - for port in flags.auth_host_port: - port_number = port - try: - httpd = ClientRedirectServer((flags.auth_host_name, port), - ClientRedirectHandler) - except socket.error: - pass - else: - success = True - break - flags.noauth_local_webserver = not success - if not success: - print('Failed to start a local webserver listening on either port 8080') - print('or port 9090. Please check your firewall settings and locally') - print('running programs that may be blocking or using those ports.') - print() - print('Falling back to --noauth_local_webserver and continuing with') - print('authorization.') - print() - - if not flags.noauth_local_webserver: - oauth_callback = 'http://%s:%s/' % (flags.auth_host_name, port_number) - else: - oauth_callback = client.OOB_CALLBACK_URN - flow.redirect_uri = oauth_callback - authorize_url = flow.step1_get_authorize_url() - - if flags.short_url: - try: - from googleapiclient.discovery import build - service = build('urlshortener', 'v1', http=http) - url_result = service.url().insert(body={'longUrl': authorize_url}, - key=u'AIzaSyBlmgbii8QfJSYmC9VTMOfqrAt5Vj5wtzE').execute() - authorize_url = url_result['id'] - except: - pass - - if not flags.noauth_local_webserver: - import webbrowser - webbrowser.open(authorize_url, new=1, autoraise=True) - print('Your browser has been opened to visit:') - print() - print(' ' + authorize_url) - print() - print('If your browser is on a different machine then exit and re-run this') - print('after creating a file called nobrowser.txt in the same path as GAM.') - print() - else: - print('Go to the following link in your browser:') - print() - print(' ' + authorize_url) - print() - - code = None - if not flags.noauth_local_webserver: - httpd.handle_request() - if 'error' in httpd.query_params: - sys.exit('Authentication request was rejected.') - if 'code' in httpd.query_params: - code = httpd.query_params['code'] + if not flags.noauth_local_webserver: + oauth_callback = 'http://%s:%s/' % (flags.auth_host_name, port_number) else: - print('Failed to find "code" in the query parameters of the redirect.') - sys.exit('Try running with --noauth_local_webserver.') - else: - code = input('Enter verification code: ').strip() + oauth_callback = client.OOB_CALLBACK_URN + flow.redirect_uri = oauth_callback + authorize_url = flow.step1_get_authorize_url() - try: - credential = flow.step2_exchange(code, http=http) - except client.FlowExchangeError as e: - sys.exit('Authentication has failed: %s' % e) + if not flags.noauth_local_webserver: + import webbrowser + webbrowser.open(authorize_url, new=1, autoraise=True) + print('Your browser has been opened to visit:') + print() + print(' ' + authorize_url) + print() + print('If your browser is on a different machine then ' + 'exit and re-run this') + print('application with the command-line parameter ') + print() + print(' --noauth_local_webserver') + print() + else: + print('Go to the following link in your browser:') + print() + print(' ' + authorize_url) + print() - storage.put(credential) - credential.set_store(storage) - print('Authentication successful.') + code = None + if not flags.noauth_local_webserver: + httpd.handle_request() + if 'error' in httpd.query_params: + sys.exit('Authentication request was rejected.') + if 'code' in httpd.query_params: + code = httpd.query_params['code'] + else: + print('Failed to find "code" in the query parameters ' + 'of the redirect.') + sys.exit('Try running with --noauth_local_webserver.') + else: + code = input('Enter verification code: ').strip() - return credential + try: + credential = flow.step2_exchange(code, http=http) + except client.FlowExchangeError as e: + sys.exit('Authentication has failed: %s' % e) + + storage.put(credential) + credential.set_store(storage) + print('Authentication successful.') + + return credential def message_if_missing(filename): - """Helpful message to display if the CLIENT_SECRETS file is missing.""" - - return _CLIENT_SECRETS_MESSAGE % filename - -try: - from oauth2client.old_run import run - from oauth2client.old_run import FLAGS -except ImportError: - def run(*args, **kwargs): - raise NotImplementedError( - 'The gflags library must be installed to use tools.run(). ' - 'Please install gflags or preferrably switch to using ' - 'tools.run_flow().') + """Helpful message to display if the CLIENT_SECRETS file is missing.""" + return _CLIENT_SECRETS_MESSAGE % filename diff --git a/oauth2client/util.py b/oauth2client/util.py index a706f026..352afd84 100644 --- a/oauth2client/util.py +++ b/oauth2client/util.py @@ -17,6 +17,16 @@ """Common utility library.""" +import functools +import inspect +import logging +import sys +import types + +import six +from six.moves import urllib + + __author__ = [ 'rafek@google.com (Rafe Kaplan)', 'guido@google.com (Guido van Rossum)', @@ -29,16 +39,6 @@ __all__ = [ 'POSITIONAL_IGNORE', ] -import functools -import inspect -import logging -import sys -import types - -import six -from six.moves import urllib - - logger = logging.getLogger(__name__) POSITIONAL_WARNING = 'WARNING' @@ -49,153 +49,178 @@ POSITIONAL_SET = frozenset([POSITIONAL_WARNING, POSITIONAL_EXCEPTION, positional_parameters_enforcement = POSITIONAL_WARNING + def positional(max_positional_args): - """A decorator to declare that only the first N arguments my be positional. + """A decorator to declare that only the first N arguments my be positional. - This decorator makes it easy to support Python 3 style keyword-only - parameters. For example, in Python 3 it is possible to write:: + This decorator makes it easy to support Python 3 style keyword-only + parameters. For example, in Python 3 it is possible to write:: - def fn(pos1, *, kwonly1=None, kwonly1=None): - ... + def fn(pos1, *, kwonly1=None, kwonly1=None): + ... - All named parameters after ``*`` must be a keyword:: + All named parameters after ``*`` must be a keyword:: - fn(10, 'kw1', 'kw2') # Raises exception. - fn(10, kwonly1='kw1') # Ok. + fn(10, 'kw1', 'kw2') # Raises exception. + fn(10, kwonly1='kw1') # Ok. - Example - ^^^^^^^ + Example + ^^^^^^^ - To define a function like above, do:: + To define a function like above, do:: - @positional(1) - def fn(pos1, kwonly1=None, kwonly2=None): - ... + @positional(1) + def fn(pos1, kwonly1=None, kwonly2=None): + ... - If no default value is provided to a keyword argument, it becomes a required - keyword argument:: + If no default value is provided to a keyword argument, it becomes a + required keyword argument:: - @positional(0) - def fn(required_kw): - ... + @positional(0) + def fn(required_kw): + ... - This must be called with the keyword parameter:: + This must be called with the keyword parameter:: - fn() # Raises exception. - fn(10) # Raises exception. - fn(required_kw=10) # Ok. + fn() # Raises exception. + fn(10) # Raises exception. + fn(required_kw=10) # Ok. - When defining instance or class methods always remember to account for - ``self`` and ``cls``:: + When defining instance or class methods always remember to account for + ``self`` and ``cls``:: - class MyClass(object): + class MyClass(object): - @positional(2) - def my_method(self, pos1, kwonly1=None): - ... + @positional(2) + def my_method(self, pos1, kwonly1=None): + ... - @classmethod - @positional(2) - def my_method(cls, pos1, kwonly1=None): - ... + @classmethod + @positional(2) + def my_method(cls, pos1, kwonly1=None): + ... - The positional decorator behavior is controlled by - ``util.positional_parameters_enforcement``, which may be set to - ``POSITIONAL_EXCEPTION``, ``POSITIONAL_WARNING`` or - ``POSITIONAL_IGNORE`` to raise an exception, log a warning, or do - nothing, respectively, if a declaration is violated. + The positional decorator behavior is controlled by + ``util.positional_parameters_enforcement``, which may be set to + ``POSITIONAL_EXCEPTION``, ``POSITIONAL_WARNING`` or + ``POSITIONAL_IGNORE`` to raise an exception, log a warning, or do + nothing, respectively, if a declaration is violated. - Args: - max_positional_arguments: Maximum number of positional arguments. All - parameters after the this index must be keyword only. + Args: + max_positional_arguments: Maximum number of positional arguments. All + parameters after the this index must be + keyword only. - Returns: - A decorator that prevents using arguments after max_positional_args from - being used as positional parameters. + Returns: + A decorator that prevents using arguments after max_positional_args + from being used as positional parameters. - Raises: - TypeError if a key-word only argument is provided as a positional - parameter, but only if util.positional_parameters_enforcement is set to - POSITIONAL_EXCEPTION. + Raises: + TypeError: if a key-word only argument is provided as a positional + parameter, but only if + util.positional_parameters_enforcement is set to + POSITIONAL_EXCEPTION. + """ - """ - def positional_decorator(wrapped): - @functools.wraps(wrapped) - def positional_wrapper(*args, **kwargs): - if len(args) > max_positional_args: - plural_s = '' - if max_positional_args != 1: - plural_s = 's' - message = '%s() takes at most %d positional argument%s (%d given)' % ( - wrapped.__name__, max_positional_args, plural_s, len(args)) - if positional_parameters_enforcement == POSITIONAL_EXCEPTION: - raise TypeError(message) - elif positional_parameters_enforcement == POSITIONAL_WARNING: - logger.warning(message) - else: # IGNORE - pass - return wrapped(*args, **kwargs) - return positional_wrapper + def positional_decorator(wrapped): + @functools.wraps(wrapped) + def positional_wrapper(*args, **kwargs): + if len(args) > max_positional_args: + plural_s = '' + if max_positional_args != 1: + plural_s = 's' + message = ('%s() takes at most %d positional ' + 'argument%s (%d given)' % ( + wrapped.__name__, max_positional_args, + plural_s, len(args))) + if positional_parameters_enforcement == POSITIONAL_EXCEPTION: + raise TypeError(message) + elif positional_parameters_enforcement == POSITIONAL_WARNING: + logger.warning(message) + else: # IGNORE + pass + return wrapped(*args, **kwargs) + return positional_wrapper - if isinstance(max_positional_args, six.integer_types): - return positional_decorator - else: - args, _, _, defaults = inspect.getargspec(max_positional_args) - return positional(len(args) - len(defaults))(max_positional_args) + if isinstance(max_positional_args, six.integer_types): + return positional_decorator + else: + args, _, _, defaults = inspect.getargspec(max_positional_args) + return positional(len(args) - len(defaults))(max_positional_args) def scopes_to_string(scopes): - """Converts scope value to a string. + """Converts scope value to a string. - If scopes is a string then it is simply passed through. If scopes is an - iterable then a string is returned that is all the individual scopes - concatenated with spaces. + If scopes is a string then it is simply passed through. If scopes is an + iterable then a string is returned that is all the individual scopes + concatenated with spaces. - Args: - scopes: string or iterable of strings, the scopes. + Args: + scopes: string or iterable of strings, the scopes. - Returns: - The scopes formatted as a single string. - """ - if isinstance(scopes, six.string_types): - return scopes - else: - return ' '.join(scopes) + Returns: + The scopes formatted as a single string. + """ + if isinstance(scopes, six.string_types): + return scopes + else: + return ' '.join(scopes) + + +def string_to_scopes(scopes): + """Converts stringifed scope value to a list. + + If scopes is a list then it is simply passed through. If scopes is an + string then a list of each individual scope is returned. + + Args: + scopes: a string or iterable of strings, the scopes. + + Returns: + The scopes in a list. + """ + if not scopes: + return [] + if isinstance(scopes, six.string_types): + return scopes.split(' ') + else: + return scopes def dict_to_tuple_key(dictionary): - """Converts a dictionary to a tuple that can be used as an immutable key. + """Converts a dictionary to a tuple that can be used as an immutable key. - The resulting key is always sorted so that logically equivalent dictionaries - always produce an identical tuple for a key. + The resulting key is always sorted so that logically equivalent + dictionaries always produce an identical tuple for a key. - Args: - dictionary: the dictionary to use as the key. + Args: + dictionary: the dictionary to use as the key. - Returns: - A tuple representing the dictionary in it's naturally sorted ordering. - """ - return tuple(sorted(dictionary.items())) + Returns: + A tuple representing the dictionary in it's naturally sorted ordering. + """ + return tuple(sorted(dictionary.items())) def _add_query_parameter(url, name, value): - """Adds a query parameter to a url. + """Adds a query parameter to a url. - Replaces the current value if it already exists in the URL. + Replaces the current value if it already exists in the URL. - Args: - url: string, url to add the query parameter to. - name: string, query parameter name. - value: string, query parameter value. + Args: + url: string, url to add the query parameter to. + name: string, query parameter name. + value: string, query parameter value. - Returns: - Updated query parameter. Does not update the url if value is None. - """ - if value is None: - return url - else: - parsed = list(urllib.parse.urlparse(url)) - q = dict(urllib.parse.parse_qsl(parsed[4])) - q[name] = value - parsed[4] = urllib.parse.urlencode(q) - return urllib.parse.urlunparse(parsed) + Returns: + Updated query parameter. Does not update the url if value is None. + """ + if value is None: + return url + else: + parsed = list(urllib.parse.urlparse(url)) + q = dict(urllib.parse.parse_qsl(parsed[4])) + q[name] = value + parsed[4] = urllib.parse.urlencode(q) + return urllib.parse.urlunparse(parsed) diff --git a/oauth2client/xsrfutil.py b/oauth2client/xsrfutil.py index 5739dcf5..685eb46e 100644 --- a/oauth2client/xsrfutil.py +++ b/oauth2client/xsrfutil.py @@ -15,104 +15,94 @@ """Helper methods for creating & verifying XSRF tokens.""" +import base64 +import binascii +import hmac +import six +import time + +from oauth2client._helpers import _to_bytes +from oauth2client import util + __authors__ = [ '"Doug Coker" ', '"Joe Gregorio" ', ] - -import base64 -import hmac -import time - -import six -from oauth2client import util - - # Delimiter character DELIMITER = b':' - # 1 hour in seconds -DEFAULT_TIMEOUT_SECS = 1*60*60 - - -def _force_bytes(s): - if isinstance(s, bytes): - return s - s = str(s) - if isinstance(s, six.text_type): - return s.encode('utf-8') - return s +DEFAULT_TIMEOUT_SECS = 60 * 60 @util.positional(2) -def generate_token(key, user_id, action_id="", when=None): - """Generates a URL-safe token for the given user, action, time tuple. +def generate_token(key, user_id, action_id='', when=None): + """Generates a URL-safe token for the given user, action, time tuple. - Args: - key: secret key to use. - user_id: the user ID of the authenticated user. - action_id: a string identifier of the action they requested - authorization for. - when: the time in seconds since the epoch at which the user was - authorized for this action. If not set the current time is used. + Args: + key: secret key to use. + user_id: the user ID of the authenticated user. + action_id: a string identifier of the action they requested + authorization for. + when: the time in seconds since the epoch at which the user was + authorized for this action. If not set the current time is used. - Returns: - A string XSRF protection token. - """ - when = _force_bytes(when or int(time.time())) - digester = hmac.new(_force_bytes(key)) - digester.update(_force_bytes(user_id)) - digester.update(DELIMITER) - digester.update(_force_bytes(action_id)) - digester.update(DELIMITER) - digester.update(when) - digest = digester.digest() + Returns: + A string XSRF protection token. + """ + digester = hmac.new(_to_bytes(key, encoding='utf-8')) + digester.update(_to_bytes(str(user_id), encoding='utf-8')) + digester.update(DELIMITER) + digester.update(_to_bytes(action_id, encoding='utf-8')) + digester.update(DELIMITER) + when = _to_bytes(str(when or int(time.time())), encoding='utf-8') + digester.update(when) + digest = digester.digest() - token = base64.urlsafe_b64encode(digest + DELIMITER + when) - return token + token = base64.urlsafe_b64encode(digest + DELIMITER + when) + return token @util.positional(3) def validate_token(key, token, user_id, action_id="", current_time=None): - """Validates that the given token authorizes the user for the action. + """Validates that the given token authorizes the user for the action. - Tokens are invalid if the time of issue is too old or if the token - does not match what generateToken outputs (i.e. the token was forged). + Tokens are invalid if the time of issue is too old or if the token + does not match what generateToken outputs (i.e. the token was forged). - Args: - key: secret key to use. - token: a string of the token generated by generateToken. - user_id: the user ID of the authenticated user. - action_id: a string identifier of the action they requested - authorization for. + Args: + key: secret key to use. + token: a string of the token generated by generateToken. + user_id: the user ID of the authenticated user. + action_id: a string identifier of the action they requested + authorization for. - Returns: - A boolean - True if the user is authorized for the action, False - otherwise. - """ - if not token: - return False - try: - decoded = base64.urlsafe_b64decode(token) - token_time = int(decoded.split(DELIMITER)[-1]) - except (TypeError, ValueError): - return False - if current_time is None: - current_time = time.time() - # If the token is too old it's not valid. - if current_time - token_time > DEFAULT_TIMEOUT_SECS: - return False + Returns: + A boolean - True if the user is authorized for the action, False + otherwise. + """ + if not token: + return False + try: + decoded = base64.urlsafe_b64decode(token) + token_time = int(decoded.split(DELIMITER)[-1]) + except (TypeError, ValueError, binascii.Error): + return False + if current_time is None: + current_time = time.time() + # If the token is too old it's not valid. + if current_time - token_time > DEFAULT_TIMEOUT_SECS: + return False - # The given token should match the generated one with the same time. - expected_token = generate_token(key, user_id, action_id=action_id, - when=token_time) - if len(token) != len(expected_token): - return False + # The given token should match the generated one with the same time. + expected_token = generate_token(key, user_id, action_id=action_id, + when=token_time) + if len(token) != len(expected_token): + return False - # Perform constant time comparison to avoid timing attacks - different = 0 - for x, y in zip(bytearray(token), bytearray(expected_token)): - different |= x ^ y - return not different + # Perform constant time comparison to avoid timing attacks + different = 0 + for x, y in zip(bytearray(token), bytearray(expected_token)): + different |= x ^ y + return not different diff --git a/passlib/__init__.py b/passlib/__init__.py index aeea1fc2..0d2dfb2b 100644 --- a/passlib/__init__.py +++ b/passlib/__init__.py @@ -1,3 +1,3 @@ -"""passlib - suite of password hashing & generation routinges""" +"""passlib - suite of password hashing & generation routines""" -__version__ = '1.6.2' +__version__ = '1.6.5' diff --git a/passlib/_setup/docdist.py b/passlib/_setup/docdist.py index dadb4b53..19c4dc12 100644 --- a/passlib/_setup/docdist.py +++ b/passlib/_setup/docdist.py @@ -1,4 +1,4 @@ -"custom command to build doc.zip file" +"""custom command to build doc.zip file""" #============================================================================= # imports #============================================================================= diff --git a/passlib/_setup/stamp.py b/passlib/_setup/stamp.py index dfa62bc4..8a68658d 100644 --- a/passlib/_setup/stamp.py +++ b/passlib/_setup/stamp.py @@ -1,4 +1,4 @@ -"update version string during build" +"""update version string during build""" #============================================================================= # imports #============================================================================= @@ -21,7 +21,7 @@ def get_command_class(opts, name): return opts['cmdclass'].get(name) or Distribution().get_command_class(name) def stamp_source(base_dir, version, dry_run=False): - "update version string in passlib dist" + """update version string in passlib dist""" path = os.path.join(base_dir, "passlib", "__init__.py") with open(path) as fh: input = fh.read() diff --git a/passlib/apache.py b/passlib/apache.py index 8497ca2e..516deeb8 100644 --- a/passlib/apache.py +++ b/passlib/apache.py @@ -188,7 +188,7 @@ class _CommonFile(object): @property def mtime(self): - "modify time when last loaded (if bound to a local file)" + """modify time when last loaded (if bound to a local file)""" return self._mtime #=================================================================== @@ -240,13 +240,13 @@ class _CommonFile(object): return True def load_string(self, data): - "Load state from unicode or bytes string, replacing current state" + """Load state from unicode or bytes string, replacing current state""" data = to_bytes(data, self.encoding, "data") self._mtime = 0 self._load_lines(BytesIO(data)) def _load_lines(self, lines): - "load from sequence of lists" + """load from sequence of lists""" # XXX: found reference that "#" comment lines may be supported by # htpasswd, should verify this, and figure out how to handle them. # if true, this would also affect what can be stored in user field. @@ -262,15 +262,15 @@ class _CommonFile(object): if key not in records: records[key] = value - def _parse_record(cls, record, lineno): # pragma: no cover - abstract method - "parse line of file into (key, value) pair" + def _parse_record(self, record, lineno): # pragma: no cover - abstract method + """parse line of file into (key, value) pair""" raise NotImplementedError("should be implemented in subclass") #=================================================================== # saving #=================================================================== def _autosave(self): - "subclass helper to call save() after any changes" + """subclass helper to call save() after any changes""" if self.autosave and self._path: self.save() @@ -289,26 +289,26 @@ class _CommonFile(object): self.__class__.__name__) def to_string(self): - "Export current state as a string of bytes" + """Export current state as a string of bytes""" return join_bytes(self._iter_lines()) def _iter_lines(self): - "iterator yielding lines of database" + """iterator yielding lines of database""" return (self._render_record(key,value) for key,value in iteritems(self._records)) - def _render_record(cls, key, value): # pragma: no cover - abstract method - "given key/value pair, encode as line of file" + def _render_record(self, key, value): # pragma: no cover - abstract method + """given key/value pair, encode as line of file""" raise NotImplementedError("should be implemented in subclass") #=================================================================== # field encoding #=================================================================== def _encode_user(self, user): - "user-specific wrapper for _encode_field()" + """user-specific wrapper for _encode_field()""" return self._encode_field(user, "user") def _encode_realm(self, realm): # pragma: no cover - abstract method - "realm-specific wrapper for _encode_field()" + """realm-specific wrapper for _encode_field()""" return self._encode_field(realm, "realm") def _encode_field(self, value, param="field"): @@ -370,19 +370,39 @@ class _CommonFile(object): # htpasswd editing #============================================================================= -# FIXME: apr_md5_crypt technically the default only for windows, netware and tpf. -# TODO: find out if htpasswd's "crypt" mode is a crypt() *call* or just des_crypt implementation. -# if the former, we can support anything supported by passlib.hosts.host_context, -# allowing more secure hashes than apr_md5_crypt to be used. -# could perhaps add this behavior as an option to the constructor. +#: default CryptContext used by HtpasswdFile +# TODO: update this to support everything in host_context (where available), +# and note in the documentation that the default is no longer guaranteed to be portable +# across platforms. # c.f. http://httpd.apache.org/docs/2.2/programs/htpasswd.html htpasswd_context = CryptContext([ - "apr_md5_crypt", # man page notes supported everywhere, default on Windows, Netware, TPF - "des_crypt", # man page notes server does NOT support this on Windows, Netware, TPF - "ldap_sha1", # man page notes only for transitioning <-> ldap - "plaintext" # man page notes server ONLY supports this on Windows, Netware, TPF + # man page notes supported everywhere; is default on Windows, Netware, TPF + "apr_md5_crypt", + + # [added in passlib 1.6.3] + # apache requires host crypt() support; but can generate natively + # (as of https://bz.apache.org/bugzilla/show_bug.cgi?id=49288) + "bcrypt", + + # [added in passlib 1.6.3] + # apache requires host crypt() support; and can't generate natively + "sha256_crypt", + "sha512_crypt", + + # man page notes apache does NOT support this on Windows, Netware, TPF + "des_crypt", + + # man page notes intended only for transitioning htpasswd <-> ldap + "ldap_sha1", + + # man page notes apache ONLY supports this on Windows, Netware, TPF + "plaintext" ]) +#: scheme that will be used when 'portable' is requested. +portable_scheme = "apr_md5_crypt" + + class HtpasswdFile(_CommonFile): """class for reading & writing Htpasswd files. @@ -444,13 +464,23 @@ class HtpasswdFile(_CommonFile): :type default_scheme: str :param default_scheme: Optionally specify default scheme to use when encoding new passwords. - Must be one of ``"apr_md5_crypt"``, ``"des_crypt"``, ``"ldap_sha1"``, - ``"plaintext"``. It defaults to ``"apr_md5_crypt"``. + May be any of ``"bcrypt"``, ``"sha256_crypt"``, ``"apr_md5_crypt"``, ``"des_crypt"``, + ``"ldap_sha1"``, ``"plaintext"``. It defaults to ``"apr_md5_crypt"``. + + .. note:: + + Some hashes are only supported by apache / htpasswd on certain operating systems + (e.g. bcrypt on BSD, sha256_crypt on linux). To get the strongest + hash that's still portable, applications can specify ``default_scheme="portable"``. .. versionadded:: 1.6 This keyword was previously named ``default``. That alias has been deprecated, and will be removed in Passlib 1.8. + .. versionchanged:: 1.6.3 + + Added support for ``"bcrypt"``, ``"sha256_crypt"``, and ``"portable"``. + :type context: :class:`~passlib.context.CryptContext` :param context: :class:`!CryptContext` instance used to encrypt @@ -464,7 +494,7 @@ class HtpasswdFile(_CommonFile): This option may be used to add support for non-standard hash formats to an htpasswd file. However, the resulting file - will probably not be usuable by another application, + will probably not be usable by another application, and particularly not by Apache. :param autoload: @@ -546,6 +576,8 @@ class HtpasswdFile(_CommonFile): DeprecationWarning, stacklevel=2) default_scheme = kwds.pop("default") if default_scheme: + if default_scheme == "portable": + default_scheme = portable_scheme context = context.copy(default=default_scheme) self.context = context super(HtpasswdFile, self).__init__(path, **kwds) @@ -566,7 +598,7 @@ class HtpasswdFile(_CommonFile): #=================================================================== def users(self): - "Return list of all users in database" + """Return list of all users in database""" return [self._decode_field(user) for user in self._records] ##def has_user(self, user): @@ -605,7 +637,7 @@ class HtpasswdFile(_CommonFile): @deprecated_method(deprecated="1.6", removed="1.8", replacement="set_password") def update(self, user, password): - "set password for user" + """set password for user""" return self.set_password(user, password) def get_hash(self, user): @@ -624,7 +656,7 @@ class HtpasswdFile(_CommonFile): @deprecated_method(deprecated="1.6", removed="1.8", replacement="get_hash") def find(self, user): - "return hash for user" + """return hash for user""" return self.get_hash(user) # XXX: rename to something more explicit, like delete_user()? @@ -673,7 +705,7 @@ class HtpasswdFile(_CommonFile): @deprecated_method(deprecated="1.6", removed="1.8", replacement="check_password") def verify(self, user, password): - "verify password for user" + """verify password for user""" return self.check_password(user, password) #=================================================================== @@ -931,7 +963,7 @@ class HtdigestFile(_CommonFile): @deprecated_method(deprecated="1.6", removed="1.8", replacement="set_password") def update(self, user, realm, password): - "set password for user" + """set password for user""" return self.set_password(user, realm, password) # XXX: rename to something more explicit, like get_hash()? @@ -957,7 +989,7 @@ class HtdigestFile(_CommonFile): @deprecated_method(deprecated="1.6", removed="1.8", replacement="get_hash") def find(self, user, realm): - "return hash for user" + """return hash for user""" return self.get_hash(user, realm) # XXX: rename to something more explicit, like delete_user()? @@ -1025,7 +1057,7 @@ class HtdigestFile(_CommonFile): @deprecated_method(deprecated="1.6", removed="1.8", replacement="check_password") def verify(self, user, realm, password): - "verify password for user" + """verify password for user""" return self.check_password(user, realm, password) #=================================================================== diff --git a/passlib/apps.py b/passlib/apps.py index 96308a4d..ceb5e1ff 100644 --- a/passlib/apps.py +++ b/passlib/apps.py @@ -77,12 +77,12 @@ custom_app_context = LazyCryptContext( all__vary_rounds = 0.1, # set a good starting point for rounds selection - sha512_crypt__min_rounds = 60000, - sha256_crypt__min_rounds = 80000, + sha512_crypt__min_rounds = 535000, + sha256_crypt__min_rounds = 535000, # if the admin user category is selected, make a much stronger hash, - admin__sha512_crypt__min_rounds = 120000, - admin__sha256_crypt__min_rounds = 160000, + admin__sha512_crypt__min_rounds = 1024000, + admin__sha256_crypt__min_rounds = 1024000, ) #============================================================================= @@ -132,7 +132,7 @@ def _iter_ldap_crypt_schemes(): return ('ldap_' + name for name in unix_crypt_schemes) def _iter_ldap_schemes(): - "helper which iterates over supported std ldap schemes" + """helper which iterates over supported std ldap schemes""" return chain(std_ldap_schemes, _iter_ldap_crypt_schemes()) ldap_context = LazyCryptContext(_iter_ldap_schemes()) @@ -159,7 +159,7 @@ postgres_context = LazyCryptContext(["postgres_md5"]) # phpass & variants #============================================================================= def _create_phpass_policy(**kwds): - "helper to choose default alg based on bcrypt availability" + """helper to choose default alg based on bcrypt availability""" kwds['default'] = 'bcrypt' if hash.bcrypt.has_backend() else 'phpass' return kwds diff --git a/passlib/context.py b/passlib/context.py index 4f7ec130..fd228ffd 100644 --- a/passlib/context.py +++ b/passlib/context.py @@ -22,7 +22,7 @@ from passlib.utils import rng, tick, to_bytes, deprecated_method, \ to_unicode, splitcomma from passlib.utils.compat import bytes, iteritems, num_types, \ PY2, PY3, PY_MIN_32, unicode, SafeConfigParser, \ - NativeStringIO, BytesIO, base_string_types + NativeStringIO, BytesIO, base_string_types, native_string_types # local __all__ = [ 'CryptContext', @@ -40,7 +40,7 @@ _UNSET = object() # TODO: merge the following helpers into _CryptConfig def _coerce_vary_rounds(value): - "parse vary_rounds string to percent as [0,1) float, or integer" + """parse vary_rounds string to percent as [0,1) float, or integer""" if value.endswith("%"): # XXX: deprecate this in favor of raw float? return float(value.rstrip("%"))*.01 @@ -77,7 +77,7 @@ class CryptPolicy(object): """ .. deprecated:: 1.6 This class has been deprecated, and will be removed in Passlib 1.8. - All of it's functionality has been rolled into :class:`CryptContext`. + All of its functionality has been rolled into :class:`CryptContext`. This class previously stored the configuration options for the CryptContext class. In the interest of interface simplification, @@ -642,7 +642,7 @@ class _CryptRecord(object): @property def _errprefix(self): - "string used to identify record in error messages" + """string used to identify record in error messages""" handler = self.handler category = self.category if category: @@ -657,7 +657,7 @@ class _CryptRecord(object): # rounds generation & limits - used by encrypt & deprecation code #=================================================================== def _init_rounds_options(self, mn, mx, df, vr): - "parse options and compile efficient generate_rounds function" + """parse options and compile efficient generate_rounds function""" #---------------------------------------------------- # extract hard limits from handler itself #---------------------------------------------------- @@ -669,7 +669,7 @@ class _CryptRecord(object): hmx = getattr(handler, "max_rounds", None) def check_against_handler(value, name): - "issue warning if value outside handler limits" + """issue warning if value outside handler limits""" if hmn is not None and value < hmn: warn("%s: %s value is below handler minimum %d: %d" % (self._errprefix, name, hmn, value), PasslibConfigWarning) @@ -721,7 +721,7 @@ class _CryptRecord(object): # is calculated, so that proportion vr values are scaled against # the effective default. def clip(value): - "clip value to intersection of policy + handler limits" + """clip value to intersection of policy + handler limits""" if mn is not None and value < mn: value = mn if hmn is not None and value < hmn: @@ -799,7 +799,7 @@ class _CryptRecord(object): # encrypt() / genconfig() #=================================================================== def _init_encrypt_and_genconfig(self): - "initialize genconfig/encrypt wrapper methods" + """initialize genconfig/encrypt wrapper methods""" settings = self.settings handler = self.handler @@ -817,17 +817,17 @@ class _CryptRecord(object): self.encrypt = handler.encrypt def genconfig(self, **kwds): - "wrapper for handler.genconfig() which adds custom settings/rounds" + """wrapper for handler.genconfig() which adds custom settings/rounds""" self._prepare_settings(kwds) return self.handler.genconfig(**kwds) def encrypt(self, secret, **kwds): - "wrapper for handler.encrypt() which adds custom settings/rounds" + """wrapper for handler.encrypt() which adds custom settings/rounds""" self._prepare_settings(kwds) return self.handler.encrypt(secret, **kwds) def _prepare_settings(self, kwds): - "add default values to settings for encrypt & genconfig" + """add default values to settings for encrypt & genconfig""" # load in default values for any settings if kwds: for k,v in iteritems(self.settings): @@ -869,7 +869,7 @@ class _CryptRecord(object): # of handler.verify() def _init_verify(self, mvt): - "initialize verify() wrapper - implements min_verify_time" + """initialize verify() wrapper - implements min_verify_time""" if mvt: assert isinstance(mvt, (int,float)) and mvt > 0, "CryptPolicy should catch this" self._min_verify_time = mvt @@ -878,7 +878,7 @@ class _CryptRecord(object): self.verify = self.handler.verify def verify(self, secret, hash, **context): - "verify helper - adds min_verify_time delay" + """verify helper - adds min_verify_time delay""" mvt = self._min_verify_time assert mvt > 0, "wrapper should have been replaced for mvt=0" start = tick() @@ -974,7 +974,7 @@ class _CryptConfig(object): """parses, validates, and stores CryptContext config this is a helper used internally by CryptContext to handle - parsing, validation, and serialization of it's config options. + parsing, validation, and serialization of its config options. split out from the main class, but not made public since that just complicates interface too much (c.f. CryptPolicy) @@ -1024,7 +1024,7 @@ class _CryptConfig(object): """initialize .handlers and .schemes attributes""" handlers = [] schemes = [] - if isinstance(data, str): + if isinstance(data, native_string_types): data = splitcomma(data) for elem in data or (): # resolve elem -> handler & scheme @@ -1032,7 +1032,7 @@ class _CryptConfig(object): handler = elem scheme = handler.name _validate_handler_name(scheme) - elif isinstance(elem, str): + elif isinstance(elem, native_string_types): handler = get_crypt_handler(elem) scheme = handler.name else: @@ -1120,7 +1120,7 @@ class _CryptConfig(object): raise KeyError("%r option not allowed in CryptContext " "configuration" % (key,)) # coerce strings for certain fields (e.g. min_rounds uses ints) - if isinstance(value, str): + if isinstance(value, native_string_types): func = _coerce_scheme_options.get(key) if func: value = func(value) @@ -1131,12 +1131,12 @@ class _CryptConfig(object): if key == "default": if hasattr(value, "name"): value = value.name - elif not isinstance(value, str): + elif not isinstance(value, native_string_types): raise ExpectedTypeError(value, "str", "default") if schemes and value not in schemes: raise KeyError("default scheme not found in policy") elif key == "deprecated": - if isinstance(value, str): + if isinstance(value, native_string_types): value = splitcomma(value) elif not isinstance(value, (list,tuple)): raise ExpectedTypeError(value, "str or seq", "deprecated") @@ -1147,7 +1147,7 @@ class _CryptConfig(object): elif schemes: # make sure list of deprecated schemes is subset of configured schemes for scheme in value: - if not isinstance(scheme, str): + if not isinstance(scheme, native_string_types): raise ExpectedTypeError(value, "str", "deprecated element") if scheme not in schemes: raise KeyError("deprecated scheme not found " @@ -1167,7 +1167,8 @@ class _CryptConfig(object): #--------------------------------------------------------------- def get_context_optionmap(self, key, _default={}): """return dict mapping category->value for specific context option. - (treat retval as readonly). + + .. warning:: treat return value as readonly! """ return self._context_options.get(key, _default) @@ -1195,7 +1196,8 @@ class _CryptConfig(object): #--------------------------------------------------------------- def _get_scheme_optionmap(self, scheme, category, default={}): """return all options for (scheme,category) combination - (treat return as readonly) + + .. warning:: treat return value as readonly! """ try: return self._scheme_options[scheme][category] @@ -1281,7 +1283,7 @@ class _CryptConfig(object): "cannot be deprecated" % cat) def default_scheme(self, category): - "return default scheme for specific category" + """return default scheme for specific category""" defaults = self._default_schemes try: return defaults[category] @@ -1293,7 +1295,7 @@ class _CryptConfig(object): return defaults[None] def is_deprecated_with_flag(self, scheme, category): - "is scheme deprecated under particular category?" + """is scheme deprecated under particular category?""" depmap = self.get_context_optionmap("deprecated") def test(cat): source = depmap.get(cat, depmap.get(None)) @@ -1339,7 +1341,7 @@ class _CryptConfig(object): """return composite dict of options for given scheme + category. this is currently a private method, though some variant - of it's output may eventually be made public. + of its output may eventually be made public. given a scheme & category, it returns two things: a set of all the keyword options to pass to the _CryptRecord constructor, @@ -1370,7 +1372,7 @@ class _CryptConfig(object): return kwds, has_cat_options def get_record(self, scheme, category): - "return record for specific scheme & category (cached)" + """return record for specific scheme & category (cached)""" # NOTE: this is part of the critical path shared by # all of CryptContext's PasswordHash methods, # hence all the caching and error checking. @@ -1382,12 +1384,12 @@ class _CryptConfig(object): pass # type check - if category is not None and not isinstance(category, str): + if category is not None and not isinstance(category, native_string_types): if PY2 and isinstance(category, unicode): # for compatibility with unicode-centric py2 apps return self.get_record(scheme, category.encode("utf-8")) raise ExpectedTypeError(category, "str or None", "category") - if scheme is not None and not isinstance(scheme, str): + if scheme is not None and not isinstance(scheme, native_string_types): raise ExpectedTypeError(scheme, "str or None", "scheme") # if scheme=None, @@ -1550,7 +1552,7 @@ class CryptContext(object): #=================================================================== @classmethod def _norm_source(cls, source): - "internal helper - accepts string, dict, or context" + """internal helper - accepts string, dict, or context""" if isinstance(source, dict): return cls(**source) elif isinstance(source, cls): @@ -1669,7 +1671,7 @@ class CryptContext(object): return other def replace(self, **kwds): - "deprecated alias of :meth:`copy`" + """deprecated alias of :meth:`copy`""" warn("CryptContext().replace() has been deprecated in Passlib 1.6, " "and will be removed in Passlib 1.8, " "it has been renamed to CryptContext().copy()", @@ -1752,7 +1754,7 @@ class CryptContext(object): #=================================================================== @staticmethod def _parse_ini_stream(stream, section, filename): - "helper read INI from stream, extract passlib section as dict" + """helper read INI from stream, extract passlib section as dict""" # NOTE: this expects a unicode stream under py3, # and a utf-8 bytes stream under py2, # allowing the resulting dict to always use native strings. @@ -1769,7 +1771,7 @@ class CryptContext(object): This function is a wrapper for :meth:`load` which loads a configuration string from the local file *path*, - instead of an in-memory source. It's behavior and options + instead of an in-memory source. Its behavior and options are otherwise identical to :meth:`!load` when provided with an INI-formatted string. @@ -1812,7 +1814,7 @@ class CryptContext(object): * another :class:`!CryptContext` object. - this will export a snapshot of it's configuration + this will export a snapshot of its configuration using :meth:`to_dict`. :type update: bool @@ -1900,7 +1902,7 @@ class CryptContext(object): def _parse_config_key(ckey): """helper used to parse ``cat__scheme__option`` keys into a tuple""" # split string into 1-3 parts - assert isinstance(ckey, str) + assert isinstance(ckey, native_string_types) parts = ckey.replace(".","__").split("__") count = len(parts) if count == 1: @@ -2019,7 +2021,7 @@ class CryptContext(object): # and then decide whether to expose ability as deprecated_schemes(), # is_deprecated(), or a just add a schemes(deprecated=True) flag. def _is_deprecated_scheme(self, scheme, category=None): - "helper used by unittests to check if scheme is deprecated" + """helper used by unittests to check if scheme is deprecated""" return self._get_record(scheme, category).deprecated def default_scheme(self, category=None, resolve=False): @@ -2092,7 +2094,7 @@ class CryptContext(object): "CryptContext instance") def _get_unregistered_handlers(self): - "check if any handlers in this context aren't in the global registry" + """check if any handlers in this context aren't in the global registry""" return tuple(handler for handler in self._config.handlers if not _is_handler_registered(handler)) @@ -2101,7 +2103,7 @@ class CryptContext(object): #=================================================================== @staticmethod def _render_config_key(key): - "convert 3-part config key to single string" + """convert 3-part config key to single string""" cat, scheme, option = key if cat: return "%s__%s__%s" % (cat, scheme or "context", option) @@ -2112,7 +2114,7 @@ class CryptContext(object): @staticmethod def _render_ini_value(key, value): - "render value to string suitable for INI file" + """render value to string suitable for INI file""" # convert lists to comma separated lists # (mainly 'schemes' & 'deprecated') if isinstance(value, (list,tuple)): @@ -2125,7 +2127,7 @@ class CryptContext(object): else: value = str(value) - assert isinstance(value, str), \ + assert isinstance(value, native_string_types), \ "expected string for key: %r %r" % (key, value) # escape any percent signs. @@ -2167,7 +2169,7 @@ class CryptContext(object): for key, value in self._config.iter_config(resolve)) def _write_to_parser(self, parser, section): - "helper to write to ConfigParser instance" + """helper to write to ConfigParser instance""" render_key = self._render_config_key render_value = self._render_ini_value parser.add_section(section) @@ -2240,7 +2242,7 @@ class CryptContext(object): # which are optimized for the specific (scheme,category) configuration. # # The record objects are cached inside the _CryptConfig - # instance stored in self._config, and are retreived + # instance stored in self._config, and are retrieved # via get_record() and identify_record(). # # _get_record() and _identify_record() are references @@ -2248,7 +2250,7 @@ class CryptContext(object): # stored in CryptContext for speed. def _get_or_identify_record(self, hash, scheme=None, category=None): - "return record based on scheme, or failing that, by identifying hash" + """return record based on scheme, or failing that, by identifying hash""" if scheme: if not isinstance(hash, base_string_types): raise ExpectedStringError(hash, "hash") @@ -2354,7 +2356,7 @@ class CryptContext(object): :param \*\*settings: All additional keywords are passed to the appropriate handler, - and should match it's :attr:`~passlib.ifc.PasswordHash.setting_kwds`. + and should match its :attr:`~passlib.ifc.PasswordHash.setting_kwds`. :returns: A configuration string suitable for passing to :meth:`~CryptContext.genhash`, @@ -2398,7 +2400,7 @@ class CryptContext(object): :param \*\*kwds: All additional keywords are passed to the appropriate handler, - and should match it's :attr:`~passlib.ifc.PasswordHash.context_kwds`. + and should match its :attr:`~passlib.ifc.PasswordHash.context_kwds`. :returns: The secret as encoded by the specified algorithm and options. @@ -2526,7 +2528,7 @@ class CryptContext(object): :param \*\*kwds: All additional keywords are passed to the appropriate handler, - and should match it's :attr:`~passlib.ifc.PasswordHash.context_kwds`. + and should match its :attr:`~passlib.ifc.PasswordHash.context_kwds`. :returns: ``True`` if the password matched the hash, else ``False``. @@ -2627,9 +2629,9 @@ class LazyCryptContext(CryptContext): """CryptContext subclass which doesn't load handlers until needed. This is a subclass of CryptContext which takes in a set of arguments - exactly like CryptContext, but won't load any handlers - (or even parse it's arguments) until - the first time one of it's methods is accessed. + exactly like CryptContext, but won't import any handlers + (or even parse its arguments) until + the first time one of its methods is accessed. :arg schemes: The first positional argument can be a list of schemes, or omitted, @@ -2666,6 +2668,12 @@ class LazyCryptContext(CryptContext): but using :func:`!onload()` to provide dynamic configuration at *application-run* time. + .. note:: + This class is only useful if you're referencing handler objects by name, + and don't want them imported until runtime. If you want to have the config + validated before your application runs, or are passing in already-imported + handler instances, you should use :class:`CryptContext` instead. + .. versionadded:: 1.4 """ _lazy_kwds = None diff --git a/passlib/exc.py b/passlib/exc.py index 8d872a71..b5a12759 100644 --- a/passlib/exc.py +++ b/passlib/exc.py @@ -39,6 +39,15 @@ class PasswordSizeError(ValueError): # this also prevents a glibc crypt segfault issue, detailed here ... # http://www.openwall.com/lists/oss-security/2011/11/15/1 + +class PasslibSecurityError(RuntimeError): + """ + Error raised if critical security issue is detected + (e.g. an attempt is made to use a vulnerable version of a bcrypt backend). + + .. versionadded:: 1.6.3 + """ + #============================================================================= # warnings #============================================================================= @@ -86,7 +95,7 @@ class PasslibRuntimeWarning(PasslibWarning): """Warning issued when something unexpected happens during runtime. The fact that it's a warning instead of an error means Passlib - was able to correct for the issue, but that it's anonmalous enough + was able to correct for the issue, but that it's anomalous enough that the developers would love to hear under what conditions it occurred. .. versionadded:: 1.6 @@ -116,7 +125,7 @@ def _get_name(handler): # generic helpers #------------------------------------------------------------------------ def type_name(value): - "return pretty-printed string containing name of value's type" + """return pretty-printed string containing name of value's type""" cls = value.__class__ if cls.__module__ and cls.__module__ not in ["__builtin__", "builtins"]: return "%s.%s" % (cls.__module__, cls.__name__) @@ -126,26 +135,26 @@ def type_name(value): return cls.__name__ def ExpectedTypeError(value, expected, param): - "error message when param was supposed to be one type, but found another" + """error message when param was supposed to be one type, but found another""" # NOTE: value is never displayed, since it may sometimes be a password. name = type_name(value) return TypeError("%s must be %s, not %s" % (param, expected, name)) def ExpectedStringError(value, param): - "error message when param was supposed to be unicode or bytes" + """error message when param was supposed to be unicode or bytes""" return ExpectedTypeError(value, "unicode or bytes", param) #------------------------------------------------------------------------ # encrypt/verify parameter errors #------------------------------------------------------------------------ def MissingDigestError(handler=None): - "raised when verify() method gets passed config string instead of hash" + """raised when verify() method gets passed config string instead of hash""" name = _get_name(handler) return ValueError("expected %s hash, got %s config string instead" % (name, name)) def NullPasswordError(handler=None): - "raised by OS crypt() supporting hashes, which forbid NULLs in password" + """raised by OS crypt() supporting hashes, which forbid NULLs in password""" name = _get_name(handler) return ValueError("%s does not allow NULL bytes in password" % name) @@ -153,25 +162,25 @@ def NullPasswordError(handler=None): # errors when parsing hashes #------------------------------------------------------------------------ def InvalidHashError(handler=None): - "error raised if unrecognized hash provided to handler" + """error raised if unrecognized hash provided to handler""" return ValueError("not a valid %s hash" % _get_name(handler)) def MalformedHashError(handler=None, reason=None): - "error raised if recognized-but-malformed hash provided to handler" + """error raised if recognized-but-malformed hash provided to handler""" text = "malformed %s hash" % _get_name(handler) if reason: text = "%s (%s)" % (text, reason) return ValueError(text) def ZeroPaddedRoundsError(handler=None): - "error raised if hash was recognized but contained zero-padded rounds field" + """error raised if hash was recognized but contained zero-padded rounds field""" return MalformedHashError(handler, "zero-padded rounds") #------------------------------------------------------------------------ # settings / hash component errors #------------------------------------------------------------------------ def ChecksumSizeError(handler, raw=False): - "error raised if hash was recognized, but checksum was wrong size" + """error raised if hash was recognized, but checksum was wrong size""" # TODO: if handler.use_defaults is set, this came from app-provided value, # not from parsing a hash string, might want different error msg. checksum_size = handler.checksum_size diff --git a/passlib/ext/django/models.py b/passlib/ext/django/models.py index 6c4d245a..f82e3994 100644 --- a/passlib/ext/django/models.py +++ b/passlib/ext/django/models.py @@ -61,7 +61,7 @@ def _apply_patch(): FORMS_PATH = "django.contrib.auth.forms" # - # import UNUSUABLE_PASSWORD and is_password_usuable() helpers + # import UNUSABLE_PASSWORD and is_password_usable() helpers # (providing stubs for older django versions) # if VERSION < (1,4): @@ -72,7 +72,7 @@ def _apply_patch(): from django.contrib.auth.models import UNUSABLE_PASSWORD def is_password_usable(encoded): - return encoded is not None and encoded != UNUSABLE_PASSWORD + return (encoded is not None and encoded != UNUSABLE_PASSWORD) def is_valid_secret(secret): return secret is not None @@ -128,7 +128,7 @@ def _apply_patch(): # @_manager.monkeypatch(USER_PATH) def set_password(user, password): - "passlib replacement for User.set_password()" + """passlib replacement for User.set_password()""" if is_valid_secret(password): # NOTE: pulls _get_category from module globals cat = _get_category(user) @@ -138,7 +138,7 @@ def _apply_patch(): @_manager.monkeypatch(USER_PATH) def check_password(user, password): - "passlib replacement for User.check_password()" + """passlib replacement for User.check_password()""" hash = user.password if not is_valid_secret(password) or not is_password_usable(hash): return False @@ -160,8 +160,8 @@ def _apply_patch(): @_manager.monkeypatch(HASHERS_PATH, enable=has_hashers) @_manager.monkeypatch(MODELS_PATH) def check_password(password, encoded, setter=None, preferred="default"): - "passlib replacement for check_password()" - # XXX: this currently ignores "preferred" keyword, since it's purpose + """passlib replacement for check_password()""" + # XXX: this currently ignores "preferred" keyword, since its purpose # was for hash migration, and that's handled by the context. if not is_valid_secret(password) or not is_password_usable(encoded): return False @@ -178,7 +178,7 @@ def _apply_patch(): @_manager.monkeypatch(HASHERS_PATH) @_manager.monkeypatch(MODELS_PATH) def make_password(password, salt=None, hasher="default"): - "passlib replacement for make_password()" + """passlib replacement for make_password()""" if not is_valid_secret(password): return make_unusable_password() if hasher == "default": @@ -187,17 +187,22 @@ def _apply_patch(): scheme = hasher_to_passlib_name(hasher) kwds = dict(scheme=scheme) handler = password_context.handler(scheme) - # NOTE: django make specify an empty string for the salt, - # even if scheme doesn't accept a salt. we omit keyword - # in that case. - if salt is not None and (salt or 'salt' in handler.setting_kwds): - kwds['salt'] = salt + if "salt" in handler.setting_kwds: + if hasher.startswith("unsalted_"): + # Django 1.4.6+ uses a separate 'unsalted_sha1' hasher for "sha1$$digest", + # but passlib just reuses it's "sha1" handler ("sha1$salt$digest"). To make + # this work, have to explicitly tell the sha1 handler to use an empty salt. + kwds['salt'] = '' + elif salt: + # Django make_password() autogenerates a salt if salt is bool False (None / ''), + # so we only pass the keyword on if there's actually a fixed salt. + kwds['salt'] = salt return password_context.encrypt(password, **kwds) @_manager.monkeypatch(HASHERS_PATH) @_manager.monkeypatch(FORMS_PATH) def get_hasher(algorithm="default"): - "passlib replacement for get_hasher()" + """passlib replacement for get_hasher()""" if algorithm == "default": scheme = None else: @@ -214,7 +219,7 @@ def _apply_patch(): @_manager.monkeypatch(HASHERS_PATH) @_manager.monkeypatch(FORMS_PATH) def identify_hasher(encoded): - "passlib helper to identify hasher from encoded password" + """passlib helper to identify hasher from encoded password""" handler = password_context.identify(encoded, resolve=True, required=True) algorithm = None diff --git a/passlib/ext/django/utils.py b/passlib/ext/django/utils.py index 161212b4..863f11e1 100644 --- a/passlib/ext/django/utils.py +++ b/passlib/ext/django/utils.py @@ -123,14 +123,14 @@ DJANGO_PASSLIB_PREFIX = "django_" _other_django_hashes = ["hex_md5"] def passlib_to_hasher_name(passlib_name): - "convert passlib handler name -> hasher name" + """convert passlib handler name -> hasher name""" handler = get_crypt_handler(passlib_name) if hasattr(handler, "django_name"): return handler.django_name return PASSLIB_HASHER_PREFIX + passlib_name def hasher_to_passlib_name(hasher_name): - "convert hasher name -> passlib handler name" + """convert hasher name -> passlib handler name""" if hasher_name.startswith(PASSLIB_HASHER_PREFIX): return hasher_name[len(PASSLIB_HASHER_PREFIX):] if hasher_name == "unsalted_sha1": @@ -186,7 +186,9 @@ class _HasherWrapper(object): _translate_kwds = dict(checksum="hash", rounds="iterations") def safe_summary(self, encoded): - from django.contrib.auth.hashers import mask_hash, _, SortedDict + from django.contrib.auth.hashers import mask_hash + from django.utils.translation import ugettext_noop as _ + from django.utils.datastructures import SortedDict handler = self.passlib_handler items = [ # since this is user-facing, we're reporting passlib's name, @@ -252,14 +254,14 @@ def get_passlib_hasher(handler, algorithm=None): return hasher def _get_hasher(algorithm): - "wrapper to call django.contrib.auth.hashers:get_hasher()" + """wrapper to call django.contrib.auth.hashers:get_hasher()""" import sys module = sys.modules.get("passlib.ext.django.models") if module is None: # we haven't patched django, so just import directly from django.contrib.auth.hashers import get_hasher else: - # we've patched django, so have to use patch manager to retreive + # we've patched django, so have to use patch manager to retrieve # original get_hasher() function... get_hasher = module._manager.getorig("django.contrib.auth.hashers:get_hasher") return get_hasher(algorithm) @@ -364,7 +366,7 @@ def _get_hasher(algorithm): _UNSET = object() class _PatchManager(object): - "helper to manage monkeypatches and run sanity checks" + """helper to manage monkeypatches and run sanity checks""" # NOTE: this could easily use a dict interface, # but keeping it distinct to make clear that it's not a dict, @@ -383,7 +385,7 @@ class _PatchManager(object): __bool__ = __nonzero__ = lambda self: bool(self._state) def _import_path(self, path): - "retrieve obj and final attribute name from resource path" + """retrieve obj and final attribute name from resource path""" name, attr = path.split(":") obj = __import__(name, fromlist=[attr], level=0) while '.' in attr: @@ -393,7 +395,7 @@ class _PatchManager(object): @staticmethod def _is_same_value(left, right): - "check if two values are the same (stripping method wrappers, etc)" + """check if two values are the same (stripping method wrappers, etc)""" return get_method_function(left) == get_method_function(right) #=================================================================== @@ -404,11 +406,11 @@ class _PatchManager(object): return getattr(obj, attr, default) def get(self, path, default=None): - "return current value for path" + """return current value for path""" return self._get_path(path, default) def getorig(self, path, default=None): - "return original (unpatched) value for path" + """return original (unpatched) value for path""" try: value, _= self._state[path] except KeyError: @@ -439,7 +441,7 @@ class _PatchManager(object): setattr(obj, attr, value) def patch(self, path, value): - "monkeypatch object+attr at to have , stores original" + """monkeypatch object+attr at to have , stores original""" assert value != _UNSET current = self._get_path(path) try: @@ -461,7 +463,7 @@ class _PatchManager(object): ## self.patch(path, value) def monkeypatch(self, parent, name=None, enable=True): - "function decorator which patches function of same name in " + """function decorator which patches function of same name in """ def builder(func): if enable: sep = "." if ":" in parent else ":" diff --git a/passlib/handlers/bcrypt.py b/passlib/handlers/bcrypt.py index 42f0eca1..b0f4de0e 100644 --- a/passlib/handlers/bcrypt.py +++ b/passlib/handlers/bcrypt.py @@ -24,11 +24,11 @@ try: except ImportError: # pragma: no cover _bcrypt = None try: - from bcryptor.engine import Engine as bcryptor_engine + import bcryptor as _bcryptor except ImportError: # pragma: no cover - bcryptor_engine = None + _bcryptor = None # pkg -from passlib.exc import PasslibHashWarning +from passlib.exc import PasslibHashWarning, PasslibSecurityWarning, PasslibSecurityError from passlib.utils import bcrypt64, safe_crypt, repeat_string, to_bytes, \ classproperty, rng, getrandstr, test_crypt, to_unicode from passlib.utils.compat import bytes, b, u, uascii_to_str, unicode, str_to_uascii @@ -53,8 +53,43 @@ IDENT_2 = u("$2$") IDENT_2A = u("$2a$") IDENT_2X = u("$2x$") IDENT_2Y = u("$2y$") +IDENT_2B = u("$2b$") _BNULL = b('\x00') +def _detect_pybcrypt(): + """ + internal helper which tries to distinguish pybcrypt vs bcrypt. + + :returns: + True if cext-based py-bcrypt, + False if ffi-based bcrypt, + None if 'bcrypt' module not found. + + .. versionchanged:: 1.6.3 + + Now assuming bcrypt installed, unless py-bcrypt explicitly detected. + Previous releases assumed py-bcrypt by default. + + Making this change since py-bcrypt is (apparently) unmaintained and static, + whereas bcrypt is being actively maintained, and it's internal structure may shift. + """ + # NOTE: this is also used by the unittests. + + # check for module. + try: + import bcrypt + except ImportError: + return None + + # py-bcrypt has a "._bcrypt.__version__" attribute (confirmed for v0.1 - 0.4), + # which bcrypt lacks (confirmed for v1.0 - 2.0) + # "._bcrypt" alone isn't sufficient, since bcrypt 2.0 now has that attribute. + try: + from bcrypt._bcrypt import __version__ + except ImportError: + return False + return True + #============================================================================= # handler #============================================================================= @@ -85,9 +120,11 @@ class bcrypt(uh.HasManyIdents, uh.HasRounds, uh.HasSalt, uh.HasManyBackends, uh. If specified, it must be one of the following: * ``"2"`` - the first revision of BCrypt, which suffers from a minor security flaw and is generally not used anymore. - * ``"2a"`` - latest revision of the official BCrypt algorithm, and the current default. + * ``"2a"`` - some implementations suffered from a very rare security flaw. + current default for compatibility purposes. * ``"2y"`` - format specific to the *crypt_blowfish* BCrypt implementation, identical to ``"2a"`` in all but name. + * ``"2b"`` - latest revision of the official BCrypt algorithm (will be default in Passlib 1.7). :type relaxed: bool :param relaxed: @@ -107,6 +144,10 @@ class bcrypt(uh.HasManyIdents, uh.HasRounds, uh.HasSalt, uh.HasManyBackends, uh. .. versionchanged:: 1.6 Added a pure-python backend. + + .. versionchanged:: 1.6.3 + + Added support for ``"2b"`` variant. """ #=================================================================== @@ -120,8 +161,9 @@ class bcrypt(uh.HasManyIdents, uh.HasRounds, uh.HasSalt, uh.HasManyBackends, uh. #--HasManyIdents-- default_ident = IDENT_2A - ident_values = (IDENT_2, IDENT_2A, IDENT_2X, IDENT_2Y) - ident_aliases = {u("2"): IDENT_2, u("2a"): IDENT_2A, u("2y"): IDENT_2Y} + ident_values = (IDENT_2, IDENT_2A, IDENT_2X, IDENT_2Y, IDENT_2B) + ident_aliases = {u("2"): IDENT_2, u("2a"): IDENT_2A, u("2y"): IDENT_2Y, + u("2b"): IDENT_2B} #--HasSalt-- min_salt_size = max_salt_size = 22 @@ -161,19 +203,10 @@ class bcrypt(uh.HasManyIdents, uh.HasRounds, uh.HasSalt, uh.HasManyBackends, uh. self.checksum or u('')) return uascii_to_str(hash) - def _get_config(self, ident=None): - "internal helper to prepare config string for backends" - if ident is None: - ident = self.ident - if ident == IDENT_2Y: - # none of passlib's backends suffered from crypt_blowfish's - # buggy "2a" hash, which means we can safely implement - # crypt_blowfish's "2y" hash by passing "2a" to the backends. - ident = IDENT_2A - else: - # no backends currently support 2x, but that should have - # been caught earlier in from_string() - assert ident != IDENT_2X + # NOTE: this should be kept separate from to_string() + # so that bcrypt_sha256() can still use it, while overriding to_string() + def _get_config(self, ident): + """internal helper to prepare config string for backends""" config = u("%s%02d$%s") % (ident, self.rounds, self.salt) return uascii_to_str(config) @@ -197,7 +230,7 @@ class bcrypt(uh.HasManyIdents, uh.HasRounds, uh.HasSalt, uh.HasManyBackends, uh. @classmethod def normhash(cls, hash): - "helper to normalize hash, correcting any bcrypt padding bits" + """helper to normalize hash, correcting any bcrypt padding bits""" if cls.identify(hash): return cls.from_string(hash).to_string() else: @@ -241,17 +274,100 @@ class bcrypt(uh.HasManyIdents, uh.HasRounds, uh.HasSalt, uh.HasManyBackends, uh. #=================================================================== backends = ("bcrypt", "pybcrypt", "bcryptor", "os_crypt", "builtin") + # backend workaround detection + _has_wraparound_bug = False + _lacks_20_support = False + _lacks_2y_support = False + _lacks_2b_support = False + + @classmethod + def set_backend(cls, *a, **k): + backend = super(bcrypt, cls).set_backend(*a, **k) + cls._scan_backend(backend) + return backend + + @classmethod + def _scan_backend(cls, backend): + """ + check for known bugs & feature support once backend is loaded + """ + # check for cryptblowfish 8bit bug (fixed in 2y/2b); + # even though it's not known to be present in any of passlib's backends. + # this is treated as FATAL, because it can easily result in seriously malformed hashes, + # and we can't correct for it ourselves. + # test cases from + # NOTE: reference hash taken from above url, and is the incorrectly generate 2x hash. + if cls.verify(u("\xA3"), + "$2a$05$/OK.fbVrR/bpIqNJ5ianF.CE5elHaaO4EbggVDjb8P19RukzXSM3e"): + raise PasslibSecurityError( + "passlib.hash.bcrypt: Your installation of the %r backend is vulnerable to " + "the crypt_blowfish 8-bit bug (CVE-2011-2483), " + "and should be upgraded or replaced with another backend." % backend) + + # check for bsd wraparound bug (fixed in 2b) + # this is treated as a warning, because it's rare in the field, + # and pybcrypt (as of 2015-7-21) is unpatched, but some people may be stuck with it. + # test cases from + # NOTE: reference hash is of password "0"*72 + # NOTE: if in future we need to deliberately create hashes which have this bug, + # can use something like 'hashpw(repeat_string(secret[:((1+secret) % 256) or 1]), 72)' + cls._has_wraparound_bug = False + if cls.verify(("0123456789"*26)[:255], + "$2a$04$R1lJ2gkNaoPGdafE.H.16.nVyh2niHsGJhayOHLMiXlI45o8/DU.6"): + warn("passlib.hash.bcrypt: Your installation of the %r backend is vulnerable to " + "the bsd wraparound bug, " + "and should be upgraded or replaced with another backend " + "(this warning will be fatal under passlib 1.7)" % backend) + cls._has_wraparound_bug = True + + def _detect_lacks_variant(ident, refhash): + """helper to detect if backend *lacks* support for specified bcrypt variant""" + assert refhash.startswith(ident) + # NOTE: can't use cls.verify() directly or we have recursion error + try: + result = cls.verify("test", refhash) + except (ValueError, _bcryptor.engine.SaltError if _bcryptor else ValueError): + # backends without support will throw various errors about unrecognized version + # pybcrypt, bcrypt -- raises ValueError + # bcryptor -- raises bcryptor.engine.SaltError + log.debug("%r backend lacks %r support", backend, ident) + return True + assert result, "%r backend %r check failed" % (backend, ident) + return False + + # check for native 2 support + # NOTE: have to clear workaround first, so verify() doesn't enable it during detection. + cls._lacks_20_support = False + cls._lacks_20_support = _detect_lacks_variant("$2$", "$2$04$5BJqKfqMQvV7nS.yUguNcu" + "RfMMOXK0xPWavM7pOzjEi5ze5T1k8/S") + + # TODO: check for 2x support + + # check for native 2y support + cls._lacks_2y_support = False + cls._lacks_2y_support = _detect_lacks_variant("$2y$", "$2y$04$5BJqKfqMQvV7nS.yUguNcu" + "eVirQqDBGaLXSqj.rs.pZPlNR0UX/HK") + + # check for native 2b support + cls._lacks_2b_support = False + cls._lacks_2b_support = _detect_lacks_variant("$2b$", "$2b$04$5BJqKfqMQvV7nS.yUguNcu" + "eVirQqDBGaLXSqj.rs.pZPlNR0UX/HK") + + # sanity check + assert cls._lacks_2b_support or not cls._has_wraparound_bug, \ + "sanity check failed: %r backend supports $2b$ but has wraparound bug" % backend + @classproperty def _has_backend_bcrypt(cls): - return _bcrypt is not None and hasattr(_bcrypt, "_ffi") + return _bcrypt is not None and not _detect_pybcrypt() @classproperty def _has_backend_pybcrypt(cls): - return _bcrypt is not None and not hasattr(_bcrypt, "_ffi") + return _bcrypt is not None and _detect_pybcrypt() @classproperty def _has_backend_bcryptor(cls): - return bcryptor_engine is not None + return _bcryptor is not None @classproperty def _has_backend_builtin(cls): @@ -271,51 +387,93 @@ class bcrypt(uh.HasManyIdents, uh.HasRounds, uh.HasSalt, uh.HasManyBackends, uh. @classmethod def _no_backends_msg(cls): - return "no bcrypt backends available - please install py-bcrypt" + return "no bcrypt backends available -- recommend you install one (e.g. 'pip install bcrypt')" def _calc_checksum(self, secret): - "common backend code" + """common backend code""" + + # make sure it's unicode if isinstance(secret, unicode): secret = secret.encode("utf-8") - if _BNULL in secret: - # NOTE: especially important to forbid NULLs for bcrypt, since many - # backends (bcryptor, bcrypt) happily accept them, and then - # silently truncate the password at first NULL they encounter! - raise uh.exc.NullPasswordError(self) - return self._calc_checksum_backend(secret) - def _calc_checksum_os_crypt(self, secret): - config = self._get_config() + # NOTE: especially important to forbid NULLs for bcrypt, since many + # backends (bcryptor, bcrypt) happily accept them, and then + # silently truncate the password at first NULL they encounter! + if _BNULL in secret: + raise uh.exc.NullPasswordError(self) + + # ensure backend is loaded before workaround detection + self.get_backend() + + # protect from wraparound bug by truncating secret before handing it to the backend. + # bcrypt only uses first 72 bytes anyways. + if self._has_wraparound_bug and len(secret) >= 255: + secret = secret[:72] + + # special case handling for variants (ordered most common first) + ident = self.ident + if ident == IDENT_2A: + # fall through and use backend w/o hacks + pass + + elif ident == IDENT_2B: + if self._lacks_2b_support: + # handle $2b$ hash format even if backend is too old. + # have it generate a 2A digest, then return it as a 2B hash. + ident = IDENT_2A + + elif ident == IDENT_2Y: + if self._lacks_2y_support: + # handle $2y$ hash format (not supported by BSDs, being phased out on others) + # have it generate a 2A digest, then return it as a 2Y hash. + ident = IDENT_2A + + elif ident == IDENT_2: + if self._lacks_20_support: + # handle legacy $2$ format (not supported by most backends except BSD os_crypt) + # we can fake $2$ behavior using the $2a$ algorithm + # by repeating the password until it's at least 72 chars in length. + if secret: + secret = repeat_string(secret, 72) + ident = IDENT_2A + + elif ident == IDENT_2X: + + # NOTE: shouldn't get here. + # XXX: could check if backend does actually offer 'support' + raise RuntimeError("$2x$ hashes not currently supported by passlib") + + else: + raise AssertionError("unexpected ident value: %r" % ident) + + # invoke backend + config = self._get_config(ident) + return self._calc_checksum_backend(secret, config) + + def _calc_checksum_os_crypt(self, secret, config): hash = safe_crypt(secret, config) if hash: assert hash.startswith(config) and len(hash) == len(config)+31 return hash[-31:] else: - # NOTE: it's unlikely any other backend will be available, - # but checking before we bail, just in case. - for name in self.backends: - if name != "os_crypt" and self.has_backend(name): - func = getattr(self, "_calc_checksum_" + name) - return func(secret) + # NOTE: Have to raise this error because python3's crypt.crypt() only accepts unicode. + # This means it can't handle any passwords that aren't either unicode + # or utf-8 encoded bytes. However, hashing a password with an alternate + # encoding should be a pretty rare edge case; if user needs it, they can just + # install bcrypt backend. + # XXX: is this the right error type to raise? + # maybe have safe_crypt() not swallow UnicodeDecodeError, and have handlers + # like sha256_crypt trap it if they have alternate method of handling them? raise uh.exc.MissingBackendError( - "password can't be handled by os_crypt, " - "recommend installing py-bcrypt.", + "non-utf8 encoded passwords can't be handled by crypt.crypt() under python3, " + "recommend running `pip install bcrypt`.", ) - def _calc_checksum_bcrypt(self, secret): + def _calc_checksum_bcrypt(self, secret, config): # bcrypt behavior: # hash must be ascii bytes # secret must be bytes # returns bytes - if self.ident == IDENT_2: - # bcrypt doesn't support $2$ hashes; but we can fake $2$ behavior - # using the $2a$ algorithm, by repeating the password until - # it's at least 72 chars in length. - if secret: - secret = repeat_string(secret, 72) - config = self._get_config(IDENT_2A) - else: - config = self._get_config() if isinstance(config, unicode): config = config.encode("ascii") hash = _bcrypt.hashpw(secret, config) @@ -323,37 +481,27 @@ class bcrypt(uh.HasManyIdents, uh.HasRounds, uh.HasSalt, uh.HasManyBackends, uh. assert isinstance(hash, bytes) return hash[-31:].decode("ascii") - def _calc_checksum_pybcrypt(self, secret): + def _calc_checksum_pybcrypt(self, secret, config): # py-bcrypt behavior: # py2: unicode secret/hash encoded as ascii bytes before use, # bytes taken as-is; returns ascii bytes. # py3: unicode secret encoded as utf-8 bytes, # hash encoded as ascii bytes, returns ascii unicode. - config = self._get_config() hash = _bcrypt.hashpw(secret, config) assert hash.startswith(config) and len(hash) == len(config)+31 return str_to_uascii(hash[-31:]) - def _calc_checksum_bcryptor(self, secret): + def _calc_checksum_bcryptor(self, secret, config): # bcryptor behavior: # py2: unicode secret/hash encoded as ascii bytes before use, # bytes taken as-is; returns ascii bytes. # py3: not supported - if self.ident == IDENT_2: - # bcryptor doesn't support $2$ hashes; but we can fake $2$ behavior - # using the $2a$ algorithm, by repeating the password until - # it's at least 72 chars in length. - if secret: - secret = repeat_string(secret, 72) - config = self._get_config(IDENT_2A) - else: - config = self._get_config() - hash = bcryptor_engine(False).hash_key(secret, config) + hash = _bcryptor.engine.Engine(False).hash_key(secret, config) assert hash.startswith(config) and len(hash) == len(config)+31 return str_to_uascii(hash[-31:]) - def _calc_checksum_builtin(self, secret): - chk = _builtin_bcrypt(secret, self.ident.strip("$"), + def _calc_checksum_builtin(self, secret, config): + chk = _builtin_bcrypt(secret, config[1:config.index("$", 1)], self.salt.encode("ascii"), self.rounds) return chk.decode("ascii") @@ -435,19 +583,21 @@ class bcrypt_sha256(bcrypt): return uascii_to_str(hash) def _calc_checksum(self, secret): - # NOTE: this bypasses bcrypt's _calc_checksum, - # so has to take care of all it's issues, such as secret encoding. - if isinstance(secret, unicode): - secret = secret.encode("utf-8") # NOTE: can't use digest directly, since bcrypt stops at first NULL. # NOTE: bcrypt doesn't fully mix entropy for bytes 55-72 of password # (XXX: citation needed), so we don't want key to be > 55 bytes. # thus, have to use base64 (44 bytes) rather than hex (64 bytes). + # XXX: it's later come out that 55-72 may be ok, so later revision of bcrypt_sha256 + # may switch to hex encoding, since it's simpler to implement elsewhere. + if isinstance(secret, unicode): + secret = secret.encode("utf-8") key = b64encode(sha256(secret).digest()) - return self._calc_checksum_backend(key) + + # hand result off to normal bcrypt algorithm + return super(bcrypt_sha256, self)._calc_checksum(key) # patch set_backend so it modifies bcrypt class, not this one... - # else it would clobber our _calc_checksum() wrapper above. + # else the bcrypt.set_backend() tests will call the wrong class. @classmethod def set_backend(cls, *args, **kwds): return bcrypt.set_backend(*args, **kwds) diff --git a/passlib/handlers/cisco.py b/passlib/handlers/cisco.py index b1d25b51..1588e80d 100644 --- a/passlib/handlers/cisco.py +++ b/passlib/handlers/cisco.py @@ -109,7 +109,7 @@ class cisco_type7(uh.GenericHandler): will be issued instead. Correctable errors include ``salt`` values that are out of range. - Note that while this class outputs digests in upper-case hexidecimal, + Note that while this class outputs digests in upper-case hexadecimal, it will accept lower-case as well. This class also provides the following additional method: @@ -156,7 +156,7 @@ class cisco_type7(uh.GenericHandler): self.salt = self._norm_salt(salt) def _norm_salt(self, salt): - "the salt for this algorithm is an integer 0-52, not a string" + """the salt for this algorithm is an integer 0-52, not a string""" # XXX: not entirely sure that values >15 are valid, so for # compatibility we don't output those values, but we do accept them. if salt is None: @@ -206,7 +206,7 @@ class cisco_type7(uh.GenericHandler): @classmethod def _cipher(cls, data, salt): - "xor static key against data - encrypts & decrypts" + """xor static key against data - encrypts & decrypts""" key = cls._key key_size = len(key) return join_byte_values( diff --git a/passlib/handlers/des_crypt.py b/passlib/handlers/des_crypt.py index 1699e1d7..dc28783a 100644 --- a/passlib/handlers/des_crypt.py +++ b/passlib/handlers/des_crypt.py @@ -40,7 +40,7 @@ def _crypt_secret_to_key(secret): for i, c in enumerate(secret[:8])) def _raw_des_crypt(secret, salt): - "pure-python backed for des_crypt" + """pure-python backed for des_crypt""" assert len(salt) == 2 # NOTE: some OSes will accept non-HASH64 characters in the salt, @@ -73,7 +73,7 @@ def _raw_des_crypt(secret, salt): return h64big.encode_int64(result) def _bsdi_secret_to_key(secret): - "covert secret to DES key used by bsdi_crypt" + """covert secret to DES key used by bsdi_crypt""" key_value = _crypt_secret_to_key(secret) idx = 8 end = len(secret) @@ -85,7 +85,7 @@ def _bsdi_secret_to_key(secret): return key_value def _raw_bsdi_crypt(secret, rounds, salt): - "pure-python backend for bsdi_crypt" + """pure-python backend for bsdi_crypt""" # decode salt try: diff --git a/passlib/handlers/digests.py b/passlib/handlers/digests.py index f1a21bde..402c6702 100644 --- a/passlib/handlers/digests.py +++ b/passlib/handlers/digests.py @@ -24,10 +24,10 @@ __all__ = [ ] #============================================================================= -# helpers for hexidecimal hashes +# helpers for hexadecimal hashes #============================================================================= class HexDigestHash(uh.StaticHandler): - "this provides a template for supporting passwords stored as plain hexidecimal hashes" + """this provides a template for supporting passwords stored as plain hexadecimal hashes""" #=================================================================== # class attrs #=================================================================== @@ -60,7 +60,7 @@ def create_hex_hash(hash, digest_name, module=__name__): __module__=module, # so ABCMeta won't clobber it _hash_func=staticmethod(hash), # sometimes it's a function, sometimes not. so wrap it. checksum_size=h.digest_size*2, - __doc__="""This class implements a plain hexidecimal %s hash, and follows the :ref:`password-hash-api`. + __doc__="""This class implements a plain hexadecimal %s hash, and follows the :ref:`password-hash-api`. It supports no optional or contextual keywords. """ % (digest_name,) @@ -106,7 +106,7 @@ class htdigest(uh.PasswordHash): @classmethod def _norm_hash(cls, hash): - "normalize hash to native string, and validate it" + """normalize hash to native string, and validate it""" hash = to_native_str(hash, param="hash") if len(hash) != 32: raise uh.exc.MalformedHashError(cls, "wrong size") diff --git a/passlib/handlers/django.py b/passlib/handlers/django.py index cdb853b8..59574a13 100644 --- a/passlib/handlers/django.py +++ b/passlib/handlers/django.py @@ -270,7 +270,7 @@ class django_pbkdf2_sha256(DjangoVariableHash): :type rounds: int :param rounds: Optional number of rounds to use. - Defaults to 20000, but must be within ``range(1,1<<32)``. + Defaults to 29000, but must be within ``range(1,1<<32)``. :type relaxed: bool :param relaxed: @@ -323,7 +323,7 @@ class django_pbkdf2_sha1(django_pbkdf2_sha256): :type rounds: int :param rounds: Optional number of rounds to use. - Defaults to 60000, but must be within ``range(1,1<<32)``. + Defaults to 131000, but must be within ``range(1,1<<32)``. :type relaxed: bool :param relaxed: diff --git a/passlib/handlers/fshp.py b/passlib/handlers/fshp.py index 6efc782e..920283f1 100644 --- a/passlib/handlers/fshp.py +++ b/passlib/handlers/fshp.py @@ -40,7 +40,7 @@ class fshp(uh.HasRounds, uh.HasRawSalt, uh.HasRawChecksum, uh.GenericHandler): :param rounds: Optional number of rounds to use. - Defaults to 100000, must be between 1 and 4294967295, inclusive. + Defaults to 480000, must be between 1 and 4294967295, inclusive. :param variant: Optionally specifies variant of FSHP to use. @@ -79,7 +79,7 @@ class fshp(uh.HasRounds, uh.HasRawSalt, uh.HasRawChecksum, uh.GenericHandler): #--HasRounds-- # FIXME: should probably use different default rounds # based on the variant. setting for default variant (sha256) for now. - default_rounds = 100000 # current passlib default, FSHP uses 4096 + default_rounds = 480000 # current passlib default, FSHP uses 4096 min_rounds = 1 # set by FSHP max_rounds = 4294967295 # 32-bit integer limit - not set by FSHP rounds_cost = "linear" diff --git a/passlib/handlers/ldap_digests.py b/passlib/handlers/ldap_digests.py index a25a3946..fb378c06 100644 --- a/passlib/handlers/ldap_digests.py +++ b/passlib/handlers/ldap_digests.py @@ -38,7 +38,7 @@ __all__ = [ # ldap helpers #============================================================================= class _Base64DigestHelper(uh.StaticHandler): - "helper for ldap_md5 / ldap_sha1" + """helper for ldap_md5 / ldap_sha1""" # XXX: could combine this with hex digests in digests.py ident = None # required - prefix identifier @@ -48,7 +48,7 @@ class _Base64DigestHelper(uh.StaticHandler): @classproperty def _hash_prefix(cls): - "tell StaticHandler to strip ident from checksum" + """tell StaticHandler to strip ident from checksum""" return cls.ident def _calc_checksum(self, secret): @@ -58,7 +58,7 @@ class _Base64DigestHelper(uh.StaticHandler): return b64encode(chk).decode("ascii") class _SaltedBase64DigestHelper(uh.HasRawSalt, uh.HasRawChecksum, uh.GenericHandler): - "helper for ldap_salted_md5 / ldap_salted_sha1" + """helper for ldap_salted_md5 / ldap_salted_sha1""" setting_kwds = ("salt", "salt_size") checksum_chars = uh.PADDED_BASE64_CHARS diff --git a/passlib/handlers/md5_crypt.py b/passlib/handlers/md5_crypt.py index 642316e8..bb28b168 100644 --- a/passlib/handlers/md5_crypt.py +++ b/passlib/handlers/md5_crypt.py @@ -191,7 +191,7 @@ def _raw_md5_crypt(pwd, salt, use_apr=False): # handler #============================================================================= class _MD5_Common(uh.HasSalt, uh.GenericHandler): - "common code for md5_crypt and apr_md5_crypt" + """common code for md5_crypt and apr_md5_crypt""" #=================================================================== # class attrs #=================================================================== diff --git a/passlib/handlers/misc.py b/passlib/handlers/misc.py index e7f8fe1a..a89ac72a 100644 --- a/passlib/handlers/misc.py +++ b/passlib/handlers/misc.py @@ -37,7 +37,7 @@ class unix_fallback(uh.StaticHandler): all passwords will be allowed through if the hash is an empty string. .. deprecated:: 1.6 - This has been deprecated due to it's "wildcard" feature, + This has been deprecated due to its "wildcard" feature, and will be removed in Passlib 1.8. Use :class:`unix_disabled` instead. """ name = "unix_fallback" diff --git a/passlib/handlers/mssql.py b/passlib/handlers/mssql.py index 1d892732..d50100fc 100644 --- a/passlib/handlers/mssql.py +++ b/passlib/handlers/mssql.py @@ -64,7 +64,7 @@ BIDENT = b("0x0100") UIDENT = u("0x0100") def _ident_mssql(hash, csize, bsize): - "common identify for mssql 2000/2005" + """common identify for mssql 2000/2005""" if isinstance(hash, unicode): if len(hash) == csize and hash.startswith(UIDENT): return True @@ -78,7 +78,7 @@ def _ident_mssql(hash, csize, bsize): return False def _parse_mssql(hash, csize, bsize, handler): - "common parser for mssql 2000/2005; returns 4 byte salt + checksum" + """common parser for mssql 2000/2005; returns 4 byte salt + checksum""" if isinstance(hash, unicode): if len(hash) == csize and hash.startswith(UIDENT): try: diff --git a/passlib/handlers/oracle.py b/passlib/handlers/oracle.py index b8265201..3cd3ba18 100644 --- a/passlib/handlers/oracle.py +++ b/passlib/handlers/oracle.py @@ -113,7 +113,7 @@ class oracle11(uh.HasSalt, uh.GenericHandler): :param salt: Optional salt string. If not specified, one will be autogenerated (this is recommended). - If specified, it must be 20 hexidecimal characters. + If specified, it must be 20 hexadecimal characters. :type relaxed: bool :param relaxed: diff --git a/passlib/handlers/pbkdf2.py b/passlib/handlers/pbkdf2.py index cadbbea8..fd5fbad4 100644 --- a/passlib/handlers/pbkdf2.py +++ b/passlib/handlers/pbkdf2.py @@ -28,7 +28,7 @@ __all__ = [ # #============================================================================= class Pbkdf2DigestHandler(uh.HasRounds, uh.HasRawSalt, uh.HasRawChecksum, uh.GenericHandler): - "base class for various pbkdf2_{digest} algorithms" + """base class for various pbkdf2_{digest} algorithms""" #=================================================================== # class attrs #=================================================================== @@ -84,7 +84,7 @@ class Pbkdf2DigestHandler(uh.HasRounds, uh.HasRawSalt, uh.HasRawChecksum, uh.Gen return pbkdf2(secret, self.salt, self.rounds, self.checksum_size, self._prf) def create_pbkdf2_hash(hash_name, digest_size, rounds=12000, ident=None, module=__name__): - "create new Pbkdf2DigestHandler subclass for a specific hash" + """create new Pbkdf2DigestHandler subclass for a specific hash""" name = 'pbkdf2_' + hash_name if ident is None: ident = u("$pbkdf2-%s$") % (hash_name,) @@ -135,9 +135,9 @@ def create_pbkdf2_hash(hash_name, digest_size, rounds=12000, ident=None, module= #------------------------------------------------------------------------ # derived handlers #------------------------------------------------------------------------ -pbkdf2_sha1 = create_pbkdf2_hash("sha1", 20, 60000, ident=u("$pbkdf2$")) -pbkdf2_sha256 = create_pbkdf2_hash("sha256", 32, 20000) -pbkdf2_sha512 = create_pbkdf2_hash("sha512", 64, 19000) +pbkdf2_sha1 = create_pbkdf2_hash("sha1", 20, 131000, ident=u("$pbkdf2$")) +pbkdf2_sha256 = create_pbkdf2_hash("sha256", 32, 29000) +pbkdf2_sha512 = create_pbkdf2_hash("sha512", 64, 25000) ldap_pbkdf2_sha1 = uh.PrefixWrapper("ldap_pbkdf2_sha1", pbkdf2_sha1, "{PBKDF2}", "$pbkdf2$", ident=True) ldap_pbkdf2_sha256 = uh.PrefixWrapper("ldap_pbkdf2_sha256", pbkdf2_sha256, "{PBKDF2-SHA256}", "$pbkdf2-sha256$", ident=True) diff --git a/passlib/handlers/phpass.py b/passlib/handlers/phpass.py index 45bd9a64..7db32b0e 100644 --- a/passlib/handlers/phpass.py +++ b/passlib/handlers/phpass.py @@ -42,12 +42,12 @@ class phpass(uh.HasManyIdents, uh.HasRounds, uh.HasSalt, uh.GenericHandler): :type rounds: int :param rounds: Optional number of rounds to use. - Defaults to 17, must be between 7 and 30, inclusive. + Defaults to 19, must be between 7 and 30, inclusive. This value is logarithmic, the actual number of iterations used will be :samp:`2**{rounds}`. :type ident: str :param ident: - phpBB3 uses ``H`` instead of ``P`` for it's identifier, + phpBB3 uses ``H`` instead of ``P`` for its identifier, this may be set to ``H`` in order to generate phpBB3 compatible hashes. it defaults to ``P``. @@ -75,7 +75,7 @@ class phpass(uh.HasManyIdents, uh.HasRounds, uh.HasSalt, uh.GenericHandler): salt_chars = uh.HASH64_CHARS #--HasRounds-- - default_rounds = 17 + default_rounds = 19 min_rounds = 7 max_rounds = 30 rounds_cost = "log2" diff --git a/passlib/handlers/scram.py b/passlib/handlers/scram.py index 1c5f9e87..02133cd6 100644 --- a/passlib/handlers/scram.py +++ b/passlib/handlers/scram.py @@ -15,7 +15,7 @@ from passlib.exc import PasslibHashWarning from passlib.utils import ab64_decode, ab64_encode, consteq, saslprep, \ to_native_str, xor_bytes, splitcomma from passlib.utils.compat import b, bytes, bascii_to_str, iteritems, \ - PY3, u, unicode + PY3, u, unicode, native_string_types from passlib.utils.pbkdf2 import pbkdf2, get_prf, norm_hash_name import passlib.utils.handlers as uh # local @@ -49,7 +49,7 @@ class scram(uh.HasRounds, uh.HasRawSalt, uh.HasRawChecksum, uh.GenericHandler): :type rounds: int :param rounds: Optional number of rounds to use. - Defaults to 20000, but must be within ``range(1,1<<32)``. + Defaults to 100000, but must be within ``range(1,1<<32)``. :type algs: list of strings :param algs: @@ -102,7 +102,7 @@ class scram(uh.HasRounds, uh.HasRawSalt, uh.HasRawChecksum, uh.GenericHandler): max_salt_size = 1024 #--HasRounds-- - default_rounds = 20000 + default_rounds = 100000 min_rounds = 1 max_rounds = 2**32-1 rounds_cost = "linear" @@ -317,7 +317,7 @@ class scram(uh.HasRounds, uh.HasRawSalt, uh.HasRawChecksum, uh.GenericHandler): return checksum def _norm_algs(self, algs): - "normalize algs parameter" + """normalize algs parameter""" # determine default algs value if algs is None: # derive algs list from checksum (if present). @@ -332,7 +332,7 @@ class scram(uh.HasRounds, uh.HasRawSalt, uh.HasRawChecksum, uh.GenericHandler): raise RuntimeError("checksum & algs kwds are mutually exclusive") # parse args value - if isinstance(algs, str): + if isinstance(algs, native_string_types): algs = splitcomma(algs) algs = sorted(norm_hash_name(alg, 'iana') for alg in algs) if any(len(alg)>9 for alg in algs): @@ -348,7 +348,7 @@ class scram(uh.HasRounds, uh.HasRawSalt, uh.HasRawChecksum, uh.GenericHandler): @classmethod def _bind_needs_update(cls, **settings): - "generate a deprecation detector for CryptContext to use" + """generate a deprecation detector for CryptContext to use""" # generate deprecation hook which marks hashes as deprecated # if they don't support a superset of current algs. algs = frozenset(cls(use_defaults=True, **settings).algs) diff --git a/passlib/handlers/sha1_crypt.py b/passlib/handlers/sha1_crypt.py index 885c67fc..b243bc04 100644 --- a/passlib/handlers/sha1_crypt.py +++ b/passlib/handlers/sha1_crypt.py @@ -47,7 +47,7 @@ class sha1_crypt(uh.HasManyBackends, uh.HasRounds, uh.HasSalt, uh.GenericHandler :type rounds: int :param rounds: Optional number of rounds to use. - Defaults to 64000, must be between 1 and 4294967295, inclusive. + Defaults to 480000, must be between 1 and 4294967295, inclusive. :type relaxed: bool :param relaxed: @@ -77,7 +77,7 @@ class sha1_crypt(uh.HasManyBackends, uh.HasRounds, uh.HasSalt, uh.GenericHandler salt_chars = uh.HASH64_CHARS #--HasRounds-- - default_rounds = 64000 # current passlib default + default_rounds = 480000 # current passlib default min_rounds = 1 # really, this should be higher. max_rounds = 4294967295 # 32-bit integer limit rounds_cost = "linear" diff --git a/passlib/handlers/sha2_crypt.py b/passlib/handlers/sha2_crypt.py index c4faaad3..3c3dd660 100644 --- a/passlib/handlers/sha2_crypt.py +++ b/passlib/handlers/sha2_crypt.py @@ -240,7 +240,7 @@ _UZERO = u("0") class _SHA2_Common(uh.HasManyBackends, uh.HasRounds, uh.HasSalt, uh.GenericHandler): - "class containing common code shared by sha256_crypt & sha512_crypt" + """class containing common code shared by sha256_crypt & sha512_crypt""" #=================================================================== # class attrs #=================================================================== @@ -374,7 +374,7 @@ class sha256_crypt(_SHA2_Common): :type rounds: int :param rounds: Optional number of rounds to use. - Defaults to 110000, must be between 1000 and 999999999, inclusive. + Defaults to 535000, must be between 1000 and 999999999, inclusive. :type implicit_rounds: bool :param implicit_rounds: @@ -402,7 +402,7 @@ class sha256_crypt(_SHA2_Common): ident = u("$5$") checksum_size = 43 # NOTE: using 25/75 weighting of builtin & os_crypt backends - default_rounds = 110000 + default_rounds = 535000 #=================================================================== # backends @@ -435,7 +435,7 @@ class sha512_crypt(_SHA2_Common): :type rounds: int :param rounds: Optional number of rounds to use. - Defaults to 100000, must be between 1000 and 999999999, inclusive. + Defaults to 656000, must be between 1000 and 999999999, inclusive. :type implicit_rounds: bool :param implicit_rounds: @@ -465,7 +465,7 @@ class sha512_crypt(_SHA2_Common): checksum_size = 86 _cdb_use_512 = True # NOTE: using 25/75 weighting of builtin & os_crypt backends - default_rounds = 100000 + default_rounds = 656000 #=================================================================== # backend diff --git a/passlib/handlers/sun_md5_crypt.py b/passlib/handlers/sun_md5_crypt.py index 41d3331b..a6a966d3 100644 --- a/passlib/handlers/sun_md5_crypt.py +++ b/passlib/handlers/sun_md5_crypt.py @@ -82,7 +82,7 @@ _XY_ROUNDS = [ del xr def raw_sun_md5_crypt(secret, rounds, salt): - "given secret & salt, return encoded sun-md5-crypt checksum" + """given secret & salt, return encoded sun-md5-crypt checksum""" global MAGIC_HAMLET assert isinstance(secret, bytes) assert isinstance(salt, bytes) @@ -193,7 +193,7 @@ class sun_md5_crypt(uh.HasRounds, uh.HasSalt, uh.GenericHandler): :type rounds: int :param rounds: Optional number of rounds to use. - Defaults to 5500, must be between 0 and 4294963199, inclusive. + Defaults to 34000, must be between 0 and 4294963199, inclusive. :type bare_salt: bool :param bare_salt: @@ -231,7 +231,7 @@ class sun_md5_crypt(uh.HasRounds, uh.HasSalt, uh.GenericHandler): max_salt_size = None salt_chars = uh.HASH64_CHARS - default_rounds = 5500 # current passlib default + default_rounds = 34000 # current passlib default min_rounds = 0 max_rounds = 4294963199 ##2**32-1-4096 # XXX: ^ not sure what it does if past this bound... does 32 int roll over? diff --git a/passlib/handlers/windows.py b/passlib/handlers/windows.py index 3bc3e4f9..3f911be0 100644 --- a/passlib/handlers/windows.py +++ b/passlib/handlers/windows.py @@ -40,7 +40,7 @@ class lmhash(uh.HasEncodingContext, uh.StaticHandler): calculating digest. It defaults to ``cp437``, the most common encoding encountered. - Note that while this class outputs digests in lower-case hexidecimal, + Note that while this class outputs digests in lower-case hexadecimal, it will accept upper-case as well. """ #=================================================================== @@ -116,7 +116,7 @@ class nthash(uh.StaticHandler): The :meth:`~passlib.ifc.PasswordHash.encrypt` and :meth:`~passlib.ifc.PasswordHash.genconfig` methods accept no optional keywords. - Note that while this class outputs lower-case hexidecimal digests, + Note that while this class outputs lower-case hexadecimal digests, it will accept upper-case digests as well. """ #=================================================================== @@ -228,7 +228,7 @@ class msdcc(uh.HasUserContext, uh.StaticHandler): This keyword is case-insensitive, and should contain just the username (e.g. ``Administrator``, not ``SOMEDOMAIN\\Administrator``). - Note that while this class outputs lower-case hexidecimal digests, + Note that while this class outputs lower-case hexadecimal digests, it will accept upper-case digests as well. """ name = "msdcc" diff --git a/passlib/hosts.py b/passlib/hosts.py index f6eb0076..7df3efd2 100644 --- a/passlib/hosts.py +++ b/passlib/hosts.py @@ -71,7 +71,7 @@ if has_crypt: # and can be introspected and used much more flexibly. def _iter_os_crypt_schemes(): - "helper which iterates over supported os_crypt schemes" + """helper which iterates over supported os_crypt schemes""" found = False for name in unix_crypt_schemes: handler = get_crypt_handler(name) diff --git a/passlib/ifc.py b/passlib/ifc.py index 908890aa..5e4e7d40 100644 --- a/passlib/ifc.py +++ b/passlib/ifc.py @@ -26,7 +26,7 @@ else: # return None def create_with_metaclass(meta): - "class decorator that re-creates class using metaclass" + """class decorator that re-creates class using metaclass""" # have to do things this way since abc not present in py25, # and py2/py3 have different ways of doing metaclasses. def builder(cls): @@ -84,13 +84,13 @@ class PasswordHash(object): @classmethod @abstractmethod def encrypt(cls, secret, **setting_and_context_kwds): # pragma: no cover -- abstract method - "encrypt secret, returning resulting hash" + """encrypt secret, returning resulting hash""" raise NotImplementedError("must be implemented by subclass") @classmethod @abstractmethod def verify(cls, secret, hash, **context_kwds): # pragma: no cover -- abstract method - "verify secret against hash, returns True/False" + """verify secret against hash, returns True/False""" raise NotImplementedError("must be implemented by subclass") #=================================================================== @@ -99,19 +99,19 @@ class PasswordHash(object): @classmethod @abstractmethod def identify(cls, hash): # pragma: no cover -- abstract method - "check if hash belongs to this scheme, returns True/False" + """check if hash belongs to this scheme, returns True/False""" raise NotImplementedError("must be implemented by subclass") @classmethod @abstractmethod def genconfig(cls, **setting_kwds): # pragma: no cover -- abstract method - "compile settings into a configuration string for genhash()" + """compile settings into a configuration string for genhash()""" raise NotImplementedError("must be implemented by subclass") @classmethod @abstractmethod def genhash(cls, secret, config, **context_kwds): # pragma: no cover -- abstract method - "generated hash for secret, using settings from config/hash string" + """generated hash for secret, using settings from config/hash string""" raise NotImplementedError("must be implemented by subclass") #=================================================================== diff --git a/passlib/registry.py b/passlib/registry.py index 938bc5ec..1f07940b 100644 --- a/passlib/registry.py +++ b/passlib/registry.py @@ -9,6 +9,7 @@ from warnings import warn # pkg from passlib.exc import ExpectedTypeError, PasslibWarning from passlib.utils import is_crypt_handler +from passlib.utils.compat import native_string_types # local __all__ = [ "register_crypt_handler_path", @@ -262,7 +263,8 @@ def register_crypt_handler(handler, force=False, _attr=None): name = handler.name _validate_handler_name(name) if _attr and _attr != name: - raise ValueError("handlers must be stored only under their own name") + raise ValueError("handlers must be stored only under their own name (%r != %r)" % + (_attr, name)) # check for existing handler other = _handlers.get(name) @@ -310,7 +312,7 @@ def get_crypt_handler(name, default=_UNSET): pass # normalize name (and if changed, check dict again) - assert isinstance(name, str), "name must be str instance" + assert isinstance(name, native_string_types), "name must be str instance" alt = name.replace("-","_").lower() if alt != name: warn("handler names should be lower-case, and use underscores instead " @@ -338,7 +340,7 @@ def get_crypt_handler(name, default=_UNSET): mod = __import__(modname, fromlist=[modattr], level=0) # first check if importing module triggered register_crypt_handler(), - # (this is discouraged due to it's magical implicitness) + # (this is discouraged due to its magical implicitness) handler = _handlers.get(name) if handler: # XXX: issue deprecation warning here? @@ -394,7 +396,7 @@ def _unload_handler_name(name, locations=True): used only by the unittests. if loaded handler is found with specified name, it's removed. - if path to lazy load handler is found, its' removed. + if path to lazy load handler is found, it's removed. missing names are a noop. diff --git a/passlib/tests/_test_bad_register.py b/passlib/tests/_test_bad_register.py index 26cc6bbb..f0683fcc 100644 --- a/passlib/tests/_test_bad_register.py +++ b/passlib/tests/_test_bad_register.py @@ -1,4 +1,4 @@ -"helper for method in test_registry.py" +"""helper for method in test_registry.py""" from passlib.registry import register_crypt_handler import passlib.utils.handlers as uh diff --git a/passlib/tests/test_apache.py b/passlib/tests/test_apache.py index 68a992ff..1785a3ed 100644 --- a/passlib/tests/test_apache.py +++ b/passlib/tests/test_apache.py @@ -11,6 +11,7 @@ import time # site # pkg from passlib import apache +from passlib.exc import MissingBackendError from passlib.utils.compat import irange, unicode from passlib.tests.utils import TestCase, get_file, set_file, catch_warnings, ensure_mtime_changed from passlib.utils.compat import b, bytes, u @@ -18,7 +19,7 @@ from passlib.utils.compat import b, bytes, u log = getLogger(__name__) def backdate_file_mtime(path, offset=10): - "backdate file's mtime by specified amount" + """backdate file's mtime by specified amount""" # NOTE: this is used so we can test code which detects mtime changes, # without having to actually *pause* for that long. atime = os.path.getatime(path) @@ -29,7 +30,7 @@ def backdate_file_mtime(path, offset=10): # htpasswd #============================================================================= class HtpasswdFileTest(TestCase): - "test HtpasswdFile class" + """test HtpasswdFile class""" descriptionPrefix = "HtpasswdFile" # sample with 4 users @@ -54,8 +55,17 @@ class HtpasswdFileTest(TestCase): sample_dup = b('user1:pass1\nuser1:pass2\n') + # sample with bcrypt & sha256_crypt hashes + sample_05 = b('user2:2CHkkwa2AtqGs\n' + 'user3:{SHA}3ipNV1GrBtxPmHFC21fCbVCSXIo=\n' + 'user4:pass4\n' + 'user1:$apr1$t4tc7jTh$GPIWVUo8sQKJlUdV8V5vu0\n' + 'user5:$2a$12$yktDxraxijBZ360orOyCOePFGhuis/umyPNJoL5EbsLk.s6SWdrRO\n' + 'user6:$5$rounds=110000$cCRp/xUUGVgwR4aP$' + 'p0.QKFS5qLNRqw1/47lXYiAcgIjJK.WjCO8nrEKuUK.\n') + def test_00_constructor_autoload(self): - "test constructor autoload" + """test constructor autoload""" # check with existing file path = self.mktemp() set_file(path, self.sample_01) @@ -97,7 +107,7 @@ class HtpasswdFileTest(TestCase): self.assertFalse(ht.mtime) def test_01_delete(self): - "test delete()" + """test delete()""" ht = apache.HtpasswdFile.from_string(self.sample_01) self.assertTrue(ht.delete("user1")) # should delete both entries self.assertTrue(ht.delete("user2")) @@ -121,7 +131,7 @@ class HtpasswdFileTest(TestCase): self.assertEqual(get_file(path), b("user2:pass2\n")) def test_02_set_password(self): - "test set_password()" + """test set_password()""" ht = apache.HtpasswdFile.from_string( self.sample_01, default_scheme="plaintext") self.assertTrue(ht.set_password("user2", "pass2x")) @@ -156,8 +166,29 @@ class HtpasswdFileTest(TestCase): ht.set_password("user1", "pass2") self.assertEqual(get_file(path), b("user1:pass2\n")) + def test_02_set_password_default_scheme(self): + """test set_password() -- default_scheme""" + + def check(scheme): + ht = apache.HtpasswdFile(default_scheme=scheme) + ht.set_password("user1", "pass1") + return ht.context.identify(ht.get_hash("user1")) + + # explicit scheme + self.assertEqual(check("sha256_crypt"), "sha256_crypt") + self.assertEqual(check("des_crypt"), "des_crypt") + + # unknown scheme + self.assertRaises(KeyError, check, "xxx") + + # portable alias + self.assertEqual(check("portable"), apache.portable_scheme) + + # default -- currently same as portable, will be host-specific under passlib 1.7. + self.assertEqual(check(None), "apr_md5_crypt") + def test_03_users(self): - "test users()" + """test users()""" ht = apache.HtpasswdFile.from_string(self.sample_01) ht.set_password("user5", "pass5") ht.delete("user3") @@ -166,14 +197,23 @@ class HtpasswdFileTest(TestCase): "user3"]) def test_04_check_password(self): - "test check_password()" - ht = apache.HtpasswdFile.from_string(self.sample_01) - self.assertRaises(TypeError, ht.check_password, 1, 'pass5') - self.assertTrue(ht.check_password("user5","pass5") is None) - for i in irange(1,5): + """test check_password()""" + ht = apache.HtpasswdFile.from_string(self.sample_05) + self.assertRaises(TypeError, ht.check_password, 1, 'pass9') + self.assertTrue(ht.check_password("user9","pass9") is None) + + # users 1..6 of sample_01 run through all the main hash formats, + # to make sure they're recognized. + for i in irange(1, 7): i = str(i) - self.assertTrue(ht.check_password("user"+i, "pass"+i)) - self.assertTrue(ht.check_password("user"+i, "pass5") is False) + try: + self.assertTrue(ht.check_password("user"+i, "pass"+i)) + self.assertTrue(ht.check_password("user"+i, "pass9") is False) + except MissingBackendError: + if i == "5": + # user5 uses bcrypt, which is apparently not available right now + continue + raise self.assertRaises(ValueError, ht.check_password, "user:", "pass") @@ -183,7 +223,7 @@ class HtpasswdFileTest(TestCase): self.assertFalse(ht.verify("user1", "pass2")) def test_05_load(self): - "test load()" + """test load()""" # setup empty file path = self.mktemp() set_file(path, "") @@ -220,7 +260,7 @@ class HtpasswdFileTest(TestCase): # NOTE: load_string() tested via from_string(), which is used all over this file def test_06_save(self): - "test save()" + """test save()""" # load from file path = self.mktemp() set_file(path, self.sample_01) @@ -242,7 +282,7 @@ class HtpasswdFileTest(TestCase): self.assertEqual(get_file(path), b("user1:pass1\n")) def test_07_encodings(self): - "test 'encoding' kwd" + """test 'encoding' kwd""" # test bad encodings cause failure in constructor self.assertRaises(ValueError, apache.HtpasswdFile, encoding="utf-16") @@ -262,7 +302,7 @@ class HtpasswdFileTest(TestCase): self.assertEqual(ht.users(), [ u("user\u00e6") ]) def test_08_get_hash(self): - "test get_hash()" + """test get_hash()""" ht = apache.HtpasswdFile.from_string(self.sample_01) self.assertEqual(ht.get_hash("user3"), b("{SHA}3ipNV1GrBtxPmHFC21fCbVCSXIo=")) self.assertEqual(ht.get_hash("user4"), b("pass4")) @@ -272,7 +312,7 @@ class HtpasswdFileTest(TestCase): self.assertEqual(ht.find("user4"), b("pass4")) def test_09_to_string(self): - "test to_string" + """test to_string""" # check with known sample ht = apache.HtpasswdFile.from_string(self.sample_01) @@ -305,7 +345,7 @@ class HtpasswdFileTest(TestCase): # htdigest #============================================================================= class HtdigestFileTest(TestCase): - "test HtdigestFile class" + """test HtdigestFile class""" descriptionPrefix = "HtdigestFile" # sample with 4 users @@ -330,7 +370,7 @@ class HtdigestFileTest(TestCase): sample_04_latin1 = b('user\xe6:realm\xe6:549d2a5f4659ab39a80dac99e159ab19\n') def test_00_constructor_autoload(self): - "test constructor autoload" + """test constructor autoload""" # check with existing file path = self.mktemp() set_file(path, self.sample_01) @@ -348,7 +388,7 @@ class HtdigestFileTest(TestCase): # NOTE: default_realm option checked via other tests. def test_01_delete(self): - "test delete()" + """test delete()""" ht = apache.HtdigestFile.from_string(self.sample_01) self.assertTrue(ht.delete("user1", "realm")) self.assertTrue(ht.delete("user2", "realm")) @@ -377,7 +417,7 @@ class HtdigestFileTest(TestCase): self.assertEqual(get_file(path), self.sample_02) def test_02_set_password(self): - "test update()" + """test update()""" ht = apache.HtdigestFile.from_string(self.sample_01) self.assertTrue(ht.set_password("user2", "realm", "pass2x")) self.assertFalse(ht.set_password("user5", "realm", "pass5")) @@ -405,7 +445,7 @@ class HtdigestFileTest(TestCase): # TODO: test set_password autosave def test_03_users(self): - "test users()" + """test users()""" ht = apache.HtdigestFile.from_string(self.sample_01) ht.set_password("user5", "realm", "pass5") ht.delete("user3", "realm") @@ -415,7 +455,7 @@ class HtdigestFileTest(TestCase): self.assertRaises(TypeError, ht.users, 1) def test_04_check_password(self): - "test check_password()" + """test check_password()""" ht = apache.HtdigestFile.from_string(self.sample_01) self.assertRaises(TypeError, ht.check_password, 1, 'realm', 'pass5') self.assertRaises(TypeError, ht.check_password, 'user', 1, 'pass5') @@ -440,7 +480,7 @@ class HtdigestFileTest(TestCase): self.assertRaises(ValueError, ht.check_password, "user:", "realm", "pass") def test_05_load(self): - "test load()" + """test load()""" # setup empty file path = self.mktemp() set_file(path, "") @@ -481,7 +521,7 @@ class HtdigestFileTest(TestCase): self.assertEqual(ha.to_string(), b("")) def test_06_save(self): - "test save()" + """test save()""" # load from file path = self.mktemp() set_file(path, self.sample_01) @@ -503,7 +543,7 @@ class HtdigestFileTest(TestCase): self.assertEqual(get_file(path), hb.to_string()) def test_07_realms(self): - "test realms() & delete_realm()" + """test realms() & delete_realm()""" ht = apache.HtdigestFile.from_string(self.sample_01) self.assertEqual(ht.delete_realm("x"), 0) @@ -514,7 +554,7 @@ class HtdigestFileTest(TestCase): self.assertEqual(ht.to_string(), b("")) def test_08_get_hash(self): - "test get_hash()" + """test get_hash()""" ht = apache.HtdigestFile.from_string(self.sample_01) self.assertEqual(ht.get_hash("user3", "realm"), "a500bb8c02f6a9170ae46af10c898744") self.assertEqual(ht.get_hash("user4", "realm"), "ab7b5d5f28ccc7666315f508c7358519") @@ -524,7 +564,7 @@ class HtdigestFileTest(TestCase): self.assertEqual(ht.find("user4", "realm"), "ab7b5d5f28ccc7666315f508c7358519") def test_09_encodings(self): - "test encoding parameter" + """test encoding parameter""" # test bad encodings cause failure in constructor self.assertRaises(ValueError, apache.HtdigestFile, encoding="utf-16") @@ -539,7 +579,7 @@ class HtdigestFileTest(TestCase): self.assertEqual(ht.users(u("realm\u00e6")), [ u("user\u00e6") ]) def test_10_to_string(self): - "test to_string()" + """test to_string()""" # check sample ht = apache.HtdigestFile.from_string(self.sample_01) diff --git a/passlib/tests/test_apps.py b/passlib/tests/test_apps.py index 421f83b0..c3d68204 100644 --- a/passlib/tests/test_apps.py +++ b/passlib/tests/test_apps.py @@ -15,7 +15,7 @@ from passlib.tests.utils import TestCase # test predefined app contexts #============================================================================= class AppsTest(TestCase): - "perform general tests to make sure contexts work" + """perform general tests to make sure contexts work""" # NOTE: these tests are not really comprehensive, # since they would do little but duplicate # the presets in apps.py diff --git a/passlib/tests/test_context.py b/passlib/tests/test_context.py index cdd8746b..d87bfa18 100644 --- a/passlib/tests/test_context.py +++ b/passlib/tests/test_context.py @@ -178,7 +178,7 @@ sha512_crypt__min_rounds = 45000 # constructors #=================================================================== def test_01_constructor(self): - "test class constructor" + """test class constructor""" # test blank constructor works correctly ctx = CryptContext() @@ -200,8 +200,12 @@ sha512_crypt__min_rounds = 45000 ctx = CryptContext(**self.sample_3_dict) self.assertEqual(ctx.to_dict(), self.sample_3_dict) + # test unicode scheme names (issue 54) + ctx = CryptContext(schemes=[u("sha256_crypt")]) + self.assertEqual(ctx.schemes(), ("sha256_crypt",)) + def test_02_from_string(self): - "test from_string() constructor" + """test from_string() constructor""" # test sample 1 unicode ctx = CryptContext.from_string(self.sample_1_unicode) self.assertEqual(ctx.to_dict(), self.sample_1_dict) @@ -231,7 +235,7 @@ sha512_crypt__min_rounds = 45000 self.sample_1_unicode, section="fakesection") def test_03_from_path(self): - "test from_path() constructor" + """test from_path() constructor""" # make sure sample files exist if not os.path.exists(self.sample_1_path): raise RuntimeError("can't find data file: %r" % self.sample_1_path) @@ -258,7 +262,7 @@ sha512_crypt__min_rounds = 45000 self.sample_1_path, section="fakesection") def test_04_copy(self): - "test copy() method" + """test copy() method""" cc1 = CryptContext(**self.sample_1_dict) # overlay sample 2 onto copy @@ -287,7 +291,7 @@ sha512_crypt__min_rounds = 45000 self.assertEqual(cc4.to_dict(), self.sample_12_dict) def test_09_repr(self): - "test repr()" + """test repr()""" cc1 = CryptContext(**self.sample_1_dict) self.assertRegex(repr(cc1), "^$") @@ -295,9 +299,9 @@ sha512_crypt__min_rounds = 45000 # modifiers #=================================================================== def test_10_load(self): - "test load() / load_path() method" + """test load() / load_path() method""" # NOTE: load() is the workhorse that handles all policy parsing, - # compilation, and validation. most of it's features are tested + # compilation, and validation. most of its features are tested # elsewhere, since all the constructors and modifiers are just # wrappers for it. @@ -338,7 +342,7 @@ sha512_crypt__min_rounds = 45000 self.assertEqual(ctx.to_dict(), self.sample_2_dict) def test_11_load_rollback(self): - "test load() errors restore old state" + """test load() errors restore old state""" # create initial context cc = CryptContext(["des_crypt", "sha256_crypt"], sha256_crypt__default_rounds=5000, @@ -362,7 +366,7 @@ sha512_crypt__min_rounds = 45000 self.assertEqual(cc.to_string(), result) def test_12_update(self): - "test update() method" + """test update() method""" # empty overlay ctx = CryptContext(**self.sample_1_dict) @@ -399,7 +403,7 @@ sha512_crypt__min_rounds = 45000 # option parsing #=================================================================== def test_20_options(self): - "test basic option parsing" + """test basic option parsing""" def parse(**kwds): return CryptContext(**kwds).to_dict() @@ -475,7 +479,7 @@ sha512_crypt__min_rounds = 45000 all__salt="xx") def test_21_schemes(self): - "test 'schemes' context option parsing" + """test 'schemes' context option parsing""" # schemes can be empty cc = CryptContext(schemes=None) @@ -511,7 +515,7 @@ sha512_crypt__min_rounds = 45000 admin__context__schemes=["md5_crypt"]) def test_22_deprecated(self): - "test 'deprecated' context option parsing" + """test 'deprecated' context option parsing""" def getdep(ctx, category=None): return [name for name in ctx.schemes() if ctx._is_deprecated_scheme(name, category)] @@ -603,7 +607,7 @@ sha512_crypt__min_rounds = 45000 self.assertEqual(getdep(cc, "admin"), []) def test_23_default(self): - "test 'default' context option parsing" + """test 'default' context option parsing""" # anything allowed if no schemes self.assertEqual(CryptContext(default="md5_crypt").to_dict(), @@ -640,7 +644,7 @@ sha512_crypt__min_rounds = 45000 self.assertEqual(ctx.default_scheme("admin"), "md5_crypt") def test_24_vary_rounds(self): - "test 'vary_rounds' hash option parsing" + """test 'vary_rounds' hash option parsing""" def parse(v): return CryptContext(all__vary_rounds=v).to_dict()['all__vary_rounds'] @@ -659,7 +663,7 @@ sha512_crypt__min_rounds = 45000 # inspection & serialization #=================================================================== def test_30_schemes(self): - "test schemes() method" + """test schemes() method""" # NOTE: also checked under test_21 # test empty @@ -677,7 +681,7 @@ sha512_crypt__min_rounds = 45000 self.assertEqual(ctx.schemes(), ()) def test_31_default_scheme(self): - "test default_scheme() method" + """test default_scheme() method""" # NOTE: also checked under test_23 # test empty @@ -700,7 +704,7 @@ sha512_crypt__min_rounds = 45000 # categories tested under test_23 def test_32_handler(self): - "test handler() method" + """test handler() method""" # default for empty ctx = CryptContext() @@ -729,7 +733,7 @@ sha512_crypt__min_rounds = 45000 self.assertEqual(ctx.handler(category=u("admin")), hash.md5_crypt) def test_33_options(self): - "test internal _get_record_options() method" + """test internal _get_record_options() method""" def options(ctx, scheme, category=None): return ctx._config._get_record_options_with_flag(scheme, category)[0] @@ -804,14 +808,14 @@ sha512_crypt__min_rounds = 45000 )) def test_34_to_dict(self): - "test to_dict() method" + """test to_dict() method""" # NOTE: this is tested all throughout this test case. ctx = CryptContext(**self.sample_1_dict) self.assertEqual(ctx.to_dict(), self.sample_1_dict) self.assertEqual(ctx.to_dict(resolve=True), self.sample_1_resolved_dict) def test_35_to_string(self): - "test to_string() method" + """test to_string() method""" # create ctx and serialize ctx = CryptContext(**self.sample_1_dict) @@ -834,7 +838,6 @@ sha512_crypt__min_rounds = 45000 self.assertEqual(other, dump.replace("[passlib]","[password-security]")) # test unmanaged handler warning - from passlib import hash from passlib.tests.test_utils_handlers import UnsaltedHash ctx3 = CryptContext([UnsaltedHash, "md5_crypt"]) dump = ctx3.to_string() @@ -852,7 +855,7 @@ sha512_crypt__min_rounds = 45000 ] def test_40_basic(self): - "test basic encrypt/identify/verify functionality" + """test basic encrypt/identify/verify functionality""" handlers = [hash.md5_crypt, hash.des_crypt, hash.bsdi_crypt] cc = CryptContext(handlers, bsdi_crypt__default_rounds=5) @@ -878,7 +881,7 @@ sha512_crypt__min_rounds = 45000 self.assertRaises(ValueError, cc.genhash, 'secret', cc.genconfig(), scheme="des_crypt") def test_41_genconfig(self): - "test genconfig() method" + """test genconfig() method""" cc = CryptContext(schemes=["md5_crypt", "phpass"], phpass__ident="H", phpass__default_rounds=7, @@ -927,7 +930,7 @@ sha512_crypt__min_rounds = 45000 def test_42_genhash(self): - "test genhash() method" + """test genhash() method""" #-------------------------------------------------------------- # border cases @@ -960,7 +963,7 @@ sha512_crypt__min_rounds = 45000 def test_43_encrypt(self): - "test encrypt() method" + """test encrypt() method""" cc = CryptContext(**self.sample_4_dict) # hash specific settings @@ -1015,7 +1018,7 @@ sha512_crypt__min_rounds = 45000 def test_44_identify(self): - "test identify() border cases" + """test identify() border cases""" handlers = ["md5_crypt", "des_crypt", "bsdi_crypt"] cc = CryptContext(handlers, bsdi_crypt__default_rounds=5) @@ -1041,7 +1044,7 @@ sha512_crypt__min_rounds = 45000 self.assertRaises(TypeError, cc.identify, None, category=1) def test_45_verify(self): - "test verify() scheme kwd" + """test verify() scheme kwd""" handlers = ["md5_crypt", "des_crypt", "bsdi_crypt"] cc = CryptContext(handlers, bsdi_crypt__default_rounds=5) @@ -1087,7 +1090,7 @@ sha512_crypt__min_rounds = 45000 self.assertRaises(TypeError, cc.verify, 'secret', refhash, category=1) def test_46_needs_update(self): - "test needs_update() method" + """test needs_update() method""" cc = CryptContext(**self.sample_4_dict) # check deprecated scheme @@ -1167,7 +1170,7 @@ sha512_crypt__min_rounds = 45000 self.assertRaises(TypeError, cc.needs_update, refhash, category=1) def test_47_verify_and_update(self): - "test verify_and_update()" + """test verify_and_update()""" cc = CryptContext(**self.sample_4_dict) # create some hashes @@ -1227,7 +1230,7 @@ sha512_crypt__min_rounds = 45000 # genconfig(). it's assumed encrypt() takes the same codepath. def test_50_rounds_limits(self): - "test rounds limits" + """test rounds limits""" cc = CryptContext(schemes=["sha256_crypt"], all__min_rounds=2000, all__max_rounds=3000, @@ -1344,7 +1347,7 @@ sha512_crypt__min_rounds = 45000 self.assertRaises(TypeError, CryptContext, "sha256_crypt", all__default_rounds=bad) def test_51_linear_vary_rounds(self): - "test linear vary rounds" + """test linear vary rounds""" cc = CryptContext(schemes=["sha256_crypt"], all__min_rounds=1995, all__max_rounds=2005, @@ -1376,7 +1379,7 @@ sha512_crypt__min_rounds = 45000 self.assert_rounds_range(c2, "sha256_crypt", 1995, 2005) def test_52_log2_vary_rounds(self): - "test log2 vary rounds" + """test log2 vary rounds""" cc = CryptContext(schemes=["bcrypt"], all__min_rounds=15, all__max_rounds=25, @@ -1415,7 +1418,7 @@ sha512_crypt__min_rounds = 45000 self.assert_rounds_range(c2, "bcrypt", 15, 21) def assert_rounds_range(self, context, scheme, lower, upper): - "helper to check vary_rounds covers specified range" + """helper to check vary_rounds covers specified range""" # NOTE: this runs enough times the min and max *should* be hit, # though there's a faint chance it will randomly fail. handler = context.handler(scheme) @@ -1432,7 +1435,7 @@ sha512_crypt__min_rounds = 45000 # feature tests #=================================================================== def test_60_min_verify_time(self): - "test verify() honors min_verify_time" + """test verify() honors min_verify_time""" delta = .05 if TICK_RESOLUTION >= delta/10: raise self.skipTest("timer not accurate enough") @@ -1441,7 +1444,7 @@ sha512_crypt__min_rounds = 45000 max_delay = 8*delta class TimedHash(uh.StaticHandler): - "psuedo hash that takes specified amount of time" + """psuedo hash that takes specified amount of time""" name = "timed_hash" delay = 0 @@ -1498,7 +1501,7 @@ sha512_crypt__min_rounds = 45000 self.assertRaises(ValueError, CryptContext, min_verify_time=-1) def test_61_autodeprecate(self): - "test deprecated='auto' is handled correctly" + """test deprecated='auto' is handled correctly""" def getstate(ctx, category=None): return [ctx._is_deprecated_scheme(scheme, category) for scheme in ctx.schemes()] @@ -1533,7 +1536,7 @@ sha512_crypt__min_rounds = 45000 # handler deprecation detectors #=================================================================== def test_62_bcrypt_update(self): - "test verify_and_update / needs_update corrects bcrypt padding" + """test verify_and_update / needs_update corrects bcrypt padding""" # see issue 25. bcrypt = hash.bcrypt @@ -1554,7 +1557,7 @@ sha512_crypt__min_rounds = 45000 self.assertTrue(new_hash and new_hash != BAD1) def test_63_bsdi_crypt_update(self): - "test verify_and_update / needs_update corrects bsdi even rounds" + """test verify_and_update / needs_update corrects bsdi even rounds""" even_hash = '_Y/../cG0zkJa6LY6k4c' odd_hash = '_Z/..TgFg0/ptQtpAgws' secret = 'test' @@ -1588,7 +1591,7 @@ class LazyCryptContextTest(TestCase): self.addCleanup(unload_handler_name, "dummy_2") def test_kwd_constructor(self): - "test plain kwds" + """test plain kwds""" self.assertFalse(has_crypt_handler("dummy_2")) register_crypt_handler_path("dummy_2", "passlib.tests.test_context") diff --git a/passlib/tests/test_context_deprecated.py b/passlib/tests/test_context_deprecated.py index db0c49d9..df9e1b3d 100644 --- a/passlib/tests/test_context_deprecated.py +++ b/passlib/tests/test_context_deprecated.py @@ -41,7 +41,7 @@ log = getLogger(__name__) # #============================================================================= class CryptPolicyTest(TestCase): - "test CryptPolicy object" + """test CryptPolicy object""" # TODO: need to test user categories w/in all this @@ -220,7 +220,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt r"the method.*hash_needs_update.*is deprecated") def test_00_constructor(self): - "test CryptPolicy() constructor" + """test CryptPolicy() constructor""" policy = CryptPolicy(**self.sample_config_1pd) self.assertEqual(policy.to_dict(), self.sample_config_1pd) @@ -260,7 +260,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt default='md5_crypt') def test_01_from_path_simple(self): - "test CryptPolicy.from_path() constructor" + """test CryptPolicy.from_path() constructor""" # NOTE: this is separate so it can also run under GAE # test preset stored in existing file @@ -272,7 +272,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt self.assertRaises(EnvironmentError, CryptPolicy.from_path, path + 'xxx') def test_01_from_path(self): - "test CryptPolicy.from_path() constructor with encodings" + """test CryptPolicy.from_path() constructor with encodings""" path = self.mktemp() # test "\n" linesep @@ -292,7 +292,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt self.assertEqual(policy.to_dict(), self.sample_config_1pd) def test_02_from_string(self): - "test CryptPolicy.from_string() constructor" + """test CryptPolicy.from_string() constructor""" # test "\n" linesep policy = CryptPolicy.from_string(self.sample_config_1s) self.assertEqual(policy.to_dict(), self.sample_config_1pd) @@ -317,7 +317,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt self.assertEqual(policy.to_dict(), self.sample_config_4pd) def test_03_from_source(self): - "test CryptPolicy.from_source() constructor" + """test CryptPolicy.from_source() constructor""" # pass it a path policy = CryptPolicy.from_source(self.sample_config_1s_path) self.assertEqual(policy.to_dict(), self.sample_config_1pd) @@ -339,7 +339,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt self.assertRaises(TypeError, CryptPolicy.from_source, []) def test_04_from_sources(self): - "test CryptPolicy.from_sources() constructor" + """test CryptPolicy.from_sources() constructor""" # pass it empty list self.assertRaises(ValueError, CryptPolicy.from_sources, []) @@ -358,7 +358,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt self.assertEqual(policy.to_dict(), self.sample_config_123pd) def test_05_replace(self): - "test CryptPolicy.replace() constructor" + """test CryptPolicy.replace() constructor""" p1 = CryptPolicy(**self.sample_config_1pd) @@ -375,7 +375,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt self.assertEqual(p3.to_dict(), self.sample_config_123pd) def test_06_forbidden(self): - "test CryptPolicy() forbidden kwds" + """test CryptPolicy() forbidden kwds""" # salt not allowed to be set self.assertRaises(KeyError, CryptPolicy, @@ -397,7 +397,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt # reading #=================================================================== def test_10_has_schemes(self): - "test has_schemes() method" + """test has_schemes() method""" p1 = CryptPolicy(**self.sample_config_1pd) self.assertTrue(p1.has_schemes()) @@ -406,7 +406,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt self.assertTrue(not p3.has_schemes()) def test_11_iter_handlers(self): - "test iter_handlers() method" + """test iter_handlers() method""" p1 = CryptPolicy(**self.sample_config_1pd) s = self.sample_config_1prd['schemes'] @@ -416,7 +416,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt self.assertEqual(list(p3.iter_handlers()), []) def test_12_get_handler(self): - "test get_handler() method" + """test get_handler() method""" p1 = CryptPolicy(**self.sample_config_1pd) @@ -431,7 +431,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt self.assertIs(p1.get_handler(), hash.md5_crypt) def test_13_get_options(self): - "test get_options() method" + """test get_options() method""" p12 = CryptPolicy(**self.sample_config_12pd) @@ -470,7 +470,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt )) def test_14_handler_is_deprecated(self): - "test handler_is_deprecated() method" + """test handler_is_deprecated() method""" pa = CryptPolicy(**self.sample_config_1pd) pb = CryptPolicy(**self.sample_config_5pd) @@ -500,7 +500,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt self.assertTrue(pc.handler_is_deprecated("des_crypt", "user")) def test_15_min_verify_time(self): - "test get_min_verify_time() method" + """test get_min_verify_time() method""" # silence deprecation warnings for min verify time warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -524,20 +524,20 @@ admin__context__deprecated = des_crypt, bsdi_crypt # serialization #=================================================================== def test_20_iter_config(self): - "test iter_config() method" + """test iter_config() method""" p5 = CryptPolicy(**self.sample_config_5pd) self.assertEqual(dict(p5.iter_config()), self.sample_config_5pd) self.assertEqual(dict(p5.iter_config(resolve=True)), self.sample_config_5prd) self.assertEqual(dict(p5.iter_config(ini=True)), self.sample_config_5pid) def test_21_to_dict(self): - "test to_dict() method" + """test to_dict() method""" p5 = CryptPolicy(**self.sample_config_5pd) self.assertEqual(p5.to_dict(), self.sample_config_5pd) self.assertEqual(p5.to_dict(resolve=True), self.sample_config_5prd) def test_22_to_string(self): - "test to_string() method" + """test to_string() method""" pa = CryptPolicy(**self.sample_config_5pd) s = pa.to_string() # NOTE: can't compare string directly, ordering etc may not match pb = CryptPolicy.from_string(s) @@ -554,7 +554,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt # CryptContext #============================================================================= class CryptContextTest(TestCase): - "test CryptContext class" + """test CryptContext class""" descriptionPrefix = "CryptContext" def setUp(self): @@ -571,7 +571,7 @@ class CryptContextTest(TestCase): # constructor #=================================================================== def test_00_constructor(self): - "test constructor" + """test constructor""" # create crypt context using handlers cc = CryptContext([hash.md5_crypt, hash.bsdi_crypt, hash.des_crypt]) c,b,a = cc.policy.iter_handlers() @@ -600,7 +600,7 @@ class CryptContextTest(TestCase): self.assertRaises(TypeError, CryptContext, policy='x') def test_01_replace(self): - "test replace()" + """test replace()""" cc = CryptContext(["md5_crypt", "bsdi_crypt", "des_crypt"]) self.assertIs(cc.policy.get_handler(), hash.md5_crypt) @@ -617,7 +617,7 @@ class CryptContextTest(TestCase): self.assertIs(cc3.policy.get_handler(), hash.bsdi_crypt) def test_02_no_handlers(self): - "test no handlers" + """test no handlers""" # check constructor... cc = CryptContext() @@ -653,7 +653,7 @@ class CryptContextTest(TestCase): ) def test_12_hash_needs_update(self): - "test hash_needs_update() method" + """test hash_needs_update() method""" cc = CryptContext(**self.sample_policy_1) # check deprecated scheme @@ -672,7 +672,7 @@ class CryptContextTest(TestCase): # border cases #=================================================================== def test_30_nonstring_hash(self): - "test non-string hash values cause error" + """test non-string hash values cause error""" # # test hash=None or some other non-string causes TypeError # and that explicit-scheme code path behaves the same. @@ -716,7 +716,7 @@ class LazyCryptContextTest(TestCase): warnings.filterwarnings("ignore", ".*(CryptPolicy|context\.policy).*(has|have) been deprecated.*") def test_kwd_constructor(self): - "test plain kwds" + """test plain kwds""" self.assertFalse(has_crypt_handler("dummy_2")) register_crypt_handler_path("dummy_2", "passlib.tests.test_context") @@ -730,7 +730,7 @@ class LazyCryptContextTest(TestCase): self.assertTrue(has_crypt_handler("dummy_2", True)) def test_callable_constructor(self): - "test create_policy() hook, returning CryptPolicy" + """test create_policy() hook, returning CryptPolicy""" self.assertFalse(has_crypt_handler("dummy_2")) register_crypt_handler_path("dummy_2", "passlib.tests.test_context") diff --git a/passlib/tests/test_ext_django.py b/passlib/tests/test_ext_django.py index 33a67385..d89b2b4b 100644 --- a/passlib/tests/test_ext_django.py +++ b/passlib/tests/test_ext_django.py @@ -2,7 +2,10 @@ #============================================================================= # imports #============================================================================= +# NOTE: double __future__ is workaround for py2.5.0 bug, +# per https://bitbucket.org/ecollins/passlib/issues/58#comment-20589295 from __future__ import with_statement +from __future__ import absolute_import # core import logging; log = logging.getLogger(__name__) import sys @@ -76,7 +79,7 @@ if has_django: from django.contrib.auth.models import User class FakeUser(User): - "mock user object for use in testing" + """mock user object for use in testing""" # NOTE: this mainly just overrides .save() to test commit behavior. @memoized_property @@ -109,7 +112,27 @@ def create_mock_setter(): # work up stock django config #============================================================================= sample_hashes = {} # override sample hashes used in test cases -if DJANGO_VERSION >= (1,6): +if DJANGO_VERSION >= (1,8): + stock_config = django16_context.to_dict() + stock_config.update( + deprecated="auto", + django_pbkdf2_sha1__default_rounds=20000, + django_pbkdf2_sha256__default_rounds=20000, + ) + sample_hashes.update( + django_pbkdf2_sha256=("not a password", "pbkdf2_sha256$20000$arJ31mmmlSmO$XNBTUKe4UCUGPeHTmXpYjaKmJaDGAsevd0LWvBtzP18="), + ) +elif DJANGO_VERSION >= (1,7): + stock_config = django16_context.to_dict() + stock_config.update( + deprecated="auto", + django_pbkdf2_sha1__default_rounds=15000, + django_pbkdf2_sha256__default_rounds=15000, + ) + sample_hashes.update( + django_pbkdf2_sha256=("not a password", "pbkdf2_sha256$15000$xb2YnidpItz1$uHvLChIjUDc5HVUfQnE6lDMbgkTAiSYknGCtjuX4AVo="), + ) +elif DJANGO_VERSION >= (1,6): stock_config = django16_context.to_dict() stock_config.update( deprecated="auto", @@ -139,7 +162,7 @@ else: # test utils #============================================================================= class _ExtensionSupport(object): - "support funcs for loading/unloading extension" + """support funcs for loading/unloading extension""" #=================================================================== # support funcs #=================================================================== @@ -182,7 +205,7 @@ class _ExtensionSupport(object): # verify current patch state #=================================================================== def assert_unpatched(self): - "test that django is in unpatched state" + """test that django is in unpatched state""" # make sure we aren't currently patched mod = sys.modules.get("passlib.ext.django.models") self.assertFalse(mod and mod._patched, "patch should not be enabled") @@ -199,7 +222,7 @@ class _ExtensionSupport(object): (obj, attr, source)) def assert_patched(self, context=None): - "helper to ensure django HAS been patched, and is using specified config" + """helper to ensure django HAS been patched, and is using specified config""" # make sure we're currently patched mod = sys.modules.get("passlib.ext.django.models") self.assertTrue(mod and mod._patched, "patch should have been enabled") @@ -225,9 +248,8 @@ class _ExtensionSupport(object): # load / unload the extension (and verify it worked) #=================================================================== _config_keys = ["PASSLIB_CONFIG", "PASSLIB_CONTEXT", "PASSLIB_GET_CATEGORY"] - def load_extension(self, check=True, **kwds): - "helper to load extension with specified config & patch django" + """helper to load extension with specified config & patch django""" self.unload_extension() if check: config = kwds.get("PASSLIB_CONFIG") or kwds.get("PASSLIB_CONTEXT") @@ -239,7 +261,7 @@ class _ExtensionSupport(object): self.assert_patched(context=config) def unload_extension(self): - "helper to remove patches and unload extension" + """helper to remove patches and unload extension""" # remove patches and unload module mod = sys.modules.get("passlib.ext.django.models") if mod: @@ -275,7 +297,7 @@ class _ExtensionTest(TestCase, _ExtensionSupport): # extension tests #============================================================================= class DjangoBehaviorTest(_ExtensionTest): - "tests model to verify it matches django's behavior" + """tests model to verify it matches django's behavior""" descriptionPrefix = "verify django behavior" patched = False config = stock_config @@ -593,7 +615,9 @@ class DjangoBehaviorTest(_ExtensionTest): if testcase.is_disabled_handler: continue if not has_active_backend(handler): - assert scheme == "django_bcrypt" + # TODO: move this above get_handler_case(), + # and omit MissingBackendError check. + assert scheme in ["django_bcrypt", "django_bcrypt_sha256"], "%r scheme should always have active backend" % scheme continue try: secret, hash = sample_hashes[scheme] @@ -680,7 +704,7 @@ class DjangoBehaviorTest(_ExtensionTest): self.assertEqual(name, scheme) class ExtensionBehaviorTest(DjangoBehaviorTest): - "test model to verify passlib.ext.django conforms to it" + """test model to verify passlib.ext.django conforms to it""" descriptionPrefix = "verify extension behavior" patched = True config = dict( @@ -700,7 +724,7 @@ class DjangoExtensionTest(_ExtensionTest): # monkeypatch testing #=================================================================== def test_00_patch_control(self): - "test set_django_password_context patch/unpatch" + """test set_django_password_context patch/unpatch""" # check config="disabled" self.load_extension(PASSLIB_CONFIG="disabled", check=False) @@ -726,7 +750,7 @@ class DjangoExtensionTest(_ExtensionTest): self.unload_extension() def test_01_overwrite_detection(self): - "test detection of foreign monkeypatching" + """test detection of foreign monkeypatching""" # NOTE: this sets things up, and spot checks two methods, # this should be enough to verify patch manager is working. # TODO: test unpatch behavior honors flag. @@ -756,7 +780,7 @@ class DjangoExtensionTest(_ExtensionTest): models.check_password = orig def test_02_handler_wrapper(self): - "test Hasher-compatible handler wrappers" + """test Hasher-compatible handler wrappers""" if not has_django14: raise self.skipTest("Django >= 1.4 not installed") from passlib.ext.django.utils import get_passlib_hasher @@ -795,7 +819,7 @@ class DjangoExtensionTest(_ExtensionTest): # PASSLIB_CONFIG settings #=================================================================== def test_11_config_disabled(self): - "test PASSLIB_CONFIG='disabled'" + """test PASSLIB_CONFIG='disabled'""" # test config=None (deprecated) with self.assertWarningList("PASSLIB_CONFIG=None is deprecated"): self.load_extension(PASSLIB_CONFIG=None, check=False) @@ -806,7 +830,7 @@ class DjangoExtensionTest(_ExtensionTest): self.assert_unpatched() def test_12_config_presets(self): - "test PASSLIB_CONFIG=''" + """test PASSLIB_CONFIG=''""" # test django presets self.load_extension(PASSLIB_CONTEXT="django-default", check=False) if DJANGO_VERSION >= (1,6): @@ -824,7 +848,7 @@ class DjangoExtensionTest(_ExtensionTest): self.assert_patched(django14_context) def test_13_config_defaults(self): - "test PASSLIB_CONFIG default behavior" + """test PASSLIB_CONFIG default behavior""" # check implicit default from passlib.ext.django.utils import PASSLIB_DEFAULT default = CryptContext.from_string(PASSLIB_DEFAULT) @@ -840,7 +864,7 @@ class DjangoExtensionTest(_ExtensionTest): self.assert_patched(PASSLIB_DEFAULT) def test_14_config_invalid(self): - "test PASSLIB_CONFIG type checks" + """test PASSLIB_CONFIG type checks""" update_settings(PASSLIB_CONTEXT=123, PASSLIB_CONFIG=UNSET) self.assertRaises(TypeError, __import__, 'passlib.ext.django.models') @@ -852,7 +876,7 @@ class DjangoExtensionTest(_ExtensionTest): # PASSLIB_GET_CATEGORY setting #=================================================================== def test_21_category_setting(self): - "test PASSLIB_GET_CATEGORY parameter" + """test PASSLIB_GET_CATEGORY parameter""" # define config where rounds can be used to detect category config = dict( schemes = ["sha256_crypt"], @@ -863,7 +887,7 @@ class DjangoExtensionTest(_ExtensionTest): from passlib.hash import sha256_crypt def run(**kwds): - "helper to take in user opts, return rounds used in password" + """helper to take in user opts, return rounds used in password""" user = FakeUser(**kwds) user.set_password("stub") return sha256_crypt.from_string(user.password).rounds @@ -920,14 +944,42 @@ class ContextWithHook(CryptContext): # hack up the some of the real django tests to run w/ extension loaded, # to ensure we mimic their behavior. -if has_django14: - from passlib.tests.utils import patchAttr - if DJANGO_VERSION >= (1,6): - from django.contrib.auth.tests import test_hashers as _thmod +# however, the django tests were moved out of the package, and into a source-only location +# as of django 1.7. so we disable tests from that point on unless test-runner specifies +test_hashers_mod = None +hashers_skip_msg = None +if TEST_MODE(max="quick"): + hashers_skip_msg = "requires >= 'default' test mode" +elif DJANGO_VERSION >= (1, 7): + import os + import sys + source_path = os.environ.get("PASSLIB_TESTS_DJANGO_SOURCE_PATH") + if source_path: + if not os.path.exists(source_path): + raise EnvironmentError("django source path not found: %r" % source_path) + if not all(os.path.exists(os.path.join(source_path, name)) + for name in ["django", "tests"]): + raise EnvironmentError("invalid django source path: %r" % source_path) + log.info("using django tests from source path: %r", source_path) + tests_path = os.path.join(source_path, "tests") + sys.path.insert(0, tests_path) + from auth_tests import test_hashers as test_hashers_mod + sys.path.remove(tests_path) else: - from django.contrib.auth.tests import hashers as _thmod + hashers_skip_msg = "requires PASSLIB_TESTS_DJANGO_SOURCE_PATH to be set for django 1.7+" +elif DJANGO_VERSION >= (1, 6): + from django.contrib.auth.tests import test_hashers as test_hashers_mod +elif DJANGO_VERSION >= (1, 4): + from django.contrib.auth.tests import hashers as test_hashers_mod +else: + hashers_skip_msg = "requires django 1.4+ to be present" - class HashersTest(_thmod.TestUtilsHashPass, _ExtensionSupport): +# hack up the some of the real django tests to run w/ extension loaded, +# to ensure we mimic their behavior. +if test_hashers_mod: + from passlib.tests.utils import patchAttr + + class HashersTest(test_hashers_mod.TestUtilsHashPass, _ExtensionSupport): """run django's hasher unittests against passlib's extension and workalike implementations""" def setUp(self): @@ -942,17 +994,17 @@ if has_django14: "check_password", "identify_hasher", "get_hasher"]: - patchAttr(self, _thmod, attr, getattr(hashers, attr)) + patchAttr(self, test_hashers_mod, attr, getattr(hashers, attr)) - # django 1.5 tests expect empty django_des_crypt salt field - if DJANGO_VERSION > (1,4): + # django 1.4 tests expect empty django_des_crypt salt field + if DJANGO_VERSION >= (1,4): from passlib.hash import django_des_crypt patchAttr(self, django_des_crypt, "use_duplicate_salt", False) # hack: need password_context to keep up to date with hasher.iterations if DJANGO_VERSION >= (1,6): def update_hook(self): - rounds = _thmod.get_hasher("pbkdf2_sha256").iterations + rounds = test_hashers_mod.get_hasher("pbkdf2_sha256").iterations self.update( django_pbkdf2_sha256__min_rounds=rounds, django_pbkdf2_sha256__default_rounds=rounds, @@ -967,9 +1019,12 @@ if has_django14: def tearDown(self): self.unload_extension() super(HashersTest, self).tearDown() +else: + class HashersTest(TestCase): - HashersTest = skipUnless(TEST_MODE("default"), - "requires >= 'default' test mode")(HashersTest) + def test_external_django_hasher_tests(self): + """external django hasher tests""" + raise self.skipTest(hashers_skip_msg) #============================================================================= # eof diff --git a/passlib/tests/test_handlers.py b/passlib/tests/test_handlers.py index d300a848..142e9e0d 100644 --- a/passlib/tests/test_handlers.py +++ b/passlib/tests/test_handlers.py @@ -30,7 +30,7 @@ UPASS_TABLE = u("t\u00e1\u0411\u2113\u0259") PASS_TABLE_UTF8 = b('t\xc3\xa1\xd0\x91\xe2\x84\x93\xc9\x99') # utf-8 def get_handler_case(scheme): - "return HandlerCase instance for scheme, used by other tests" + """return HandlerCase instance for scheme, used by other tests""" from passlib.registry import get_crypt_handler handler = get_crypt_handler(scheme) if hasattr(handler, "backends") and not hasattr(handler, "wrapped") and handler.name != "django_bcrypt_sha256": @@ -122,7 +122,7 @@ class bigcrypt_test(HandlerCase): # bsdi crypt #============================================================================= class _bsdi_crypt_test(HandlerCase): - "test BSDiCrypt algorithm" + """test BSDiCrypt algorithm""" handler = hash.bsdi_crypt known_correct_hashes = [ @@ -291,7 +291,7 @@ class cisco_type7_test(HandlerCase): ] def test_90_decode(self): - "test cisco_type7.decode()" + """test cisco_type7.decode()""" from passlib.utils import to_unicode, to_bytes handler = self.handler @@ -305,7 +305,7 @@ class cisco_type7_test(HandlerCase): '0958EDC8A9F495F6F8A5FD', 'ascii') def test_91_salt(self): - "test salt value border cases" + """test salt value border cases""" handler = self.handler self.assertRaises(TypeError, handler, salt=None) handler(salt=None, use_defaults=True) @@ -348,7 +348,7 @@ class crypt16_test(HandlerCase): # des crypt #============================================================================= class _des_crypt_test(HandlerCase): - "test des-crypt algorithm" + """test des-crypt algorithm""" handler = hash.des_crypt secret_size = 8 @@ -396,7 +396,7 @@ des_crypt_os_crypt_test, des_crypt_builtin_test = \ # fshp #============================================================================= class fshp_test(HandlerCase): - "test fshp algorithm" + """test fshp algorithm""" handler = hash.fshp known_correct_hashes = [ @@ -449,7 +449,7 @@ class fshp_test(HandlerCase): ] def test_90_variant(self): - "test variant keyword" + """test variant keyword""" handler = self.handler kwds = dict(salt=b('a'), rounds=1) @@ -546,7 +546,7 @@ class htdigest_test(UserHandlerMixin, HandlerCase): raise self.skipTest("test case doesn't support 'realm' keyword") def populate_context(self, secret, kwds): - "insert username into kwds" + """insert username into kwds""" if isinstance(secret, tuple): secret, user, realm = secret else: @@ -702,7 +702,7 @@ ldap_sha1_crypt_os_crypt_test, = _ldap_sha1_crypt_test.create_backend_cases(["os class ldap_pbkdf2_test(TestCase): def test_wrappers(self): - "test ldap pbkdf2 wrappers" + """test ldap pbkdf2 wrappers""" self.assertTrue( hash.ldap_pbkdf2_sha1.verify( @@ -768,7 +768,7 @@ class lmhash_test(EncodingHandlerMixin, HandlerCase): ] def test_90_raw(self): - "test lmhash.raw() method" + """test lmhash.raw() method""" from binascii import unhexlify from passlib.utils.compat import str_to_bascii lmhash = self.handler @@ -1134,7 +1134,7 @@ class mysql323_test(HandlerCase): ] def test_90_whitespace(self): - "check whitespace is ignored per spec" + """check whitespace is ignored per spec""" h = self.do_encrypt("mypass") h2 = self.do_encrypt("my pass") self.assertEqual(h, h2) @@ -1575,7 +1575,7 @@ class scram_test(HandlerCase): warnings.filterwarnings("ignore", r"norm_hash_name\(\): unknown hash") def test_90_algs(self): - "test parsing of 'algs' setting" + """test parsing of 'algs' setting""" defaults = dict(salt=b('A')*10, rounds=1000) def parse(algs, **kwds): for k in defaults: @@ -1605,7 +1605,7 @@ class scram_test(HandlerCase): checksum={"sha-1": b("\x00"*20)}) def test_90_checksums(self): - "test internal parsing of 'checksum' keyword" + """test internal parsing of 'checksum' keyword""" # check non-bytes checksum values are rejected self.assertRaises(TypeError, self.handler, use_defaults=True, checksum={'sha-1': u('X')*20}) @@ -1617,7 +1617,7 @@ class scram_test(HandlerCase): # XXX: anything else that's not tested by the other code already? def test_91_extract_digest_info(self): - "test scram.extract_digest_info()" + """test scram.extract_digest_info()""" edi = self.handler.extract_digest_info # return appropriate value or throw KeyError @@ -1635,7 +1635,7 @@ class scram_test(HandlerCase): self.assertRaises(ValueError, edi, c, "ddd") def test_92_extract_digest_algs(self): - "test scram.extract_digest_algs()" + """test scram.extract_digest_algs()""" eda = self.handler.extract_digest_algs self.assertEqual(eda('$scram$4096$QSXCR.Q6sek8bf92$' @@ -1653,10 +1653,9 @@ class scram_test(HandlerCase): ["sha-1","sha-256","sha-512"]) def test_93_derive_digest(self): - "test scram.derive_digest()" + """test scram.derive_digest()""" # NOTE: this just does a light test, since derive_digest # is used by encrypt / verify, and is tested pretty well via those. - hash = self.handler.derive_digest # check various encodings of password work. @@ -1679,7 +1678,7 @@ class scram_test(HandlerCase): self.assertRaises(TypeError, hash, "IX", u('\x01'), 1000, 'md5') def test_94_saslprep(self): - "test encrypt/verify use saslprep" + """test encrypt/verify use saslprep""" # NOTE: this just does a light test that saslprep() is being # called in various places, relying in saslpreps()'s tests # to verify full normalization behavior. @@ -1699,7 +1698,7 @@ class scram_test(HandlerCase): self.assertRaises(ValueError, self.do_verify, u("\uFDD0"), h) def test_95_context_algs(self): - "test handling of 'algs' in context object" + """test handling of 'algs' in context object""" handler = self.handler from passlib.context import CryptContext c1 = CryptContext(["scram"], scram__algs="sha1,md5") @@ -1715,7 +1714,7 @@ class scram_test(HandlerCase): self.assertTrue(c2.needs_update(h)) def test_96_full_verify(self): - "test verify(full=True) flag" + """test verify(full=True) flag""" def vpart(s, h): return self.handler.verify(s, h) def vfull(s, h): @@ -1803,7 +1802,6 @@ sha1_crypt_os_crypt_test, sha1_crypt_builtin_test = \ # NOTE: all roundup hashes use PrefixWrapper, # so there's nothing natively to test. # so we just have a few quick cases... -from passlib.handlers import roundup class RoundupTest(TestCase): @@ -2125,7 +2123,6 @@ class sun_md5_crypt_test(HandlerCase): ("solaris", True), ("freebsd|openbsd|netbsd|linux|darwin", False), ] - def do_verify(self, secret, hash): # override to fake error for "$..." hash strings listed in known_config. # these have to be hash strings, in order to test bare salt issue. @@ -2162,7 +2159,7 @@ class unix_disabled_test(HandlerCase): super(unix_disabled_test, self).test_76_hash_border() def test_90_special(self): - "test marker option & special behavior" + """test marker option & special behavior""" handler = self.handler # preserve hash if provided @@ -2194,16 +2191,16 @@ class unix_fallback_test(HandlerCase): warnings.filterwarnings("ignore", "'unix_fallback' is deprecated") def test_90_wildcard(self): - "test enable_wildcard flag" + """test enable_wildcard flag""" h = self.handler self.assertTrue(h.verify('password','', enable_wildcard=True)) self.assertFalse(h.verify('password','')) - for c in ("!*x"): + for c in "!*x": self.assertFalse(h.verify('password',c, enable_wildcard=True)) self.assertFalse(h.verify('password',c)) def test_91_preserves_existing(self): - "test preserves existing disabled hash" + """test preserves existing disabled hash""" handler = self.handler # use marker if no hash diff --git a/passlib/tests/test_handlers_bcrypt.py b/passlib/tests/test_handlers_bcrypt.py index b12759d1..d92b1ffe 100644 --- a/passlib/tests/test_handlers_bcrypt.py +++ b/passlib/tests/test_handlers_bcrypt.py @@ -70,6 +70,20 @@ class _bcrypt_test(HandlerCase): (b('\xa3'), '$2y$05$/OK.fbVrR/bpIqNJ5ianF.Sa7shbm4.OzKpvFnX1pQLmQW96oUlCq'), + # + # bsd wraparound bug (fixed in 2b) + # + + # NOTE: if backend is vulnerable, password will hash the same as '0'*72 + # ("$2a$04$R1lJ2gkNaoPGdafE.H.16.nVyh2niHsGJhayOHLMiXlI45o8/DU.6"), + # rather than same as ("0123456789"*8)[:72] + # 255 should be sufficient, but checking + (("0123456789"*26)[:254], '$2a$04$R1lJ2gkNaoPGdafE.H.16.1MKHPvmKwryeulRe225LKProWYwt9Oi'), + (("0123456789"*26)[:255], '$2a$04$R1lJ2gkNaoPGdafE.H.16.1MKHPvmKwryeulRe225LKProWYwt9Oi'), + (("0123456789"*26)[:256], '$2a$04$R1lJ2gkNaoPGdafE.H.16.1MKHPvmKwryeulRe225LKProWYwt9Oi'), + (("0123456789"*26)[:257], '$2a$04$R1lJ2gkNaoPGdafE.H.16.1MKHPvmKwryeulRe225LKProWYwt9Oi'), + + # # from py-bcrypt tests # @@ -88,6 +102,11 @@ class _bcrypt_test(HandlerCase): # ensures utf-8 used for unicode (UPASS_TABLE, '$2a$05$Z17AXnnlpzddNUvnC6cZNOSwMA/8oNiKnHTHTwLlBijfucQQlHjaG'), + + # ensure 2b support + (UPASS_TABLE, + '$2b$05$Z17AXnnlpzddNUvnC6cZNOSwMA/8oNiKnHTHTwLlBijfucQQlHjaG'), + ] if TEST_MODE("full"): @@ -116,7 +135,7 @@ class _bcrypt_test(HandlerCase): known_unidentified_hashes = [ # invalid minor version - "$2b$12$EXRkfkdmXnagzds2SSitu.MW9.gAVqa9eLS1//RYtYCmB1eLHg.9q", + "$2f$12$EXRkfkdmXnagzds2SSitu.MW9.gAVqa9eLS1//RYtYCmB1eLHg.9q", "$2`$12$EXRkfkdmXnagzds2SSitu.MW9.gAVqa9eLS1//RYtYCmB1eLHg.9q", ] @@ -156,6 +175,7 @@ class _bcrypt_test(HandlerCase): self.addCleanup(os.environ.__delitem__, key) os.environ[key] = "enabled" super(_bcrypt_test, self).setUp() + warnings.filterwarnings("ignore", ".*backend is vulnerable to the bsd wraparound bug.*") def populate_settings(self, kwds): # builtin is still just way too slow. @@ -167,11 +187,13 @@ class _bcrypt_test(HandlerCase): # fuzz testing #=================================================================== def os_supports_ident(self, hash): - "check if OS crypt is expected to support given ident" + """check if OS crypt is expected to support given ident""" if hash is None: return True # most OSes won't support 2x/2y # XXX: definitely not the BSDs, but what about the linux variants? + # XXX: replace this all with 'handler._lacks_2{x}_support' feature detection? + # could even just do call to safe_crypt(ident + salt) and see what we get from passlib.handlers.bcrypt import IDENT_2X, IDENT_2Y if hash.startswith(IDENT_2X) or hash.startswith(IDENT_2Y): return False @@ -179,21 +201,22 @@ class _bcrypt_test(HandlerCase): def fuzz_verifier_bcrypt(self): # test against bcrypt, if available - from passlib.handlers.bcrypt import IDENT_2, IDENT_2A, IDENT_2X, IDENT_2Y + from passlib.handlers.bcrypt import IDENT_2, IDENT_2A, IDENT_2B, IDENT_2X, IDENT_2Y, _detect_pybcrypt from passlib.utils import to_native_str, to_bytes try: import bcrypt except ImportError: return - if not hasattr(bcrypt, "_ffi"): + if _detect_pybcrypt(): return def check_bcrypt(secret, hash): - "bcrypt" + """bcrypt""" secret = to_bytes(secret, self.fuzz_password_encoding) - #if hash.startswith(IDENT_2Y): - # hash = IDENT_2A + hash[4:] - if hash.startswith(IDENT_2): - # bcryptor doesn't support $2$ hashes; but we can fake it + if hash.startswith(IDENT_2B): + # bcrypt <1.1 lacks 2b support + hash = IDENT_2A + hash[4:] + elif hash.startswith(IDENT_2): + # bcrypt doesn't support $2$ hashes; but we can fake it # using the $2a$ algorithm, by repeating the password until # it's 72 chars in length. hash = IDENT_2A + hash[3:] @@ -203,23 +226,25 @@ class _bcrypt_test(HandlerCase): try: return bcrypt.hashpw(secret, hash) == hash except ValueError: - raise ValueError("bcrypt rejected hash: %r" % (hash,)) + raise ValueError("bcrypt rejected hash: %r (secret=%r)" % (hash, secret)) return check_bcrypt def fuzz_verifier_pybcrypt(self): # test against py-bcrypt, if available - from passlib.handlers.bcrypt import IDENT_2, IDENT_2A, IDENT_2X, IDENT_2Y + from passlib.handlers.bcrypt import IDENT_2, IDENT_2A, IDENT_2B, IDENT_2X, IDENT_2Y, _detect_pybcrypt from passlib.utils import to_native_str try: import bcrypt except ImportError: return - if hasattr(bcrypt, "_ffi"): + if not _detect_pybcrypt(): return def check_pybcrypt(secret, hash): - "pybcrypt" + """pybcrypt""" secret = to_native_str(secret, self.fuzz_password_encoding) - if hash.startswith(IDENT_2Y): + if len(secret) > 200: # vulnerable to wraparound bug + secret = secret[:200] + if hash.startswith((IDENT_2B, IDENT_2Y)): hash = IDENT_2A + hash[4:] try: return bcrypt.hashpw(secret, hash) == hash @@ -229,16 +254,16 @@ class _bcrypt_test(HandlerCase): def fuzz_verifier_bcryptor(self): # test against bcryptor, if available - from passlib.handlers.bcrypt import IDENT_2, IDENT_2A, IDENT_2Y + from passlib.handlers.bcrypt import IDENT_2, IDENT_2A, IDENT_2Y, IDENT_2B from passlib.utils import to_native_str try: from bcryptor.engine import Engine except ImportError: return def check_bcryptor(secret, hash): - "bcryptor" + """bcryptor""" secret = to_native_str(secret, self.fuzz_password_encoding) - if hash.startswith(IDENT_2Y): + if hash.startswith((IDENT_2B, IDENT_2Y)): hash = IDENT_2A + hash[4:] elif hash.startswith(IDENT_2): # bcryptor doesn't support $2$ hashes; but we can fake it @@ -297,7 +322,7 @@ class _bcrypt_test(HandlerCase): ] def test_90_bcrypt_padding(self): - "test passlib correctly handles bcrypt padding bits" + """test passlib correctly handles bcrypt padding bits""" self.require_TEST_MODE("full") # # prevents reccurrence of issue 25 (https://code.google.com/p/passlib/issues/detail?id=25) @@ -327,10 +352,12 @@ class _bcrypt_test(HandlerCase): self.assertEqual(hash, "$2a$05$" + "." * 22) # - # make sure genhash() corrects input + # test public methods against good & bad hashes # samples = self.known_incorrect_padding for pwd, bad, good in samples: + + # make sure genhash() corrects bad configs, leaves good unchanged with self.assertWarningList([corr_desc]): self.assertEqual(bcrypt.genhash(pwd, bad), good) with self.assertWarningList([]): @@ -437,6 +464,7 @@ class _bcrypt_sha256_test(HandlerCase): self.addCleanup(os.environ.__delitem__, key) os.environ[key] = "enabled" super(_bcrypt_sha256_test, self).setUp() + warnings.filterwarnings("ignore", ".*backend is vulnerable to the bsd wraparound bug.*") def populate_settings(self, kwds): # builtin is still just way too slow. diff --git a/passlib/tests/test_handlers_django.py b/passlib/tests/test_handlers_django.py index 2d516ae2..248be3ac 100644 --- a/passlib/tests/test_handlers_django.py +++ b/passlib/tests/test_handlers_django.py @@ -43,7 +43,7 @@ class _DjangoHelper(object): return None from django.contrib.auth.models import check_password def verify_django(secret, hash): - "django/check_password" + """django/check_password""" if (1,4) <= DJANGO_VERSION < (1,6) and not secret: return "skip" if self.handler.name == "django_bcrypt" and hash.startswith("bcrypt$$2y$"): @@ -57,7 +57,7 @@ class _DjangoHelper(object): return verify_django def test_90_django_reference(self): - "run known correct hashes through Django's check_password()" + """run known correct hashes through Django's check_password()""" from passlib.tests.test_ext_django import DJANGO_VERSION # check_password() not added until 1.0 min_django_version = max(self.min_django_version, (1,0)) @@ -81,7 +81,7 @@ class _DjangoHelper(object): django_has_encoding_glitch = False def test_91_django_generation(self): - "test against output of Django's make_password()" + """test against output of Django's make_password()""" from passlib.tests.test_ext_django import DJANGO_VERSION # make_password() not added until 1.4 min_django_version = max(self.min_django_version, (1,4)) @@ -106,7 +106,7 @@ class _DjangoHelper(object): self.assertFalse(self.do_verify(other, hash)) class django_disabled_test(HandlerCase): - "test django_disabled" + """test django_disabled""" handler = hash.django_disabled is_disabled_handler = True @@ -123,7 +123,7 @@ class django_disabled_test(HandlerCase): ] class django_des_crypt_test(HandlerCase, _DjangoHelper): - "test django_des_crypt" + """test django_des_crypt""" handler = hash.django_des_crypt secret_size = 8 @@ -164,7 +164,7 @@ class django_des_crypt_test(HandlerCase, _DjangoHelper): ] class django_salted_md5_test(HandlerCase, _DjangoHelper): - "test django_salted_md5" + """test django_salted_md5""" handler = hash.django_salted_md5 django_has_encoding_glitch = True @@ -204,7 +204,7 @@ class django_salted_md5_test(HandlerCase, _DjangoHelper): return randintgauss(lower, upper, default, default*.5) class django_salted_sha1_test(HandlerCase, _DjangoHelper): - "test django_salted_sha1" + """test django_salted_sha1""" handler = hash.django_salted_sha1 django_has_encoding_glitch = True @@ -236,7 +236,7 @@ class django_salted_sha1_test(HandlerCase, _DjangoHelper): fuzz_setting_salt_size = get_method_function(django_salted_md5_test.fuzz_setting_salt_size) class django_pbkdf2_sha256_test(HandlerCase, _DjangoHelper): - "test django_pbkdf2_sha256" + """test django_pbkdf2_sha256""" handler = hash.django_pbkdf2_sha256 min_django_version = (1,4) @@ -251,7 +251,7 @@ class django_pbkdf2_sha256_test(HandlerCase, _DjangoHelper): ] class django_pbkdf2_sha1_test(HandlerCase, _DjangoHelper): - "test django_pbkdf2_sha1" + """test django_pbkdf2_sha1""" handler = hash.django_pbkdf2_sha1 min_django_version = (1,4) @@ -266,7 +266,7 @@ class django_pbkdf2_sha1_test(HandlerCase, _DjangoHelper): ] class django_bcrypt_test(HandlerCase, _DjangoHelper): - "test django_bcrypt" + """test django_bcrypt""" handler = hash.django_bcrypt secret_size = 72 min_django_version = (1,4) @@ -303,7 +303,7 @@ django_bcrypt_test = skipUnless(hash.bcrypt.has_backend(), "no bcrypt backends available")(django_bcrypt_test) class django_bcrypt_sha256_test(HandlerCase, _DjangoHelper): - "test django_bcrypt_sha256" + """test django_bcrypt_sha256""" handler = hash.django_bcrypt_sha256 min_django_version = (1,6) forbidden_characters = None diff --git a/passlib/tests/test_hosts.py b/passlib/tests/test_hosts.py index b01a108b..c1572ae5 100644 --- a/passlib/tests/test_hosts.py +++ b/passlib/tests/test_hosts.py @@ -17,7 +17,7 @@ from passlib.tests.utils import TestCase # test predefined app contexts #============================================================================= class HostsTest(TestCase): - "perform general tests to make sure contexts work" + """perform general tests to make sure contexts work""" # NOTE: these tests are not really comprehensive, # since they would do little but duplicate # the presets in apps.py diff --git a/passlib/tests/test_registry.py b/passlib/tests/test_registry.py index 27c5c5c6..cef255bb 100644 --- a/passlib/tests/test_registry.py +++ b/passlib/tests/test_registry.py @@ -12,7 +12,7 @@ import warnings import sys # site # pkg -from passlib import hash, registry +from passlib import hash, registry, exc from passlib.registry import register_crypt_handler, register_crypt_handler_path, \ get_crypt_handler, list_crypt_handlers, _unload_handler_name as unload_handler_name import passlib.utils.handlers as uh @@ -40,14 +40,23 @@ dummy_x = 1 #============================================================================= class RegistryTest(TestCase): - descriptionPrefix = "passlib registry" + descriptionPrefix = "passlib.registry" - def tearDown(self): - for name in ("dummy_0", "dummy_1", "dummy_x", "dummy_bad"): - unload_handler_name(name) + def setUp(self): + super(RegistryTest, self).setUp() + + # backup registry state & restore it after test. + locations = dict(registry._locations) + handlers = dict(registry._handlers) + def restore(): + registry._locations.clear() + registry._locations.update(locations) + registry._handlers.clear() + registry._handlers.update(handlers) + self.addCleanup(restore) def test_hash_proxy(self): - "test passlib.hash proxy object" + """test passlib.hash proxy object""" # check dir works dir(hash) @@ -80,7 +89,7 @@ class RegistryTest(TestCase): self.assertRaises(ValueError, setattr, hash, "dummy_1x", dummy_1) def test_register_crypt_handler_path(self): - "test register_crypt_handler_path()" + """test register_crypt_handler_path()""" # NOTE: this messes w/ internals of registry, shouldn't be used publically. paths = registry._locations @@ -116,6 +125,7 @@ class RegistryTest(TestCase): # check lazy load w/ wrong name fails register_crypt_handler_path('alt_dummy_0', __name__) self.assertRaises(ValueError, get_crypt_handler, "alt_dummy_0") + unload_handler_name("alt_dummy_0") # TODO: check lazy load which calls register_crypt_handler (warning should be issued) sys.modules.pop("passlib.tests._test_bad_register", None) @@ -127,7 +137,7 @@ class RegistryTest(TestCase): self.assertIs(h, tbr.alt_dummy_bad) def test_register_crypt_handler(self): - "test register_crypt_handler()" + """test register_crypt_handler()""" self.assertRaises(TypeError, register_crypt_handler, {}) @@ -158,7 +168,7 @@ class RegistryTest(TestCase): self.assertTrue('dummy_1' in list_crypt_handlers()) def test_get_crypt_handler(self): - "test get_crypt_handler()" + """test get_crypt_handler()""" class dummy_1(uh.StaticHandler): name = "dummy_1" @@ -189,7 +199,7 @@ class RegistryTest(TestCase): self.assertIs(get_crypt_handler(name, None), None) def test_list_crypt_handlers(self): - "test list_crypt_handlers()" + """test list_crypt_handlers()""" from passlib.registry import list_crypt_handlers # check system & private names aren't returned @@ -197,17 +207,25 @@ class RegistryTest(TestCase): passlib.hash.__dict__["_fake"] = "dummy" # so behavior seen under py2x also for name in list_crypt_handlers(): self.assertFalse(name.startswith("_"), "%r: " % name) + unload_handler_name("_fake") def test_handlers(self): - "verify we have tests for all handlers" + """verify we have tests for all builtin handlers""" from passlib.registry import list_crypt_handlers from passlib.tests.test_handlers import get_handler_case for name in list_crypt_handlers(): + # skip some wrappers that don't need independant testing if name.startswith("ldap_") and name[5:] in list_crypt_handlers(): continue if name in ["roundup_plaintext"]: continue - self.assertTrue(get_handler_case(name)) + # check the remaining ones all have a handler + try: + self.assertTrue(get_handler_case(name)) + except exc.MissingBackendError: + if name in ["bcrypt", "bcrypt_sha256"]: # expected to fail on some setups + continue + raise #============================================================================= # eof diff --git a/passlib/tests/test_utils.py b/passlib/tests/test_utils.py index 67834a03..936b7397 100644 --- a/passlib/tests/test_utils.py +++ b/passlib/tests/test_utils.py @@ -22,12 +22,12 @@ def hb(source): # byte funcs #============================================================================= class MiscTest(TestCase): - "tests various parts of utils module" + """tests various parts of utils module""" # NOTE: could test xor_bytes(), but it's exercised well enough by pbkdf2 test def test_compat(self): - "test compat's lazymodule" + """test compat's lazymodule""" from passlib.utils import compat # "" self.assertRegex(repr(compat), @@ -58,7 +58,7 @@ class MiscTest(TestCase): @deprecated_function(deprecated="1.6", removed="1.8") def test_func(*args): - "test docstring" + """test docstring""" return args self.assertTrue(".. deprecated::" in test_func.__doc__) @@ -91,7 +91,7 @@ class MiscTest(TestCase): self.assertIs(prop.im_func, prop.__func__) def test_getrandbytes(self): - "test getrandbytes()" + """test getrandbytes()""" from passlib.utils import getrandbytes, rng def f(*a,**k): return getrandbytes(rng, *a, **k) @@ -104,7 +104,7 @@ class MiscTest(TestCase): self.assertNotEqual(a, b) def test_getrandstr(self): - "test getrandstr()" + """test getrandstr()""" from passlib.utils import getrandstr, rng def f(*a,**k): return getrandstr(rng, *a, **k) @@ -141,7 +141,7 @@ class MiscTest(TestCase): self.assertEqual(len(generate_password(15)), 15) def test_is_crypt_context(self): - "test is_crypt_context()" + """test is_crypt_context()""" from passlib.utils import is_crypt_context from passlib.context import CryptContext cc = CryptContext(["des_crypt"]) @@ -149,7 +149,7 @@ class MiscTest(TestCase): self.assertFalse(not is_crypt_context(cc)) def test_genseed(self): - "test genseed()" + """test genseed()""" import random from passlib.utils import genseed rng = random.Random(genseed()) @@ -163,7 +163,7 @@ class MiscTest(TestCase): rng.seed(genseed(rng)) def test_crypt(self): - "test crypt.crypt() wrappers" + """test crypt.crypt() wrappers""" from passlib.utils import has_crypt, safe_crypt, test_crypt # test everything is disabled @@ -220,7 +220,7 @@ class MiscTest(TestCase): mod._crypt = orig def test_consteq(self): - "test consteq()" + """test consteq()""" # NOTE: this test is kind of over the top, but that's only because # this is used for the critical task of comparing hashes for equality. from passlib.utils import consteq @@ -304,7 +304,7 @@ class MiscTest(TestCase): ## ##print ", ".join(str(c) for c in [run] + times) def test_saslprep(self): - "test saslprep() unicode normalizer" + """test saslprep() unicode normalizer""" self.require_stringprep() from passlib.utils import saslprep as sp @@ -395,10 +395,10 @@ class MiscTest(TestCase): # byte/unicode helpers #============================================================================= class CodecTest(TestCase): - "tests bytes/unicode helpers in passlib.utils" + """tests bytes/unicode helpers in passlib.utils""" def test_bytes(self): - "test b() helper, bytes and native str type" + """test b() helper, bytes and native str type""" if PY3: import builtins self.assertIs(bytes, builtins.bytes) @@ -414,7 +414,7 @@ class CodecTest(TestCase): self.assertEqual(b('\x00\xff'), "\x00\xff") def test_to_bytes(self): - "test to_bytes()" + """test to_bytes()""" from passlib.utils import to_bytes # check unicode inputs @@ -443,7 +443,7 @@ class CodecTest(TestCase): self.assertRaises(TypeError, to_bytes, None) def test_to_unicode(self): - "test to_unicode()" + """test to_unicode()""" from passlib.utils import to_unicode # check unicode inputs @@ -465,7 +465,7 @@ class CodecTest(TestCase): self.assertRaises(TypeError, to_unicode, None) def test_to_native_str(self): - "test to_native_str()" + """test to_native_str()""" from passlib.utils import to_native_str # test plain ascii @@ -496,7 +496,7 @@ class CodecTest(TestCase): self.assertRaises(TypeError, to_native_str, None, 'ascii') def test_is_ascii_safe(self): - "test is_ascii_safe()" + """test is_ascii_safe()""" from passlib.utils import is_ascii_safe self.assertTrue(is_ascii_safe(b("\x00abc\x7f"))) self.assertTrue(is_ascii_safe(u("\x00abc\x7f"))) @@ -504,7 +504,7 @@ class CodecTest(TestCase): self.assertFalse(is_ascii_safe(u("\x00abc\x80"))) def test_is_same_codec(self): - "test is_same_codec()" + """test is_same_codec()""" from passlib.utils import is_same_codec self.assertTrue(is_same_codec(None, None)) @@ -523,7 +523,7 @@ class CodecTest(TestCase): # base64engine #============================================================================= class Base64EngineTest(TestCase): - "test standalone parts of Base64Engine" + """test standalone parts of Base64Engine""" # NOTE: most Base64Engine testing done via _Base64Test subclasses below. def test_constructor(self): @@ -547,7 +547,7 @@ class Base64EngineTest(TestCase): self.assertRaises(ValueError, ab64_decode, "abcde") class _Base64Test(TestCase): - "common tests for all Base64Engine instances" + """common tests for all Base64Engine instances""" #=================================================================== # class attrs #=================================================================== @@ -566,14 +566,14 @@ class _Base64Test(TestCase): # helper to generate bytemap-specific strings def m(self, *offsets): - "generate byte string from offsets" + """generate byte string from offsets""" return join_bytes(self.engine.bytemap[o:o+1] for o in offsets) #=================================================================== # test encode_bytes #=================================================================== def test_encode_bytes(self): - "test encode_bytes() against reference inputs" + """test encode_bytes() against reference inputs""" engine = self.engine encode = engine.encode_bytes for raw, encoded in self.encoded_data: @@ -581,7 +581,7 @@ class _Base64Test(TestCase): self.assertEqual(result, encoded, "encode %r:" % (raw,)) def test_encode_bytes_bad(self): - "test encode_bytes() with bad input" + """test encode_bytes() with bad input""" engine = self.engine encode = engine.encode_bytes self.assertRaises(TypeError, encode, u('\x00')) @@ -591,7 +591,7 @@ class _Base64Test(TestCase): # test decode_bytes #=================================================================== def test_decode_bytes(self): - "test decode_bytes() against reference inputs" + """test decode_bytes() against reference inputs""" engine = self.engine decode = engine.decode_bytes for raw, encoded in self.encoded_data: @@ -599,7 +599,7 @@ class _Base64Test(TestCase): self.assertEqual(result, raw, "decode %r:" % (encoded,)) def test_decode_bytes_padding(self): - "test decode_bytes() ignores padding bits" + """test decode_bytes() ignores padding bits""" bchr = (lambda v: bytes([v])) if PY3 else chr engine = self.engine m = self.m @@ -626,7 +626,7 @@ class _Base64Test(TestCase): "%d/2 bits:" % i) def test_decode_bytes_bad(self): - "test decode_bytes() with bad input" + """test decode_bytes() with bad input""" engine = self.engine decode = engine.decode_bytes @@ -645,7 +645,7 @@ class _Base64Test(TestCase): # encode_bytes+decode_bytes #=================================================================== def test_codec(self): - "test encode_bytes/decode_bytes against random data" + """test encode_bytes/decode_bytes against random data""" engine = self.engine from passlib.utils import getrandbytes, getrandstr saw_zero = False @@ -691,7 +691,7 @@ class _Base64Test(TestCase): self.assertEqual(result, encoded) def test_repair_unused(self): - "test repair_unused()" + """test repair_unused()""" # NOTE: this test relies on encode_bytes() always returning clear # padding bits - which should be ensured by test vectors. from passlib.utils import rng, getrandstr @@ -739,7 +739,7 @@ class _Base64Test(TestCase): ] def test_encode_transposed_bytes(self): - "test encode_transposed_bytes()" + """test encode_transposed_bytes()""" engine = self.engine for result, input, offsets in self.transposed + self.transposed_dups: tmp = engine.encode_transposed_bytes(input, offsets) @@ -749,7 +749,7 @@ class _Base64Test(TestCase): self.assertRaises(TypeError, engine.encode_transposed_bytes, u("a"), []) def test_decode_transposed_bytes(self): - "test decode_transposed_bytes()" + """test decode_transposed_bytes()""" engine = self.engine for input, result, offsets in self.transposed: tmp = engine.encode_bytes(input) @@ -757,7 +757,7 @@ class _Base64Test(TestCase): self.assertEqual(out, result) def test_decode_transposed_bytes_bad(self): - "test decode_transposed_bytes() fails if map is a one-way" + """test decode_transposed_bytes() fails if map is a one-way""" engine = self.engine for input, _, offsets in self.transposed_dups: tmp = engine.encode_bytes(input) @@ -768,7 +768,7 @@ class _Base64Test(TestCase): # test 6bit handling #=================================================================== def check_int_pair(self, bits, encoded_pairs): - "helper to check encode_intXX & decode_intXX functions" + """helper to check encode_intXX & decode_intXX functions""" engine = self.engine encode = getattr(engine, "encode_int%s" % bits) decode = getattr(engine, "decode_int%s" % bits) @@ -795,7 +795,7 @@ class _Base64Test(TestCase): self.assertRaises(TypeError, decode, None) # do random testing. - from passlib.utils import getrandbytes, getrandstr + from passlib.utils import getrandstr for i in irange(100): # generate random value, encode, and then decode value = random.randint(0, upper-1) @@ -844,7 +844,7 @@ class _Base64Test(TestCase): else m(63,63,63,63, 63,63,63,63, 63,63,15))]) def test_encoded_ints(self): - "test against reference integer encodings" + """test against reference integer encodings""" if not self.encoded_ints: raise self.skipTests("none defined for class") engine = self.engine @@ -863,7 +863,7 @@ class _Base64Test(TestCase): from passlib.utils import h64, h64big class H64_Test(_Base64Test): - "test H64 codec functions" + """test H64 codec functions""" engine = h64 descriptionPrefix = "h64 codec" @@ -888,7 +888,7 @@ class H64_Test(_Base64Test): ] class H64Big_Test(_Base64Test): - "test H64Big codec functions" + """test H64Big codec functions""" engine = h64big descriptionPrefix = "h64big codec" diff --git a/passlib/tests/test_utils_crypto.py b/passlib/tests/test_utils_crypto.py index 9284e22a..0784ef3d 100644 --- a/passlib/tests/test_utils_crypto.py +++ b/passlib/tests/test_utils_crypto.py @@ -31,7 +31,7 @@ def hb(source): # test assorted crypto helpers #============================================================================= class CryptoTest(TestCase): - "test various crypto functions" + """test various crypto functions""" ndn_formats = ["hashlib", "iana"] ndn_values = [ @@ -48,7 +48,7 @@ class CryptoTest(TestCase): ] def test_norm_hash_name(self): - "test norm_hash_name()" + """test norm_hash_name()""" from itertools import chain from passlib.utils.pbkdf2 import norm_hash_name, _nhn_hash_names @@ -123,7 +123,7 @@ class DesTest(TestCase): ] def test_01_expand(self): - "test expand_des_key()" + """test expand_des_key()""" from passlib.utils.des import expand_des_key, shrink_des_key, \ _KDATA_MASK, INT_56_MASK @@ -147,7 +147,7 @@ class DesTest(TestCase): self.assertRaises(ValueError, expand_des_key, b("\x00")*6) def test_02_shrink(self): - "test shrink_des_key()" + """test shrink_des_key()""" from passlib.utils.des import expand_des_key, shrink_des_key, \ INT_64_MASK from passlib.utils import random, getrandbytes @@ -172,13 +172,13 @@ class DesTest(TestCase): self.assertRaises(ValueError, shrink_des_key, b("\x00")*7) def _random_parity(self, key): - "randomize parity bits" + """randomize parity bits""" from passlib.utils.des import _KDATA_MASK, _KPARITY_MASK, INT_64_MASK from passlib.utils import rng return (key & _KDATA_MASK) | (rng.randint(0,INT_64_MASK) & _KPARITY_MASK) def test_03_encrypt_bytes(self): - "test des_encrypt_block()" + """test des_encrypt_block()""" from passlib.utils.des import (des_encrypt_block, shrink_des_key, _pack64, _unpack64) @@ -224,8 +224,8 @@ class DesTest(TestCase): self.assertRaises(ValueError, des_encrypt_block, stub, stub, 0, rounds=0) def test_04_encrypt_ints(self): - "test des_encrypt_int_block()" - from passlib.utils.des import (des_encrypt_int_block, shrink_des_key) + """test des_encrypt_int_block()""" + from passlib.utils.des import des_encrypt_int_block # run through test vectors for key, plaintext, correct in self.des_test_vectors: @@ -285,7 +285,7 @@ class _MD4_Test(TestCase): ] def test_md4_update(self): - "test md4 update" + """test md4 update""" from passlib.utils.md4 import md4 h = md4(b('')) self.assertEqual(h.hexdigest(), "31d6cfe0d16ae931b73c59d7e0c089c0") @@ -302,21 +302,21 @@ class _MD4_Test(TestCase): self.assertEqual(h.hexdigest(), "d79e1c308aa5bbcdeea8ed63df412da9") def test_md4_hexdigest(self): - "test md4 hexdigest()" + """test md4 hexdigest()""" from passlib.utils.md4 import md4 for input, hex in self.vectors: out = md4(input).hexdigest() self.assertEqual(out, hex) def test_md4_digest(self): - "test md4 digest()" + """test md4 digest()""" from passlib.utils.md4 import md4 for input, hex in self.vectors: out = bascii_to_str(hexlify(md4(input).digest())) self.assertEqual(out, hex) def test_md4_copy(self): - "test md4 copy()" + """test md4 copy()""" from passlib.utils.md4 import md4 h = md4(b('abc')) @@ -342,7 +342,7 @@ MD4_Builtin_Test = skipUnless(TEST_MODE("full") or not has_native_md4, # test PBKDF1 support #============================================================================= class Pbkdf1_Test(TestCase): - "test kdf helpers" + """test kdf helpers""" descriptionPrefix = "pbkdf1" pbkdf1_tests = [ @@ -369,14 +369,14 @@ class Pbkdf1_Test(TestCase): ) def test_known(self): - "test reference vectors" + """test reference vectors""" from passlib.utils.pbkdf2 import pbkdf1 for secret, salt, rounds, keylen, digest, correct in self.pbkdf1_tests: result = pbkdf1(secret, salt, rounds, keylen, digest) self.assertEqual(result, correct) def test_border(self): - "test border cases" + """test border cases""" from passlib.utils.pbkdf2 import pbkdf1 def helper(secret=b('secret'), salt=b('salt'), rounds=1, keylen=1, hash='md5'): return pbkdf1(secret, salt, rounds, keylen, hash) @@ -402,7 +402,7 @@ class Pbkdf1_Test(TestCase): # test PBKDF2 support #============================================================================= class _Pbkdf2_Test(TestCase): - "test pbkdf2() support" + """test pbkdf2() support""" _disable_m2crypto = False def setUp(self): @@ -533,7 +533,7 @@ class _Pbkdf2_Test(TestCase): ] def test_known(self): - "test reference vectors" + """test reference vectors""" from passlib.utils.pbkdf2 import pbkdf2 for row in self.pbkdf2_test_vectors: correct, secret, salt, rounds, keylen = row[:5] @@ -542,7 +542,7 @@ class _Pbkdf2_Test(TestCase): self.assertEqual(result, correct) def test_border(self): - "test border cases" + """test border cases""" from passlib.utils.pbkdf2 import pbkdf2 def helper(secret=b('password'), salt=b('salt'), rounds=1, keylen=None, prf="hmac-sha1"): return pbkdf2(secret, salt, rounds, keylen, prf) @@ -568,7 +568,7 @@ class _Pbkdf2_Test(TestCase): self.assertRaises(TypeError, helper, prf=5) def test_default_keylen(self): - "test keylen==None" + """test keylen==None""" from passlib.utils.pbkdf2 import pbkdf2 def helper(secret=b('password'), salt=b('salt'), rounds=1, keylen=None, prf="hmac-sha1"): return pbkdf2(secret, salt, rounds, keylen, prf) @@ -576,7 +576,7 @@ class _Pbkdf2_Test(TestCase): self.assertEqual(len(helper(prf='hmac-sha256')), 32) def test_custom_prf(self): - "test custom prf function" + """test custom prf function""" from passlib.utils.pbkdf2 import pbkdf2 def prf(key, msg): return hashlib.md5(key+msg+b('fooey')).digest() diff --git a/passlib/tests/test_utils_handlers.py b/passlib/tests/test_utils_handlers.py index 5191111d..0b5522dc 100644 --- a/passlib/tests/test_utils_handlers.py +++ b/passlib/tests/test_utils_handlers.py @@ -18,7 +18,7 @@ from passlib.utils import getrandstr, JYTHON, rng from passlib.utils.compat import b, bytes, bascii_to_str, str_to_uascii, \ uascii_to_str, unicode, PY_MAX_25, SUPPORTS_DIR_METHOD import passlib.utils.handlers as uh -from passlib.tests.utils import HandlerCase, TestCase, catch_warnings +from passlib.tests.utils import HandlerCase, TestCase, catch_warnings, patchAttr from passlib.utils.compat import u, PY3 # module log = getLogger(__name__) @@ -27,7 +27,7 @@ log = getLogger(__name__) # utils #============================================================================= def _makelang(alphabet, size): - "generate all strings of given size using alphabet" + """generate all strings of given size using alphabet""" def helper(size): if size < 2: for char in alphabet: @@ -42,13 +42,15 @@ def _makelang(alphabet, size): # test GenericHandler & associates mixin classes #============================================================================= class SkeletonTest(TestCase): - "test hash support classes" + """test hash support classes""" + + patchAttr = patchAttr #=================================================================== # StaticHandler #=================================================================== def test_00_static_handler(self): - "test StaticHandler class" + """test StaticHandler class""" class d1(uh.StaticHandler): name = "d1" @@ -94,7 +96,7 @@ class SkeletonTest(TestCase): self.assertEqual(d1.encrypt('s', flag=True), '_b') def test_01_calc_checksum_hack(self): - "test StaticHandler legacy attr" + """test StaticHandler legacy attr""" # release 1.5 StaticHandler required genhash(), # not _calc_checksum, be implemented. we have backward compat wrapper, # this tests that it works. @@ -103,7 +105,7 @@ class SkeletonTest(TestCase): name = "d1" @classmethod - def identify(self, hash): + def identify(cls, hash): if not hash or len(hash) != 40: return False try: @@ -111,7 +113,6 @@ class SkeletonTest(TestCase): except ValueError: return False return True - @classmethod def genhash(cls, secret, hash): if secret is None: @@ -121,7 +122,6 @@ class SkeletonTest(TestCase): if hash is not None and not cls.identify(hash): raise ValueError("invalid hash") return hashlib.sha1(b("xyz") + secret).hexdigest() - @classmethod def verify(cls, secret, hash): if hash is None: @@ -144,9 +144,8 @@ class SkeletonTest(TestCase): # GenericHandler & mixins #=================================================================== def test_10_identify(self): - "test GenericHandler.identify()" + """test GenericHandler.identify()""" class d1(uh.GenericHandler): - @classmethod def from_string(cls, hash): if isinstance(hash, bytes): @@ -180,7 +179,7 @@ class SkeletonTest(TestCase): del d1.ident def test_11_norm_checksum(self): - "test GenericHandler checksum handling" + """test GenericHandler checksum handling""" # setup helpers class d1(uh.GenericHandler): name = 'd1' @@ -216,7 +215,7 @@ class SkeletonTest(TestCase): self.assertIs(norm_checksum(u('zzzz')), None) def test_12_norm_checksum_raw(self): - "test GenericHandler + HasRawChecksum mixin" + """test GenericHandler + HasRawChecksum mixin""" class d1(uh.HasRawChecksum, uh.GenericHandler): name = 'd1' checksum_size = 4 @@ -236,7 +235,7 @@ class SkeletonTest(TestCase): self.assertIs(norm_checksum(b('0')*4), None) def test_20_norm_salt(self): - "test GenericHandler + HasSalt mixin" + """test GenericHandler + HasSalt mixin""" # setup helpers class d1(uh.HasSalt, uh.GenericHandler): name = 'd1' @@ -312,7 +311,7 @@ class SkeletonTest(TestCase): # TODO: test HasRawSalt mixin def test_30_norm_rounds(self): - "test GenericHandler + HasRounds mixin" + """test GenericHandler + HasRounds mixin""" # setup helpers class d1(uh.HasRounds, uh.GenericHandler): name = 'd1' @@ -359,7 +358,7 @@ class SkeletonTest(TestCase): self.assertRaises(TypeError, norm_rounds, use_defaults=True) def test_40_backends(self): - "test GenericHandler + HasManyBackends mixin" + """test GenericHandler + HasManyBackends mixin""" class d1(uh.HasManyBackends, uh.GenericHandler): name = 'd1' setting_kwds = () @@ -412,7 +411,7 @@ class SkeletonTest(TestCase): self.assertRaises(ValueError, d1.has_backend, 'c') def test_50_norm_ident(self): - "test GenericHandler + HasManyIdents" + """test GenericHandler + HasManyIdents""" # setup helpers class d1(uh.HasManyIdents, uh.GenericHandler): name = 'd1' @@ -458,7 +457,7 @@ class SkeletonTest(TestCase): # but way work correctly for some hashes #=================================================================== def test_91_parsehash(self): - "test parsehash()" + """test parsehash()""" # NOTE: this just tests some existing GenericHandler classes from passlib import hash @@ -514,7 +513,7 @@ class SkeletonTest(TestCase): )) def test_92_bitsize(self): - "test bitsize()" + """test bitsize()""" # NOTE: this just tests some existing GenericHandler classes from passlib import hash @@ -527,10 +526,14 @@ class SkeletonTest(TestCase): {'checksum': 186, 'salt': 132}) # linear rounds + # NOTE: +3 comes from int(math.log(.1,2)), + # where 0.1 = 10% = default allowed variation in rounds + self.patchAttr(hash.sha256_crypt, "default_rounds", 1 << (14 + 3)) self.assertEqual(hash.sha256_crypt.bitsize(), {'checksum': 258, 'rounds': 14, 'salt': 96}) # raw checksum + self.patchAttr(hash.pbkdf2_sha1, "default_rounds", 1 << (13 + 3)) self.assertEqual(hash.pbkdf2_sha1.bitsize(), {'checksum': 160, 'rounds': 13, 'salt': 128}) @@ -546,7 +549,7 @@ class SkeletonTest(TestCase): # PrefixWrapper #============================================================================= class dummy_handler_in_registry(object): - "context manager that inserts dummy handler in registry" + """context manager that inserts dummy handler in registry""" def __init__(self, name): self.name = name self.dummy = type('dummy_' + name, (uh.GenericHandler,), dict( @@ -566,10 +569,10 @@ class dummy_handler_in_registry(object): registry._unload_handler_name(self.name, locations=False) class PrefixWrapperTest(TestCase): - "test PrefixWrapper class" + """test PrefixWrapper class""" def test_00_lazy_loading(self): - "test PrefixWrapper lazy loading of handler" + """test PrefixWrapper lazy loading of handler""" d1 = uh.PrefixWrapper("d1", "ldap_md5", "{XXX}", "{MD5}", lazy=True) # check base state @@ -585,7 +588,7 @@ class PrefixWrapperTest(TestCase): self.assertIs(d1.wrapped, ldap_md5) def test_01_active_loading(self): - "test PrefixWrapper active loading of handler" + """test PrefixWrapper active loading of handler""" d1 = uh.PrefixWrapper("d1", "ldap_md5", "{XXX}", "{MD5}") # check base state @@ -598,7 +601,7 @@ class PrefixWrapperTest(TestCase): self.assertIs(d1.wrapped, ldap_md5) def test_02_explicit(self): - "test PrefixWrapper with explicitly specified handler" + """test PrefixWrapper with explicitly specified handler""" d1 = uh.PrefixWrapper("d1", ldap_md5, "{XXX}", "{MD5}") @@ -696,7 +699,7 @@ class PrefixWrapperTest(TestCase): self.assertEqual(h.ident, None) def test_13_repr(self): - "test repr()" + """test repr()""" h = uh.PrefixWrapper("h2", "md5_crypt", "{XXX}", orig_prefix="$1$") self.assertRegex(repr(h), r"""(?x)^PrefixWrapper\( @@ -707,7 +710,7 @@ class PrefixWrapperTest(TestCase): \)$""") def test_14_bad_hash(self): - "test orig_prefix sanity check" + """test orig_prefix sanity check""" # shoudl throw InvalidHashError if wrapped hash doesn't begin # with orig_prefix. h = uh.PrefixWrapper("h2", "md5_crypt", orig_prefix="$6$") @@ -719,7 +722,7 @@ class PrefixWrapperTest(TestCase): # parts of passlib. they shouldn't be used as actual password schemes. #============================================================================= class UnsaltedHash(uh.StaticHandler): - "test algorithm which lacks a salt" + """test algorithm which lacks a salt""" name = "unsalted_test_hash" checksum_chars = uh.LOWER_HEX_CHARS checksum_size = 40 @@ -731,7 +734,7 @@ class UnsaltedHash(uh.StaticHandler): return str_to_uascii(hashlib.sha1(data).hexdigest()) class SaltedHash(uh.HasSalt, uh.GenericHandler): - "test algorithm with a salt" + """test algorithm with a salt""" name = "salted_test_hash" setting_kwds = ("salt",) diff --git a/passlib/tests/test_win32.py b/passlib/tests/test_win32.py index 9b01752f..6bcdaf56 100644 --- a/passlib/tests/test_win32.py +++ b/passlib/tests/test_win32.py @@ -15,7 +15,7 @@ from passlib.utils.compat import u # #============================================================================= class UtilTest(TestCase): - "test util funcs in passlib.win32" + """test util funcs in passlib.win32""" ##test hashes from http://msdn.microsoft.com/en-us/library/cc245828(v=prot.10).aspx ## among other places diff --git a/passlib/tests/tox_support.py b/passlib/tests/tox_support.py index 2072806d..43170bc4 100644 --- a/passlib/tests/tox_support.py +++ b/passlib/tests/tox_support.py @@ -25,7 +25,7 @@ __all__ = [ TH_PATH = "passlib.tests.test_handlers" def do_hash_tests(*args): - "return list of hash algorithm tests that match regexes" + """return list of hash algorithm tests that match regexes""" if not args: print(TH_PATH) return @@ -44,7 +44,7 @@ def do_hash_tests(*args): return not names def do_preset_tests(name): - "return list of preset test names" + """return list of preset test names""" if name == "django" or name == "django-hashes": do_hash_tests("django_.*_test", "hex_md5_test") if name == "django": @@ -53,7 +53,7 @@ def do_preset_tests(name): raise ValueError("unknown name: %r" % name) def do_setup_gae(path, runtime): - "write fake GAE ``app.yaml`` to current directory so nosegae will work" + """write fake GAE ``app.yaml`` to current directory so nosegae will work""" from passlib.tests.utils import set_file set_file(os.path.join(path, "app.yaml"), """\ application: fake-app diff --git a/passlib/tests/utils.py b/passlib/tests/utils.py index b840aff0..b7e9ca41 100644 --- a/passlib/tests/utils.py +++ b/passlib/tests/utils.py @@ -50,7 +50,7 @@ else: GAE = True def ensure_mtime_changed(path): - "ensure file's mtime has changed" + """ensure file's mtime has changed""" # NOTE: this is hack to deal w/ filesystems whose mtime resolution is >= 1s, # when a test needs to be sure the mtime changed after writing to the file. last = os.path.getmtime(path) @@ -103,14 +103,14 @@ def TEST_MODE(min=None, max=None): # hash object inspection #============================================================================= def has_crypt_support(handler): - "check if host's crypt() supports this natively" + """check if host's crypt() supports this natively""" if hasattr(handler, "orig_prefix"): # ignore wrapper classes return False return 'os_crypt' in getattr(handler, "backends", ()) and handler.has_backend("os_crypt") def has_relaxed_setting(handler): - "check if handler supports 'relaxed' kwd" + """check if handler supports 'relaxed' kwd""" # FIXME: I've been lazy, should probably just add 'relaxed' kwd # to all handlers that derive from GenericHandler @@ -122,7 +122,7 @@ def has_relaxed_setting(handler): uh.GenericHandler) def has_active_backend(handler): - "return active backend for handler, if any" + """return active backend for handler, if any""" if not hasattr(handler, "get_backend"): return "builtin" try: @@ -131,7 +131,7 @@ def has_active_backend(handler): return None def is_default_backend(handler, backend): - "check if backend is the default for source" + """check if backend is the default for source""" try: orig = handler.get_backend() except MissingBackendError: @@ -142,7 +142,12 @@ def is_default_backend(handler, backend): handler.set_backend(orig) class temporary_backend(object): - "temporarily set handler to specific backend" + """ + temporarily set handler to specific backend + """ + + _orig = None + def __init__(self, handler, backend=None): self.handler = handler self.backend = backend @@ -160,19 +165,19 @@ class temporary_backend(object): # misc helpers #============================================================================= def set_file(path, content): - "set file to specified bytes" + """set file to specified bytes""" if isinstance(content, unicode): content = content.encode("utf-8") with open(path, "wb") as fh: fh.write(content) def get_file(path): - "read file as bytes" + """read file as bytes""" with open(path, "rb") as fh: return fh.read() def tonn(source): - "convert native string to non-native string" + """convert native string to non-native string""" if not isinstance(source, str): return source elif PY3: @@ -191,11 +196,11 @@ def limit(value, lower, upper): return value def randintgauss(lower, upper, mu, sigma): - "hack used by fuzz testing" + """hack used by fuzz testing""" return int(limit(rng.normalvariate(mu, sigma), lower, upper)) def quicksleep(delay): - "because time.sleep() doesn't even have 10ms accuracy on some OSes" + """because time.sleep() doesn't even have 10ms accuracy on some OSes""" start = tick() while tick()-start < delay: pass @@ -241,7 +246,7 @@ class TestCase(_TestCase): descriptionPrefix = None def shortDescription(self): - "wrap shortDescription() method to prepend descriptionPrefix" + """wrap shortDescription() method to prepend descriptionPrefix""" desc = super(TestCase, self).shortDescription() prefix = self.descriptionPrefix if prefix: @@ -282,7 +287,7 @@ class TestCase(_TestCase): self.setUpWarnings() def setUpWarnings(self): - "helper to init warning filters before subclass setUp()" + """helper to init warning filters before subclass setUp()""" if self.resetWarningState: ctx = reset_warnings() ctx.__enter__() @@ -445,7 +450,7 @@ class TestCase(_TestCase): # capability tests #=================================================================== def require_stringprep(self): - "helper to skip test if stringprep is missing" + """helper to skip test if stringprep is missing""" from passlib.utils import stringprep if not stringprep: from passlib.utils import _stringprep_missing_reason @@ -453,12 +458,12 @@ class TestCase(_TestCase): _stringprep_missing_reason) def require_TEST_MODE(self, level): - "skip test for all PASSLIB_TEST_MODE values below " + """skip test for all PASSLIB_TEST_MODE values below """ if not TEST_MODE(level): raise self.skipTest("requires >= %r test mode" % level) def require_writeable_filesystem(self): - "skip test if writeable FS not available" + """skip test if writeable FS not available""" if GAE: return self.skipTest("GAE doesn't offer read/write filesystem access") @@ -468,7 +473,7 @@ class TestCase(_TestCase): _mktemp_queue = None def mktemp(self, *args, **kwds): - "create temp file that's cleaned up at end of test" + """create temp file that's cleaned up at end of test""" self.require_writeable_filesystem() fd, path = tempfile.mkstemp(*args, **kwds) os.close(fd) @@ -628,7 +633,7 @@ class HandlerCase(TestCase): @classmethod def iter_known_hashes(cls): - "iterate through known (secret, hash) pairs" + """iterate through known (secret, hash) pairs""" for secret, hash in cls.known_correct_hashes: yield secret, hash for config, secret, hash in cls.known_correct_configs: @@ -637,7 +642,7 @@ class HandlerCase(TestCase): yield secret, hash def get_sample_hash(self): - "test random sample secret/hash pair" + """test random sample secret/hash pair""" known = list(self.iter_known_hashes()) return rng.choice(known) @@ -645,7 +650,7 @@ class HandlerCase(TestCase): # test helpers #--------------------------------------------------------------- def check_verify(self, secret, hash, msg=None, negate=False): - "helper to check verify() outcome, honoring is_disabled_handler" + """helper to check verify() outcome, honoring is_disabled_handler""" result = self.do_verify(secret, hash) self.assertTrue(result is True or result is False, "verify() returned non-boolean value: %r" % (result,)) @@ -672,7 +677,7 @@ class HandlerCase(TestCase): # so that subclasses can fill in defaults and account for other specialized behavior #--------------------------------------------------------------- def populate_settings(self, kwds): - "subclassable method to populate default settings" + """subclassable method to populate default settings""" # use lower rounds settings for certain test modes handler = self.handler if 'rounds' in handler.setting_kwds and 'rounds' not in kwds: @@ -687,35 +692,35 @@ class HandlerCase(TestCase): if getattr(handler, "rounds_cost", None) == "log2": df -= factor else: - df = df//(1<.@*#! \u00AC') def populate_context(self, secret, kwds): - "insert encoding into kwds" + """insert encoding into kwds""" if isinstance(secret, tuple): secret, encoding = secret kwds.setdefault('encoding', encoding) diff --git a/passlib/utils/__init__.py b/passlib/utils/__init__.py index 654124ea..b4c3f907 100644 --- a/passlib/utils/__init__.py +++ b/passlib/utils/__init__.py @@ -130,7 +130,7 @@ class classproperty(object): @property def __func__(self): - "py3 compatible alias" + """py3 compatible alias""" return self.im_func def deprecated_function(msg=None, deprecated=None, removed=None, updoc=True, @@ -221,7 +221,7 @@ class memoized_property(object): @property def __func__(self): - "py3 alias" + """py3 alias""" return self.im_func # works but not used @@ -253,7 +253,7 @@ def consteq(left, right): The purpose of this function is to help prevent timing attacks during digest comparisons: the standard ``==`` operator aborts - after the first mismatched character, causing it's runtime to be + after the first mismatched character, causing its runtime to be proportional to the longest prefix shared by the two inputs. If an attacker is able to predict and control one of the two inputs, repeated queries can be leveraged to reveal information about @@ -456,7 +456,7 @@ def saslprep(source, param="value"): # replace saslprep() with stub when stringprep is missing if stringprep is None: # pragma: no cover -- runtime detection def saslprep(source, param="value"): - "stub for saslprep()" + """stub for saslprep()""" raise NotImplementedError("saslprep() support requires the 'stringprep' " "module, which is " + _stringprep_missing_reason) @@ -503,11 +503,11 @@ add_doc(bytes_to_int, "decode byte string as single big-endian integer") add_doc(int_to_bytes, "encode integer as single big-endian byte string") def xor_bytes(left, right): - "Perform bitwise-xor of two byte strings (must be same size)" + """Perform bitwise-xor of two byte strings (must be same size)""" return int_to_bytes(bytes_to_int(left) ^ bytes_to_int(right), len(left)) def repeat_string(source, size): - "repeat or truncate string, so it has length " + """repeat or truncate string, so it has length """ cur = len(source) if size > cur: mult = (size+cur-1)//cur @@ -519,7 +519,7 @@ _BNULL = b("\x00") _UNULL = u("\x00") def right_pad_string(source, size, pad=None): - "right-pad or truncate string, so it has length " + """right-pad or truncate string, so it has length """ cur = len(source) if size > cur: if pad is None: @@ -535,11 +535,11 @@ _ASCII_TEST_BYTES = b("\x00\n aA:#!\x7f") _ASCII_TEST_UNICODE = _ASCII_TEST_BYTES.decode("ascii") def is_ascii_codec(codec): - "Test if codec is compatible with 7-bit ascii (e.g. latin-1, utf-8; but not utf-16)" + """Test if codec is compatible with 7-bit ascii (e.g. latin-1, utf-8; but not utf-16)""" return _ASCII_TEST_UNICODE.encode(codec) == _ASCII_TEST_BYTES def is_same_codec(left, right): - "Check if two codec names are aliases for same codec" + """Check if two codec names are aliases for same codec""" if left == right: return True if not (left and right): @@ -549,7 +549,7 @@ def is_same_codec(left, right): _B80 = b('\x80')[0] _U80 = u('\x80') def is_ascii_safe(source): - "Check if string (bytes or unicode) contains only 7-bit ascii" + """Check if string (bytes or unicode) contains only 7-bit ascii""" r = _B80 if isinstance(source, bytes) else _U80 return all(c < r for c in source) @@ -656,7 +656,7 @@ add_doc(to_native_str, @deprecated_function(deprecated="1.6", removed="1.7") def to_hash_str(source, encoding="ascii"): # pragma: no cover -- deprecated & unused - "deprecated, use to_native_str() instead" + """deprecated, use to_native_str() instead""" return to_native_str(source, encoding, param="hash") #============================================================================= @@ -671,7 +671,7 @@ class Base64Engine(object): A string of 64 unique characters, which will be used to encode successive 6-bit chunks of data. A character's position within the string should correspond - to it's 6-bit value. + to its 6-bit value. :param big: Whether the encoding should be big-endian (default False). @@ -783,7 +783,7 @@ class Base64Engine(object): @property def charmap(self): - "charmap as unicode" + """charmap as unicode""" return self.bytemap.decode("latin-1") #=================================================================== @@ -811,7 +811,7 @@ class Base64Engine(object): return out def _encode_bytes_little(self, next_value, chunks, tail): - "helper used by encode_bytes() to handle little-endian encoding" + """helper used by encode_bytes() to handle little-endian encoding""" # # output bit layout: # @@ -850,7 +850,7 @@ class Base64Engine(object): yield v2>>4 def _encode_bytes_big(self, next_value, chunks, tail): - "helper used by encode_bytes() to handle big-endian encoding" + """helper used by encode_bytes() to handle big-endian encoding""" # # output bit layout: # @@ -916,7 +916,7 @@ class Base64Engine(object): raise ValueError("invalid character: %r" % (err.args[0],)) def _decode_bytes_little(self, next_value, chunks, tail): - "helper used by decode_bytes() to handle little-endian encoding" + """helper used by decode_bytes() to handle little-endian encoding""" # # input bit layout: # @@ -951,7 +951,7 @@ class Base64Engine(object): yield (v2>>2) | ((v3 & 0xF) << 4) def _decode_bytes_big(self, next_value, chunks, tail): - "helper used by decode_bytes() to handle big-endian encoding" + """helper used by decode_bytes() to handle big-endian encoding""" # # input bit layout: # @@ -993,21 +993,21 @@ class Base64Engine(object): # equivalent char with no padding bits set. def __make_padset(self, bits): - "helper to generate set of valid last chars & bytes" + """helper to generate set of valid last chars & bytes""" pset = set(c for i,c in enumerate(self.bytemap) if not i & bits) pset.update(c for i,c in enumerate(self.charmap) if not i & bits) return frozenset(pset) @memoized_property def _padinfo2(self): - "mask to clear padding bits, and valid last bytes (for strings 2 % 4)" + """mask to clear padding bits, and valid last bytes (for strings 2 % 4)""" # 4 bits of last char unused (lsb for big, msb for little) bits = 15 if self.big else (15<<2) return ~bits, self.__make_padset(bits) @memoized_property def _padinfo3(self): - "mask to clear padding bits, and valid last bytes (for strings 3 % 4)" + """mask to clear padding bits, and valid last bytes (for strings 3 % 4)""" # 2 bits of last char unused (lsb for big, msb for little) bits = 3 if self.big else (3<<4) return ~bits, self.__make_padset(bits) @@ -1072,14 +1072,14 @@ class Base64Engine(object): # transposed encoding/decoding #=================================================================== def encode_transposed_bytes(self, source, offsets): - "encode byte string, first transposing source using offset list" + """encode byte string, first transposing source using offset list""" if not isinstance(source, bytes): raise TypeError("source must be bytes, not %s" % (type(source),)) tmp = join_byte_elems(source[off] for off in offsets) return self.encode_bytes(tmp) def decode_transposed_bytes(self, source, offsets): - "decode byte string, then reverse transposition described by offset list" + """decode byte string, then reverse transposition described by offset list""" # NOTE: if transposition does not use all bytes of source, # the original can't be recovered... and join_byte_elems() will throw # an error because 1+ values in will be None. @@ -1133,7 +1133,7 @@ class Base64Engine(object): #--------------------------------------------------------------- def decode_int6(self, source): - "decode single character -> 6 bit integer" + """decode single character -> 6 bit integer""" if not isinstance(source, bytes): raise TypeError("source must be bytes, not %s" % (type(source),)) if len(source) != 1: @@ -1147,7 +1147,7 @@ class Base64Engine(object): raise ValueError("invalid character") def decode_int12(self, source): - "decodes 2 char string -> 12-bit integer" + """decodes 2 char string -> 12-bit integer""" if not isinstance(source, bytes): raise TypeError("source must be bytes, not %s" % (type(source),)) if len(source) != 2: @@ -1162,7 +1162,7 @@ class Base64Engine(object): raise ValueError("invalid character") def decode_int24(self, source): - "decodes 4 char string -> 24-bit integer" + """decodes 4 char string -> 24-bit integer""" if not isinstance(source, bytes): raise TypeError("source must be bytes, not %s" % (type(source),)) if len(source) != 4: @@ -1216,7 +1216,7 @@ class Base64Engine(object): #--------------------------------------------------------------- def encode_int6(self, value): - "encodes 6-bit integer -> single hash64 character" + """encodes 6-bit integer -> single hash64 character""" if value < 0 or value > 63: raise ValueError("value out of range") if PY3: @@ -1225,7 +1225,7 @@ class Base64Engine(object): return self._encode64(value) def encode_int12(self, value): - "encodes 12-bit integer -> 2 char string" + """encodes 12-bit integer -> 2 char string""" if value < 0 or value > 0xFFF: raise ValueError("value out of range") raw = [value & 0x3f, (value>>6) & 0x3f] @@ -1234,7 +1234,7 @@ class Base64Engine(object): return join_byte_elems(imap(self._encode64, raw)) def encode_int24(self, value): - "encodes 24-bit integer -> 4 char string" + """encodes 24-bit integer -> 4 char string""" if value < 0 or value > 0xFFFFFF: raise ValueError("value out of range") raw = [value & 0x3f, (value>>6) & 0x3f, @@ -1258,7 +1258,7 @@ class Base64Engine(object): #=================================================================== class LazyBase64Engine(Base64Engine): - "Base64Engine which delays initialization until it's accessed" + """Base64Engine which delays initialization until it's accessed""" _lazy_opts = None def __init__(self, *args, **kwds): @@ -1331,6 +1331,7 @@ def ab64_decode(data): try: from crypt import crypt as _crypt except ImportError: # pragma: no cover + _crypt = None has_crypt = False def safe_crypt(secret, hash): return None @@ -1451,13 +1452,13 @@ except NotImplementedError: # pragma: no cover has_urandom = False def genseed(value=None): - "generate prng seed value from system resources" + """generate prng seed value from system resources""" from hashlib import sha512 text = u("%s %s %s %s %.15f %.15f %s") % ( # if caller specified a seed value, mix it in value, - # if caller's seed value was an RNG, mix in bits from it's state + # if caller's seed value was an RNG, mix in bits from its state value.getrandbits(1<<15) if hasattr(value, "getrandbits") else None, # add current process id @@ -1572,7 +1573,7 @@ _handler_attrs = ( ) def is_crypt_handler(obj): - "check if object follows the :ref:`password-hash-api`" + """check if object follows the :ref:`password-hash-api`""" # XXX: change to use isinstance(obj, PasswordHash) under py26+? return all(hasattr(obj, name) for name in _handler_attrs) @@ -1583,7 +1584,7 @@ _context_attrs = ( ) def is_crypt_context(obj): - "check if object appears to be a :class:`~passlib.context.CryptContext` instance" + """check if object appears to be a :class:`~passlib.context.CryptContext` instance""" # XXX: change to use isinstance(obj, CryptContext)? return all(hasattr(obj, name) for name in _context_attrs) @@ -1593,12 +1594,12 @@ def is_crypt_context(obj): ## return hasattr(handler, "set_backend") def has_rounds_info(handler): - "check if handler provides the optional :ref:`rounds information ` attributes" + """check if handler provides the optional :ref:`rounds information ` attributes""" return ('rounds' in handler.setting_kwds and getattr(handler, "min_rounds", None) is not None) def has_salt_info(handler): - "check if handler provides the optional :ref:`salt information ` attributes" + """check if handler provides the optional :ref:`salt information ` attributes""" return ('salt' in handler.setting_kwds and getattr(handler, "min_salt_size", None) is not None) diff --git a/passlib/utils/_blowfish/__init__.py b/passlib/utils/_blowfish/__init__.py index 16b85443..3281be94 100644 --- a/passlib/utils/_blowfish/__init__.py +++ b/passlib/utils/_blowfish/__init__.py @@ -18,11 +18,11 @@ This package contains two submodules: Status ------ -This implementation is usuable, but is an order of magnitude too slow to be -usuable with real security. For "ok" security, BCrypt hashes should have at +This implementation is usable, but is an order of magnitude too slow to be +usable with real security. For "ok" security, BCrypt hashes should have at least 2**11 rounds (as of 2011). Assuming a desired response time <= 100ms, this means a BCrypt implementation should get at least 20 rounds/ms in order -to be both usuable *and* secure. On a 2 ghz cpu, this implementation gets +to be both usable *and* secure. On a 2 ghz cpu, this implementation gets roughly 0.09 rounds/ms under CPython (220x too slow), and 1.9 rounds/ms under PyPy (10x too slow). @@ -55,7 +55,7 @@ from itertools import chain import struct # pkg from passlib.utils import bcrypt64, getrandbytes, rng -from passlib.utils.compat import b, bytes, BytesIO, unicode, u +from passlib.utils.compat import b, bytes, BytesIO, unicode, u, native_string_types from passlib.utils._blowfish.unrolled import BlowfishEngine # local __all__ = [ @@ -98,19 +98,15 @@ def raw_bcrypt(password, ident, salt, log_rounds): #=================================================================== # parse ident - assert isinstance(ident, unicode) - if ident == u('2'): - minor = 0 - elif ident == u('2a'): - minor = 1 - # XXX: how to indicate caller wants to use crypt_blowfish's - # workaround variant of 2a? + assert isinstance(ident, native_string_types) + add_null_padding = True + if ident == u('2a') or ident == u('2y') or ident == u('2b'): + pass + elif ident == u('2'): + add_null_padding = False elif ident == u('2x'): raise ValueError("crypt_blowfish's buggy '2x' hashes are not " "currently supported") - elif ident == u('2y'): - # crypt_blowfish compatibility ident which guarantees compat w/ 2a - minor = 1 else: raise ValueError("unknown ident: %r" % (ident,)) @@ -124,7 +120,7 @@ def raw_bcrypt(password, ident, salt, log_rounds): # prepare password assert isinstance(password, bytes) - if minor > 0: + if add_null_padding: password += BNULL # validate rounds diff --git a/passlib/utils/_blowfish/_gen_files.py b/passlib/utils/_blowfish/_gen_files.py index a5757f2f..a8fca486 100644 --- a/passlib/utils/_blowfish/_gen_files.py +++ b/passlib/utils/_blowfish/_gen_files.py @@ -17,7 +17,7 @@ def varlist(name, count): def indent_block(block, padding): - "ident block of text" + """ident block of text""" lines = block.split("\n") return "\n".join( padding + line if line else "" diff --git a/passlib/utils/_blowfish/base.py b/passlib/utils/_blowfish/base.py index f62aca24..5621e4c9 100644 --- a/passlib/utils/_blowfish/base.py +++ b/passlib/utils/_blowfish/base.py @@ -339,7 +339,7 @@ class BlowfishEngine(object): # blowfish routines #=================================================================== def encipher(self, l, r): - "loop version of blowfish encipher routine" + """loop version of blowfish encipher routine""" P, S = self.P, self.S l ^= P[0] i = 1 @@ -355,7 +355,7 @@ class BlowfishEngine(object): # NOTE: decipher is same as above, just with reversed(P) instead. def expand(self, key_words): - "perform stock Blowfish keyschedule setup" + """perform stock Blowfish keyschedule setup""" assert len(key_words) >= 18, "key_words must be at least as large as P" P, S, encipher = self.P, self.S, self.encipher @@ -379,7 +379,7 @@ class BlowfishEngine(object): # eks-blowfish routines #=================================================================== def eks_salted_expand(self, key_words, salt_words): - "perform EKS' salted version of Blowfish keyschedule setup" + """perform EKS' salted version of Blowfish keyschedule setup""" # NOTE: this is the same as expand(), except for the addition # of the operations involving *salt_words*. @@ -416,7 +416,7 @@ class BlowfishEngine(object): i += 2 def eks_repeated_expand(self, key_words, salt_words, rounds): - "perform rounds stage of EKS keyschedule setup" + """perform rounds stage of EKS keyschedule setup""" expand = self.expand n = 0 while n < rounds: @@ -425,7 +425,7 @@ class BlowfishEngine(object): n += 1 def repeat_encipher(self, l, r, count): - "repeatedly apply encipher operation to a block" + """repeatedly apply encipher operation to a block""" encipher = self.encipher n = 0 while n < count: diff --git a/passlib/utils/compat.py b/passlib/utils/compat.py index a7bb626c..4cf9b81a 100644 --- a/passlib/utils/compat.py +++ b/passlib/utils/compat.py @@ -97,6 +97,7 @@ if PY3: return s.encode("latin-1") base_string_types = (unicode, bytes) + native_string_types = (unicode,) else: unicode = builtins.unicode @@ -111,6 +112,7 @@ else: return s base_string_types = basestring + native_string_types = (basestring,) #============================================================================= # unicode & bytes helpers @@ -253,7 +255,7 @@ else: if PY_MAX_25: _undef = object() def next(itr, default=_undef): - "compat wrapper for next()" + """compat wrapper for next()""" if default is _undef: return itr.next() try: @@ -282,7 +284,7 @@ else: # introspection #============================================================================= def exc_err(): - "return current error object (to avoid try/except syntax change)" + """return current error object (to avoid try/except syntax change)""" return sys.exc_info()[1] if PY3: @@ -291,7 +293,7 @@ else: method_function_attr = "im_func" def get_method_function(func): - "given (potential) method, return underlying function" + """given (potential) method, return underlying function""" return getattr(func, method_function_attr, func) #============================================================================= @@ -366,7 +368,7 @@ else: from types import ModuleType def _import_object(source): - "helper to import object from module; accept format `path.to.object`" + """helper to import object from module; accept format `path.to.object`""" modname, modattr = source.rsplit(".",1) mod = __import__(modname, fromlist=[modattr], level=0) return getattr(mod, modattr) diff --git a/passlib/utils/des.py b/passlib/utils/des.py index def894d3..a2fc2bf3 100644 --- a/passlib/utils/des.py +++ b/passlib/utils/des.py @@ -81,7 +81,7 @@ _KS_MASK = 0xfcfcfcfcffffffff PCXROT = IE3264 = SPE = CF6464 = None def _load_tables(): - "delay loading tables until they are actually needed" + """delay loading tables until they are actually needed""" global PCXROT, IE3264, SPE, CF6464 #--------------------------------------------------------------- @@ -612,7 +612,7 @@ def _unpack56(value): _EXPAND_ITER = irange(49,-7,-7) def expand_des_key(key): - "convert DES from 7 bytes to 8 bytes (by inserting empty parity bits)" + """convert DES from 7 bytes to 8 bytes (by inserting empty parity bits)""" if isinstance(key, bytes): if len(key) != 7: raise ValueError("key must be 7 bytes in size") @@ -631,7 +631,7 @@ def expand_des_key(key): return join_byte_values(((key>>shift) & 0x7f)<<1 for shift in _EXPAND_ITER) def shrink_des_key(key): - "convert DES key from 8 bytes to 7 bytes (by discarding the parity bits)" + """convert DES key from 8 bytes to 7 bytes (by discarding the parity bits)""" if isinstance(key, bytes): if len(key) != 8: raise ValueError("key must be 8 bytes in size") @@ -666,7 +666,7 @@ def des_encrypt_block(key, input, salt=0, rounds=1): :arg salt: Optional 24-bit integer used to mutate the base DES algorithm in a - manner specific to :class:`~passlib.hash.des_crypt` and it's variants. + manner specific to :class:`~passlib.hash.des_crypt` and its variants. The default value ``0`` provides the normal (unsalted) DES behavior. The salt functions as follows: if the ``i``'th bit of ``salt`` is set, @@ -675,7 +675,7 @@ def des_encrypt_block(key, input, salt=0, rounds=1): :arg rounds: Optional number of rounds of to apply the DES key schedule. the default (``rounds=1``) provides the normal DES behavior, - but :class:`~passlib.hash.des_crypt` and it's variants use + but :class:`~passlib.hash.des_crypt` and its variants use alternate rounds values. :raises TypeError: if any of the provided args are of the wrong type. @@ -779,7 +779,7 @@ def des_encrypt_int_block(key, input, salt=0, rounds=1): # NOTE: generation was modified to output two elements at a time, # so that per-round loop could do two passes at once. def _iter_key_schedule(ks_odd): - "given 64-bit key, iterates over the 8 (even,odd) key schedule pairs" + """given 64-bit key, iterates over the 8 (even,odd) key schedule pairs""" for p_even, p_odd in PCXROT: ks_even = _permute(ks_odd, p_even) ks_odd = _permute(ks_even, p_odd) diff --git a/passlib/utils/handlers.py b/passlib/utils/handlers.py index 4d03b3b2..b25d2037 100644 --- a/passlib/utils/handlers.py +++ b/passlib/utils/handlers.py @@ -88,14 +88,14 @@ _UDOLLAR = u("$") _UZERO = u("0") def validate_secret(secret): - "ensure secret has correct type & size" + """ensure secret has correct type & size""" if not isinstance(secret, base_string_types): raise exc.ExpectedStringError(secret, "secret") if len(secret) > MAX_PASSWORD_SIZE: raise exc.PasswordSizeError() def to_unicode_for_identify(hash): - "convert hash to unicode for identify method" + """convert hash to unicode for identify method""" if isinstance(hash, unicode): return hash elif isinstance(hash, bytes): @@ -584,7 +584,7 @@ class GenericHandler(PasswordHash): @staticmethod def _sanitize(value, char=u("*")): - "default method to obscure sensitive fields" + """default method to obscure sensitive fields""" if value is None: return None if isinstance(value, bytes): @@ -606,7 +606,7 @@ class GenericHandler(PasswordHash): (with the extra keyword *checksum*). this method may not work correctly for all hashes, - and may not be available on some few. it's interface may + and may not be available on some few. its interface may change in future releases, if it's kept around at all. :arg hash: hash to parse @@ -634,7 +634,7 @@ class GenericHandler(PasswordHash): @classmethod def bitsize(cls, **kwds): - "[experimental method] return info about bitsizes of hash" + """[experimental method] return info about bitsizes of hash""" try: info = super(GenericHandler, cls).bitsize(**kwds) except AttributeError: @@ -692,7 +692,7 @@ class StaticHandler(GenericHandler): @classmethod def _norm_hash(cls, hash): - "helper for subclasses to normalize case if needed" + """helper for subclasses to normalize case if needed""" return hash def to_string(self): @@ -737,7 +737,7 @@ class StaticHandler(GenericHandler): hash = wrapper_cls.genhash(secret, None, **context) warn("%r should be updated to implement StaticHandler._calc_checksum() " "instead of StaticHandler.genhash(), support for the latter " - "style will be removed in Passlib 1.8" % (cls), + "style will be removed in Passlib 1.8" % cls, DeprecationWarning) return str_to_uascii(hash) @@ -920,7 +920,7 @@ class HasSalt(GenericHandler): Class Attributes ================ - In order for :meth:`!_norm_salt` to do it's job, the following + In order for :meth:`!_norm_salt` to do its job, the following attributes should be provided by the handler subclass: .. attribute:: min_salt_size @@ -986,12 +986,12 @@ class HasSalt(GenericHandler): @classproperty def default_salt_size(cls): - "default salt size (defaults to *max_salt_size*)" + """default salt size (defaults to *max_salt_size*)""" return cls.max_salt_size @classproperty def default_salt_chars(cls): - "charset used to generate new salt strings (defaults to *salt_chars*)" + """charset used to generate new salt strings (defaults to *salt_chars*)""" return cls.salt_chars # private helpers for HasRawSalt, shouldn't be used by subclasses @@ -1082,7 +1082,7 @@ class HasSalt(GenericHandler): @staticmethod def _truncate_salt(salt, mx): # NOTE: some hashes (e.g. bcrypt) has structure within their - # salt string. this provides a method to overide to perform + # salt string. this provides a method to override to perform # the truncation properly return salt[:mx] @@ -1095,7 +1095,7 @@ class HasSalt(GenericHandler): @classmethod def bitsize(cls, salt_size=None, **kwds): - "[experimental method] return info about bitsizes of hash" + """[experimental method] return info about bitsizes of hash""" info = super(HasSalt, cls).bitsize(**kwds) if salt_size is None: salt_size = cls.default_salt_size @@ -1143,7 +1143,7 @@ class HasRounds(GenericHandler): Class Attributes ================ - In order for :meth:`!_norm_rounds` to do it's job, the following + In order for :meth:`!_norm_rounds` to do its job, the following attributes must be provided by the handler subclass: .. attribute:: min_rounds @@ -1259,7 +1259,7 @@ class HasRounds(GenericHandler): @classmethod def bitsize(cls, rounds=None, vary_rounds=.1, **kwds): - "[experimental method] return info about bitsizes of hash" + """[experimental method] return info about bitsizes of hash""" info = super(HasRounds, cls).bitsize(**kwds) # NOTE: this essentially estimates how many bits of "salt" # can be added by varying the rounds value just a little bit. @@ -1448,7 +1448,10 @@ class HasManyBackends(GenericHandler): return name def _calc_checksum_backend(self, secret): - "stub for _calc_checksum_backend(), default backend will be selected first time stub is called" + """ + stub for _calc_checksum_backend(), + the default backend will be selected the first time stub is called. + """ # if we got here, no backend has been loaded; so load default backend assert not self._backend, "set_backend() failed to replace lazy loader" self.set_backend() @@ -1458,7 +1461,7 @@ class HasManyBackends(GenericHandler): return self._calc_checksum_backend(secret) def _calc_checksum(self, secret): - "wrapper for backend, for common code""" + """wrapper for backend, for common code""" return self._calc_checksum_backend(secret) #============================================================================= @@ -1605,13 +1608,13 @@ class PrefixWrapper(object): return list(attrs) def __getattr__(self, attr): - "proxy most attributes from wrapped class (e.g. rounds, salt size, etc)" + """proxy most attributes from wrapped class (e.g. rounds, salt size, etc)""" if attr in self._proxy_attrs: return getattr(self.wrapped, attr) raise AttributeError("missing attribute: %r" % (attr,)) def _unwrap_hash(self, hash): - "given hash belonging to wrapper, return orig version" + """given hash belonging to wrapper, return orig version""" # NOTE: assumes hash has been validated as unicode already prefix = self.prefix if not hash.startswith(prefix): @@ -1620,7 +1623,7 @@ class PrefixWrapper(object): return self.orig_prefix + hash[len(prefix):] def _wrap_hash(self, hash): - "given orig hash; return one belonging to wrapper" + """given orig hash; return one belonging to wrapper""" # NOTE: should usually be native string. # (which does mean extra work under py2, but not py3) if isinstance(hash, bytes): diff --git a/passlib/utils/md4.py b/passlib/utils/md4.py index cdc14939..cd067d9f 100644 --- a/passlib/utils/md4.py +++ b/passlib/utils/md4.py @@ -56,7 +56,7 @@ class md4(object): .. method:: hexdigest - return hexdecimal version of digest + return hexadecimal version of digest """ # FIXME: make this follow hash object PEP better. # FIXME: this isn't threadsafe @@ -146,7 +146,7 @@ class md4(object): ] def _process(self, block): - "process 64 byte block" + """process 64 byte block""" # unpack block into 16 32-bit ints X = struct.unpack("<16I", block) @@ -258,7 +258,7 @@ def _has_native_md4(): # pragma: no cover -- runtime detection if _has_native_md4(): # overwrite md4 class w/ hashlib wrapper def md4(content=None): - "wrapper for hashlib.new('md4')" + """wrapper for hashlib.new('md4')""" return hashlib.new('md4', content or b('')) #============================================================================= diff --git a/passlib/utils/pbkdf2.py b/passlib/utils/pbkdf2.py index 1cd0d8f1..ee245d6c 100644 --- a/passlib/utils/pbkdf2.py +++ b/passlib/utils/pbkdf2.py @@ -142,7 +142,7 @@ _trans_5C = join_byte_values((x ^ 0x5C) for x in irange(256)) _trans_36 = join_byte_values((x ^ 0x36) for x in irange(256)) def _get_hmac_prf(digest): - "helper to return HMAC prf for specific digest" + """helper to return HMAC prf for specific digest""" def tag_wrapper(prf): prf.__name__ = "hmac_" + digest prf.__doc__ = ("hmac_%s(key, msg) -> digest;" @@ -150,7 +150,7 @@ def _get_hmac_prf(digest): digest) if _EVP and digest == "sha1": - # use m2crypto function directly for sha1, since that's it's default digest + # use m2crypto function directly for sha1, since that's its default digest try: result = _EVP.hmac(b('x'),b('y')) except ValueError: # pragma: no cover @@ -199,7 +199,7 @@ def _get_hmac_prf(digest): _prf_cache = {} def _clear_prf_cache(): - "helper for unit tests" + """helper for unit tests""" _prf_cache.clear() def get_prf(name): diff --git a/passlib/win32.py b/passlib/win32.py index 78155976..fd6febe7 100644 --- a/passlib/win32.py +++ b/passlib/win32.py @@ -51,9 +51,9 @@ LM_MAGIC = b("KGS!@#$%") raw_nthash = nthash.raw_nthash def raw_lmhash(secret, encoding="ascii", hex=False): - "encode password using des-based LMHASH algorithm; returns string of raw bytes, or unicode hex" + """encode password using des-based LMHASH algorithm; returns string of raw bytes, or unicode hex""" # NOTE: various references say LMHASH uses the OEM codepage of the host - # for it's encoding. until a clear reference is found, + # for its encoding. until a clear reference is found, # as well as a path for getting the encoding, # letting this default to "ascii" to prevent incorrect hashes # from being made w/o user explicitly choosing an encoding.